In [1]:
import collections
import numpy as np
import pandas as pd
import itertools

import matplotlib
import matplotlib.pyplot as plt

import torch 
import torch.nn as nn 
import torch.nn.functional as F
from torchmetrics import Accuracy
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST

import pytorch_lightning as pl

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx
from torch_geometric.nn import GCNConv

import networkx as nx

# data processing

In [2]:
df_raw = pd.read_csv('/scratch-shared/martin/003_SPECS1K_ML/001_data/Specs935_ImageMeans_AfterQC_AnnotatedWithMOA.csv', sep=';')
df_raw.rename(columns={'Compound ID':'Compound_ID'}, inplace = True)

In [3]:
df = df_raw.copy()

In [4]:
df.dropna(subset=['Compound_ID'], inplace=True)
df = df[df['selected_mechanism'].str.contains('dmso')==False] # actually not dropping anything, since dropna already drop all dmso

In [5]:
# avg seems not good for GNN?

# averaging the image data based on Compound Id and mech
df_avg = df.groupby(['selected_mechanism','Compound_ID']).mean() # merge the data based on Compound_ID, avoiding reading "duplicated data"
df_avg.reset_index(inplace=True)
df_avg.drop('Compound_ID', axis=1, inplace=True)

In [13]:
a = list('ABCDE')
bins = pd.IntervalIndex.from_tuples([(0, 40), (40, 80), (80, 120), (120, 160), (160, np.max(df["Count_nuclei"]) + 1)], closed='left')

df_avg['new'] = np.array(a)[pd.cut(df_avg["Count_nuclei"], bins = bins).cat.codes]
df_avg

Unnamed: 0,selected_mechanism,ImageNumber_nuclei,ObjectNumber_nuclei,Metadata_Site_nuclei,AreaShape_Area_nuclei,AreaShape_BoundingBoxArea_nuclei,AreaShape_BoundingBoxMaximum_X_nuclei,AreaShape_BoundingBoxMaximum_Y_nuclei,AreaShape_BoundingBoxMinimum_X_nuclei,AreaShape_BoundingBoxMinimum_Y_nuclei,...,RadialDistribution_ZernikePhase_illumSYTO_9_1_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_3_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_5_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_7_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_9_cytoplasm,Site,cmpd_conc,Flag,Count_nuclei,new
0,ATPase inhibitor,1.0,71.245173,5.239130,2594.052840,3641.637710,1115.659189,1105.359249,1055.397388,1044.748298,...,0.037613,-0.004677,0.029404,-0.004017,-0.006743,5.239130,10.0,0.0,129.934783,D
1,ATPase inhibitor,1.0,75.987102,4.909091,2575.014391,3595.923493,1103.946315,1113.647177,1044.400607,1053.059617,...,0.017658,0.006239,0.016753,0.010671,-0.045862,4.909091,10.0,0.0,137.136364,D
2,ATPase inhibitor,1.0,76.035368,5.187500,2520.257557,3542.050967,1091.662087,1114.037379,1031.724602,1054.629479,...,-0.062818,-0.000146,0.000251,0.025794,-0.043205,5.187500,10.0,0.0,137.479167,D
3,ATPase inhibitor,1.0,87.332931,5.022222,2590.518971,3655.140277,1089.390643,1122.687979,1028.800183,1062.081933,...,0.030789,-0.005525,0.026131,0.017409,0.006531,5.022222,10.0,0.0,159.022222,D
4,ATPase inhibitor,1.0,10.026111,4.968750,1715.445263,2698.307079,1042.539294,1133.979970,993.095621,1082.576196,...,-0.065412,-0.053279,-0.157138,-0.003663,-0.152837,4.968750,10.0,0.0,20.500000,A
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
909,tubulin polymerization inhibitor,1.0,20.402987,5.000000,2082.156880,3014.069150,1076.408519,1068.403089,1025.965967,1017.558366,...,-0.138100,-0.064953,0.028538,-0.088900,0.061930,5.000000,10.0,0.0,30.416667,A
910,tubulin polymerization inhibitor,1.0,11.223901,4.790698,2383.460366,3494.812003,1085.155299,1068.350095,1030.907400,1012.764174,...,-0.170362,0.252728,0.184233,0.153065,0.110785,4.790698,10.0,0.0,19.302326,A
911,tubulin polymerization inhibitor,1.0,13.496681,5.000000,1996.259680,2951.947594,1027.813396,1103.537110,978.226385,1052.836261,...,-0.031864,-0.069233,0.052349,0.003676,-0.093969,5.000000,10.0,0.0,22.541667,A
912,tubulin polymerization inhibitor,1.0,13.776276,4.903846,2145.525287,3180.830545,1094.460819,1059.192757,1042.688779,1005.678258,...,-0.037189,-0.199157,0.027529,-0.168356,-0.187250,4.903846,10.0,0.0,24.596154,A


