In [23]:
import dgl
from sklearn.neighbors import kneighbors_graph
import numpy as np
import pandas as pd
import torch

In [24]:
pt_data = pd.read_excel('./CRC_TMAs_patient_annotations.xlsx')
cell_data = pd.read_csv('./CRC_master.csv')

In [25]:
LABEL = "label"
CENTROID = "centroid"
FEATURES = "feat"
def data_to_dgl_graph(pt_data, cell_data, k=5, thresh=50, mode="distance", normalize=False):
    labs, ids = np.unique(cell_data.loc[:,'ClusterName'].to_numpy(), return_inverse=True)
    cell_data['ClusterID'] = ids
    graphs = []
    patients = []
    targets = []
    for spot in cell_data["spots"].unique():
        subset = cell_data.loc[cell_data.loc[:,'spots'] == spot,:]
        features = subset.loc[:,'size':'Treg-PD-1+'].to_numpy() # TODO: This step depends on column order
        if normalize:
            features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
        centroids = subset.loc[:,'X':'Y'].to_numpy()
        annotation = subset.loc[:,'ClusterID'].to_numpy()
        num_nodes = features.shape[0]
        graph = dgl.DGLGraph()
        graph.add_nodes(num_nodes)
        graph.ndata[CENTROID] = torch.FloatTensor(centroids)
        graph.ndata[FEATURES] = torch.FloatTensor(features)
        if annotation is not None:
            graph.ndata[LABEL] = torch.FloatTensor(annotation.astype(float))
        adj = kneighbors_graph(
            centroids,
            k,
            mode=mode,
            include_self=False,
            metric="euclidean").toarray()
        if thresh is not None:
            adj[adj > thresh] = 0
        edge_list = np.nonzero(adj)
        graph.add_edges(list(edge_list[0]), list(edge_list[1]))
        graphs.append(graph)
        assert(len(subset['patients'].unique()) == 1)
        patients.append(subset['patients'].unique()[0])
    for pt in patients:
        cp = pt_data.loc[pt_data.loc[:,'Patient'] == pt,'cp_TNM_Simple'].values[0]
        assert(cp == 3.0 or cp == 4.0)
        t = 0 if cp == 3.0 else 1
        targets.append(t)
    return list(zip(graphs,targets))

In [42]:
import torch
from tqdm import trange
from histocartography.ml import CellGraphModel
import random
from torch.utils.data import DataLoader

def collate(batch):
    g = dgl.batch([example[0] for example in batch])
    l = torch.LongTensor([example[1] for example in batch])
    return g, l

def dataset_split(data, val_prop):
    if val_prop == 0:
        return data, data
    random.shuffle(data)
    train_data = data[:int(len(data)*val_prop)]
    val_data = data[int(len(data)*val_prop):]
    return train_data, val_data

#TODO: More sophisticated oversample
def oversample_positive(data, oversample_factor=2):
    negative = []
    positive = []
    for item in data:
        if item[1] == 0:
            negative.append(item)
        else:
            positive.append(item)
    positive = oversample_factor*positive
    return positive+negative

