In [1]:
import sys
sys.path.append('../')

import meta_dataloader.TCGA

import models.mlp, models.gcn
import numpy as np
import data.gene_graphs
import collections
import sklearn.metrics
import sklearn.model_selection
import pandas as pd
%load_ext autoreload
%autoreload 2

In [29]:
tasks = meta_dataloader.TCGA.TCGAMeta(download=True, 
                                      min_samples_per_class=10, 
                                      gene_symbol_map_file="../genenames_code_map_Feb2019.txt")
#task = tasks[113]

Downloading or checking for TCGA_HiSeqV2 using Academic Torrents
Torrent name: HiSeqV2.gz, Size: 513.04MB


In [28]:
for taskid in tasks.task_ids:
    if "BRCA" in taskid:
        print(taskid)

('gender', 'BRCA')
('_EVENT', 'BRCA')
('oct_embedded', 'BRCA')
('menopause_status', 'BRCA')
('PAM50Call_RNAseq', 'BRCA')
('_PANCAN_DNAMethyl_BRCA', 'BRCA')
('Node_nature2012', 'BRCA')
('Metastasis_nature2012', 'BRCA')
('_PANCAN_mirna_BRCA', 'BRCA')
('metastatic_breast_carcinoma_estrogen_receptor_status', 'BRCA')
('metastatic_breast_carcinoma_progesterone_receptor_status', 'BRCA')


In [3]:
# for i, task in enumerate(tasks):
#     print (i, task.id, collections.Counter(task._labels))

In [78]:
# clinical_M  PAM50Call_RNAseq
task = meta_dataloader.TCGA.TCGATask(('PAM50Call_RNAseq', 'BRCA'), 
                                     gene_symbol_map_file="../genenames_code_map_Feb2019.txt")


In [79]:
print(task.id)
print(task._samples.shape)
print(np.asarray(task._labels).shape)
print(collections.Counter(task._labels))

('PAM50Call_RNAseq', 'BRCA')
(956, 20530)
(956,)
Counter({2: 434, 3: 194, 0: 142, 4: 119, 1: 67})


In [89]:
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(task._samples, 
                                                                            task._labels, 
                                                                            stratify=task._labels,
                                                                            train_size=30,
                                                                            test_size=100,
                                                                            shuffle=True,
                                                                            random_state=0
                                                                             )
X_test, X_valid, y_test, y_valid = sklearn.model_selection.train_test_split(X_test, 
                                                                            y_test, 
                                                                            stratify=y_test,
                                                                            train_size=50,
                                                                            test_size=50,
                                                                            shuffle=True,
                                                                            random_state=0
                                                                           )

In [90]:
collections.Counter(y_train)

Counter({0: 4, 3: 6, 2: 14, 4: 4, 1: 2})

In [93]:
for lr in [0.000001,0.00001,0.0001,0.001,0.01]:
    for seed in [0,1]:
        model = models.mlp.MLP(name="MLP_lay2_chan512",
                               num_layer=1, 
                               channels=256, 
                               lr=lr,
                               patience=50,
                               cuda=True,
                               metric=sklearn.metrics.accuracy_score,
                               verbose=False,
                               seed=seed)

        model.fit(X_train, y_train)

        y_valid_pred = model.predict(X_valid)
        print(seed, lr, sklearn.metrics.accuracy_score(y_valid, np.argmax(y_valid_pred,axis=1)))


0 1e-06 0.72
1 1e-06 0.48
0 1e-05 0.78
1 1e-05 0.8
0 0.0001 0.74
1 0.0001 0.78
0 0.001 0.76
1 0.001 0.78
0 0.01 0.44
1 0.01 0.48


In [None]:
# y_pred = model.predict(X_test)
# y_pred = np.argmax(y_pred,axis=1)
# print(sklearn.metrics.accuracy_score(y_test, y_pred))


In [94]:
graph = data.gene_graphs.GeneManiaGraph()
adj = graph.adj()

Torrent name: genemania.pkl, Size: 9.61MB


In [95]:
# import gc
# gc.collect()

In [None]:
for nl in [1]:
    for lr in [0.0001]:
        model = models.gcn.GCN(name="GCN_lay3_chan64_emb32_dropout_agg_hierarchy", 
                               dropout=False, 
                               cuda=True,
                               num_layer=nl,
                               prepool_extralayers=2,
                               channels=64, 
                               embedding=32, 
                               aggregation="hierarchy",
                               lr=lr,
                               num_epochs=200,
                               patience=100,
                               verbose=True
                              )
        model.fit(X_train, y_train, adj)

        y_valid_pred = model.predict(X_valid)
        print("###",lr,sklearn.metrics.accuracy_score(y_valid, np.argmax(y_valid_pred,axis=1)))


Early stopping metric is accuracy_score


  self[i, j] = values


Reducing graph by a factor of 2 to 10265 nodes
Found cache for /network/tmp1/cohenjos/workspace/gene-graph-conv/.cache/hierarchical1dcdcbfce4653f65a26c8ed5e3d27bf8a50714909cc99c7329856ec4c66b9dcf10265.npy


In [58]:
y_pred = model.predict(X_test)
print(y_pred)
y_pred = np.argmax(y_pred,axis=1)
# print(y_pred)
# print(y_test)
print(sklearn.metrics.accuracy_score(y_test, y_pred))

tensor([[0.2050, 0.7950],
        [0.0973, 0.9027],
        [0.4638, 0.5362],
        [0.0148, 0.9852],
        [0.4317, 0.5683],
        [0.3772, 0.6228],
        [0.0479, 0.9521],
        [0.1708, 0.8292],
        [0.1488, 0.8512],
        [0.3514, 0.6486],
        [0.2116, 0.7884],
        [0.1141, 0.8859],
        [0.0137, 0.9863],
        [0.0993, 0.9007],
        [0.0296, 0.9704],
        [0.3073, 0.6927],
        [0.1613, 0.8387],
        [0.0485, 0.9515],
        [0.3388, 0.6612],
        [0.0369, 0.9631],
        [0.3893, 0.6107],
        [0.0872, 0.9128],
        [0.2450, 0.7550],
        [0.2745, 0.7255],
        [0.1240, 0.8760],
        [0.6442, 0.3558],
        [0.4622, 0.5378],
        [0.4259, 0.5741],
        [0.4604, 0.5396],
        [0.3750, 0.6250],
        [0.0498, 0.9502],
        [0.0970, 0.9030],
        [0.3438, 0.6562],
        [0.0727, 0.9273],
        [0.1216, 0.8784],
        [0.4283, 0.5717],
        [0.0854, 0.9146],
        [0.1055, 0.8945],
        [0.0