In [None]:
#df.drop(["Plate", "Plate_Well", "batch_id", "pertType", "Batch nr", "Compound_ID", "PlateID", "Well"], axis=1, inplace=True)

# Data loading

In [22]:
df_avg

Unnamed: 0,selected_mechanism,ImageNumber_nuclei,ObjectNumber_nuclei,Metadata_Site_nuclei,AreaShape_Area_nuclei,AreaShape_BoundingBoxArea_nuclei,AreaShape_BoundingBoxMaximum_X_nuclei,AreaShape_BoundingBoxMaximum_Y_nuclei,AreaShape_BoundingBoxMinimum_X_nuclei,AreaShape_BoundingBoxMinimum_Y_nuclei,...,RadialDistribution_ZernikePhase_illumSYTO_9_1_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_3_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_5_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_7_cytoplasm,RadialDistribution_ZernikePhase_illumSYTO_9_9_cytoplasm,Site,cmpd_conc,Flag,Count_nuclei,new
0,0,1.0,71.245173,5.239130,2594.052840,3641.637710,1115.659189,1105.359249,1055.397388,1044.748298,...,0.037613,-0.004677,0.029404,-0.004017,-0.006743,5.239130,10.0,0.0,129.934783,D
1,0,1.0,75.987102,4.909091,2575.014391,3595.923493,1103.946315,1113.647177,1044.400607,1053.059617,...,0.017658,0.006239,0.016753,0.010671,-0.045862,4.909091,10.0,0.0,137.136364,D
2,0,1.0,76.035368,5.187500,2520.257557,3542.050967,1091.662087,1114.037379,1031.724602,1054.629479,...,-0.062818,-0.000146,0.000251,0.025794,-0.043205,5.187500,10.0,0.0,137.479167,D
3,0,1.0,87.332931,5.022222,2590.518971,3655.140277,1089.390643,1122.687979,1028.800183,1062.081933,...,0.030789,-0.005525,0.026131,0.017409,0.006531,5.022222,10.0,0.0,159.022222,D
4,0,1.0,10.026111,4.968750,1715.445263,2698.307079,1042.539294,1133.979970,993.095621,1082.576196,...,-0.065412,-0.053279,-0.157138,-0.003663,-0.152837,4.968750,10.0,0.0,20.500000,A
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
909,29,1.0,20.402987,5.000000,2082.156880,3014.069150,1076.408519,1068.403089,1025.965967,1017.558366,...,-0.138100,-0.064953,0.028538,-0.088900,0.061930,5.000000,10.0,0.0,30.416667,A
910,29,1.0,11.223901,4.790698,2383.460366,3494.812003,1085.155299,1068.350095,1030.907400,1012.764174,...,-0.170362,0.252728,0.184233,0.153065,0.110785,4.790698,10.0,0.0,19.302326,A
911,29,1.0,13.496681,5.000000,1996.259680,2951.947594,1027.813396,1103.537110,978.226385,1052.836261,...,-0.031864,-0.069233,0.052349,0.003676,-0.093969,5.000000,10.0,0.0,22.541667,A
912,29,1.0,13.776276,4.903846,2145.525287,3180.830545,1094.460819,1059.192757,1042.688779,1005.678258,...,-0.037189,-0.199157,0.027529,-0.168356,-0.187250,4.903846,10.0,0.0,24.596154,A


