In [8]:
import os.path as op
import random
import time

import matplotlib.pyplot as plt
import numpy
import sklearn
import torch
import torch.nn.functional as nn_func
from sklearn import preprocessing
from sklearn.metrics import adjusted_rand_score
from torch.nn import Linear
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GraphConv, global_mean_pool

In [9]:
node_features_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/input/node_features10vs0_2_time_cross.txt'
graph_targets_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/input/graph_targets10vs0_time_cross.txt'
edges_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/input/edges.txt'
model_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/GNN/trained_pytorch_model_rewired10_fold_full_dataset.pt'
output_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/output/predictions10vs0_time_cross.tsv'
sample_id_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/output/tcdd_sample_id10vs0_time_cross.txt'
project_id_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/output/tcdd_project_id10vs0_time_cross.txt'
gender_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/output/tcdd_gender10vs0_time_cross.txt'
dose_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/output/tcdd_dose10vs0_time_cross.txt'
# test graph_targets.txt, node_features.txt and edges.txt
features_exist = op.exists(node_features_fn)
targets_exist = op.exists(graph_targets_fn)
edges_exist = op.exists(edges_fn)
model_exists = op.exists(model_fn)

print(f'features exist: {features_exist},'
      f' targets exist: {targets_exist},'
      f' edges exist: {edges_exist}',
      f' model exists: {model_exists}')
assert features_exist
assert targets_exist
assert edges_exist
assert model_exists


features exist: True, targets exist: True, edges exist: True  model exists: True


In [10]:
INPUT_CHANNELS = 1
OUTPUT_CHANNELS = 26
NEW_CHANNELS = 2
HIDDEN_CHANNELS = 64
BATCH_SIZE = 64
BENCHMARKING = False
EPOCHS = 500

In [11]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()

        self.conv1 = GraphConv(INPUT_CHANNELS, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, OUTPUT_CHANNELS)

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        #x = nn_func.dropout(x, training=self.training)
        x = self.lin(x)

        return x


def read_reactome_graph(e_fn):
    e_v1 = []
    e_v2 = []

    for line in open(e_fn, 'r'):
        dt = line.split()
        node1 = int(dt[0]) - 1  # subtracting to convert R idx to python idx
        node2 = int(dt[1]) - 1  # " "
        e_v1.append(node1)
        e_v2.append(node2)

    return e_v1, e_v2


def build_reactome_graph_datalist(e_v1, e_v2, n_fn, g_fn, pid_fn, sid_fn, gen_fn, dose_fn):
    edge_index = torch.tensor([e_v1, e_v2], dtype=torch.long)
    feature_v = numpy.loadtxt(n_fn)
    target_v = numpy.loadtxt(g_fn, dtype=float, delimiter=",")
    projectID_v = numpy.loadtxt(pid_fn, dtype=str, delimiter="\t")
    sampleID_v = numpy.loadtxt(sid_fn, dtype=str, delimiter="\t")
    gender_v = numpy.loadtxt(gen_fn, dtype=str, delimiter="\t")
    dose_v = numpy.loadtxt(dose_fn, dtype=str, delimiter="\t")
    
    binary_labels = (target_v > 0).astype(int)
    
    print("labels check:")
    for dose, label in zip(target_v[:10], binary_labels[:10]): 
        print(f"dose: {dose}, label: {label}")


    d_list = []
    for row_idx in range(len(feature_v)):
        features = feature_v[row_idx, :]
        x = torch.tensor(features, dtype=torch.float)
        x = x.unsqueeze(1)
#         y = torch.tensor([target_v[row_idx]])
        y = torch.tensor([binary_labels[row_idx]], dtype=torch.long)
        
        pid = projectID_v[row_idx]
        sid = sampleID_v[row_idx]
        gen = gender_v[row_idx]
        dose = dose_v[row_idx]
        
        d_list.append(Data(x=x, y=y, pid=pid, sid=sid, gen=gen, dose=dose, edge_index=edge_index))

    return d_list


