In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
from deepsurvloss import NegativeLogLikelihood, NegativeLogLikelihoodWithRegular
from FusionE2EModel_crossattention import CrossMRIModel,ImageFusionModel,AttentionFusion
from tqdm import tqdm
from utils import c_index
import random
import torch.nn as nn 

In [2]:
class MySampler(torch.utils.data.Sampler):
    def __init__(self, data_source, neg_expend):
        self.data_source = data_source
        self.neg_expend = neg_expend
 
    def __iter__(self):
        indices_pos = [] 
        indices_neg = []
        indices = []
        for i in range(len(self.data_source)):
            if self.data_source[i][1] == 1:
                indices_pos.append(i)
            else:
                indices_neg.append(i)
        # random.seed(1)
        random.shuffle(indices_pos)
        random.shuffle(indices_neg)
        
        for i in range(len(indices_pos)):
            indices += [indices_pos[i], ] + indices_neg[i*self.neg_expend: (i+1)*self.neg_expend]
        # print('pos_count:{}, neg_count:{}, indices:{}'.format(len(indices_pos), len(indices_neg), indices))
        return iter(indices)
 
    def __len__(self):
        return len(self.data_source)

In [3]:
rad_features = pd.read_csv('data/rad_features.csv')
rad_features = rad_features.drop(columns=rad_features.columns[[1, 2]])
rad_features.head
clinical_features = pd.read_csv('data/Clifeatures.csv')
clinical_features = clinical_features[['ID','VIRADS', 'Age', 'Recurrent tumor']]
# clinical_features.head
scaler = MinMaxScaler()
clinical_features['Age'] = scaler.fit_transform(clinical_features[['Age']])

In [4]:
class MyDataset(Dataset):
    def __init__(self, data_df):
        super(Dataset, self).__init__()
        self.data_df = data_df
        self.rad_features = rad_features
        self.clif = clinical_features
    def __getitem__(self, idx):
        data = self.data_df.iloc[idx]
        pids = data['ID']
        status = data['status']
        times = data['time']
        DWI_ImgTensor = torch.tensor(np.load(f'../npy_dwi_normed_224x224x12_ai/DWI{pids}.npy')).float()
        T2_ImgTensor = torch.tensor(np.load(f'../npy_t2_normed_224x224x12_ai/T2{pids}.npy')).float()
        rad_fe = self.rad_features[self.rad_features['ID'] == pids].drop('ID', axis=1)      
        cli_fe = self.clif[self.clif['ID'] == pids].drop('ID', axis=1)
        return pids, status, times, DWI_ImgTensor.cuda(), T2_ImgTensor.cuda(),torch.tensor(rad_fe.values).float().squeeze(0).cuda(), torch.tensor(cli_fe.values).float().squeeze(0).cuda()
    def __len__(self):
        return len(self.data_df)

In [None]:
fold_dir = f'resnet_model/final'
os.makedirs(os.path.join(fold_dir, 'saved_models'), exist_ok=True)
os.makedirs(os.path.join(fold_dir, 'Figures'), exist_ok=True)
all_df = pd.read_csv(f'data/label_group.csv')
train_df = all_df[all_df['group'] == 'train'].reset_index(drop=True)
val_df = all_df[all_df['group'] == 'val'].reset_index(drop=True)

neg_expend = 3
traindataset = MyDataset(train_df)
trainloader = DataLoader(traindataset, batch_size=8, sampler=MySampler(traindataset, neg_expend=neg_expend))
valdataset = MyDataset(val_df)
valloader = DataLoader(valdataset, batch_size=8,shuffle=False)

epochs = 100
max_no_improve = 80  
weight_decay = 0.05
lr = 0.00001

best_val_cindex = 0.0
no_improve_count = 0
train_cindex_list = []
val_cindex_list = []
train_loss_list = []
val_loss_list = []


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = CrossMRIModel(device=device, img_encoder='resnet50').to(device)
criterion = NegativeLogLikelihoodWithRegular('cuda', model, weight_decay=weight_decay)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