In [20]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
df_avg["selected_mechanism"] = le.fit_transform(df_avg.iloc[:,0])

## Graph formation

In [30]:
def tabular2graph(df, num_graphs=1000, num_nodes=500):
    #https://colab.research.google.com/drive/1_eR7DXBF3V4EwH946dDPOxeclDBeKNMD?usp=sharing
    data_lst = []
    
    for i in range(num_graphs):
        df_sub = df.sample(n=num_nodes)
        df_sub.reset_index(drop=True, inplace=True)
        teams = df_sub["new"].unique()
        all_edges = np.array([], dtype=np.int32).reshape((0, 2))

        for team in teams:
            team_df = df_sub[df_sub["new"] == team]
            players = team_df.index
            # Build all combinations, as all players are connected
            permutations = list(itertools.combinations(players, 2))
            edges_source = [e[0] for e in permutations]
            edges_target = [e[1] for e in permutations]
            team_edges = np.column_stack([edges_source, edges_target])
            all_edges = np.vstack([all_edges, team_edges])
        # Convert to Pytorch Geometric format
        edge_index = all_edges.transpose()
    
        node_features = df_sub.iloc[:,1:-1]
        labels = df_sub.iloc[:,0]
    
        x = torch.tensor(node_features.values)    # node features
        y = torch.tensor(labels.values) # label in label endcoder form
        data = Data(x=x, edge_index=edge_index, y=y) # making graph in PyG
    
        data_lst.append(data)
    return data_lst

In [31]:
val_pct = 0.2

train_df = df_avg.sample(frac = 1-val_pct)
valid_df = df_avg.drop(train_df.index)

train_lst = tabular2graph(train_df, num_graphs=1000, num_nodes=100)
valid_lst = tabular2graph(valid_df, num_graphs=1000, num_nodes=100)

In [32]:
from torch_geometric.loader import DataLoader

trainloader = DataLoader(train_lst, batch_size=64)
validloader = DataLoader(valid_lst, batch_size=64)

# GNN model

In [36]:
from torch_geometric.nn import GCNConv, Sequential, global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(2130, 128, aggr='add')
        self.conv2 = GCNConv(128, 128, aggr='add')
        self.conv3 = GCNConv(128, 30, aggr='add')
        self.fc = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, 30)
        )
             
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.fc(x)
        return x
    
    
        """
        self.num_features = 2130
        self.hidden = 128
        self.num_classes = 30
        
        
        self.gnn = Sequential("x, edge_index", [                
                (GCNConv(self.num_features, self.hidden), "x, edge_index -> x1"),
                (nn.ReLU(), "x1 -> x1a"),                                         
                (nn.Dropout(p=0.5), "x1a -> x1d"),                         
                (GCNConv(self.hidden, self.hidden), "x1d, edge_index -> x2"), 
                (nn.ReLU(), "x2 -> x2a"),                                      
                (nn.Dropout(p=0.5), "x2a -> x2d"),                           
                (GCNConv(self.hidden, self.hidden), "x2d, edge_index -> x3"), 
                (nn.ReLU(), "x3 -> x3a"),                                       
                (nn.Dropout(p=0.5), "x3a -> x3d"),                         
                (GCNConv(self.hidden, self.hidden), "x3d, edge_index -> x4"), 
                (nn.ReLU(), "x4 -> x4a"),                                       
                (nn.Dropout(p=0.5), "x4a -> x4d"),                          
                (GCNConv(self.hidden, self.hidden), "x4d, edge_index -> x5"), 
                (nn.ReLU(), "x5 -> x5a"),                                         
                (nn.Dropout(p=0.5), "x5a -> x5d"),                               
                #(global_mean_pool, "x5d, batch_index -> x6"),                 
                (nn.Linear(self.hidden, self.num_classes), "x5d -> x_out")]) 
        
        self.fc = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 30)
        )
        
        
        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index)
            x = x.relu()
            # = F.dropout(x, p=0.5)
            x = self.conv2(x, edge_index)
            x = x.relu()
            #x = F.dropout(x, p=0.5)
            x = self.conv2(x, edge_index)
            x = x.relu()
            #x = F.dropout(x, p=0.5)
            #x = self.conv3(x, edge_index)
            x = self.fc(x)
            return x
        
        """

