In [None]:
import pandas as pd
import numpy as np

import torch
import torchmetrics
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torchmetrics import ConfusionMatrix

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import tensorflow as tf

from sklearn import metrics
from sklearn import decomposition
from sklearn import manifold
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestCentroid
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from scipy.stats import kurtosis, skew
from scipy.stats import wilcoxon

import scipy
import matplotlib.pyplot as plt
import shap

from operator import itemgetter
from tqdm.notebook import trange, tqdm
from google.colab import files

import copy
import random
import time

In [None]:
scrna_df = pd.read_csv('/content/drive/MyDrive/Data/scRNA_gb2_inputs_labels.tsv', sep='\t')
scrna_df.head()

In [None]:
sup_df = pd.read_csv('/content/drive/MyDrive/Data/Lambda_values.csv')
sup_df.drop(0, inplace = True)
sup_df.reset_index(drop = True, inplace = True)
sup_df.rename(columns={'cell id':'cell'}, inplace = True)
sup_df.head()

In [None]:
def create_subpops(df):
    my_passed = dict()
    sub_pops = list()

    i = 0
    for x in df['Cell specific barcode']:
        if x not in my_passed:
            sub_pops.append(i)
            my_passed[x] = i
            i += 1
        else:
            sub_pops.append(my_passed[x])
    return sub_pops

scrna_df['Sub-Populations'] = create_subpops(sup_df)
print(scrna_df['Sub-Populations'].max())
scrna_df.head()

In [None]:
mesenchymal_proneural = pd.read_csv('/content/drive/MyDrive/Data/mesenchymal_proneural.csv')
mesenchymal_proneural = mesenchymal_proneural.rename(columns={"Unnamed: 0": "cell"})
mesenchymal_proneural.head()

In [None]:
scrna_df['Mesenchymal'] = mesenchymal_proneural['Mesenchymal']
scrna_df['Proneural'] = mesenchymal_proneural['Proneural']

In [None]:
def combine_data(df):
    cell_type = list()
    var_df = np.var(df.iloc[:,2:-1])
    ind = np.argpartition(var_df.values, -3000)[-3000:]
    X = df.iloc[:, ind]
    cols = list(X.columns)

    samples = X.to_numpy()
    clusters = df['cluster']
    subpops = df['Sub-Populations']
    proneural = df['Proneural']
    mesenchymal = df['Mesenchymal']

    for x in df['cell']:
        if x.endswith('_1'):
            cell_type.append('Control')
        else:
            cell_type.append('Treatment')

    all_data = []

    for i in range(len(df)):
        all_data.append((torch.FloatTensor(samples[i]), clusters[i], subpops[i], proneural[i], mesenchymal[i], cell_type[i]))

    all_df = pd.DataFrame()
    all_df['Data'] = all_data
    all_df['Clusters'] = clusters
    all_df['Sub_pops'] = subpops
    all_df['Cell_type'] = cell_type
    all_df['Proneural'] = proneural
    all_df['Mesenchymal'] = mesenchymal

    all_df = all_df.sample(frac = 1)
    train_inter, test_df = train_test_split(all_df, stratify = clusters, test_size = 0.15)
    train_df, valid_df = train_test_split(train_inter, stratify = train_inter['Clusters'], test_size = 0.1)

    for temp_df in [all_df, train_df, test_df, valid_df]:
        temp_df.reset_index(drop = True, inplace = True)

    print("Length of training df:", len(train_df))
    print("Length of testing df:", len(test_df))
    print("Length of validation df:", len(valid_df))

    return cols, all_df, train_df, test_df, valid_df

cols, all_df, train_df, test_df, valid_df = combine_data(scrna_df)

In [None]:
#Showing that the ratio of classes is maintained
y = scrna_df['cluster']
print(y.value_counts().sort_index() / len(y))
print(train_df['Clusters'].value_counts().sort_index() / len(train_df))
print(test_df['Clusters'].value_counts().sort_index() / len(test_df))
print(valid_df['Clusters'].value_counts().sort_index() / len(valid_df))

