In [23]:
import dgl
from sklearn.neighbors import kneighbors_graph
import numpy as np
import pandas as pd
import torch

In [24]:
pt_data = pd.read_excel('./CRC_TMAs_patient_annotations.xlsx')
cell_data = pd.read_csv('./CRC_master.csv')

In [25]:
LABEL = "label"
CENTROID = "centroid"
FEATURES = "feat"
def data_to_dgl_graph(pt_data, cell_data, k=5, thresh=50, mode="distance", normalize=False):
    labs, ids = np.unique(cell_data.loc[:,'ClusterName'].to_numpy(), return_inverse=True)
    cell_data['ClusterID'] = ids
    graphs = []
    patients = []
    targets = []
    for spot in cell_data["spots"].unique():
        subset = cell_data.loc[cell_data.loc[:,'spots'] == spot,:]
        features = subset.loc[:,'size':'Treg-PD-1+'].to_numpy() # TODO: This step depends on column order
        if normalize:
            features = np.nan_to_num((features-features.mean(axis=0))/features.std(axis=0))
        centroids = subset.loc[:,'X':'Y'].to_numpy()
        annotation = subset.loc[:,'ClusterID'].to_numpy()
        num_nodes = features.shape[0]
        graph = dgl.DGLGraph()
        graph.add_nodes(num_nodes)
        graph.ndata[CENTROID] = torch.FloatTensor(centroids)
        graph.ndata[FEATURES] = torch.FloatTensor(features)
        if annotation is not None:
            graph.ndata[LABEL] = torch.FloatTensor(annotation.astype(float))
        adj = kneighbors_graph(
            centroids,
            k,
            mode=mode,
            include_self=False,
            metric="euclidean").toarray()
        if thresh is not None:
            adj[adj > thresh] = 0
        edge_list = np.nonzero(adj)
        graph.add_edges(list(edge_list[0]), list(edge_list[1]))
        graphs.append(graph)
        assert(len(subset['patients'].unique()) == 1)
        patients.append(subset['patients'].unique()[0])
    for pt in patients:
        cp = pt_data.loc[pt_data.loc[:,'Patient'] == pt,'cp_TNM_Simple'].values[0]
        assert(cp == 3.0 or cp == 4.0)
        t = 0 if cp == 3.0 else 1
        targets.append(t)
    return list(zip(graphs,targets))

In [26]:
data_distance = data_to_dgl_graph(pt_data, cell_data)

In [27]:
data_connectivity = data_to_dgl_graph(pt_data, cell_data, mode="connectivity")

In [6]:
import torch
from tqdm import trange
from histocartography.ml import CellGraphModel
import random
from torch.utils.data import DataLoader

def collate(batch):
    g = dgl.batch([example[0] for example in batch])
    l = torch.LongTensor([example[1] for example in batch])
    return g, l

def dataset_split(data, val_prop):
    if val_prop == 0:
        return data, data
    random.shuffle(data)
    train_data = data[:int(len(data)*val_prop)]
    val_data = data[int(len(data)*val_prop):]
    return train_data, val_data

#TODO: More sophisticated oversample
def oversample_positive(data, oversample_factor=2):
    negative = []
    positive = []
    for item in data:
        if item[1] == 0:
            negative.append(item)
        else:
            positive.append(item)
    positive = oversample_factor*positive
    return positive+negative

class CGModel():
    def __init__(self, gnn_params, classification_params, node_dim, num_classes=2, lr=10e-3, weight_decay=5e-4, num_epochs=50, batch_size=8):
        self.gnn_params = gnn_params
        self.classification_params = classification_params
        self.node_dim = node_dim
        self.num_classes = num_classes
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.cgm = CellGraphModel(gnn_params, classification_params, node_dim, num_classes=2)

    def train(self, data, val_prop=0, oversample_factor=1):
        optimizer = torch.optim.Adam(
            self.cgm.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )
        # define loss function
        loss_fn = torch.nn.CrossEntropyLoss()
        # training
        loss = 10e5
        val_accuracy = 0.
        train_dataloader = None
        val_dataloader = None
        loss_list = []
        val_accuracy_list = []
        train_data, val_data = dataset_split(data, val_prop)
        train_data = oversample_positive(train_data, oversample_factor=oversample_factor)
        train_dataloader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True, collate_fn=collate)
        val_dataloader = DataLoader(val_data, batch_size=self.batch_size, shuffle=True, collate_fn=collate)
        with trange(self.num_epochs) as t:
            for epoch in t:
                t.set_description('Loss={} | Val Accuracy={}'.format(loss, val_accuracy))
                self.cgm.train()
                for graphs, labels in train_dataloader:
                    logits = self.cgm(graphs)
                    loss = loss_fn(logits, labels)
                    loss_list.append(loss.item())
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                self.cgm.eval()
                all_val_logits = []
                all_val_labels = []
                for graphs, labels in val_dataloader:
                    with torch.no_grad():
                        logits = self.cgm(graphs)
                    all_val_logits.append(logits)
                    all_val_labels.append(labels)
                all_val_logits = torch.cat(all_val_logits).cpu()
                all_val_labels = torch.cat(all_val_labels).cpu()
                with torch.no_grad():
                    _, predictions = torch.max(all_val_logits, dim=1)
                    correct = torch.sum(predictions.to(int) == all_val_labels.to(int))
                    val_accuracy = round(correct.item() * 1.0 / len(all_val_labels), 2)
                    val_accuracy_list.append(val_accuracy)
        return loss_list, val_accuracy_list
    
    def infer(self, graphs):
        with torch.no_grad():
            logits = self.cgm(graphs)
            _, predictions = torch.max(logits, dim=1)
            return predictions
    
    def test(self, graphs, labels):
        predictions = self.infer(graphs)
        with torch.no_grad():
            correct = torch.sum(predictions.to(int) == labels.to(int))
            accuracy = round(correct.item() * 1.0 / len(labels), 2)
            return accuracy
    
    def save(self, model_path):
        #TODO: Save model and params to model path
        pass

    def load(self, model_path):
        #TODO: Load model cpt and params from model path
        pass

