In [1]:
import torch
import torch.nn as nn
import torch_geometric
import torch_geometric.nn as geo_nn
from torch_geometric.data import Data

In [2]:
from torchvision import datasets, transforms
from torch.nn.functional import binary_cross_entropy
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

from util.Classifier import Classifier
from util.ImageFolderWithPathsEncoder import ImageFolderWithPathsEncoder as IFWPE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
ENCODER_PATH = './models/encoder.pth'
TRANSFORM_PATH = './models/transform.pth'
encoder = torch.load(ENCODER_PATH).to(device)
transform = torch.load(TRANSFORM_PATH)

In [4]:
MAIN_PATH = '/groups/francescavitali/eb2/NewsubSubImages4/H&E'

# Full dataset
dataset = IFWPE(MAIN_PATH, transform=transform, encoder=encoder, device=device)

In [5]:
loader = torch.utils.data.DataLoader(dataset = dataset,
                                            batch_size = 1)

In [None]:
all_patients_dict = {}
i = 0
prev_id = None
label = None

# For each encoding in the loader
for enc, label, _ in loader:
    
    label = label.to(device)
    
    # Extract the id
    img_id = dataset.get_main_img_id(i)
    
    # Add to dictionary if its not existing
    if not img_id in all_patients_dict:
        
        # Reset the previous id to a torch tensor
        if prev_id:
            all_patients_dict[prev_id] = {'enc': torch.cat(all_patients_dict[prev_id]), 'label': label.unsqueeze(1)}
        all_patients_dict[img_id] = []
    
    # Add the encoded img to the relevent dictionary position
    all_patients_dict[img_id].append(enc.squeeze(0))
    
    # Increment
    i += 1
    
    # Break early
    if(i == -1):
        break
        
    # Update previous img
    prev_id = img_id
    
all_patients_dict[prev_id] = {'enc': torch.cat(all_patients_dict[prev_id]), 'label': label.unsqueeze(1)}
print(f'Amt of imgs: {i}')

In [36]:
all_patients_dict

{'01-001_01': {'enc': tensor([[ 0.4025,  0.5792, -0.1650,  ...,  0.7040, -0.0979, -0.4558],
          [ 0.5433,  0.4107, -0.1429,  ...,  0.5656, -0.2066,  0.0991],
          [ 0.5123,  0.0709, -0.0997,  ...,  0.3540,  0.0235,  0.1223],
          ...,
          [ 0.3152,  0.4982, -0.5094,  ...,  0.0213, -0.1469,  0.2509],
          [-0.0339,  1.1082,  0.2197,  ..., -0.3705,  0.1844, -0.0451],
          [ 0.2415,  0.2128,  0.0727,  ..., -0.2213,  0.2826,  0.1858]],
         device='cuda:0'),
  'label': tensor([[0]], device='cuda:0')}}

In [8]:
all_patients_dict.keys()

dict_keys(['01-001_01', '02-001_02', '03-001_06', '04-001_03', '01-002_01', '02-002_03', '03-002_02', '04-002_04', '01-004_04', '02-004_01', '03-004_03', '04-004_02', '01-005_01', '02-005_03', '03-005_04', '04-005_02', '01-006_03', '02-006_01', '03-006_05', '04-006_02', '01-007_04', '02-007_03', '03-007_01', '04-007_02', '01-008_04', '02-008_03', '03-008_05', '04-008_01', '01-009_04', '02-009_05', '03-009_01', '04-009_03', '01-010_02', '02-010_04', '03-010_03', '04-010_05', '01-011_02', '02-011_03', '03-011_01', '04-011_04', '01-012_01', '02-012_02', '03-012_04', '04-012_03', '01-013_02', '02-013_04', '03-013_03', '04-013_05', '01-014_03', '02-014_04', '03-014_01', '04-014_06', '01-015_02', '02-015_01', '03-015_05', '04-015_06', '01-016_07', '02-016_01', '03-016_06', '04-016_05', '01-017_06', '02-017_04', '03-017_02', '04-017_05', '01-018_08', '02-018_04', '03-018_07', '04-018_06', '01-019_07', '02-019_06', '03-019_04', '04-019_01', '01-020_07', '02-020_06', '03-020_09', '04-020_04', '

In [22]:
torch.save(all_patients_dict, './data/img_dict.pth')

In [3]:
all_patients_dict = torch.load('./data/img_dict.pth')

In [4]:
all_edges = {}
for key, all_imgs in all_patients_dict.items():
    edges = []

    # Make a K_n graph
    for i in range(all_imgs.shape[0]):
        for j in range(all_imgs.shape[0]):
            if(i != j):
                edges.append([i, j])
    all_edges[key] = edges
print('done')

done


In [6]:
torch.save(all_edges, './data/edges.pth')

In [11]:
edge_index = torch.tensor(edges, dtype=torch.long)
x = all_imgs

data = Data(x=x, edge_index=edge_index.t().contiguous()).to(device)

In [12]:
data

Data(x=[157, 1536], edge_index=[2, 24492])

In [13]:
data.validate(raise_on_error=True)

True

In [14]:
print('Data:')
print(data.keys())
print(data['x'])

Data:
['edge_index', 'x']
tensor([[ 0.5633,  0.1086,  0.9908,  ...,  0.2653, -0.5071,  0.6718],
        [ 0.5297, -0.0064, -0.0167,  ...,  0.1558, -0.9215,  0.0542],
        [ 0.5419,  0.1699, -0.3524,  ..., -0.2456, -0.6731,  0.1582],
        ...,
        [ 0.1839,  0.2731,  0.1263,  ..., -0.4063, -0.4424,  0.0840],
        [-0.4735, -0.0241,  0.3040,  ..., -0.6489, -0.4606,  0.1246],
        [-0.0177,  0.2882,  0.1080,  ..., -0.9171, -0.4901,  0.3214]],
       device='cuda:0')


In [15]:
print('Data:')
print(f'{data.num_nodes=}')
print(f'{data.num_edges=}')
print(f'{data.num_node_features=}')

Data:
data.num_nodes=157
data.num_edges=24492
data.num_node_features=1536


In [16]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = geo_nn.GCNConv(1536, 200)  # First layer: data.num_node_features input features, some amt of output features y
        self.conv2 = geo_nn.GCNConv(200, 2)  # Second layer: y input features, num of classes

    def forward(self, data):
        # Extract Relevant Info
        x, edge_index = data.x, data.edge_index
        
        # Pass through GCN
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x= nn.functional.softmax(x, dim=1)
        return x

In [17]:
model = GCN().to(device)

In [18]:
model

GCN(
  (conv1): GCNConv(1536, 200)
  (conv2): GCNConv(200, 2)
)

In [19]:
y = model(data)

In [20]:
y.shape

torch.Size([157, 2])