In [2]:
import pandas as pd
import openpyxl
import h5py
import cv2
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from torchvision.transforms import v2
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.decomposition import PCA
import sys
import os

import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


## Experiment hyperparameters

In [14]:
is_local = True # todo

# Experiment
seed = 3 if is_local else int(sys.argv[-2])
torch.manual_seed(seed)
image_size = 256

# Data: which wavenumbers are even allowed to be considered?
wv_start = 0
wv_end = 965

# Data loading
test_set_fraction = 0.2
val_set_fraction = 0.2
batch_size= 64
patch_dim = 101
use_augmentation = True

# Network
dropout_p=0.5

# Training schedule
lr = 1e-5
l2 = 5e-1
max_iters = 5000
pseudo_epoch = 100

# dimensionality reduction parameters
r_method = 'linear' # {'linear','pca,'fixed'} # todo change to linear
reduce_dim = 64 if is_local else int(sys.argv[-1]) # used only for r_method = 'pca' or 'linear'
channels_used = np.s_[...,wv_start:wv_end] # used only when r_method = 'fixed'
print(channels_used)

(Ellipsis, slice(0, 965, None))


In [4]:
def csf_fp(filepath):
    return filepath.replace('D:/datasets','D:/datasets' if is_local else './')

master = pd.read_excel(csf_fp(rf'D:/datasets/pcuk2023_ftir_whole_core/master_sheet.xlsx'))
slide = master['slide'].to_numpy()
patient_id = master['patient_id'].to_numpy()
hdf5_filepaths = np.array([csf_fp(fp) for fp in master['hdf5_filepath']])
annotation_filepaths = np.array([csf_fp(fp) for fp in master['annotation_filepath']])
mask_filepaths = np.array([csf_fp(fp) for fp in master['mask_filepath']])
wavenumbers = np.load(csf_fp(f'D:/datasets/pcuk2023_ftir_whole_core/wavenumbers.npy'))[wv_start:wv_end]
wavenumbers_used = wavenumbers[channels_used]

annotation_class_colors = np.array([[0,255,0],[128,0,128],[255,0,255],[0,0,255],[255,165,0],[255,0,0]])
annotation_class_names = np.array(['epithelium_n','stroma_n','epithelium_c','stroma_c','corpora_amylacea','blood'])
n_classes = len(annotation_class_names)
print(f"Loaded {len(slide)} cores")
print(f"Using {len(wavenumbers_used)}/{len(wavenumbers)} wavenumbers")

Loaded 228 cores
Using 965/965 wavenumbers


## Define datasets and loaders

In [5]:
unique_pids = np.unique(patient_id)
pids_trainval, pids_test, _, _ = train_test_split(
    unique_pids, np.zeros_like(unique_pids), test_size=test_set_fraction, random_state=seed)
pids_train, pids_val, _, _ = train_test_split(
    pids_trainval, np.zeros_like(pids_trainval), test_size=(val_set_fraction/(1-test_set_fraction)), random_state=seed)
where_train = np.where(np.isin(patient_id,pids_train))
where_val = np.where(np.isin(patient_id,pids_val))
where_test = np.where(np.isin(patient_id,pids_test))
print(f"Patients per data split:\n\tTRAIN: {len(where_train[0])}\n\tVAL: {len(where_val[0])}\n\tTEST: {len(where_test[0])}")

Patients per data split:
	TRAIN: 139
	VAL: 48
	TEST: 41


