In [1]:
import os.path as op
import random
import time

import matplotlib.pyplot as plt
import numpy
import sklearn
import torch
import torch.nn.functional as nn_func
from sklearn import preprocessing
from sklearn.metrics import adjusted_rand_score
from torch.nn import Linear
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GraphConv, global_mean_pool

random.seed = 88888888

In [2]:
node_features_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_validation/input/node_features.txt'
graph_targets_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_validation/input/graph_targets.txt'
edges_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/input/edges.txt'
model_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/GNN/trained_pytorch_model_rewired10_fold_full_dataset.pt'
output_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_validation/output/gnn_predictions.tsv'
sampleID_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_validation/input/sample_id.txt'

In [3]:
features_exist = op.exists(node_features_fn)
targets_exist = op.exists(graph_targets_fn)
edges_exist = op.exists(edges_fn)
model_exists = op.exists(model_fn)

print(f'features exist: {features_exist},'
      f' targets exist: {targets_exist},'
      f' edges exist: {edges_exist}',
      f' model exists: {model_exists}')
assert features_exist
assert targets_exist
assert edges_exist
assert model_exists

features exist: True, targets exist: True, edges exist: True  model exists: True


In [4]:
INPUT_CHANNELS = 1
OUTPUT_CHANNELS = 51
NEW_CHANNELS = 13
HIDDEN_CHANNELS = 64
BATCH_SIZE = 64
BENCHMARKING = False
EPOCHS = 500

In [5]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()

        self.conv1 = GraphConv(INPUT_CHANNELS, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, OUTPUT_CHANNELS)

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        #x = nn_func.dropout(x, training=self.training)
        x = self.lin(x)

        return x


def read_reactome_graph(e_fn):
    e_v1 = []
    e_v2 = []

    for line in open(e_fn, 'r'):
        dt = line.split()
        node1 = int(dt[0]) - 1  # subtracting to convert R idx to python idx
        node2 = int(dt[1]) - 1  # " "
        e_v1.append(node1)
        e_v2.append(node2)

    return e_v1, e_v2


def build_reactome_graph_datalist(e_v1, e_v2, n_fn, g_fn, s_fn):
    edge_index = torch.tensor([e_v1, e_v2], dtype=torch.long)
    feature_v = numpy.loadtxt(n_fn)
    target_v = numpy.loadtxt(g_fn, dtype=str, delimiter=",")
    sampleID_v = numpy.loadtxt(s_fn, dtype=str, delimiter=",")

    target_encoder = sklearn.preprocessing.LabelEncoder()
    target_v = target_encoder.fit_transform(target_v)
    sampleID_v = target_encoder.fit_transform(sampleID_v)

    d_list = []
    for row_idx in range(len(feature_v)):
        features = feature_v[row_idx, :]
        x = torch.tensor(features, dtype=torch.float)
        x = x.unsqueeze(1)
        y = torch.tensor([target_v[row_idx]])
        sid = torch.tensor([sampleID_v[row_idx]])
        d_list.append(Data(x=x, y=y, sid=sid, edge_index=edge_index))

    return d_list


def build_reactome_graph_loader(d_list, batch_size):
    loader = DataLoader(d_list, batch_size=batch_size, shuffle=False)#True)

    return loader