def build_reactome_graph_loader(d_list, batch_size):
    loader = DataLoader(d_list, batch_size=batch_size, shuffle=False)  # True)

    return loader


def train(loader, dv):
    model.train()

    correct = 0
    for batch in loader:  # Iterate in batches over the training dataset.
        batch.validate()
        x = batch.x.to(dv)
        e = batch.edge_index.to(dv)
        b = batch.batch.to(dv)
        y = batch.y.to(dv)

        out = model(x, e, b)  # Perform a single forward pass.
        loss = criterion(out, y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


def test(loader, dv):
    model.eval()

    targets = []
    predictions = []
    project_ids = []
    sample_ids = []
    genders = []
    doses = []
    confidences = []
    for batch in loader:  # Iterate in batches over the test dataset.
        x = batch.x.to(dv)
        e = batch.edge_index.to(dv)
        b = batch.batch.to(dv)
        y = batch.y.to(dv)
        targets += torch.Tensor.tolist(y)
        
        project_ids += batch.pid
        sample_ids += batch.sid
        genders += batch.gen
        doses += batch.dose
        
        out = model(x, e, b)  # Perform a single forward pass.
        prob = torch.softmax(out, dim=1)
        
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        predictions += torch.Tensor.tolist(pred)
        confidences += torch.Tensor.tolist(prob)
        
    num_classes = len(confidences[0])

    data_to_save = []
    for i in range(len(targets)):
        row = [project_ids[i], sample_ids[i], genders[i], doses[i], targets[i], predictions[i]] + confidences[i]
        data_to_save.append(row)
    data_to_save = numpy.array(data_to_save)
    print(data_to_save)
    
    fmt = ['%s', '%s', '%s', '%s', '%s', '%s'] + ['%s' for _ in range(num_classes)]
    
    headers = ['project_ids', 'sample_ids', 'genders', 'doses', 'target', 'prediction'] + [f'confidence_class_{i}' for i in range(num_classes)]
    numpy.savetxt(output_fn, data_to_save, fmt='\t'.join(fmt), delimiter='\t', header='\t'.join(headers), comments='')
        
    ari = adjusted_rand_score(targets, predictions)
    print(f'ari: {ari}')
    return ari

In [12]:
def change_key(self, old, new):
    for _ in range(len(self)):
        k, v = self.popitem(False)
        self[new if old == k else k] = v


(edge_v1, edge_v2) = read_reactome_graph(edges_fn)
model = GNN(hidden_channels=HIDDEN_CHANNELS)
device = cpu = torch.device('cpu')

sd = torch.load(model_fn, map_location=device)
change_key(sd, 'conv1.lin_l.weight', 'conv1.lin_rel.weight')
change_key(sd, 'conv1.lin_l.bias', 'conv1.lin_rel.bias')
change_key(sd, 'conv1.lin_r.weight', 'conv1.lin_root.weight')
change_key(sd, 'conv2.lin_l.weight', 'conv2.lin_rel.weight')
change_key(sd, 'conv2.lin_l.bias', 'conv2.lin_rel.bias')
change_key(sd, 'conv2.lin_r.weight', 'conv2.lin_root.weight')
change_key(sd, 'conv3.lin_l.weight', 'conv3.lin_rel.weight')
change_key(sd, 'conv3.lin_l.bias', 'conv3.lin_rel.bias')
change_key(sd, 'conv3.lin_r.weight', 'conv3.lin_root.weight')
change_key(sd, 'lin.weight', 'lin.weight')
change_key(sd, 'lin.bias', 'lin.bias')

model.load_state_dict(sd)
model.eval()

GNN(
  (conv1): GraphConv(1, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=26, bias=True)
)

In [13]:
# replace final layer with new shape matching new dataset
model.lin = Linear(HIDDEN_CHANNELS, NEW_CHANNELS)

model.conv1.lin_rel.weight.requires_grad = False
model.conv1.lin_rel.bias.requires_grad = False
model.conv1.lin_root.weight.requires_grad = False
model.conv2.lin_rel.weight.requires_grad = False
model.conv2.lin_rel.bias.requires_grad = False
model.conv2.lin_root.weight.requires_grad = False
model.conv3.lin_rel.weight.requires_grad = False
model.conv3.lin_rel.bias.requires_grad = False
model.conv3.lin_root.weight.requires_grad = False
model.lin.weight.requires_grad = True
model.lin.bias.requires_grad = True

# for name, param in model.named_parameters(): print(name, param)

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()))
criterion = torch.nn.CrossEntropyLoss()

