In [1]:
import torch
import torch.nn as nn
import torch_geometric
import torch_geometric.nn as geo_nn
from torch_geometric.data import Data

In [2]:
from torchvision import datasets, transforms
from torch.nn.functional import binary_cross_entropy

from sklearn.metrics.pairwise import cosine_similarity as cos_sim
from skimage import color

import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import math

from util.image import get_patches
from util.graph import get_adj
from util.CustomDatasets import ImageFolderWithPaths as IFWP

# Decompression Bomb error
import PIL
PIL.Image.MAX_IMAGE_PIXELS = None

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
ENCODER_PATH = './models/encoder.pth'
TRANSFORM_PATH = './models/transform.pth'
encoder = torch.load(ENCODER_PATH).to(device)
transform = torch.load(TRANSFORM_PATH)

transform = transforms.Compose([
    transforms.ToPILImage(), # Convert numpy array to PIL image
    transform,
])
transform

Compose(
    ToPILImage()
    Compose(
    Resize(size=224, interpolation=bilinear, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
)
)

In [4]:
MAIN_PATH = '/groups/francescavitali/eb2/subImages1/H&E/'

# Full dataset
dataset = IFWP(MAIN_PATH, device=device)

In [5]:
loader = torch.utils.data.DataLoader(dataset = dataset,
                                            batch_size = 1)

In [6]:
WHITE_IMG = color.rgb2gray(np.zeros((224, 224, 3)) + 255)

In [None]:
with open('progress.txt', 'w') as file:
    file.write(f"Starting graph creation:\n")

THRESHOLD = 1 - 1e-3 # by inspection seems good ._.

all_patients_dict = {}
img_idx = 0
prev_id = None
label = None

# For each encoding in the loader
for img, label, path in loader:
    
    img = img.squeeze(0)
    label = label.to(device)
    
    # Extract the id
    img_id = dataset.get_main_img_id(img_idx)[:-4]
    
    print(f'\n{img_id} ({img_idx+1}/{len(loader)}):\n')
    with open('progress.txt', 'a') as file:
        file.write(f'\n{img_id} ({img_idx+1}/{len(loader)}):\n')
    
    img_idx += 1
    patches, num_w, num_h = get_patches(img)
    
    print(f'{patches.shape=} , {num_w=}, {num_h=}')
    with open('progress.txt', 'a') as file:
        file.write(f'{patches.shape=} , {num_w=}, {num_h=}\n')
    
    enc_list = []
    
    valid_idxs = []
    # Loop over all the patches
    for i in range(num_h):
        for j in range(num_w):
            idx = i * num_w + j
                
            gray_scale = color.rgb2gray(patches[idx])*255

            cos_sim_val = cos_sim([gray_scale.reshape(224*224)], [WHITE_IMG.reshape(224*224)]).item()

            # Ignore images that are similar to white image
            if cos_sim_val > THRESHOLD:
                continue
            else:
                valid_idxs.append(idx)
                
                with torch.no_grad():
                    enc = encoder(transform(patches[idx]).unsqueeze(dim=0).to(device))
                    enc_list.append(enc.squeeze(0).cpu())
            

    # Mapping orig -> new to lookup new idxs
    adjusted_idxs = {orig: new for new, orig in enumerate(valid_idxs)}
    
    edge_list = []
    
    # Loop through and add the correct values to edge_list
    for orig in sorted(valid_idxs):
        adj = get_adj(orig, num_w, num_h)
        
        temp = []
        for neighbor in adj:
            if neighbor in adjusted_idxs:
                temp.append(adjusted_idxs[neighbor])
        edge_list.append(temp)
    
    # Progress debugging
    if img_id in all_patients_dict:
        print(f'{img_id} already in patients dict')
        with open('progress.txt', 'a') as file:
            file.write(f'{img_id} already in patients dict')
            
    # Create the COO format for the edges     
    src = []
    dest = []
    for i, neighbors in enumerate(edge_list):
        for neighbor in neighbors:
            src.append(i)
            dest.append(neighbor)
    
    # Save the information in the dict
    all_patients_dict[img_id] = {'enc': torch.stack(enc_list), 'label': label, 'edge_list': [src, dest]}
    
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
            
print(f'Amt of imgs: {len(dataset)}')


001_01 (1/506):

Original shape: torch.Size([6761, 11420, 3])
New shape: (6944, 11424, 3)
patches.shape=(1581, 224, 224, 3) , num_w=51, num_h=31

001_02 (2/506):

Original shape: torch.Size([6706, 12679, 3])
New shape: (6720, 12768, 3)
patches.shape=(1710, 224, 224, 3) , num_w=57, num_h=30

001_03 (3/506):

Original shape: torch.Size([4672, 8125, 3])
New shape: (4704, 8288, 3)
patches.shape=(777, 224, 224, 3) , num_w=37, num_h=21

001_04 (4/506):

Original shape: torch.Size([5625, 5342, 3])
New shape: (5824, 5376, 3)
patches.shape=(624, 224, 224, 3) , num_w=24, num_h=26

001_05 (5/506):

Original shape: torch.Size([4844, 6923, 3])
New shape: (4928, 6944, 3)
patches.shape=(682, 224, 224, 3) , num_w=31, num_h=22

001_06 (6/506):

Original shape: torch.Size([6103, 7000, 3])
New shape: (6272, 7168, 3)
patches.shape=(896, 224, 224, 3) , num_w=32, num_h=28

002_01 (7/506):

Original shape: torch.Size([6445, 15704, 3])
New shape: (6496, 15904, 3)
patches.shape=(2059, 224, 224, 3) , num_w=71,

patches.shape=(2205, 224, 224, 3) , num_w=45, num_h=49

013_04 (58/506):

Original shape: torch.Size([10334, 11664, 3])
New shape: (10528, 11872, 3)
patches.shape=(2491, 224, 224, 3) , num_w=53, num_h=47

013_05 (59/506):

Original shape: torch.Size([11410, 12648, 3])
New shape: (11424, 12768, 3)
patches.shape=(2907, 224, 224, 3) , num_w=57, num_h=51

014_01 (60/506):

Original shape: torch.Size([8526, 7003, 3])
New shape: (8736, 7168, 3)
patches.shape=(1248, 224, 224, 3) , num_w=32, num_h=39

014_02 (61/506):

Original shape: torch.Size([7893, 6192, 3])
New shape: (8064, 6272, 3)
patches.shape=(1008, 224, 224, 3) , num_w=28, num_h=36

014_03 (62/506):

Original shape: torch.Size([7586, 7550, 3])
New shape: (7616, 7616, 3)
patches.shape=(1156, 224, 224, 3) , num_w=34, num_h=34

014_04 (63/506):

Original shape: torch.Size([8116, 8174, 3])
New shape: (8288, 8288, 3)
patches.shape=(1369, 224, 224, 3) , num_w=37, num_h=37

014_05 (64/506):

Original shape: torch.Size([7533, 7021, 3])
New 

In [10]:
all_patients_dict.keys()

dict_keys(['001_01', '001_02', '001_03', '001_04', '001_05', '001_06', '002_01', '002_02', '002_03', '002_04', '003_01', '003_02', '004_01', '004_02', '004_03', '004_04', '005_01', '005_02', '005_03', '005_04', '006_01', '006_02', '006_03', '006_04', '006_05', '007_01', '007_02', '007_03', '007_04', '008_01', '008_02', '008_03', '008_04', '008_05', '009_01', '009_02', '009_03', '009_04', '009_05', '010_01', '010_02', '010_03', '010_04', '010_05', '010_06', '011_01', '011_02', '011_03', '011_04', '012_01', '012_02', '012_03', '012_04', '012_05', '013_01', '013_02', '013_03', '013_04', '013_05', '014_01', '014_02', '014_03', '014_04', '014_05', '014_06', '015_01', '015_02', '015_03', '015_04', '015_05', '015_06', '016_01', '016_02', '016_03', '016_04', '016_05', '016_06', '016_07', '017_01', '017_02', '017_03', '017_04', '017_05', '017_06', '018_01', '018_02', '018_03', '018_04', '018_05', '018_06', '018_07', '018_08', '018_09', '018_10', '018_11', '018_12', '019_01', '019_02', '019_03',

In [None]:
torch.save(all_patients_dict, './data/adj_graph/img_dict.pth')

In [11]:
temp = all_patients_dict

In [12]:
all_patients_dict = torch.load('./data/adj_graph/img_dict.pth')

In [14]:
invalid_patients = []
for patient in all_patients_dict.keys():
    src, dest = all_patients_dict[patient]['edge_list']
    all_imgs = all_patients_dict[patient]['enc']

    edge_index = torch.tensor([src, dest], dtype=torch.long)
    x = all_imgs

    data = Data(x=x, edge_index=edge_index.contiguous()).to(device)
    
#     print(f'{patient}: {data.validate(raise_on_error=True)=}')
    
    if not data.validate(raise_on_error=True):
        invalid_patients.append(patient)

print(f'{invalid_patients=}')

invalid_patients=[]