class CGModel():
    def __init__(self, gnn_params, classification_params, node_dim, num_classes=2, lr=10e-3, weight_decay=5e-4, num_epochs=50, batch_size=8):
        self.gnn_params = gnn_params
        self.classification_params = classification_params
        self.node_dim = node_dim
        self.num_classes = num_classes
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.cgm = CellGraphModel(gnn_params, classification_params, node_dim, num_classes=2)

    def train(self, data, val_prop=0, oversample_factor=1):
        optimizer = torch.optim.Adam(
            self.cgm.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )
        # define loss function
        loss_fn = torch.nn.CrossEntropyLoss()
        # training
        loss = 10e5
        val_accuracy = 0.
        train_dataloader = None
        val_dataloader = None
        loss_list = []
        val_accuracy_list = []
        train_data, val_data = dataset_split(data, val_prop)
        train_data = oversample_positive(train_data, oversample_factor=oversample_factor)
        train_dataloader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True, collate_fn=collate)
        val_dataloader = DataLoader(val_data, batch_size=self.batch_size, shuffle=True, collate_fn=collate)
        with trange(self.num_epochs) as t:
            for epoch in t:
                t.set_description('Loss={} | Val Accuracy={}'.format(loss, val_accuracy))
                self.cgm.train()
                for graphs, labels in train_dataloader:
                    logits = self.cgm(graphs)
                    loss = loss_fn(logits, labels)
                    loss_list.append(loss.item())
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                self.cgm.eval()
                all_val_logits = []
                all_val_labels = []
                for graphs, labels in val_dataloader:
                    with torch.no_grad():
                        logits = self.cgm(graphs)
                    all_val_logits.append(logits)
                    all_val_labels.append(labels)
                all_val_logits = torch.cat(all_val_logits).cpu()
                all_val_labels = torch.cat(all_val_labels).cpu()
                with torch.no_grad():
                    _, predictions = torch.max(all_val_logits, dim=1)
                    correct = torch.sum(predictions.to(int) == all_val_labels.to(int))
                    val_accuracy = round(correct.item() * 1.0 / len(all_val_labels), 2)
                    val_accuracy_list.append(val_accuracy)
        return loss_list, val_accuracy_list
    
    def infer(self, graphs):
        with torch.no_grad():
            logits = self.cgm(graphs)
            _, predictions = torch.max(logits, dim=1)
            return predictions
    
    def test(self, graphs, labels):
        predictions = self.infer(graphs)
        with torch.no_grad():
            correct = torch.sum(predictions.to(int) == labels.to(int))
            accuracy = round(correct.item() * 1.0 / len(labels), 2)
            return accuracy
    
    def save(self, model_path):
        #TODO: Save model and params to model path
        pass

    def load(self, model_path):
        #TODO: Load model cpt and params from model path
        pass

In [28]:
gnn_params = {
    'readout_op': 'concat',
    'layer_type': 'gin_layer',
    'output_dim': 32,
    'num_layers': 2,
    'readout_type': 'mean'
}
classification_params = {
    'hidden_dim': 20,
    'num_layers': 2
}

In [81]:
data = data_to_dgl_graph(pt_data, cell_data)

In [36]:
data_connectivity = data_to_dgl_graph(pt_data, cell_data, mode='connectivity')

In [37]:
data_normalized = data_to_dgl_graph(pt_data, cell_data, normalize=True)

  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.

In [38]:
data_connectivity_normalized = data_to_dgl_graph(pt_data, cell_data, mode='connectivity', normalize=True)

  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
  features = np.

In [69]:
data[0][0]

