In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
from torch.optim import Adam
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import scale
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
from tqdm import tqdm, trange

from torchmtlr import MTLR, mtlr_neg_log_likelihood, mtlr_cif, mtlr_risk, mtlr_survival
from torchmtlr.utils import make_time_bins, encode_survival, reset_parameters

In [3]:
torch.manual_seed(1129)
torch.backends.deterministic = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(1129)
device = "cpu"

sns.set(context="poster", style="white")
plt.rcParams["figure.figsize"] = (10, 7)

## Some methods

In [4]:
def make_optimizer(opt_cls, model, **kwargs):
    """Creates a PyTorch optimizer for MTLR training."""
    params_dict = dict(model.named_parameters())
    weights = [v for k, v in params_dict.items() if "mtlr" not in k and "bias" not in k]
    biases = [v for k, v in params_dict.items() if "bias" in k]
    mtlr_weights = [v for k, v in params_dict.items() if "mtlr_weight" in k]
    # Don't use weight decay on the biases and MTLR parameters, which have
    # their own separate L2 regularization
    optimizer = opt_cls([
        {"params": weights},
        {"params": biases, "weight_decay": 0.},
        {"params": mtlr_weights, "weight_decay": 0.},
    ], **kwargs)
    return optimizer

In [23]:
#dataload function
def make_data(path, split="training"):
    """Load and preprocess the data."""
    try:
        df = pd.read_csv(path)
    except:
        df = path
    clinical_data = (df
                     .query("split == @split")
                     .set_index("Study ID")
                     .drop(["split"], axis=1, errors="ignore"))
    
    clinical_data = clinical_data.rename(columns={"death": "event", "survival_time": "time"})
    # Convert time to months
    clinical_data["time"] *= 12
        
    # binarize T stage as T1/2 = 0, T3/4 = 1
    clinical_data["T Stage"] = clinical_data["T Stage"].map(
        lambda x: "T1/2" if x in ["T1", "T1a", "T1b", "T2"] else "T3/4")
    
    # use more fine-grained grouping for N stage
    clinical_data["N Stage"] = clinical_data["N Stage"].map({
        "N0":  "N0",
        "N1":  "N1",
        "N2":  "N2",
        "N2a": "N2",
        "N2b": "N2",
        "N2c": "N2",
        "N3":  "N3",
        "N3a": "N3",
        "N3b": "N3"
    })
    
    clinical_data["age at dx"] = scale(clinical_data["age at dx"])
    clinical_data["Dose"] = scale(clinical_data["Dose"])
    
    clinical_data["Stage"] = clinical_data["Stage"].map(
        lambda x: "I/II" if x in ["I", "II", "IIA"] else "III/IV")
    
    clinical_data["ECOG"] = clinical_data["ECOG"].map(
        lambda x: ">0" if x > 0 else "0")
    
    clinical_data = pd.get_dummies(clinical_data,
                                   columns=["Sex",
                                            "T Stage",
                                            "N Stage",
                                            "Disease Site",
                                            "Stage",
                                            "ECOG"],
                                   drop_first=True)
    clinical_data = pd.get_dummies(clinical_data, columns=["HPV Combined"])
    return clinical_data