In [None]:
def create_iterator(data_list, BATCH_SIZE = 64):
    iterator = data.DataLoader(data_list,
                            shuffle=True,
                            batch_size = BATCH_SIZE)
    return iterator

In [None]:
def biggest_subpops(df, p):
    percent_pop = df['Sub_pops'].value_counts().sort_index() / len(df['Sub_pops'])
    max_counts = percent_pop[percent_pop > p]

    max_count_list = list(max_counts.index)

    #print(max_counts)

    #print(len(max_count_list))

    values = [i for i in range(len(max_count_list))]
    count_subpops = dict(zip(max_count_list, values))

    temp_df = df[df['Sub_pops'].isin(max_count_list)]
    final_df = temp_df.copy()

    reordered_subpops = list()
    new_subpop_tuple = list()
    
    for x in final_df['Data']:
        reordered_subpops.append(count_subpops[x[2]])
        new_subpop_tuple.append((x[0], x[1], count_subpops[x[2]], x[3], x[4], x[5]))

    final_df['Subpopulations'] = reordered_subpops
    final_df['Data Max Subpops'] = new_subpop_tuple
    final_df.reset_index(drop = True, inplace = True)

    return final_df, count_subpops

# p = 0.01
all_subpops, count_subpops = biggest_subpops(all_df, 0.01)

In [None]:
def subpop_in_cluster(df):
    dict1 = dict()
    for i, my_cluster in enumerate(df['Clusters']):
        x = df['Subpopulations'][i]
        if my_cluster not in dict1:
            dict1[my_cluster] = dict()
        if x not in dict1[my_cluster]:
            dict1[my_cluster][x] = 1
        else:
            dict1[my_cluster][x] += 1
    return dict1

def all_iterators(df, p):
    list_df = list()
    tuple_data = list(df['Data'])
    all_iterator = create_iterator(tuple_data)

    max_subpops_df, _ = biggest_subpops(df, p)
    list_df.append(max_subpops_df)
    shortened_iterator = create_iterator(list(max_subpops_df['Data Max Subpops']))
    all_pie = subpop_in_cluster(max_subpops_df)

    type_iterators = list()
    type_pies = list()
    for cell_type in ['Control', 'Treatment']:
        split_data = max_subpops_df[max_subpops_df['Cell_type'] == cell_type]
        split_data.reset_index(inplace = True, drop = True)
        list_df.append(split_data)
        type_pies.append(subpop_in_cluster(split_data))
        type_iterators.append(create_iterator(list(split_data['Data Max Subpops'])))

    return list_df, all_pie, type_pies[0], type_pies[1], all_iterator, shortened_iterator, type_iterators[0], type_iterators[1]

#all_pie, control_pie, treatment_pie, all_iterator, shortened_iterator, control_iterator, treatment_iterator = all_iterators(all_df, 0.01)

#all_iterator = iterator with all data
#shortened_iterator = iterator that has cells with subpops that make up at least p of the total number of cells
#type_iterators[0] = control_iterator = iterator with control cells in largest subpops
#type_iterators[1] = treatment_iterator = iterator with treatment cells in largest subpops

In [None]:
def make_box_plot(df, my_type, title):
    df1 = df.groupby('Subpopulations', as_index=False)[my_type].median().sort_values(by = my_type, ascending = False)
    df[my_type + ' Real Values'] = df[my_type]
    ww = df1.merge(df, on = 'Subpopulations')
    #df2 = df1.reindex(df1.median().sort_values().index, axis=1)
    fig = px.box(df1, x="Subpopulations", y=my_type, title = title)
    #go.Figure([go.Box(y=df2[c], name=meta[c]["name"], line={"color":meta[c]["color"]}) for c in df2.columns])

make_box_plot(list_df[0], 'Proneural', 'All')

In [None]:
def plot_subpop(df):
    fig = make_subplots(rows=2, cols=7, shared_yaxes = True, shared_xaxes=True)
    row_num = 1
    col_num = 1
    for x in range(15):
        if col_num == 8:
            row_num = 2
            col_num = 1
        temp_df = df[df["Subpopulations"] == x]
        fig.add_trace(go.Scatter(x=temp_df['Mesenchymal'], y=temp_df['Proneural'], mode="markers"), row=row_num, col=col_num)
        col_num += 1
    return fig