In [15]:
class ftir_annot_dataset(torch.utils.data.Dataset):
    def __init__(self, 
                 hdf5_filepaths, mask_filepaths, annotation_filepaths, channels_use):
        self.hdf5_filepaths = hdf5_filepaths
        self.mask_filepaths = mask_filepaths
        self.annotation_filepaths = annotation_filepaths
        self.channels_use = channels_use
        
        # class data
        self.annotation_class_colors = annotation_class_colors
        self.annotation_class_names = annotation_class_names
        
    def __len__(self):
        return len(self.hdf5_filepaths)
    
    # split annotations from H x W x 3 to C x H x W, one/zerohot along C dimension
    def split_annotations(self,annotations_img):
        split = torch.zeros((len(self.annotation_class_colors),*annotations_img.shape[:-1]))
        for c,col in enumerate(self.annotation_class_colors):
            split[c,:,:] = torch.from_numpy(np.all(annotations_img == self.annotation_class_colors[c],axis=-1)) 
        return split
        
    def __getitem__(self, idx):    
        
        # open hdf5 file
        hdf5_file = h5py.File(self.hdf5_filepaths[idx],'r')
        
        # get mask
        mask = torch.from_numpy(
            hdf5_file['mask'][:],
        ).unsqueeze(0)
        
        # get ftir
        ftir = torch.from_numpy(
            hdf5_file['spectra'][*self.channels_use],
        ).permute(2,0,1)
        hdf5_file.close()
        ftir *= mask
        
        # get annotations
        annotations = self.split_annotations(cv2.imread(self.annotation_filepaths[idx])[:,:,::-1])
        annotations *= mask
        
        
        return ftir, annotations, mask

In [16]:
dataset_train = ftir_annot_dataset(
    hdf5_filepaths[where_train], mask_filepaths[where_train], annotation_filepaths[where_train], channels_used,
)
dataset_val = ftir_annot_dataset(
    hdf5_filepaths[where_val], mask_filepaths[where_val], annotation_filepaths[where_val], channels_used,
)
dataset_test = ftir_annot_dataset(
    hdf5_filepaths[where_test], mask_filepaths[where_test], annotation_filepaths[where_test], channels_used,
)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size,shuffle=False)

## Define dimensionality reduction methods

In [17]:
class LinearReduction(nn.Module):
    def __init__(self,input_dim,reduce_dim):
        super().__init__()
        self.reduce_dim = reduce_dim
        self.input_norm = nn.BatchNorm2d(input_dim)
        self.projection = nn.Conv2d(input_dim,reduce_dim,kernel_size=1,stride=1) 
        self.projection_norm = nn.BatchNorm2d(reduce_dim)
    
    def forward(self,x):
        return self.projection_norm(self.projection(self.input_norm(x)))
    
class PCAReduce(nn.Module): 
    def __init__(self,reduce_dim,means,loadings):
        super().__init__()
        self.reduce_dim = reduce_dim
        self.register_buffer('means', torch.from_numpy(means).float().reshape(1,-1,1,1))
        self.register_buffer('loadings', torch.from_numpy(loadings).float())
    
    def forward(self,x):
        projected = x - self.means
        
        b,c,h,w = projected.shape
        projected = projected.permute(0,2,3,1).reshape(b,h*w,c)
        projected = torch.matmul(projected, self.loadings.T)
        projected = projected.reshape(b,h,w,self.reduce_dim).permute(0,3,1,2)
        
        return projected
        
class FixedReduction(nn.Module):
    def __init__(self,input_dim):
        super().__init__()
        self.input_norm = nn.BatchNorm2d(input_dim)
    
    def forward(self,x):
        return self.input_norm(x)

