In [2]:
import torch
import torch.nn as nn
import torch_geometric
import torch_geometric.nn as geo_nn
from torch_geometric.data import Data

In [4]:
from torchvision import datasets, transforms
from torch.nn.functional import binary_cross_entropy
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

from util.Classifier import Classifier
from util.CustomDatasets import ImageFolderWithPathsEncoder as IFWPE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
ENCODER_PATH = './models/encoder.pth'
TRANSFORM_PATH = './models/transform.pth'
encoder = torch.load(ENCODER_PATH).to(device)
transform = torch.load(TRANSFORM_PATH)

In [4]:
MAIN_PATH = '/groups/francescavitali/eb2/NewsubSubImages4/H&E'

# Full dataset
dataset = IFWPE(MAIN_PATH, transform=transform, encoder=encoder, device=device)

In [5]:
loader = torch.utils.data.DataLoader(dataset = dataset,
                                            batch_size = 1)

In [40]:
all_patients_dict = {}
i = 0
prev_id = None
label = None

# For each encoding in the loader
for enc, label, _ in loader:
    
    label = label.to(device)
    
    # Extract the id
    img_id = dataset.get_main_img_id(i)
    
    # Add to dictionary if its not existing
    if not img_id in all_patients_dict:
        
        # Reset the previous id to a torch tensor
        if prev_id:
            all_patients_dict[prev_id] = {'enc': torch.cat(all_patients_dict[prev_id]), 'label': label.unsqueeze(1)}
        all_patients_dict[img_id] = []
    
    # Add the encoded img to the relevent dictionary position
    all_patients_dict[img_id].append(enc.squeeze(0))
    
    # Increment
    i += 1
    
    # Break early
    if(i == -1):
        break
        
    # Update previous img
    prev_id = img_id
    
all_patients_dict[prev_id] = {'enc': torch.cat(all_patients_dict[prev_id]), 'label': label.unsqueeze(1)}
print(f'Amt of imgs: {i}')

Amt of imgs: 69709


In [41]:
all_patients_dict