new_fig = plot_subpop(list_df[2])
new_fig.show()

In [None]:
l1 = ['blue', 'red', 'green', 'purple', 'orange', 'pink', 'grey', 'blue', 'red', 'green', 'purple', 'orange', 'pink', 'grey']

def centeroidnp(arr):
    length = arr.shape[0]
    sum_x = np.sum(arr[:, 0])
    sum_y = np.sum(arr[:, 1])
    return sum_x/length, sum_y/length

def matplot_subpop(df):
    clf = NearestCentroid()
    centroid_coords = list()
    plt.figure(figsize=(40,40))
    #subpops = df["Subpopulations"].unique().sort()
    for subpop in range(14):
        temp_df = df[df["Subpopulations"] == subpop]
        my_x = list(temp_df['Mesenchymal'])
        my_y = list(temp_df['Proneural'])
        my_zip = zip(my_x, my_y)
        X = np.array([i for i in my_zip])
        centroid_coord = centeroidnp(X)
        centroid_coords.append(centroid_coord)
        plt.subplot(4,4,subpop+1)
        plt.xlim([0.05, 0.2])
        plt.ylim([0.05, 0.2])
        plt.scatter(my_x, my_y, c = l1[subpop])
        plt.scatter(centroid_coord[0], centroid_coord[1], c = 'black')
        plt.title('Subpopulation ' + str(subpop))
    plt.show()
    return centroid_coords

def vector_subplot(control, treatment):
    plt.figure(figsize=(40,40))
    for i in range(14):
        plt.subplot(4,4,i+1)
        plt.quiver(control[i][0], control[i][1], treatment[i][0], treatment[i][1], color=l1, units='xy', scale=1)
        plt.xlim([0, 0.5])
        plt.ylim([0, 0.5])
        plt.title('Subpopulation ' + str(i))
    plt.show()

In [None]:
def histo_graph(df):
    skews = list()
    plt.figure(figsize=(40,40))
    subpops = df['Subpopulations'].unique()
    for i,subpop in enumerate(subpops):
        temp_df = df[df['Subpopulations'] == subpop]
        plt.subplot(5,3,i+1)
        y = temp_df['Mesenchymal'].values
        plt.hist(y, bins='auto')
        plt.title('Subpopulation ' + str(i) + ', Skew = ' + str(np.round(skew(y),3)))
        skews.append(skew(y))
        plt.show()
    return skews

In [None]:
#Making pie plots

def no_split(my_dict):
    not_split_pie = dict()
    for cluster in my_dict:
        for subpop in my_dict[cluster]:
            if subpop not in not_split_pie:
                not_split_pie[subpop] = my_dict[cluster][subpop]
            else:
                not_split_pie[subpop] += my_dict[cluster][subpop]
    return not_split_pie

def make_split_pie(i, pie_data):
    df = pd.DataFrame()
    df['Subpopulations'] = [*pie_data[i].keys()]
    df['Num'] = [*pie_data[i].values()]
    fig = px.pie(df, values = 'Num', names = 'Subpopulations')
    #fig.update_traces(textposition='inside')
    return fig, df

def make_combined_pie(pie_data):
    df = pd.DataFrame()
    df['Subpopulations'] = [*pie_data.keys()]
    df['Num'] = [*pie_data.values()]
    fig = px.pie(df, values = 'Num', names = 'Subpopulations')
    #fig.update_traces(textposition='inside')
    return fig, df

In [None]:
def subpop_change(df_control, df_treatment):
    merged_df = df_control.merge(df_treatment, on='Subpopulations')
    merged_df['Control'] = merged_df['Num_x']/sum(merged_df['Num_x'])
    merged_df['Treatment'] = merged_df['Num_y']/sum(merged_df['Num_y'])
    df2 = pd.DataFrame()
    df2['Subpops'] = 2*list(merged_df['Subpopulations'])
    df2['Cell Type'] = len(merged_df)*['Control'] + len(merged_df)*['Treatment']
    df2['Percent'] = list(merged_df['Control']) + list(merged_df['Treatment'])
    return df2

