In [61]:
import os, time
import torch
import torch_geometric
from datasets.BatchWSI import BatchWSI
from models.model_graph_mil import *
device = torch.device('cuda:0')

dataroot = './data/TCGA/BRCA/'
large_graph_pt = 'TCGA-BH-A0DV-01Z-00-DX1.2F0B5FB3-40F0-4D27-BFAC-390FB9A42B39.pt' # example input

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### Graph Data Structure
- `N`: number of patches
- `M`: number of edges
- `centroid`: [N x 2] matrix containing centroids for each patch
- `edge_index`: [2 x M] matrix containing edges between patches (connected via adjacent spatial coordinates)
- `edge_latent`: [2 x M] matric containing edges between patches (connected via latent space)
- `x`: [N x 1024] matrix which uses 1024-dim extracted ResNet features for each iamge patch (features saved for simplicity)

In [34]:
data = torch.load(os.path.join(dataroot, large_graph_pt))
data

Data(centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], x=[23049, 1024])

In PyTorch Geometric, inference on large graphs is very tractable. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension

This procedure has some crucial advantages over other batching procedures:

- GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs.

- There is no computational or memory overhead. For example, this batching procedure works completely without any padding of node or edge features. Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, i.e., the edges. 
- For more details, see the advanced mini-batching FAQ in: https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html

In [36]:
data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), 
                                torch.load(os.path.join(dataroot, large_graph_pt))])
data

BatchWSI(batch=[46098], centroid=[46098, 2], edge_index=[2, 322686], edge_latent=[4, 161343], ptr=[3], x=[46098, 1024])

### Inference + Backprop using 23K patches

In [66]:
data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt))])
data = data.to(device)
data

BatchWSI(batch=[23049], centroid=[23049, 2], edge_index=[2, 161343], edge_latent=[2, 161343], ptr=[2], x=[23049, 1024])

In [67]:
model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}
model = PatchGCN_Surv(**model_dict).to(device)
print("Number of Parameters:", count_parameters(model))

### Example Forward Paas + Gradient Backprop
start = time.time()
out = model(x_path=data)
out[0].backward()
print('Time Elapsed: %0.5f seconds' % (time.time() - start))

Number of Parameters: 1382917
Time Elapsed: 0.06325 seconds


### Inference + Backprop using 92K patches

In [50]:
### Simulating a very large graph (containing 4 subgraphs of 23K patches each)
data = BatchWSI.from_data_list([torch.load(os.path.join(dataroot, large_graph_pt)), 
                                torch.load(os.path.join(dataroot, large_graph_pt)),
                                torch.load(os.path.join(dataroot, large_graph_pt)),
                                torch.load(os.path.join(dataroot, large_graph_pt))])
data = data.to(device)
data

BatchWSI(batch=[92196], centroid=[92196, 2], edge_index=[2, 645372], edge_latent=[8, 161343], ptr=[5], x=[92196, 1024])

In [55]:
model_dict = {'num_layers': 4, 'edge_agg': 'spatial', 'resample': 0, 'n_classes': 1}
model = PatchGCN_Surv(**model_dict).to(device)
print("Number of Parameters:", count_parameters(model))

### Example Forward Paas + Gradient Backprop
start = time.time()
out = model(x_path=data)
out[0].backward()
print('Time Elapsed: %0.5f seconds' % (time.time() - start))

Number of Parameters: 1382917
Time Elapsed: 0.20629 seconds


Assuming worst case scenario that every graph has ~100K patches, for a dataset of 1000 WSIs, an epoch would take 3.43 minutes, with 20 epochs taking ~ 1 hour.