In [None]:
# --- Update the outdated pyg files
import torch
import torch_geometric
import os
from tqdm import tqdm

def bump(g):
    return torch_geometric.data.data.Data.from_dict(g.__dict__)

for file in tqdm(os.listdir('data/TCGA_GBMLGG/all_st_cpc_old')):
    old = torch.load(os.path.join('data/TCGA_GBMLGG/all_st_cpc_old', file))
    new = bump(old)
    torch.save(new, os.path.join('data/TCGA_GBMLGG/all_st_cpc', file))

In [None]:
# --- Generate VGG features for all patches
import os
import torch
import pickle
from PIL import Image
from torchvision import transforms
from src.networks import get_vgg19
import argparse
from tqdm import tqdm

device = torch.device('mps')

# Change this depending on task
load_path = 'checkpoints/surv/path_instance'
img_dir = 'data/path/patch'
vgg_file = 'data/path/vgg_features_surv.pkl'

# check file exists
if os.path.exists(vgg_file):
    with open(vgg_file, 'rb') as f:
        vgg_dict = pickle.load(f)
else:
    vgg_dict = {}

opt = argparse.Namespace(
    mil="instance",
    attn_pool=0,
)
for ckpt in os.listdir(load_path):
    ckpt_dict = {}
    if 'path' in ckpt:
        split = ckpt.split('_')[1].split('.')[0]
        if split in vgg_dict:
            print(f"Skipping {split}")
            continue
        model = get_vgg19(opt)
        model_ckpt = torch.load(os.path.join(load_path, ckpt), map_location=device)
        model.load_state_dict(model_ckpt['model'])
        model.to(device)
        model.eval()

        for fname in tqdm(os.listdir(img_dir)):
            x_path = Image.open(os.path.join(img_dir, fname)).convert('RGB')
            tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            x_path = torch.unsqueeze(tf(x_path), dim=0)
            features, _,_ = model(x_path=x_path.to(device))
            ckpt_dict[fname] = features.cpu().detach().numpy()

        vgg_dict[split] = ckpt_dict
        assert len(vgg_dict[split]) == len(os.listdir(img_dir))

with open(vgg_file, 'wb') as f:  # Rename depending on task
    pickle.dump(vgg_dict, f)

In [None]:
# Optimizing CoxLoss: over 400 times faster on CPU for a batch size of 64
# Performance gains scale with batch size

import torch
import numpy as np

# Paper's implementation
def CoxLoss_old(survtime, censor, hazard_pred, device):
    current_batch_len = len(survtime)
    R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
    for i in range(current_batch_len):
        for j in range(current_batch_len):
            R_mat[i,j] = survtime[j] >= survtime[i]

    R_mat = torch.FloatTensor(R_mat).to(device)
    theta = hazard_pred.reshape(-1)
    exp_theta = torch.exp(theta)
    loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor)
    return loss_cox


# Optimized implementation
def cox_loss(
    survtime: torch.Tensor, event: torch.Tensor, hazard_pred: torch.Tensor
) -> torch.Tensor:
    """Source: https://github.com/traversc/cox-nnet"""
    # Predictions are not independent in Coxloss; calculating over batch != whole dataset
    R_mat = (survtime.repeat(len(survtime), 1) >= survtime.unsqueeze(1)).int()
    theta = hazard_pred.view(-1)
    exp_theta = theta.exp()
    loss_cox = -torch.mean((theta - (exp_theta * R_mat).sum(dim=1).log()) * event)
    return loss_cox

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch = 64

survtime = torch.rand(batch).to(device)
censor = torch.randint(0, 2, (batch,)).to(device)
hazard_pred = torch.rand(batch).to(device)
assert CoxLoss_old(survtime, censor, hazard_pred, device) == cox_loss(survtime, censor, hazard_pred)

%timeit CoxLoss_old(survtime, censor, hazard_pred, device)
%timeit cox_loss(survtime, censor, hazard_pred)