In [None]:
def horizontal_bar_graph(df, df_type):
    perc = list()
    for x in df['Num']:
        perc.append(x/sum(df['Num']))
        df['Percent'] = perc
        df["Subpopulations"] = df["Subpopulations"].astype(str)
        df.sort_values(by=['Percent'], inplace = True)
        df['Type'] = len(df)*[df_type]
    return df

In [None]:
#Creating an alluvial plot

def alluvial_change(pie_df_control, pie_df_treatment):
    merged_df = pie_df_control.merge(treatment_pie_df, on='Subpopulations')
    merged_df['Control'] = merged_df['Num_x']/sum(merged_df['Num_x'])
    merged_df['Treatment'] = merged_df['Num_y']/sum(merged_df['Num_y'])

    control = pd.DataFrame()
    control['Subpops'] = list(merged_df['Subpopulations'])
    control['Cell Type'] = len(merged_df)*['Control']
    control['Percent'] = list(merged_df['Control'])

    treatment = pd.DataFrame()
    treatment['Subpops'] = list(merged_df['Subpopulations'])
    treatment['Cell Type'] = len(merged_df)*['Treatment']
    treatment['Percent'] = list(merged_df['Treatment'])

    extra = list()
    for x in control['Subpops']:
        extra.append(' ')
  
    labels = 2*list(control['Subpops'])
    node_colors = len(labels)*[f'rgba(255,0,255,{0.7})']
    x_position = len(control['Subpops'])*[0.1]

    y_position = list()
    for i in range(len(control['Subpops'])):
        y_position.append(0.1*i+0.1)
        y_position = y_position*3

    x_position += len(control['Subpops'])*[0.5]
    labels += extra
    x_position += len(extra)*[0.25]
    node_colors += len(extra)*[f'rgba(255,0,255,{0})']

    values = list()
    indx1 = list()
    indx2 = list()
    link_color = list()
    my_len = len(labels)
    for i, c in enumerate(control['Percent']):
        t = treatment['Percent'][i]
        if t < c:
            values.append(t)
            indx1.append(i)
            indx2.append(i + len(control['Subpops']))
            link_color.append(f'rgba(255,0,255,{0.7})')
            values.append(c - t)
            indx1.append(i)
            indx2.append(i + 2*len(control['Subpops']))
            link_color.append(f'rgba(255,0,255,{0.1})')
        if c < t:
            values.append(c)
            indx1.append(i)
            indx2.append(i + len(control['Subpops']))
            link_color.append(f'rgba(255,0,255,{0.7})')
            values.append(t - c)
            indx1.append(i + 2*len(control['Subpops']))
            indx2.append(i + len(control['Subpops']))
            link_color.append(f'rgba(255,0,255,{0.1})')

    fig = go.Figure(go.Sankey(
        node = dict(
            pad = 15,
            thickness = 10,
            x = x_position,
            y = y_position,
            line = dict(color = "black", width = 0),
            label = labels,
            color = node_colors
        ),
        link = dict(
            source = indx1,
            target = indx2,
            value = values,
            color = link_color
        )))
  
    fig.update_layout(title_text="My Sankey Diagram", font_size=10)
    fig.show()

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        self.input_fc = nn.Linear(input_dim, 1000)
        self.hidden_1 = nn.Linear(1000, 300)
        self.drop = nn.Dropout(p = 0.2)
       # self.batch_norm = nn.BatchNorm1d(300, affine=False)
        self.hidden_2 = nn.Linear(300, 50)
        self.output_fc = nn.Linear(50, output_dim)

    def forward(self, x):

        # x = [batch size, height, width]

        batch_size = x.shape[0]

        x = x.view(batch_size, -1)

        # x = [batch size, height * width]

        h = torch.tanh(self.input_fc(x))

        h_1 = torch.tanh(self.hidden_1(h))

        h_2 = self.drop(h_1)

        #h_3 = self.batch_norm(h_2)

        h_3 = torch.tanh(self.hidden_2(h_2))

        y_pred = self.output_fc(h_3)

        #y_pred = [batch size, output dim]

        return y_pred, h_3