# training functions
def train_mtlr(X, y, model, time_bins,
               num_epochs=1000, lr=.01, weight_decay=0.,
               C1=10., batch_size=None,
               verbose=True, device="cpu"):
    """Trains the MTLR model using minibatch gradient descent.
    
    Parameters
    ----------
    model : torch.nn.Module
        MTLR model to train.
    data_train : pd.DataFrame
        The training dataset. Must contain a `time` column with the
        event time for each sample and an `event` column containing
        the event indicator.
    num_epochs : int
        Number of training epochs.
    lr : float
        The learning rate.
    weight_decay : float
        Weight decay strength for all parameters *except* the MTLR
        weights. Only used for Deep MTLR training.
    C1 : float
        L2 regularization (weight decay) strenght for MTLR parameters.
    batch_size : int
        The batch size.
    verbose : bool
        Whether to display training progress.
    device : str
        Device name or ID to use for training.
        
    Returns
    -------
    torch.nn.Module
        The trained model.
    """
    optimizer = make_optimizer(Adam, model, lr=lr, weight_decay=weight_decay)
    reset_parameters(model)
    print(X.shape, y.shape)
    model = model.to(device)
    model.train()
    train_loader = DataLoader(TensorDataset(X, y), batch_size=batch_size, shuffle=True)
    
    pbar =  trange(num_epochs, disable=not verbose)
    for i in pbar:
        for xi, yi in train_loader:
            xi, yi = xi.to(device), yi.to(device)
            y_pred = model(xi)
            loss = mtlr_neg_log_likelihood(y_pred, yi, model, C1, average=True)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        pbar.set_description(f"[epoch {i+1: 4}/{num_epochs}]")
        pbar.set_postfix_str(f"loss = {loss.item():.4f}")
    model.eval()
    return model

In [6]:
def multiple_events(row):
    event        = row["event"]
    cancer_death = row["cancer_death"]

    if event==0:
        return 0
    elif cancer_death==0:
        return 1
    elif cancer_death==1:
        return 2
    else:
        raise UhOh

In [38]:
def validate(model, X, time, other, cancer):
    """Returns inference of model on data X
    params
    ------
    model
        PyTorch model being tested
    X
        data to test moddel on
    time, other, cancer
        true labels for time, other, and cancer
    
    
    
    """
    pred_prob       = model(X)
    two_year_bin    = np.digitize(2, time_bins)
    survival_event  = mtlr_survival(pred_prob[:,:29]).detach().numpy()
    survival_cancer = mtlr_survival(pred_prob[:,29:]).detach().numpy()
    pred_event      = 1 - survival_event[:, two_year_bin]
    pred_cancer     = 1 - survival_cancer[:, two_year_bin]
    
    roc_auc_event   = roc_auc_score(other, pred_event)
    roc_auc_cancer  = roc_auc_score(cancer, pred_cancer)

    pred_risk = mtlr_risk(pred_prob, 2).detach().numpy()
        
    ci_event  = concordance_index(time, -pred_risk[:, 0], event_observed=other)
    ci_cancer = concordance_index(time, -pred_risk[:, 1], event_observed=cancer)
    
    return roc_auc_event, roc_auc_cancer, ci_event, ci_cancer

## load/process data

In [8]:
df        = make_data("/cluster/projects/radiomics/RADCURE-challenge/data/training/clinical.csv")
time_bins = make_time_bins(df["time"], event=df["event"])
multi_events = df.apply(lambda x: multiple_events(x), axis=1)

y = encode_survival(df["time"], multi_events, time_bins)
X = torch.tensor(df.drop(["time", "event", "target_binary", "cancer_death"], axis=1).values, dtype=torch.float)


In [9]:
full_indices = range(len(df))
full_targets = df["target_binary"]
train_indices, val_indices = train_test_split(full_indices, test_size=0.25, stratify=full_targets)

In [10]:
X_train, X_val = X[train_indices], X[val_indices]
y_train, y_val = y[train_indices], y[val_indices]

In [11]:
val_data = df.iloc[val_indices]

true_time   = val_data["time"]
true_event  = val_data["event"]
true_cancer = val_data["cancer_death"]

## fit MTLR

In [149]:
# fit MTLR model 
mtlr = MTLR(in_features=20, num_time_bins=29, num_events=2)            
mtlr = train_mtlr(X_train, y_train, mtlr, time_bins, num_epochs=500, 
                  lr=.0002, batch_size=128, verbose=True, device=device)

[epoch    8/500]:   1%|          | 4/500 [00:00<00:13, 37.96it/s, loss = 132.0287]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  500/500]: 100%|██████████| 500/500 [00:12<00:00, 40.34it/s, loss = 1.9435] 


In [150]:
validate(mtlr, X_val, true_time, true_event, true_cancer)

(0.7295530934893564,
 0.7305797650625236,
 0.6536868196630382,
 0.7182150264681298)