for epoch in range(epochs):
    model.train()
    train_ids = []
    train_status = []
    train_times = []
    train_outputs = []
    all_loss_train = 0

    for step, batch in enumerate(trainloader):
        train_ids.extend(batch[0])
        train_status.extend(batch[1])
        train_times.extend(batch[2])
        optimizer.zero_grad()
        output, x = model(batch[3], batch[4], batch[5], batch[6])
        output = output.squeeze(1)
        train_outputs.extend(output.tolist())
        loss = criterion(output.cuda(), batch[2].cuda(), batch[1].cuda())
        loss.backward()
        optimizer.step()
        all_loss_train += loss.item()  
    

    train_outputs = np.array(train_outputs)
    train_ids = np.array(train_ids)
    train_times = np.array(train_times)
    train_status = np.array(train_status)
    train_cindex = c_index(-train_outputs, train_times, train_status)
    avg_train_loss = all_loss_train / len(trainloader)    

    model.eval()
    with torch.no_grad():
        val_ids = []
        val_status = []
        val_times = []
        val_outputs = []
        all_loss_val = 0
        
        for valbatch in valloader:
            val_ids.extend(valbatch[0])
            val_status.extend(valbatch[1])
            val_times.extend(valbatch[2])
            output, x = model(valbatch[3], valbatch[4], valbatch[5], valbatch[6])
            output = output.squeeze(1)
            val_outputs.extend(output.tolist())
            loss = criterion(output.cuda(), valbatch[2].cuda(), valbatch[1].cuda())
            all_loss_val += loss.item() 
        
        val_outputs = np.array(val_outputs)
        val_ids = np.array(val_ids)
        val_times = np.array(val_times)
        val_status = np.array(val_status)
        val_cindex = c_index(-val_outputs, val_times, val_status)
        avg_val_loss = all_loss_val / len(valloader)
    
    train_cindex_list.append(train_cindex)
    val_cindex_list.append(val_cindex)
    train_loss_list.append(avg_train_loss)
    val_loss_list.append(avg_val_loss)
    
    print(f'Epoch {epoch+1}/{epochs}: '
          f'train cindex:{train_cindex:.4f}, train loss:{avg_train_loss:.4f}; '
          f'val cindex:{val_cindex:.4f}, val loss:{avg_val_loss:.4f}')


    if val_cindex > best_val_cindex:
        best_val_cindex = val_cindex
        no_improve_count = 0
    else:
        no_improve_count += 1

    if no_improve_count >= max_no_improve:
        print(f"Early stopping triggered after {epoch+1} epochs with no improvement")
        save_path = os.path.join(fold_dir, 'saved_models', 'final_model.pth')
        torch.save(model.state_dict(), save_path)
        print(f"Saved final model (early stopped) at: {save_path}")
        break  

if no_improve_count < max_no_improve:
    save_path = os.path.join(fold_dir, 'saved_models', 'final_model.pth')
    torch.save(model.state_dict(), save_path)
    print(f"Saved final model (completed all epochs) at: {save_path}")


In [None]:
saved_models_dir = os.path.join(fold_dir, 'saved_models')
final_model_path = os.path.join(saved_models_dir, 'final_model.pth')

if os.path.exists(final_model_path):
    model.load_state_dict(torch.load(final_model_path))
    print(f"final_model.pth")

def predict_risk(data_loader, group_name):
    model.eval()
    all_ids = []
    all_risks = []
    
    with torch.no_grad():
        for batch in data_loader:
            ids = batch[0]
            outputs, _ = model(batch[3], batch[4], batch[5], batch[6])
            risks = outputs.squeeze(1).cpu().numpy()
            
            all_ids.extend(ids)
            all_risks.extend(risks)
    
    results = pd.DataFrame({
        'ID': all_ids,
        'risk': all_risks,
        'group': group_name
    })

    merged = results.merge(all_df[['ID', 'status', 'time']], on='ID', how='left')
    return merged[['ID', 'status', 'time', 'risk', 'group']]

In [None]:
traindataset = MyDataset(train_df)
trainloader = DataLoader(traindataset, batch_size=32,shuffle=False)
multi_df = all_df[all_df['group'] == 'multi'].reset_index(drop=True)
multidataset = MyDataset(multi_df)
multiloader = DataLoader(multidataset, batch_size=32, shuffle=False)
nac_df = all_df[all_df['group'] == 'nac'].reset_index(drop=True)
nacdataset = MyDataset(nac_df)
nacloader = DataLoader(nacdataset, batch_size=32, shuffle=False)

train_results = predict_risk(trainloader, 'train')
val_results = predict_risk(valloader, 'val')
multi_results = predict_risk(multiloader, 'multi')
nac_results = predict_risk(nacloader, 'nac')

all_results = pd.concat([train_results, val_results, multi_results, nac_results], ignore_index=True)
output_csv = os.path.join(fold_dir, f'risk_predictions.csv')
all_results.to_csv(output_csv, index=False)
train_cindex = c_index(-train_results['risk'].values, train_results['time'].values, train_results['status'].values)
val_cindex = c_index(-val_results['risk'].values, val_results['time'].values, val_results['status'].values)
multi_cindex = c_index(-multi_results['risk'].values, multi_results['time'].values, multi_results['status'].values)
nac_cindex = c_index(-nac_results['risk'].values, nac_results['time'].values, nac_results['status'].values)