INPUT_DIM = 3000
OUTPUT_DIM = 15

model = MLP(INPUT_DIM, OUTPUT_DIM)

#explainer = shap.DeepExplainer(model, x_train)
#shap_values = explainer.shap_values(x_test)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
optimizer = optim.Adam(model.parameters())

In [None]:
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = criterion.to(device)

In [None]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [None]:
def train(model, iterator, optimizer, criterion, device):

    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for (x, y, _, _, _, _) in tqdm(iterator, desc="Training", leave=False):

        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        y_pred, _ = model(x)

        loss = criterion(y_pred, y)

        acc = calculate_accuracy(y_pred, y)

        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
def evaluate(model, iterator, criterion, device):

    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    labels = []
    probs = []

    with torch.no_grad():

        for (x, y, _, _, _, _) in tqdm(iterator, desc="Evaluating", leave=False):

            x = x.to(device)
            y = y.to(device)

            y_pred, h = model(x)

            y_prob = F.softmax(y_pred, dim=-1)

            loss = criterion(y_pred, y)

            acc = calculate_accuracy(y_pred, y)

            labels.append(y.cpu())
            probs.append(y_prob.cpu())

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    labels = torch.cat(labels, dim=0)
    probs = torch.cat(probs, dim=0)

    return epoch_loss / len(iterator), epoch_acc / len(iterator), labels, probs

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
def run_model(train_iterator, valid_iterator, test_iterator):
    EPOCHS = 10

    best_valid_loss = float('inf')

    history = {'Train': {'Accuracy': [], 'Loss': []}, 'Test': {'Accuracy': [], 'Loss': []}, 'Validation': {'Accuracy': [], 'Loss': []}}

    for epoch in trange(EPOCHS):

        start_time = time.monotonic()

        train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)
        valid_loss, valid_acc, labels, probs = evaluate(model, valid_iterator, criterion, device)
        history['Train']['Loss'].append(train_loss)
        history['Train']['Accuracy'].append(train_acc)
        history['Validation']['Loss'].append(valid_loss)
        history['Validation']['Accuracy'].append(valid_acc)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss

        end_time = time.monotonic()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

    test_loss, test_acc, _, _ = evaluate(model, test_iterator, criterion, device)

    return test_acc

In [None]:
def tsne_individual_subpops(mlp_output, mlp_inter, subpop_dict, sub_pops):
    tsne = manifold.TSNE(n_components=2)
    tsne_out = tsne.fit_transform(mlp_output)
    tsne_inter = tsne.fit_transform(mlp_inter)

    out_subpops = dict()
    inter_subpops = dict()
    for i in [*subpop_dict.values()]:
        x = (sub_pops == i)
        out_subpops[i] = tsne_out[x]
        inter_subpops[i] = tsne_inter[x]
  
  return tsne_out, tsne_inter, out_subpops, inter_subpops

def plot_subpops_individual(my_dict):
    plt.figure(figsize=(40,40))
    for i in range(len(my_dict)):
        my_x = my_dict[i][:, 0]
        my_y = my_dict[i][:, 1]
        plt.subplot(5,4,i+1)
        plt.scatter(my_x, my_y)
        plt.title('Subpopulation ' + str(i))
    plt.show()

def sihloutte(i, my_dict):
    for k in range(2, 10):
        model = KMeans(n_clusters=k)
        model.fit(my_dict[i])
        pred = model.predict(my_dict[i])
        score = silhouette_score(my_dict[i], pred)
        print('Silhouette Score for k = {}: {:<.3f}'.format(k, score))

In [None]:
explainer = shap.DeepExplainer(model, x_train)
shap_values = explainer.shap_values(x_test)
shap.decision_plot(explainer.expected_value[0], shap_values[0][0], features = cols, feature_names = cols)
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[0], shap_values[0][0], feature_names = cols)
shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values[0][0], features = cols)