if r_method == 'pca':
    spectral_sample = []
    batch_samples = 0
    for data,label in train_loader:
        spectral_sample.append(data[...,patch_dim//2,patch_dim//2].numpy())
        batch_samples += 1
        if batch_samples > 10000//batch_size: break
    spectral_sample = np.concatenate(spectral_sample,axis=0)
    spectral_means = np.mean(spectral_sample,axis=0)
    spectral_sample -= spectral_means
    pca = PCA(n_components=reduce_dim)
    pca.fit(spectral_sample)
    spectral_loadings = pca.components_

## Define model

In [18]:
class patch101_cnn(nn.Module):
    def __init__(self,input_dim,reduce_dim,n_classes,dropout_p=0.5):
        super().__init__()
        
        # input processing and dimensionality reduction
        if r_method == 'pca':
            self.input_processing = PCAReduce(reduce_dim,spectral_means,spectral_loadings)
        elif r_method == 'fixed':
            self.input_processing = FixedReduction(input_dim)
        elif r_method == 'linear':
            self.input_processing = LinearReduction(input_dim,reduce_dim)
        
        # Convolution layers
        self.conv1 = nn.Conv2d(reduce_dim, 32, 5, stride=2, padding=0, padding_mode='reflect')
        self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, padding_mode='reflect')
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1, padding=1, padding_mode='reflect')
        
        # Normalisation Layers
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm1d(256)
        
        # Fc Layers
        self.fc1 = nn.Linear(2304, 256)
        self.fc2 = nn.Linear(256, n_classes)
        
        # Additional kit
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(p=dropout_p)
        
        
        self.feature_extractor = nn.Sequential(
            self.conv1,
            self.activation,
            self.pool,
            self.bn1,
            self.conv2,
            self.activation,
            self.pool,
            self.bn2,
            self.conv3,
            self.activation,
            self.pool,
            self.bn3,  
        )
        
        self.classifier = nn.Sequential(
            self.fc1,
            self.activation,
            self.bn4,
            self.dropout,
        )

    def forward(self, x):
        inputs = self.input_processing(x)
        features = self.feature_extractor(inputs)
        logits = self.classifier(features.flatten(1))
        return logits
    
class patch3_cnn(nn.Module):
    def __init__(self,input_dim,reduce_dim,n_classes,dropout_p=0.5):
        super().__init__()
        
        # reduction
        # input processing and dimensionality reduction
        if r_method == 'pca':
            self.input_processing = PCAReduce(reduce_dim,spectral_means,spectral_loadings)
        elif r_method == 'fixed':
            self.input_processing = FixedReduction(input_dim)
        elif r_method == 'linear':
            self.input_processing = LinearReduction(input_dim,reduce_dim)
        
        # Convolution layers
        self.conv1 = nn.Conv2d(reduce_dim, 512, 3, stride=1, padding=0, padding_mode='reflect')
        self.conv2 = nn.Conv2d(512, 512, 1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(512, 512, 1, stride=1, padding=0)
        
        # Normalisation Layers
        self.input_norm = nn.BatchNorm2d(input_dim)
        self.bn1 = nn.BatchNorm2d(512)
        self.bn2 = nn.BatchNorm2d(512)
        self.bn3 = nn.BatchNorm2d(512)
        self.bn4 = nn.BatchNorm1d(256)
        
        # Fc Layers
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, n_classes)
        
        # Additional kit
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(p=dropout_p)
        
        self.feature_extractor = nn.Sequential(
            self.conv1,
            self.activation,
            self.bn1,
            self.conv2,
            self.activation,
            self.bn2,
            self.conv3,
            self.activation,
            self.bn3,  
        )
        
        self.classifier = nn.Sequential(
            self.fc1,
            self.activation,
            self.bn4,
            self.dropout,
        )

    def forward(self, x):
        inputs = self.input_processing(x)
        features = self.feature_extractor(inputs)
        logits = self.classifier(features.flatten(1))
        return logits

In [19]:
class fusion_model(nn.Module):
    def __init__(self,model1,model3,n_classes):
        super().__init__()
        self.model1 = model1
        self.model3 = model3
        
        self.classifier = nn.Sequential(
            nn.Linear(256*2,256),
            nn.GELU(),
            nn.BatchNorm1d(256),
            nn.Linear(256,n_classes)
        )
        
    def forward(self,x):
        x1 = self.model1(x[:,:,49:-49,49:-49])
        x3 = self.model3(x)
        out = self.classifier(torch.cat([x1,x3],dim=1))
        return out

In [20]:
model3 = patch3_cnn(
    input_dim=len(wavenumbers_used),
    reduce_dim=len(wavenumbers_used) if r_method == 'fixed' else reduce_dim,
    n_classes=n_classes,
    dropout_p=dropout_p)