In [14]:
data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn, project_id_fn, sample_id_fn, gender_fn, dose_fn)
print(len(data_list))
# retrain model for fine tuning transfer learning
train_data_list = data_list  # all data
print(len(train_data_list))
print(f'Number of training graphs: {len(train_data_list)}')
train_data_loader = build_reactome_graph_loader(train_data_list, BATCH_SIZE)
for epoch in range(EPOCHS):
    train(train_data_loader, device)
    train_acc = train(train_data_loader, device)
    print(f'Epoch: {epoch}, Train Acc: {train_acc}')
    if train_acc == 1.0:
        break

final_ari = test(train_data_loader, device)
print(f'test_ari: {final_ari}')

model_save_name = f'tuned_pytorch_tcdd_model10vs0_time_cross.pt'
path = f'/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/tcdd/output/{model_save_name}'
torch.save(model.state_dict(), path)
print(f'model saved as {path}')

labels check:
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
dose: 0.0, label: 0
90
90
Number of training graphs: 90
Epoch: 0, Train Acc: 0.6555555555555556
Epoch: 1, Train Acc: 0.7333333333333333
Epoch: 2, Train Acc: 0.7777777777777778
Epoch: 3, Train Acc: 0.8333333333333334
Epoch: 4, Train Acc: 0.8666666666666667
Epoch: 5, Train Acc: 0.8666666666666667
Epoch: 6, Train Acc: 0.8666666666666667
Epoch: 7, Train Acc: 0.8777777777777778
Epoch: 8, Train Acc: 0.8777777777777778
Epoch: 9, Train Acc: 0.8777777777777778
Epoch: 10, Train Acc: 0.8777777777777778
Epoch: 11, Train Acc: 0.8777777777777778
Epoch: 12, Train Acc: 0.8777777777777778
Epoch: 13, Train Acc: 0.8777777777777778
Epoch: 14, Train Acc: 0.8777777777777778
Epoch: 15, Train Acc: 0.8777777777777778
Epoch: 16, Train Acc: 0.8777777777777778
Epoch: 17, Train Acc: 0.8888888888888888
Epoch: 18, Train Acc: 

Epoch: 193, Train Acc: 0.9777777777777777
Epoch: 194, Train Acc: 0.9777777777777777
Epoch: 195, Train Acc: 0.9777777777777777
Epoch: 196, Train Acc: 0.9777777777777777
Epoch: 197, Train Acc: 0.9777777777777777
Epoch: 198, Train Acc: 0.9777777777777777
Epoch: 199, Train Acc: 0.9777777777777777
Epoch: 200, Train Acc: 0.9777777777777777
Epoch: 201, Train Acc: 0.9777777777777777
Epoch: 202, Train Acc: 0.9777777777777777
Epoch: 203, Train Acc: 0.9777777777777777
Epoch: 204, Train Acc: 0.9777777777777777
Epoch: 205, Train Acc: 0.9777777777777777
Epoch: 206, Train Acc: 0.9777777777777777
Epoch: 207, Train Acc: 0.9777777777777777
Epoch: 208, Train Acc: 0.9777777777777777
Epoch: 209, Train Acc: 0.9777777777777777
Epoch: 210, Train Acc: 0.9777777777777777
Epoch: 211, Train Acc: 0.9777777777777777
Epoch: 212, Train Acc: 0.9777777777777777
Epoch: 213, Train Acc: 0.9777777777777777
Epoch: 214, Train Acc: 0.9777777777777777
Epoch: 215, Train Acc: 0.9777777777777777
Epoch: 216, Train Acc: 0.977777777