## test model

In [24]:
df1 = pd.read_csv("/cluster/projects/radiomics/RADCURE-challenge/clinical_cancer_death.csv")
df2 = pd.read_csv("/cluster/projects/radiomics/RADCURE-challenge/data/clinical_test.csv")
dff = make_data(pd.merge(df1, df2[["Study ID", "ECOG"]], how='inner', on='Study ID'), split="test").drop("EGFRI", axis=1)

In [26]:
X_test = torch.tensor(dff.drop(["time", "event", "target_binary", "cancer_death"], axis=1).values, dtype=torch.float)

true_time_test   = dff["time"]
true_event_test  = dff["event"]
true_cancer_test = dff["cancer_death"]

In [154]:
validate(mtlr, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.6467351430667645,
 0.7082955095618493,
 0.6211315633027111,
 0.7054066472660439)

# MTLR + hiddenlayers

In [155]:
mtlr1 = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr1 = train_mtlr(X_train, y_train, mtlr1, time_bins, num_epochs=500, 
                   lr=.0002, batch_size=128, verbose=True, device=device)

validate(mtlr1, X_val, true_time, true_event, true_cancer)

[epoch    6/500]:   1%|          | 3/500 [00:00<00:17, 28.93it/s, loss = 354.2757]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  500/500]: 100%|██████████| 500/500 [00:15<00:00, 32.96it/s, loss = 1.6268] 


(0.747904477364206, 0.746321207528104, 0.7070217707827775, 0.7516794105810606)

In [163]:
mtlr2 = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    nn.Linear(532, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr2 = train_mtlr(X_train, y_train, mtlr2, time_bins, num_epochs=500, 
                   lr=.0002, batch_size=128, verbose=True, device=device)

validate(mtlr2, X_val, true_time, true_event, true_cancer)

[epoch    4/500]:   0%|          | 2/500 [00:00<00:25, 19.67it/s, loss = 401.8694]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  500/500]: 100%|██████████| 500/500 [00:38<00:00, 13.15it/s, loss = 2.1577] 


(0.7480185899829868, 0.752273588480485, 0.7160351521874784, 0.737098721480977)

## test deep MTLR

In [157]:
validate(mtlr1, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7507948153582782,
 0.7715822440087146,
 0.7164642011867506,
 0.7748812988206464)

In [164]:
validate(mtlr2, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7502358243370716,
 0.7674670176712659,
 0.7297842985496518,
 0.7476183182723235)

# fuck me

In [87]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    6/1000]:   0%|          | 3/1000 [00:00<00:33, 29.65it/s, loss = 430.5145]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [00:33<00:00, 30.16it/s, loss = 1.8031]


(0.761211845951391, 0.7718318809005082, 0.7206873731085461, 0.776948996783581)

In [118]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    3/1000]:   0%|          | 2/1000 [00:00<01:05, 15.33it/s, loss = 464.4609]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [00:30<00:00, 32.27it/s, loss = 1.7256]


(0.7649035158206104,
 0.7766279351246671,
 0.7271539298493567,
 0.7790779598713432)

In [119]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    6/1000]:   0%|          | 4/1000 [00:00<00:32, 30.88it/s, loss = 422.3957]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [00:31<00:00, 31.29it/s, loss = 2.3196]


(0.7649268071131609,
 0.7703038005325586,
 0.7253087458357169,
 0.7750497779139225)

In [120]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    nn.Linear(532, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    4/1000]:   0%|          | 2/1000 [00:00<00:52, 19.07it/s, loss = 453.1712]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [01:01<00:00, 16.14it/s, loss = 1.6403]


(0.7330934330200654,
 0.7645848462841927,
 0.7245628203833945,
 0.7376474192066166)

In [121]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    nn.Linear(532, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    4/1000]:   0%|          | 2/1000 [00:00<00:50, 19.87it/s, loss = 448.1322]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [01:04<00:00, 15.58it/s, loss = 1.8524]


(0.7412453854126634, 0.771536855482934, 0.7329867303787956, 0.7451064481543881)