DGLGraph(num_nodes=1164, num_edges=4376,
         ndata_schemes={'centroid': Scheme(shape=(2,), dtype=torch.float32), 'feat': Scheme(shape=(74,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.float32)}
         edata_schemes={})

In [70]:
model1 = CGModel(gnn_params=gnn_params, classification_params=classification_params, node_dim=74)

In [71]:
model2 = CGModel(gnn_params=gnn_params, classification_params=classification_params, node_dim=74)

In [72]:
model3 = CGModel(gnn_params=gnn_params, classification_params=classification_params, node_dim=74)

In [73]:
model4 = CGModel(gnn_params=gnn_params, classification_params=classification_params, node_dim=74)

In [95]:
index = [i for i in range(140)]
random.shuffle(index)

In [96]:
data = [data[i] for i in index]

In [97]:
data_normalized = [data_normalized[i] for i in index]

In [98]:
train_data = data[:int(0.7*len(data))]

In [99]:
val_data = data[int(0.7*len(data)):]

In [100]:
train_data_n = data_normalized[:int(0.7*len(data_normalized))]

In [101]:
val_data_n = data_normalized[int(0.7*len(data_normalized)):]

In [104]:
sum([i[1] for i in val_data])

9

In [105]:
sum([i[1] for i in train_data])

23

In [106]:
model1.train(train_data)

Loss=0.0742134377360344 | Val Accuracy=0.81: 100%|██████████| 50/50 [02:43<00:00,  3.27s/it]


([0.7690851092338562,
  0.6768689155578613,
  0.6784200072288513,
  0.6837761402130127,
  0.6159120798110962,
  0.6042351126670837,
  0.4625478982925415,
  0.6881422996520996,
  0.671578586101532,
  0.4934389293193817,
  0.6558221578598022,
  0.36913788318634033,
  0.311375230550766,
  0.543473482131958,
  0.3475882411003113,
  0.8856189846992493,
  0.2745167016983032,
  0.6793084144592285,
  0.44746026396751404,
  0.7713050842285156,
  0.38245290517807007,
  0.4539937376976013,
  0.5816279053688049,
  0.4315859079360962,
  0.32880041003227234,
  0.9481697082519531,
  0.38149407505989075,
  0.4808424711227417,
  0.25085386633872986,
  0.46043655276298523,
  0.8204863667488098,
  0.45188987255096436,
  0.5232120156288147,
  0.8172386884689331,
  0.3242320716381073,
  0.6124542355537415,
  0.46108993887901306,
  0.39130762219429016,
  0.24776208400726318,
  0.3441516160964966,
  0.44966983795166016,
  0.4313836395740509,
  0.8623785972595215,
  0.8348249197006226,
  0.26569539308547974,


In [107]:
model2.train(train_data, oversample_factor=3)

Loss=0.24526788294315338 | Val Accuracy=0.95: 100%|██████████| 50/50 [03:27<00:00,  4.14s/it]


([0.750299870967865,
  0.7750956416130066,
  0.6777176260948181,
  0.744937539100647,
  0.7255219221115112,
  0.7526204586029053,
  0.7503975033760071,
  0.6845007538795471,
  0.7701263427734375,
  0.7293089628219604,
  0.6695225238800049,
  0.6923144459724426,
  0.599391758441925,
  0.6311220526695251,
  0.8115617036819458,
  0.5265838503837585,
  0.6732302904129028,
  0.5546417832374573,
  0.626400887966156,
  0.5129075050354004,
  0.6075285077095032,
  0.5520399212837219,
  0.7383511662483215,
  0.594693660736084,
  0.7145194411277771,
  0.6170514822006226,
  0.7018414735794067,
  0.637908399105072,
  0.7786120176315308,
  0.46498069167137146,
  0.9260078072547913,
  0.7250997424125671,
  0.8283491134643555,
  0.5325964093208313,
  0.43606582283973694,
  0.5661289691925049,
  0.5587089657783508,
  0.7459912896156311,
  0.6792371273040771,
  1.2567111253738403,
  0.7070127129554749,
  0.7839927673339844,
  0.4762979745864868,
  0.5648226141929626,
  0.5050246119499207,
  0.5127388238

In [108]:
model3.train(train_data_n, oversample_factor=3)

Loss=0.45734965801239014 | Val Accuracy=1.0: 100%|██████████| 50/50 [03:19<00:00,  3.98s/it]


([0.7090036869049072,
  0.7152204513549805,
  0.7157958149909973,
  0.6982258558273315,
  0.6748698353767395,
  0.6964111924171448,
  0.6617436408996582,
  0.7180818915367126,
  0.6953784823417664,
  0.6406789422035217,
  0.6210229396820068,
  0.6950288414955139,
  0.5819463133811951,
  0.5076451897621155,
  0.6478569507598877,
  0.5704215168952942,
  0.43701741099357605,
  0.4971492886543274,
  0.38572749495506287,
  0.4349347651004791,
  0.8031170964241028,
  1.1639970541000366,
  0.6568372845649719,
  0.5631312727928162,
  0.5447168350219727,
  0.5871116518974304,
  0.5890299677848816,
  0.3608038127422333,
  0.3576294183731079,
  0.42346641421318054,
  0.4311385750770569,
  0.48179712891578674,
  0.6508729457855225,
  0.3970390260219574,
  0.5332486629486084,
  0.8880676627159119,
  0.30363941192626953,
  0.5143499970436096,
  0.28199416399002075,
  0.2519696354866028,
  0.5213218331336975,
  0.45842650532722473,
  0.40298694372177124,
  0.19821491837501526,
  0.17408360540866852,


In [126]:
val_dataloader = DataLoader(val_data_n, batch_size=42, shuffle=True, collate_fn=collate)

In [127]:
graphs, labels = next(iter(val_dataloader))

In [128]:
preds = model3.infer(graphs)

In [130]:
sum(labels == preds)

tensor(30)

In [131]:
len(preds)

42

In [132]:
30/42

0.7142857142857143