# trainning

In [37]:
device = "cuda" if torch.cuda.is_available() else "cpu"

#device = "cpu"

model = GCN().to(device)
n_epochs = 401
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)  # weight_decay is for L2 regularization

In [None]:
best_acc = -1

for epoch in range(n_epochs):
    #----Training----#
    model.train()
    train_loss = []
    train_accs = []
    for batch in trainloader:
        batch.to(device)
        x, edge_index = batch.x, batch.edge_index
        logits = model(x.float().to(device), torch.tensor(edge_index[0]).long().to(device))
        loss = criterion(logits, torch.tensor(batch.y))
        optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()

        acc = (logits.argmax(dim=-1) == batch.y).float().mean()

        train_loss.append(loss.item())
        train_accs.append(acc)
        
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)
    
    #----Validation----#
    model.eval()
    valid_loss = []
    valid_accs = []
    
    for batch in validloader:
        batch.to(device)
        x, edge_index = batch.x, batch.edge_index
        with torch.no_grad():
            logits = model(x.float().to(device), torch.tensor(edge_index[0]).long().to(device))
        loss = criterion(logits, torch.tensor(batch.y))
        acc = (logits.argmax(dim=-1) == batch.y).float().mean()
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)
    
    # Print the information.
    if epoch%20==0:
        print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")
        print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
        
    
    if valid_acc > best_acc:
        torch.save(model.state_dict(), './model_averaged_data.ckpt')
        print(f'model saved at {epoch} epochs with acc {valid_acc}')
        best_acc = valid_acc

  loss = criterion(logits, torch.tensor(batch.y))
  loss = criterion(logits, torch.tensor(batch.y))


[ Train | 001/401 ] loss = 1.19886, acc = 0.56995
[ Valid | 001/401 ] loss = 1.19252, acc = 0.64548
model saved at 0 epochs with acc 0.6454766392707825
model saved at 11 epochs with acc 0.660144567489624
[ Train | 021/401 ] loss = 1.19880, acc = 0.58226
[ Valid | 021/401 ] loss = 1.64590, acc = 0.50210
[ Train | 041/401 ] loss = 1.19918, acc = 0.57368
[ Valid | 041/401 ] loss = 1.33637, acc = 0.59457
model saved at 46 epochs with acc 0.6778261661529541
[ Train | 061/401 ] loss = 1.37491, acc = 0.52499
[ Valid | 061/401 ] loss = 1.66801, acc = 0.48042
model saved at 68 epochs with acc 0.6904823780059814
model saved at 77 epochs with acc 0.6951797008514404
[ Train | 081/401 ] loss = 1.11076, acc = 0.60137
[ Valid | 081/401 ] loss = 1.51814, acc = 0.55599
model saved at 98 epochs with acc 0.7305722832679749
[ Train | 101/401 ] loss = 1.37157, acc = 0.54947
[ Valid | 101/401 ] loss = 1.52464, acc = 0.52478
[ Train | 121/401 ] loss = 1.25790, acc = 0.57846
[ Valid | 121/401 ] loss = 1.27141

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)