{'01-001_01': {'enc': tensor([[ 0.4025,  0.5792, -0.1650,  ...,  0.7040, -0.0979, -0.4558],
          [ 0.5433,  0.4107, -0.1429,  ...,  0.5656, -0.2066,  0.0991],
          [ 0.5123,  0.0709, -0.0997,  ...,  0.3540,  0.0235,  0.1223],
          ...,
          [ 0.0508,  0.0082,  0.1061,  ..., -0.2907, -0.3126, -0.1858],
          [-0.0865,  0.4048, -0.3115,  ..., -0.3605, -0.3349,  0.0182],
          [-0.1220,  0.1882,  0.1595,  ..., -0.3329, -0.0069,  0.0459]],
         device='cuda:0'),
  'label': tensor([[0]], device='cuda:0')},
 '02-001_02': {'enc': tensor([[-0.6209, -0.3550,  0.2807,  ..., -0.7395,  0.3017,  0.3689],
          [-0.0622,  0.3283,  0.3166,  ..., -0.5390,  0.4285, -0.2437],
          [ 0.2917,  0.1726,  0.1924,  ...,  0.0203, -0.4217,  0.0728],
          ...,
          [-0.6083, -0.1742,  0.3662,  ..., -0.7483,  0.2430,  0.4428],
          [-0.7956, -0.1573,  0.1873,  ..., -0.7796,  0.1353,  0.0455],
          [-0.8194, -0.2087,  0.2807,  ..., -0.5360,  0.2942,  0.1

In [44]:
all_patients_dict.keys()

dict_keys(['01-001_01', '02-001_02', '03-001_06', '04-001_03', '01-002_01', '02-002_03', '03-002_02', '04-002_04', '01-004_04', '02-004_01', '03-004_03', '04-004_02', '01-005_01', '02-005_03', '03-005_04', '04-005_02', '01-006_03', '02-006_01', '03-006_05', '04-006_02', '01-007_04', '02-007_03', '03-007_01', '04-007_02', '01-008_04', '02-008_03', '03-008_05', '04-008_01', '01-009_04', '02-009_05', '03-009_01', '04-009_03', '01-010_02', '02-010_04', '03-010_03', '04-010_05', '01-011_02', '02-011_03', '03-011_01', '04-011_04', '01-012_01', '02-012_02', '03-012_04', '04-012_03', '01-013_02', '02-013_04', '03-013_03', '04-013_05', '01-014_03', '02-014_04', '03-014_01', '04-014_06', '01-015_02', '02-015_01', '03-015_05', '04-015_06', '01-016_07', '02-016_01', '03-016_06', '04-016_05', '01-017_06', '02-017_04', '03-017_02', '04-017_05', '01-018_08', '02-018_04', '03-018_07', '04-018_06', '01-019_07', '02-019_06', '03-019_04', '04-019_01', '01-020_07', '02-020_06', '03-020_09', '04-020_04', '

In [45]:
torch.save(all_patients_dict, './data/img_dict.pth')

In [46]:
all_patients_dict = torch.load('./data/img_dict.pth')

In [4]:
all_edges = {}
for key, all_imgs in all_patients_dict.items():
    edges = []

    # Make a K_n graph
    for i in range(all_imgs.shape[0]):
        for j in range(all_imgs.shape[0]):
            if(i != j):
                edges.append([i, j])
    all_edges[key] = edges
print('done')

done


In [6]:
torch.save(all_edges, './data/edges.pth')

In [7]:
all_edges = torch.load('./data/K_graph/edges.pth')

In [8]:
all_edges['01-001_01']

[[0, 1],
 [0, 2],
 [0, 3],
 [0, 4],
 [0, 5],
 [0, 6],
 [0, 7],
 [0, 8],
 [0, 9],
 [0, 10],
 [0, 11],
 [0, 12],
 [0, 13],
 [0, 14],
 [0, 15],
 [0, 16],
 [0, 17],
 [0, 18],
 [0, 19],
 [0, 20],
 [0, 21],
 [0, 22],
 [0, 23],
 [0, 24],
 [0, 25],
 [0, 26],
 [0, 27],
 [0, 28],
 [0, 29],
 [0, 30],
 [0, 31],
 [0, 32],
 [0, 33],
 [0, 34],
 [0, 35],
 [0, 36],
 [0, 37],
 [0, 38],
 [0, 39],
 [0, 40],
 [0, 41],
 [0, 42],
 [0, 43],
 [0, 44],
 [0, 45],
 [0, 46],
 [0, 47],
 [0, 48],
 [0, 49],
 [0, 50],
 [0, 51],
 [0, 52],
 [0, 53],
 [0, 54],
 [0, 55],
 [0, 56],
 [0, 57],
 [0, 58],
 [0, 59],
 [0, 60],
 [0, 61],
 [0, 62],
 [0, 63],
 [0, 64],
 [0, 65],
 [0, 66],
 [0, 67],
 [0, 68],
 [0, 69],
 [0, 70],
 [0, 71],
 [0, 72],
 [0, 73],
 [0, 74],
 [0, 75],
 [0, 76],
 [0, 77],
 [0, 78],
 [0, 79],
 [0, 80],
 [0, 81],
 [0, 82],
 [0, 83],
 [0, 84],
 [0, 85],
 [0, 86],
 [0, 87],
 [0, 88],
 [0, 89],
 [0, 90],
 [0, 91],
 [0, 92],
 [0, 93],
 [0, 94],
 [0, 95],
 [0, 96],
 [0, 97],
 [0, 98],
 [0, 99],
 [0, 100],
 [0, 101

In [53]:
all_patients_dict[patient]['enc']

tensor([[ 0.4025,  0.5792, -0.1650,  ...,  0.7040, -0.0979, -0.4558],
        [ 0.5433,  0.4107, -0.1429,  ...,  0.5656, -0.2066,  0.0991],
        [ 0.5123,  0.0709, -0.0997,  ...,  0.3540,  0.0235,  0.1223],
        ...,
        [ 0.0508,  0.0082,  0.1061,  ..., -0.2907, -0.3126, -0.1858],
        [-0.0865,  0.4048, -0.3115,  ..., -0.3605, -0.3349,  0.0182],
        [-0.1220,  0.1882,  0.1595,  ..., -0.3329, -0.0069,  0.0459]],
       device='cuda:0')

In [1]:
patient = list(all_patients_dict.keys())[0]
edges = all_edges[patient]
all_imgs = all_patients_dict[patient]['enc']

edge_index = torch.tensor(edges, dtype=torch.long)
x = all_imgs

data = Data(x=x, edge_index=edge_index.t().contiguous()).to(device)

NameError: name 'all_patients_dict' is not defined

In [55]:
data

Data(x=[183, 1536], edge_index=[2, 33306])

In [56]:
data.validate(raise_on_error=True)

True

In [57]:
print('Data:')
print(data.keys())
print(data['x'])

Data:
['x', 'edge_index']
tensor([[ 0.4025,  0.5792, -0.1650,  ...,  0.7040, -0.0979, -0.4558],
        [ 0.5433,  0.4107, -0.1429,  ...,  0.5656, -0.2066,  0.0991],
        [ 0.5123,  0.0709, -0.0997,  ...,  0.3540,  0.0235,  0.1223],
        ...,
        [ 0.0508,  0.0082,  0.1061,  ..., -0.2907, -0.3126, -0.1858],
        [-0.0865,  0.4048, -0.3115,  ..., -0.3605, -0.3349,  0.0182],
        [-0.1220,  0.1882,  0.1595,  ..., -0.3329, -0.0069,  0.0459]],
       device='cuda:0')


In [58]:
print('Data:')
print(f'{data.num_nodes=}')
print(f'{data.num_edges=}')
print(f'{data.num_node_features=}')

Data:
data.num_nodes=183
data.num_edges=33306
data.num_node_features=1536


In [59]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = geo_nn.GCNConv(1536, 200)  # First layer: data.num_node_features input features, some amt of output features y
        self.conv2 = geo_nn.GCNConv(200, 2)  # Second layer: y input features, num of classes

    def forward(self, data):
        # Extract Relevant Info
        x, edge_index = data.x, data.edge_index
        
        # Pass through GCN
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x= nn.functional.softmax(x, dim=1)
        return x

In [60]:
model = GCN().to(device)

In [61]:
model

GCN(
  (conv1): GCNConv(1536, 200)
  (conv2): GCNConv(200, 2)
)

In [62]:
y = model(data)

In [63]:
y.shape

torch.Size([183, 2])