model101 = patch101_cnn(
    input_dim=len(wavenumbers_used),
    reduce_dim=len(wavenumbers_used) if r_method == 'fixed' else reduce_dim,
    n_classes=n_classes,
    dropout_p=dropout_p)
model = fusion_model(model3,model101,n_classes)

print(f"fusion_model with {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M params composed of:")
print(f"\tpatch3_model with {sum(p.numel() for p in model3.parameters() if p.requires_grad) / 1e6:.3f}M params")
print(f"\tpatch101_model with {sum(p.numel() for p in model101.parameters() if p.requires_grad) / 1e6:.3f}M params")
model = model.to(device)

fusion_model with 1.919M params composed of:
	patch3_model with 1.023M params
	patch101_model with 0.763M params


## Load weights

In [21]:
model_weights = torch.load("../models/patch_multiscale_linear64_seed3.pt", weights_only=True) # todo remove this branch
with torch.no_grad():
    model.load_state_dict(model_weights,strict=False)
model.eval()
model = model.to(device)

## Loop through datasets

#### Train

In [22]:
time_total = 0

In [31]:
set_preds,set_targets = [], []
pred_images, annot_images = [], []
loader_use = train_loader
start_t = time.time()
with torch.no_grad():
    for bidx, (data, annot, mask) in enumerate(loader_use):
        print(f"{bidx}/{len(loader_use)}",end='\r')
        data = data.to(device); annot = annot.to(device); mask = mask.to(device)
        has_annot = annot.sum(dim=1) != 0
        
        out = torch.zeros((data.shape[0],n_classes,data.shape[2],data.shape[3]))
        pad_data = torch.nn.functional.pad(data,(patch_dim//2,patch_dim//2,patch_dim//2,patch_dim//2,0,0,0,0))
        
        unfolded = pad_data.unfold(
            2,(101),1
        ).unfold(
            3,(101),1
        ).permute(
            0,2,3,1,4,5
        ).view(
            data.shape[0],
            image_size,image_size,
            len(wavenumbers_used),
            patch_dim,
            patch_dim,
        )
        for c_idx in range(unfolded.shape[0]):
            for r_idx in range(unfolded.shape[1]):
                for col in range(0,256,64):
                    out[c_idx,:,r_idx,col:col+64] = model(unfolded[c_idx,r_idx,col:col+64]).permute(1,0)
        
        targets = annot.argmax(dim=1)[has_annot.detach().cpu()]
        preds = out.argmax(dim=1)[has_annot.detach().cpu()]
        set_targets.extend(targets.detach().cpu().numpy())
        set_preds.extend(preds.detach().cpu().numpy())
        
        bg = data[:,229,:,:].clone()
        for b in range(bg.shape[0]): bg[b] -= bg[b].min()
        for b in range(bg.shape[0]): bg[b] /= bg[b].max()
        bg = torch.stack([bg,bg,bg],dim=-1).detach().cpu().numpy()
        
        # Prediction images
        pred_image = annotation_class_colors[out.argmax(1).detach().cpu().numpy()] / 255.0
        pred_image *= mask.squeeze().unsqueeze(-1).cpu().numpy()
        pred_images.extend(pred_image)
        annot_image = annotation_class_colors[annot.argmax(1).detach().cpu().numpy()] / 255.0
        annot_image *= has_annot.squeeze().unsqueeze(-1).cpu().numpy()
        annot_image = np.where(has_annot.unsqueeze(-1).repeat(1,1,1,1,3).detach().cpu().numpy(), annot_image, bg)[0]
        annot_images.extend(annot_image)
time_total += time.time() - start_t 

# calculate test set metrics
set_acc = accuracy_score(set_targets, set_preds)
set_f1m = f1_score(set_targets, set_preds, average='macro')
set_f1 = f1_score(set_targets, set_preds, average=None)

print(f"DATASET TRAIN --- | OA: {set_acc:.4} | f1: {set_f1m:.4}")
for cls_idx, f1 in enumerate(set_f1):
    print(f"{annotation_class_names[cls_idx]}{(20 - len(annotation_class_names[cls_idx])) * ' '} : {f1:.4}")

0/139

KeyboardInterrupt: 

In [26]:
if is_local:
    i = 0
    for i in range(0,len(pred_images),2):
        fig,ax = plt.subplots(1,2,figsize=(16.5,16.5/4)); ax = ax.flatten()
        ax[0].matshow(np.hstack([pred_images[i],annot_images[i]])); ax[0].set_axis_off()
        ax[1].matshow(np.hstack([pred_images[i+1],annot_images[i+1]])); ax[1].set_axis_off()
        fig.tight_layout()

#### Val

In [None]:
set_preds,set_targets = [], []
pred_images, annot_images = [], []
loader_use = val_loader
start_t = time.time()
with torch.no_grad():
    for bidx, (data, annot, mask) in enumerate(loader_use):
        print(f"{bidx}/{len(loader_use)}",end='\r')
        data = data.to(device); annot = annot.to(device); mask = mask.to(device)
        has_annot = annot.sum(dim=1) != 0
        
        out = torch.zeros((data.shape[0],n_classes,data.shape[2],data.shape[3]))
        pad_data = torch.nn.functional.pad(data,(patch_dim//2,patch_dim//2,patch_dim//2,patch_dim//2,0,0,0,0))
        
        unfolded = pad_data.unfold(
            2,(101),1
        ).unfold(
            3,(101),1
        ).permute(
            0,2,3,1,4,5
        ).view(
            data.shape[0],
            image_size,image_size,
            len(wavenumbers_used),
            patch_dim,
            patch_dim,
        )
        for c_idx in range(unfolded.shape[0]):
            for r_idx in range(unfolded.shape[1]):
                for col in range(0,256,64):
                    out[c_idx,:,r_idx,col:col+64] = model(unfolded[c_idx,r_idx,col:col+64]).permute(1,0)
        
        targets = annot.argmax(dim=1)[has_annot.detach().cpu()]
        preds = out.argmax(dim=1)[has_annot.detach().cpu()]
        set_targets.extend(targets.detach().cpu().numpy())
        set_preds.extend(preds.detach().cpu().numpy())
        
        bg = data[:,229,:,:].clone()
        for b in range(bg.shape[0]): bg[b] -= bg[b].min()
        for b in range(bg.shape[0]): bg[b] /= bg[b].max()
        bg = torch.stack([bg,bg,bg],dim=-1).detach().cpu().numpy()
        
        # Prediction images
        pred_image = annotation_class_colors[out.argmax(1).detach().cpu().numpy()] / 255.0
        pred_image *= mask.squeeze().unsqueeze(-1).cpu().numpy()
        pred_images.extend(pred_image)
        annot_image = annotation_class_colors[annot.argmax(1).detach().cpu().numpy()] / 255.0
        annot_image *= has_annot.squeeze().unsqueeze(-1).cpu().numpy()
        annot_image = np.where(has_annot.unsqueeze(-1).repeat(1,1,1,1,3).detach().cpu().numpy(), annot_image, bg)[0]
        annot_images.extend(annot_image)
time_total += time.time() - start_t 
        
# calculate test set metrics
set_acc = accuracy_score(set_targets, set_preds)
set_f1m = f1_score(set_targets, set_preds, average='macro')
set_f1 = f1_score(set_targets, set_preds, average=None)

print(f"DATASET VAL ----- | OA: {set_acc:.4} | f1: {set_f1m:.4}")
for cls_idx, f1 in enumerate(set_f1):
    print(f"{annotation_class_names[cls_idx]}{(20 - len(annotation_class_names[cls_idx])) * ' '} : {f1:.4}")

In [None]:
if is_local:
    i = 0
    for i in range(0,len(pred_images),2):
        fig,ax = plt.subplots(1,2,figsize=(16.5,16.5/4)); ax = ax.flatten()
        ax[0].matshow(np.hstack([pred_images[i],annot_images[i]])); ax[0].set_axis_off()
        ax[1].matshow(np.hstack([pred_images[i+1],annot_images[i+1]])); ax[1].set_axis_off()
        fig.tight_layout()

#### Test

In [None]:
set_preds,set_targets = [], []
pred_images, annot_images = [], []
loader_use = test_loader
start_t = time.time()
with torch.no_grad():
    for bidx, (data, annot, mask) in enumerate(loader_use):
        print(f"{bidx}/{len(loader_use)}",end='\r')
        data = data.to(device); annot = annot.to(device); mask = mask.to(device)
        has_annot = annot.sum(dim=1) != 0
        
        out = torch.zeros((data.shape[0],n_classes,data.shape[2],data.shape[3]))
        pad_data = torch.nn.functional.pad(data,(patch_dim//2,patch_dim//2,patch_dim//2,patch_dim//2,0,0,0,0))
        
        unfolded = pad_data.unfold(
            2,(101),1
        ).unfold(
            3,(101),1
        ).permute(
            0,2,3,1,4,5
        ).view(
            data.shape[0],
            image_size,image_size,
            len(wavenumbers_used),
            patch_dim,
            patch_dim,
        )
        for c_idx in range(unfolded.shape[0]):
            for r_idx in range(unfolded.shape[1]):
                for col in range(0,256,64):
                    out[c_idx,:,r_idx,col:col+64] = model(unfolded[c_idx,r_idx,col:col+64]).permute(1,0)
        
        targets = annot.argmax(dim=1)[has_annot.detach().cpu()]
        preds = out.argmax(dim=1)[has_annot.detach().cpu()]
        set_targets.extend(targets.detach().cpu().numpy())
        set_preds.extend(preds.detach().cpu().numpy())
        
        bg = data[:,229,:,:].clone()
        for b in range(bg.shape[0]): bg[b] -= bg[b].min()
        for b in range(bg.shape[0]): bg[b] /= bg[b].max()
        bg = torch.stack([bg,bg,bg],dim=-1).detach().cpu().numpy()
        
        # Prediction images
        pred_image = annotation_class_colors[out.argmax(1).detach().cpu().numpy()] / 255.0
        pred_image *= mask.squeeze().unsqueeze(-1).cpu().numpy()
        pred_images.extend(pred_image)
        annot_image = annotation_class_colors[annot.argmax(1).detach().cpu().numpy()] / 255.0
        annot_image *= has_annot.squeeze().unsqueeze(-1).cpu().numpy()
        annot_image = np.where(has_annot.unsqueeze(-1).repeat(1,1,1,1,3).detach().cpu().numpy(), annot_image, bg)[0]
        annot_images.extend(annot_image)
time_total += time.time() - start_t 
        
# calculate test set metrics
set_acc = accuracy_score(set_targets, set_preds)
set_f1m = f1_score(set_targets, set_preds, average='macro')
set_f1 = f1_score(set_targets, set_preds, average=None)

print(f"DATASET TEST ---- | OA: {set_acc:.4} | f1: {set_f1m:.4}")
for cls_idx, f1 in enumerate(set_f1):
    print(f"{annotation_class_names[cls_idx]}{(20 - len(annotation_class_names[cls_idx])) * ' '} : {f1:.4}")

In [None]:
if is_local:
    i = 0
    for i in range(0,len(pred_images),2):
        fig,ax = plt.subplots(1,2,figsize=(16.5,16.5/4)); ax = ax.flatten()
        ax[0].matshow(np.hstack([pred_images[i],annot_images[i]])); ax[0].set_axis_off()
        ax[1].matshow(np.hstack([pred_images[i+1],annot_images[i+1]])); ax[1].set_axis_off()
        fig.tight_layout()

In [ ]:
print(f"TIME TOTAL FOR ALL IMAGES: {time_total}")