In [None]:
def perturbe_single_genes(list_important_cols, train_data3, valid_data3, test_data3):
    single_accuracies = list()

    for x in list_important_cols:
        indx1 = cols.index(x)

        C = np.delete(train_data3, indx1, axis=1)
        print(C.shape)
        print(train_data3.shape)
        W = np.split(C, len(C))
        Z = zip(W, list(train_df['Clusters']))

        D = np.delete(valid_data3, indx1, axis=1)
        W1 = np.split(D, len(D))
        Z1 = zip(W1, list(valid_df['Clusters']))

        E = np.delete(test_data3,[indx1, axis=1)
        W2 = np.split(E, len(E))
        Z2 = zip(W2, list(test_df['Clusters']))

        new_train_iterator = create_iterator(list(Z))
        new_valid_iterator = create_iterator(list(Z1))
        new_test_iterator = create_iterator(list(Z2))
        single_accuracies.append(run(new_train_iterator, new_valid_iterator, new_test_iterator))
    return single_accuracies

In [None]:
def plot_gene_pairs(pairs, accuracies):
    pred_actuals = pd.DataFrame([(genes, acc) for genes, acc in zip(pairs, accuracies)], columns=['Gene Pairs', 'Accuracies'])
    pred_actuals[['Gene1','Gene2']] = pd.DataFrame(pred_actuals['Gene Pairs'].tolist(), index= pred_actuals.index)
    pred_actuals['Pairs'] = pred_actuals[['Gene1', 'Gene2']].apply(lambda x: ', '.join(x), axis=1)
    pred_actuals['|Difference in Accuracy|'] = (pred_actuals['Accuracies'] - 0.81).abs()
    fd = pred_actuals.sort_values(by='|Difference in Accuracy|', ascending = True)
    fig = px.bar(fd, x="|Difference in Accuracy|", y='Pairs', orientation='h')
    fig.show()

In [None]:
def perturbe_gene_pairs(list_important_cols, train_data3, valid_data3, test_data3):

    accuracies = list()
    pairs = list()

    for i,x in enumerate(list_important_cols):
        for x2 in list_imp_cols[1+i:]:
            pairs.append((x,x2))
            indx1 = cols.index(x)
            indx2 = cols.index(x2)

            C = np.delete(train_data3, [indx1, indx2], axis=1)
            print(C.shape)
            print(train_data3.shape)
            W = np.split(C, len(C))
            Z = zip(W, list(train_df['Clusters']))

            D = np.delete(valid_data3, [indx1, indx2], axis=1)
            W1 = np.split(D, len(D))
            Z1 = zip(W1, list(valid_df['Clusters']))

            E = np.delete(test_data3, [indx1, indx2], axis=1)
            W2 = np.split(E, len(E))
            Z2 = zip(W2, list(test_df['Clusters']))

            new_train_iterator = create_iterator(list(Z))
            new_valid_iterator = create_iterator(list(Z1))
            new_test_iterator = create_iterator(list(Z2))
            accuracies.append(run(new_train_iterator, new_valid_iterator, new_test_iterator))
            plot_gene_pairs(pairs, accuracies)
            
    return pairs, accuracies

In [None]:
def pertube_gene_triples():
    triple_accuracies = list()
    triples = list()

    for i,x in enumerate(list_imp_cols):
        for j, x2 in list_imp_cols[1+i:]:
            for x3 in list_imp_cols[1+j]:
                triples.append((x,x2))
                indx1 = cols.index(x)
                indx2 = cols.index(x2)

                C = np.delete(train_data3, [indx1, indx2], axis=1)
                print(C.shape)
                print(train_data3.shape)
                W = np.split(C, len(C))
                Z = zip(W, list(train_df['Clusters']))

                D = np.delete(valid_data3, [indx1, indx2], axis=1)
                W1 = np.split(D, len(D))
                Z1 = zip(W1, list(valid_df['Clusters']))

                E = np.delete(test_data3, [indx1, indx2], axis=1)
                W2 = np.split(E, len(E))
                Z2 = zip(W2, list(test_df['Clusters']))

                new_train_iterator = create_iterator(list(Z))
                new_valid_iterator = create_iterator(list(Z1))
                new_test_iterator = create_iterator(list(Z2))
                accuracies.append(run(new_train_iterator, new_valid_iterator, new_test_iterator))