def train(loader, dv):
    model.train()

    correct = 0
    for batch in loader:  # Iterate in batches over the training dataset.
        x = batch.x.to(dv)
        e = batch.edge_index.to(dv)
        b = batch.batch.to(dv)
        y = batch.y.to(dv)
        out = model(x, e, b)  # Perform a single forward pass.
        loss = criterion(out, y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


def test(loader, dv):
    model.eval()

    targets = []
    predictions = []
    sample_ids = []
    tissues = []
    confidences = []
    for batch in loader:  # Iterate in batches over the test dataset.
        x = batch.x.to(dv)
        e = batch.edge_index.to(dv)
        b = batch.batch.to(dv)
        y = batch.y.to(dv)
        targets += torch.Tensor.tolist(y)
#         sample_ids += torch.Tensor.tolist(sid)
        sample_ids += batch.sid
        tissues += batch.tissue
        
        out = model(x, e, b)  # Perform a single forward pass.
        prob = torch.softmax(out, dim=1) # Apply softmax to get probabilities.

        pred = out.argmax(dim=1)  # Use the class with highest probability.
        predictions += torch.Tensor.tolist(pred)
        confidences += torch.Tensor.tolist(prob)  # Get the confidence score of the predicted class.
#         print(confidences)
    
    # Save targets, predictions, and confidences to a file
    num_classes = len(confidences[0])
    # Flatten confidences and create data for saving
    data_to_save = []
    for i in range(len(targets)):
        row = [sample_ids[i], tissues[i], targets[i], predictions[i]] + confidences[i]
        data_to_save.append(row)
    data_to_save = numpy.array(data_to_save)
    print(data_to_save)
    
    fmt = ['%s', '%s', '%s', '%s'] + ['%s' for _ in range(num_classes)]
    
    headers = ['sample_ids', 'tissues', 'target', 'prediction'] + [f'confidence_class_{i}' for i in range(num_classes)]
    numpy.savetxt(output_fn, data_to_save, fmt='\t'.join(fmt), delimiter='\t', header='\t'.join(headers), comments='')
        
    ari = adjusted_rand_score(targets, predictions)
    print(f'ari: {ari}')
    return ari

def change_key(self, old, new):
    for _ in range(len(self)):
        k, v = self.popitem(False)
        self[new if old == k else k] = v

In [6]:
(edge_v1, edge_v2) = read_reactome_graph(edges_fn)
model = GNN(hidden_channels=HIDDEN_CHANNELS)
device = cpu = torch.device('cpu')

# replace final layer with new shape matching new dataset
model.lin = Linear(HIDDEN_CHANNELS, NEW_CHANNELS)

sd = torch.load(model_fn, map_location=device)
change_key(sd, 'conv1.lin_l.weight', 'conv1.lin_rel.weight')
change_key(sd, 'conv1.lin_l.bias', 'conv1.lin_rel.bias')
change_key(sd, 'conv1.lin_r.weight', 'conv1.lin_root.weight')
change_key(sd, 'conv2.lin_l.weight', 'conv2.lin_rel.weight')
change_key(sd, 'conv2.lin_l.bias', 'conv2.lin_rel.bias')
change_key(sd, 'conv2.lin_r.weight', 'conv2.lin_root.weight')
change_key(sd, 'conv3.lin_l.weight', 'conv3.lin_rel.weight')
change_key(sd, 'conv3.lin_l.bias', 'conv3.lin_rel.bias')
change_key(sd, 'conv3.lin_r.weight', 'conv3.lin_root.weight')
change_key(sd, 'lin.weight', 'lin.weight')
change_key(sd, 'lin.bias', 'lin.bias')
sd.pop('lin.weight',None)
sd.pop('lin.bias',None)

tensor([ 1.2330, -1.6261, -0.2991, -0.6928,  0.4243, -0.0082,  0.9096,  0.0233,
         0.5126, -0.5584, -0.0224, -1.0940, -1.1144,  0.5656, -1.1386, -1.2715,
        -0.2404,  0.5158,  0.8607, -0.6714, -1.4077, -0.4892, -1.0589, -0.6149,
         0.4025, -0.9548])

In [7]:
model.load_state_dict(sd, strict=False)
model.eval()

GNN(
  (conv1): GraphConv(1, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=13, bias=True)
)

In [8]:
model.conv1.lin_rel.weight.requires_grad = False
model.conv1.lin_rel.bias.requires_grad = False
model.conv1.lin_root.weight.requires_grad = False
model.conv2.lin_rel.weight.requires_grad = False
model.conv2.lin_rel.bias.requires_grad = False
model.conv2.lin_root.weight.requires_grad = False
model.conv3.lin_rel.weight.requires_grad = False
model.conv3.lin_rel.bias.requires_grad = False
model.conv3.lin_root.weight.requires_grad = False
model.lin.weight.requires_grad = True
model.lin.bias.requires_grad = True

optimizer = torch.optim.AdamW(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [9]:
def build_reactome_graph_datalist(e_v1, e_v2, n_fn, g_fn, s_fn):
    edge_index = torch.tensor([e_v1, e_v2], dtype=torch.long)
    feature_v = numpy.loadtxt(n_fn)
    target_v = numpy.loadtxt(g_fn, dtype=str, delimiter=",")
    sampleID_v = numpy.loadtxt(s_fn, dtype=str, delimiter=",")

    target_encoder = sklearn.preprocessing.LabelEncoder()
    target_v = target_encoder.fit_transform(target_v)
    label_mapping = dict(zip(target_encoder.classes_, target_encoder.transform(target_encoder.classes_)))
    label_mapping2 = dict(zip(target_encoder.transform(target_encoder.classes_), target_encoder.classes_))
    print(label_mapping2)
    print(feature_v)
    print(target_v)

    d_list = []
    for row_idx in range(len(feature_v)):
        features = feature_v[row_idx, :]
        x = torch.tensor(features, dtype=torch.float)
        x = x.unsqueeze(1)
        y = torch.tensor([target_v[row_idx]])
        sample_id = sampleID_v[row_idx]
        tissue = label_mapping2[target_v[row_idx]]
        d_list.append(Data(x=x, y=y, edge_index=edge_index, sid=sample_id, tissue=tissue))

    return d_list

data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn, sampleID_fn)
print(data_list)

{0: '"Adipose"', 1: '"Brain"', 2: '"Eye"', 3: '"Heart"', 4: '"Intestine"', 5: '"Kidney"', 6: '"Liver"', 7: '"Lung"', 8: '"MOE"', 9: '"Muscle"', 10: '"Pancreas"', 11: '"Skin"', 12: '"Testes"'}
[[-0.16366222 -2.23275087 -2.23275087 ...  0.009564   -0.40321174
  -0.40321174]
 [-2.34776549 -1.54426281 -1.54426281 ...  0.29459949 -0.61086254
  -0.61086254]
 [ 1.55680552  1.02611109  1.02611109 ...  0.01396748  1.28198877
   1.28198877]
 ...
 [-0.05676369  2.80255259  2.80255259 ... -0.68373741 -0.11325952
  -0.11325952]
 [ 0.78743626 -0.29484588 -0.29484588 ... -0.01483183  0.47247653
   0.47247653]
 [-0.47722647 -0.88451594 -0.88451594 ...  1.23666415  1.26752322
   1.26752322]]
[ 1  6  1 ... 11  0 12]
[Data(x=[7856, 1], edge_index=[2, 6514], y=[1], sid='"SRR3087157"', tissue='"Brain"'), Data(x=[7856, 1], edge_index=[2, 6514], y=[1], sid='"SRR1810239"', tissue='"Liver"'), Data(x=[7856, 1], edge_index=[2, 6514], y=[1], sid='"SRR7790225"', tissue='"Brain"'), Data(x=[7856, 1], edge_index=[2, 

In [10]:
data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn, sampleID_fn)
print(len(data_list))
# retrain model for fine tuning transfer learning
train_data_list = data_list[0::2]
print(len(train_data_list))
print(f'Number of training graphs: {len(train_data_list)}')
train_data_loader = build_reactome_graph_loader(train_data_list, BATCH_SIZE)
for epoch in range(EPOCHS):
    train(train_data_loader, device)
    train_acc = train(train_data_loader, device)
    print(f'Epoch: {epoch}, Train Acc: {train_acc}')
    if train_acc == 1.0:
        break

test_data_list = data_list[1::2]
print(len(test_data_list))
print(f'Number of test graphs: {len(test_data_list)}')

test_data_loader = build_reactome_graph_loader(test_data_list, BATCH_SIZE)
test_ari = test(test_data_loader, device)
print(f'test_ari: {test_ari}')

model_save_name = f'tuned_pytorch_GEO_model_validation_model.pt'
path = f'/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_validation/GNN/{model_save_name}'
torch.save(model.state_dict(), path)
print(f'model saved as {path}')

{0: '"Adipose"', 1: '"Brain"', 2: '"Eye"', 3: '"Heart"', 4: '"Intestine"', 5: '"Kidney"', 6: '"Liver"', 7: '"Lung"', 8: '"MOE"', 9: '"Muscle"', 10: '"Pancreas"', 11: '"Skin"', 12: '"Testes"'}
[[-0.16366222 -2.23275087 -2.23275087 ...  0.009564   -0.40321174
  -0.40321174]
 [-2.34776549 -1.54426281 -1.54426281 ...  0.29459949 -0.61086254
  -0.61086254]
 [ 1.55680552  1.02611109  1.02611109 ...  0.01396748  1.28198877
   1.28198877]
 ...
 [-0.05676369  2.80255259  2.80255259 ... -0.68373741 -0.11325952
  -0.11325952]
 [ 0.78743626 -0.29484588 -0.29484588 ... -0.01483183  0.47247653
   0.47247653]
 [-0.47722647 -0.88451594 -0.88451594 ...  1.23666415  1.26752322
   1.26752322]]
[ 1  6  1 ... 11  0 12]
1445
723
Number of training graphs: 723




Epoch: 0, Train Acc: 0.36376210235131395
Epoch: 1, Train Acc: 0.5795297372060858
Epoch: 2, Train Acc: 0.6127247579529738
Epoch: 3, Train Acc: 0.6293222683264177
Epoch: 4, Train Acc: 0.6376210235131397
Epoch: 5, Train Acc: 0.65283540802213
Epoch: 6, Train Acc: 0.6583679114799447
Epoch: 7, Train Acc: 0.665283540802213
Epoch: 8, Train Acc: 0.673582295988935
Epoch: 9, Train Acc: 0.677731673582296
Epoch: 10, Train Acc: 0.6915629322268326
Epoch: 11, Train Acc: 0.706777316735823
Epoch: 12, Train Acc: 0.7095435684647303
Epoch: 13, Train Acc: 0.7206085753803596
Epoch: 14, Train Acc: 0.7275242047026279
Epoch: 15, Train Acc: 0.7372060857538036
Epoch: 16, Train Acc: 0.7413554633471646
Epoch: 17, Train Acc: 0.7524204702627939
Epoch: 18, Train Acc: 0.7579529737206085
Epoch: 19, Train Acc: 0.7621023513139695
Epoch: 20, Train Acc: 0.7634854771784232
Epoch: 21, Train Acc: 0.7634854771784232
Epoch: 22, Train Acc: 0.7690179806362379
Epoch: 23, Train Acc: 0.7800829875518672
Epoch: 24, Train Acc: 0.7800829

Epoch: 198, Train Acc: 0.9059474412171508
Epoch: 199, Train Acc: 0.9059474412171508
Epoch: 200, Train Acc: 0.9059474412171508
Epoch: 201, Train Acc: 0.9059474412171508
Epoch: 202, Train Acc: 0.9059474412171508
Epoch: 203, Train Acc: 0.9059474412171508
Epoch: 204, Train Acc: 0.9059474412171508
Epoch: 205, Train Acc: 0.9073305670816044
Epoch: 206, Train Acc: 0.9073305670816044
Epoch: 207, Train Acc: 0.9073305670816044
Epoch: 208, Train Acc: 0.9087136929460581
Epoch: 209, Train Acc: 0.9100968188105117
Epoch: 210, Train Acc: 0.9100968188105117
Epoch: 211, Train Acc: 0.9100968188105117
Epoch: 212, Train Acc: 0.9100968188105117
Epoch: 213, Train Acc: 0.9100968188105117
Epoch: 214, Train Acc: 0.9100968188105117
Epoch: 215, Train Acc: 0.9100968188105117
Epoch: 216, Train Acc: 0.9100968188105117
Epoch: 217, Train Acc: 0.9100968188105117
Epoch: 218, Train Acc: 0.9100968188105117
Epoch: 219, Train Acc: 0.9100968188105117
Epoch: 220, Train Acc: 0.9114799446749654
Epoch: 221, Train Acc: 0.911479944

Epoch: 395, Train Acc: 0.9239280774550485
Epoch: 396, Train Acc: 0.9239280774550485
Epoch: 397, Train Acc: 0.9239280774550485
Epoch: 398, Train Acc: 0.9239280774550485
Epoch: 399, Train Acc: 0.9239280774550485
Epoch: 400, Train Acc: 0.9239280774550485
Epoch: 401, Train Acc: 0.9239280774550485
Epoch: 402, Train Acc: 0.9239280774550485
Epoch: 403, Train Acc: 0.9239280774550485
Epoch: 404, Train Acc: 0.9239280774550485
Epoch: 405, Train Acc: 0.9239280774550485
Epoch: 406, Train Acc: 0.9239280774550485
Epoch: 407, Train Acc: 0.9239280774550485
Epoch: 408, Train Acc: 0.9239280774550485
Epoch: 409, Train Acc: 0.9239280774550485
Epoch: 410, Train Acc: 0.9280774550484094
Epoch: 411, Train Acc: 0.9280774550484094
Epoch: 412, Train Acc: 0.9280774550484094
Epoch: 413, Train Acc: 0.9280774550484094
Epoch: 414, Train Acc: 0.9280774550484094
Epoch: 415, Train Acc: 0.9280774550484094
Epoch: 416, Train Acc: 0.9280774550484094
Epoch: 417, Train Acc: 0.9280774550484094
Epoch: 418, Train Acc: 0.928077455