In [7]:
gnn_params = {
    'readout_op': 'concat',
    'layer_type': 'gin_layer',
    'output_dim': 32,
    'num_layers': 2,
    'readout_type': 'mean'
}
classification_params = {
    'hidden_dim': 20,
    'num_layers': 2
}

In [8]:
features = cell_data.loc[:,'size':'Treg-PD-1+'].to_numpy()

In [9]:
std = features.std(axis=0)

In [10]:
mean = features.mean(axis=0)

In [11]:
norm = (features-mean)/std

In [12]:
norm

array([[ 0.85247023, -0.81969919, -0.53165024, ..., -0.06755994,
        -0.04130121, -0.04153562],
       [-0.97254549, -0.73549292, -0.53007786, ..., -0.06755994,
        -0.04130121, -0.04153562],
       [ 0.08099259, -0.41216812, -0.15772312, ..., -0.06755994,
        -0.04130121, -0.04153562],
       ...,
       [-0.8197874 ,  0.33839227,  0.01888798, ..., -0.06755994,
        -0.04130121, -0.04153562],
       [-0.46519151, -0.39686416, -0.24040344, ..., -0.06755994,
        -0.04130121, -0.04153562],
       [ 0.24695686, -0.05728709, -0.55965065, ..., -0.06755994,
        -0.04130121, -0.04153562]])

In [14]:
subset = cell_data.loc[cell_data.loc[:,'spots'] == "1_A",:]

In [16]:
features = subset.loc[:,'size':'Treg-PD-1+'].to_numpy()

In [18]:
features.std(axis=0)

array([2.80316299e+03, 7.11740685e+01, 1.02455552e+02, 3.49946084e+02,
       5.65603834e+02, 2.62496114e+01, 1.48557536e+02, 1.15134739e+02,
       1.76485240e+02, 1.76445713e+01, 1.06407223e+02, 8.14510483e+00,
       1.27827952e+02, 2.39994481e+02, 1.15089871e+01, 8.77918635e+00,
       1.77025247e+02, 3.22486178e+01, 3.62872757e+01, 8.87892531e+01,
       1.57795074e+01, 5.85449549e+00, 1.40188555e+02, 4.11823975e+01,
       3.11412825e+00, 4.09298862e+02, 1.18326770e+02, 3.72497793e+01,
       1.57028497e+02, 2.65340507e+01, 3.58463394e+01, 4.81728326e+01,
       1.01512201e+01, 1.41964474e+02, 6.57819254e+01, 1.96262912e+01,
       5.53325022e+00, 2.84098839e+02, 2.76688754e+01, 4.35980413e+01,
       2.32691978e+02, 6.82050953e+00, 1.91373544e+03, 3.01795052e+02,
       3.49912562e+02, 3.62520359e+01, 2.54792174e+01, 5.81637926e+02,
       2.06596242e+02, 2.53188475e+00, 7.41416515e+02, 9.75567688e+01,
       9.19289297e+00, 5.71194626e+02, 7.97517047e+00, 6.63206233e+02,
      

In [20]:
norm = (features-features.mean(axis=0))/features.std(axis=0)

  norm = (features-features.mean(axis=0))/features.std(axis=0)


In [22]:
np.nan_to_num(norm)

array([[ 1.81810758, -1.09596514, -0.78998727, ..., -0.02932312,
         0.        ,  0.        ],
       [-1.48494687, -0.69636828, -0.78045111, ..., -0.02932312,
         0.        ,  0.        ],
       [ 0.42182773,  0.83795437,  1.47780854, ..., -0.02932312,
         0.        ,  0.        ],
       ...,
       [ 0.33335625, -0.74788513, -0.18652341, ..., -0.02932312,
         0.        ,  0.        ],
       [-1.56235942, -0.4196701 , -0.18126653, ..., -0.02932312,
         0.        ,  0.        ],
       [-0.67265019,  0.0446426 , -0.49388344, ..., -0.02932312,
         0.        ,  0.        ]])