In [None]:
import pandas as pd
import torch
import torch.nn as nn
from torch.nn.functional import softmax
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from sksurv.linear_model import CoxPHSurvivalAnalysis
from losses import CensoredMSELoss
from utils import compute_time_to_event
import numpy as np
import logging
import hydra
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf
import os
import logging
from models import TimeToDeath3DCNN
from sklearn.model_selection import train_test_split
from utils import LungCancerDataset

In [None]:
with initialize(version_base=None, config_path="."):
    cfg = compose(config_name='experiment_config.yaml')
    print(OmegaConf.to_yaml(cfg))

In [None]:
def train_model(model, train_dataset, batch_size, criterion, optimizer, writer, device, num_epochs, gamma, logger):

    dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch in dataloader: # clinical vars too
            scans, events, times, clinical_vars = batch
            scans, events, times, clinical_vars = scans.to(device), events.to(device), times.to(device), clinical_vars.to(device)

            optimizer.zero_grad()
            embedding, proba_thresh = model(scans)
            all_features = np.concatenate((embedding.detach().numpy(), clinical_vars.detach().numpy()), axis=1)
            survival_estimator = model.fit_survival_estimator(all_features, events, times)
            surv_funcs = survival_estimator.predict_survival_function(all_features)
            survival_times = compute_time_to_event(surv_funcs, thershold = proba_thresh)
            loss = criterion(survival_times, events, times, gamma)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {epoch_loss:.4f}")
        avg_loss = epoch_loss / len(dataloader)
        logger.info(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")
        writer.add_scalar("Loss/Train", avg_loss, epoch + 1)

    return model

# should the following just be test?

# def validate_model(model, dataloader, criterion, writer, device, epoch):
#     model.eval()
#     with torch.no_grad():
#         total_loss = 0
#         for batch in dataloader:
#             scans, events, times = batch
#             scans, events, times = scans.to(device), events.to(device), times.to(device)

#             predictions = model(scans).squeeze()
#             loss = criterion(predictions, events, times)
#             total_loss += loss.item()

#         print(f"Validation Loss: {total_loss:.4f}")
#         avg_loss = total_loss / len(dataloader)
#         logger.info(f"Validation Loss: {avg_loss:.4f}")
#         writer.add_scalar("Loss/Validation", avg_loss, epoch)
#         return avg_loss

In [None]:
def main(cfg=cfg):

    # cfg contains all hyperparams and paths
    # make sure the paths are correct

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    results_dir = os.path.join(cfg.results_dir, cfg.experiment_name)
    os.makedirs(results_dir, exist_ok=True)
    random_state = np.random.RandomState(seed=42)

    batch_size = cfg.batch_size
    learning_rate = cfg.learning_rate
    num_epochs = cfg.num_epochs
    in_channels = cfg.in_channel
    out_channels_conv1 = cfg.out_channels_conv1
    out_channels_conv2 = cfg.out_channels_conv2
    out_channels_conv3 = cfg.out_channels_conv3
    kernel_conv = cfg.kernel_conv
    kernel_pool = cfg.kernel_pool
    dropout = cfg.dropout
    scans_path = cfg.scans_path
    clinical_vars_path = cfg.clinical_vars_path
    gamma = cfg.gamma


    logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )
    logger = logging.getLogger(__name__)

    model  = TimeToDeath3DCNN(
        in_channels,
        out_channels_conv1,
        out_channels_conv2,
        out_channels_conv3,
        kernel_conv,
        kernel_pool,
        dropout)
    
    criterion = CensoredMSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    n_sub = len(pd.load(clinical_vars_path))
    train_indices, test_indices = train_test_split(np.arange(n_sub), test_size=0.2, random_state=random_state)
    train_dataset = LungCancerDataset(scans_path, scans_path, clinical_vars_path, train_indices, test_indices, return_train = True)

    writer = SummaryWriter(results_dir)
    train_model(model, train_dataset, batch_size, criterion, optimizer, writer, device, num_epochs, gamma, logger)
    
    

