In [1]:
print ("hi")

hi


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
import torch

In [4]:
#import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
from torch.optim import Adam

In [5]:
from pytorch_lightning import seed_everything

from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import scale
from sklearn.metrics import roc_auc_score

from tqdm import tqdm, trange
from math import floor

In [6]:
from lifelines.utils import concordance_index
from torchmtlr import MTLR, mtlr_neg_log_likelihood, mtlr_cif, mtlr_risk, mtlr_survival
from torchmtlr.utils import make_time_bins, encode_survival, reset_parameters

# ezpz

## Some methods

In [7]:
def make_optimizer(opt_cls, model, **kwargs):
    """Creates a PyTorch optimizer for MTLR training."""
    params_dict = dict(model.named_parameters())
    weights = [v for k, v in params_dict.items() if "mtlr" not in k and "bias" not in k]
    biases = [v for k, v in params_dict.items() if "bias" in k]
    mtlr_weights = [v for k, v in params_dict.items() if "mtlr_weight" in k]
    # Don't use weight decay on the biases and MTLR parameters, which have
    # their own separate L2 regularization
    optimizer = opt_cls([
        {"params": weights},
        {"params": biases, "weight_decay": 0.},
        {"params": mtlr_weights, "weight_decay": 0.},
    ], **kwargs)
    return optimizer

In [8]:
def make_data(path, split="training"):
    """Load and preprocess the data."""
    try:
        df = pd.read_csv(path)
    except:
        df = path

    clinical_data = (df
                     .query("split == @split")
                     .set_index("Study ID")
                     .drop(["split"], axis=1, errors="ignore"))
    # if split == "training":
    clinical_data = clinical_data.rename(columns={"death": "event", "survival_time": "time"})
    # Convert time to months
    clinical_data["time"] *= 12

    clinical_data["age at dx"] = scale(clinical_data["age at dx"])
    clinical_data["Dose"] = scale(clinical_data["Dose"])

    # binarize T stage as T1/2 = 0, T3/4 = 1
    clinical_data["T Stage"] = clinical_data["T Stage"].map(
        lambda x: "T1/2" if x in ["T1", "T1a", "T1b", "T2"] else "T3/4", na_action="ignore")

    # use more fine-grained grouping for N stage
    clinical_data["N Stage"] = clinical_data["N Stage"].str.slice(0, 2)

    clinical_data["Stage"] = clinical_data["Stage"].map(
        lambda x: "I/II" if x in ["I", "II", "IIA"] else "III/IV", na_action="ignore")

    clinical_data["ECOG"] = clinical_data["ECOG"].map(
        lambda x: ">0" if x > 0 else "0", na_action="ignore")

    clinical_data = pd.get_dummies(clinical_data,
                                   columns=["Sex",
                                            "N Stage",
                                            "Disease Site"],
                                   drop_first=True)
    clinical_data = pd.get_dummies(clinical_data,
                                   columns=["HPV Combined",
                                            "T Stage",
                                            "Stage",
                                            "ECOG"])

    return clinical_data

In [9]:
def multiple_events(row):
    event        = row["event"]
    cancer_death = row["cancer_death"]

    if event==0:
        return 0
    elif cancer_death==0:
        return 1
    elif cancer_death==1:
        return 2
    else:
        raise UhOh

In [10]:
def train_mtlr(x, y, model, time_bins,
               num_epochs=1000, lr=.01, weight_decay=0.,
               C1=1., batch_size=None,
               verbose=True, device="cpu"):
    """Trains the MTLR model using minibatch gradient descent.
    
    Parameters
    ----------
    model : torch.nn.Module
        MTLR model to train.
    data_train : pd.DataFrame
        The training dataset. Must contain a `time` column with the
        event time for each sample and an `event` column containing
        the event indicator.
    num_epochs : int
        Number of training epochs.
    lr : float
        The learning rate.
    weight_decay : float
        Weight decay strength for all parameters *except* the MTLR
        weights. Only used for Deep MTLR training.
    C1 : float
        L2 regularization (weight decay) strenght for MTLR parameters.
    batch_size : int
        The batch size.
    verbose : bool
        Whether to display training progress.
    device : str
        Device name or ID to use for training.
        
    Returns
    -------
    torch.nn.Module
        The trained model.
    """
    optimizer = make_optimizer(Adam, model, lr=lr, weight_decay=weight_decay)
    reset_parameters(model)
    print(x.shape, y.shape)
    model = model.to(device)
    model.train()
    train_loader = DataLoader(TensorDataset(x, y), batch_size=batch_size, shuffle=True)
    
    pbar =  trange(num_epochs, disable=not verbose)
    for i in pbar:
        for xi, yi in train_loader:
            xi, yi = xi.to(device), yi.to(device)
            y_pred = model(xi)
            loss = mtlr_neg_log_likelihood(y_pred, yi, model, C1, average=True)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        pbar.set_description(f"[epoch {i+1: 4}/{num_epochs}]")
        pbar.set_postfix_str(f"loss = {loss.item():.4f}")
    model.eval()
    return model

In [11]:
def validate(model, X, time, other, cancer):
    """Returns inference of model on data X
    params
    ------
    model
        PyTorch model being tested
    X
        data to test on
    time, other, cancer
        true labels for time, other, and cancer
    """
    pred_prob       = model(X)
    two_year_bin    = np.digitize(2, time_bins)
    survival_event  = mtlr_survival(pred_prob[:,:29]).detach().numpy()
    survival_cancer = mtlr_survival(pred_prob[:,29:]).detach().numpy()
    pred_event      = 1 - survival_event[:, two_year_bin]
    pred_cancer     = 1 - survival_cancer[:, two_year_bin]
    
    roc_auc_event   = roc_auc_score(other, pred_event)
    roc_auc_cancer  = roc_auc_score(cancer, pred_cancer)

    pred_risk = mtlr_risk(pred_prob, 2).detach().numpy()
        
    ci_event  = concordance_index(time, -pred_risk[:, 0], event_observed=other)
    ci_cancer = concordance_index(time, -pred_risk[:, 1], event_observed=cancer)
    
    return roc_auc_cancer, roc_auc_event, ci_cancer, ci_event

## load/process data

In [12]:
df1 = pd.read_csv("/cluster/projects/radiomics/RADCURE-challenge/clinical_cancer_death.csv")
df2 = pd.read_csv("/cluster/projects/radiomics/RADCURE-challenge/data/clinical_test.csv")
df3 = pd.read_csv("/cluster/projects/radiomics/RADCURE-challenge/data/clinical.csv")
ddf = pd.merge(df1, pd.concat([df2[["Study ID", "ECOG"]], df3[["Study ID", "ECOG"]]]), 
              how='outer', on='Study ID').drop("EGFRI", axis=1)

df  = make_data(ddf, split="training")

In [13]:
#df        = make_data("/cluster/projects/radiomics/RADCURE-challenge/data/training/clinical.csv")
time_bins = make_time_bins(df["time"], event=df["event"])
multi_events = df.apply(lambda x: multiple_events(x), axis=1)

y = encode_survival(df["time"], multi_events, time_bins)
X = torch.tensor(df.drop(["time", "event", "target_binary", "cancer_death"], axis=1).values, dtype=torch.float)

In [14]:
full_indices = range(len(df))
full_targets = df["target_binary"]
val_size = floor(.1 / .7 * len(full_indices))

train_indices, val_indices = train_test_split(full_indices, test_size=val_size, stratify=full_targets, random_state=1129)

In [15]:
X_train, X_val = X[train_indices], X[val_indices]
y_train, y_val = y[train_indices], y[val_indices]

df_val = df.iloc[val_indices]

In [16]:
print(X_train.shape, X_val.shape)

torch.Size([1545, 24]) torch.Size([257, 24])


In [17]:
df_test = make_data(ddf, split="test")
df_test.insert(11, 'N Stage_NX', np.zeros(df_test.shape[0]))

In [18]:
X_test = torch.tensor(df_test.drop(["time", "event", "target_binary", "cancer_death"], axis=1).values, dtype=torch.float)

In [19]:
seed_everything(1129)

1129

# C1=1. (late nov)

## fit mellow MTLR

In [20]:
device = "cpu"

In [42]:
# fit MTLR model 
mtlr = MTLR(in_features=24, num_time_bins=29, num_events=2)            
mtlr = train_mtlr(X_train, y_train, mtlr, time_bins, num_epochs=350, 
                  lr=1e-3, batch_size=128, verbose=True, device=device, C1=1.)

validate(mtlr, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    3/350]:   0%|          | 1/350 [00:00<00:37,  9.19it/s, loss = 14.7189]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  350/350]: 100%|██████████| 350/350 [00:12<00:00, 27.19it/s, loss = 1.6799]


(0.7783771106941839, 0.758510906569159, 0.764453477868112, 0.6833411130128105)

In [43]:
## TEST MODEL ##
validate(mtlr, X_test, df_test["time"], df_test["event"], df_test["cancer_death"])

(0.7442810457516339,
 0.7232062793324715,
 0.7470975647112881,
 0.6786743839104441)

## N-MTLR

In [36]:
mtlr1 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    
    MTLR(512, 29, num_events=2)
)

mtlr1 = train_mtlr(X_train, y_train, mtlr1, time_bins, num_epochs=350, 
                   lr=1e-3, batch_size=128, verbose=True, device=device, C1=1.)

validate(mtlr1, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    2/350]:   1%|          | 2/350 [00:00<00:32, 10.82it/s, loss = 31.4964]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  350/350]: 100%|██████████| 350/350 [00:44<00:00,  7.93it/s, loss = 1.8870]


(0.7621013133208255,
 0.7870066826377506,
 0.7438572719060524,
 0.7153933924588973)

In [37]:
## TEST MODEL ##
validate(mtlr1, X_test, df_test["time"], df_test["event"], df_test["cancer_death"])

(0.7676939603001695,
 0.7551910468271437,
 0.7558891101240619,
 0.7256116027862839)

In [38]:
mtlr2 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    
    MTLR(512, 29, num_events=2)
)

mtlr2 = train_mtlr(X_train, y_train, mtlr2, time_bins, num_epochs=350, 
                   lr=1e-3, batch_size=128, verbose=True, device=device, C1=1.)

validate(mtlr2, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    1/350]:   0%|          | 1/350 [00:00<00:58,  5.94it/s, loss = 42.5312]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  350/350]: 100%|██████████| 350/350 [01:05<00:00,  5.33it/s, loss = 3.3332]


(0.7000938086303939, 0.7291640398436515, 0.6863143631436315, 0.675509569005757)

In [39]:
## TEST MODEL ##
validate(mtlr2, X_test, df_test["time"], df_test["event"], df_test["cancer_death"])

(0.7548338779956427,
 0.6825163912471323,
 0.7146883136774391,
 0.6573005350472794)

In [44]:
mtlr3 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    
    MTLR(512, 29, num_events=2)
)

mtlr3 = train_mtlr(X_train, y_train, mtlr3, time_bins, num_epochs=450, 
                   lr=1e-3, batch_size=128, verbose=True, device=device, C1=1.)

validate(mtlr3, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

  0%|          | 0/450 [00:00<?, ?it/s]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  450/450]: 100%|██████████| 450/450 [02:02<00:00,  3.67it/s, loss = 1.4988]


(0.6676829268292683, 0.689509519606607, 0.6225383920505871, 0.6697526061926249)

In [45]:
## TEST MODEL ##
validate(mtlr3, X_test, df_test["time"], df_test["event"], df_test["cancer_death"])

(0.723160251755023, 0.6591552248192013, 0.6943176596722316, 0.6495384235734877)

In [36]:
mtlr2 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    #nn.Dropout(0.333),
    
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    #nn.Dropout(0.333),
    
    MTLR(512, 29, num_events=2)
)

mtlr2 = train_mtlr(X_train, y_train, mtlr2, time_bins, num_epochs=350, 
                   lr=1e-4, batch_size=128, verbose=True, device=device)

validate(mtlr2, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

  0%|          | 0/350 [00:00<?, ?it/s]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  350/350]: 100%|██████████| 350/350 [01:34<00:00,  3.72it/s, loss = 2.1508]


(0.6674953095684804,
 0.7144118017904426,
 0.6264227642276423,
 0.6774804211399824)

In [37]:
## TEST MODEL ##
validate(mtlr2, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7535781287823772, 0.6954197673199874, 0.7233420125593506, 0.679403483976624)

# with C1=10 (late oct)

## fit mellow MTLR

In [15]:
device = "cpu"

In [38]:
# fit MTLR model 
mtlr = MTLR(in_features=24, num_time_bins=29, num_events=2)            
mtlr = train_mtlr(X_train, y_train, mtlr, time_bins, num_epochs=500, 
                  lr=1e-4, batch_size=128, verbose=True, device=device, C1=10.)

validate(mtlr, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    3/500]:   0%|          | 2/500 [00:00<00:28, 17.40it/s, loss = 168.5258]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  127/500]:  25%|██▌       | 127/500 [00:07<00:21, 17.04it/s, loss = 35.5453]


KeyboardInterrupt: 

In [None]:
## TEST MODEL ##
validate(mtlr, X_test, dff["time"], dff["event"], dff["cancer_death"])

## MTLR + hiddenlayers
fc > BN > ReLU > Dropout

In [70]:
mtlr1 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.333),
    
    MTLR(512, 29, num_events=2)
)

mtlr1 = train_mtlr(X_train, y_train, mtlr1, time_bins, num_epochs=500, 
                   lr=1e-4, batch_size=128, verbose=True, device=device, C1=10.)

validate(mtlr1, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    1/500]:   0%|          | 1/500 [00:00<01:05,  7.59it/s, loss = 486.6324]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  500/500]: 100%|██████████| 500/500 [00:49<00:00, 10.03it/s, loss = 1.7968] 


(0.7807223264540337,
 0.7835392762577228,
 0.7788166214995483,
 0.7217727296302059)

In [71]:
## TEST MODEL ##
validate(mtlr1, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7760000605180343, 0.7641989542209645, 0.7772093735640986, 0.736060167580846)

In [72]:
mtlr2 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.333),
    
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.333),
    
    MTLR(512, 29, num_events=2)
)

mtlr2 = train_mtlr(X_train, y_train, mtlr2, time_bins, num_epochs=500, 
                   lr=1e-4, batch_size=128, verbose=True, device=device, C1=10.)

validate(mtlr2, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    1/500]:   0%|          | 1/500 [00:00<01:32,  5.41it/s, loss = 492.0235]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  500/500]: 100%|██████████| 500/500 [01:19<00:00,  6.30it/s, loss = 1.0344] 


(0.7670262664165103,
 0.7946349766738116,
 0.7597560975609756,
 0.7412219283232198)

In [73]:
## TEST MODEL ##
validate(mtlr2, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7637678528201405,
 0.7779815765875928,
 0.7631030785725226,
 0.7471705308970175)

In [76]:
mtlr3 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.333),
    
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.333),
    
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.333),
    
    MTLR(512, 29, num_events=2)
)

mtlr3 = train_mtlr(X_train, y_train, mtlr3, time_bins, num_epochs=500, 
                   lr=1e-4, batch_size=128, verbose=True, device=device, C1=10.)

validate(mtlr3, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

  0%|          | 0/500 [00:00<?, ?it/s]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  500/500]: 100%|██████████| 500/500 [01:52<00:00,  4.45it/s, loss = 1.4053] 


(0.7374765478424015,
 0.7832871012482663,
 0.7228093947606142,
 0.7241584980032156)

In [77]:
## TEST MODEL ##
validate(mtlr3, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7597887920600338,
 0.7682691075941258,
 0.7534997702557819,
 0.7387410124395688)

# PReLU test

In [92]:
mtlr1 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.PReLU(),
    nn.Dropout(0.333),
    
    MTLR(512, 29, num_events=2)
)

mtlr1 = train_mtlr(X_train, y_train, mtlr1, time_bins, num_epochs=500, 
                   lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr1, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    2/500]:   0%|          | 2/500 [00:00<00:38, 13.10it/s, loss = 52.7896]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  500/500]: 100%|██████████| 500/500 [00:46<00:00, 10.73it/s, loss = 2.9112]


(0.7800656660412758, 0.783097969991174, 0.7724932249322494, 0.7295005445775634)

In [93]:
validate(mtlr1, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7817568385378842,
 0.7744878827050508,
 0.7718027262980548,
 0.7380287377595316)

In [94]:
mtlr1 = nn.Sequential(
    nn.Linear(24, 512),
    nn.BatchNorm1d(512),
    nn.PReLU(),
    nn.Dropout(0.333),
    
    MTLR(512, 29, num_events=2)
)

mtlr1 = train_mtlr(X_train, y_train, mtlr1, time_bins, num_epochs=500, 
                   lr=.0001, batch_size=128, verbose=True, device=device, C1=10.)

validate(mtlr1, X_val, df_val["time"], df_val["event"], df_val["cancer_death"])

[epoch    1/500]:   0%|          | 1/500 [00:00<00:53,  9.36it/s, loss = 496.2234]

torch.Size([1545, 24]) torch.Size([1545, 58])


[epoch  500/500]: 100%|██████████| 500/500 [00:47<00:00, 10.56it/s, loss = 2.0669] 


(0.7844746716697937, 0.7841697137813644, 0.783423667570009, 0.7311083450028526)

In [95]:
validate(mtlr1, X_test, dff["time"], dff["event"], dff["cancer_death"])

(0.7778836843379326,
 0.7745985163446643,
 0.7824169091744524,
 0.7447869345268141)

It's cool and works well >> but i'll pass for the sake of keeping this study simpler

# some other trial runs

In [87]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    6/1000]:   0%|          | 3/1000 [00:00<00:33, 29.65it/s, loss = 430.5145]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [00:33<00:00, 30.16it/s, loss = 1.8031]


(0.761211845951391, 0.7718318809005082, 0.7206873731085461, 0.776948996783581)

In [118]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    3/1000]:   0%|          | 2/1000 [00:00<01:05, 15.33it/s, loss = 464.4609]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [00:30<00:00, 32.27it/s, loss = 1.7256]


(0.7649035158206104,
 0.7766279351246671,
 0.7271539298493567,
 0.7790779598713432)

In [119]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    6/1000]:   0%|          | 4/1000 [00:00<00:32, 30.88it/s, loss = 422.3957]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [00:31<00:00, 31.29it/s, loss = 2.3196]


(0.7649268071131609,
 0.7703038005325586,
 0.7253087458357169,
 0.7750497779139225)

In [120]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    nn.Linear(532, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    4/1000]:   0%|          | 2/1000 [00:00<00:52, 19.07it/s, loss = 453.1712]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [01:01<00:00, 16.14it/s, loss = 1.6403]


(0.7330934330200654,
 0.7645848462841927,
 0.7245628203833945,
 0.7376474192066166)

In [121]:
mtlr_test = nn.Sequential(
    nn.Linear(20, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    nn.Linear(532, 532),
    nn.ReLU(),
    #nn.Dropout(0.4),
    MTLR(532, 29, num_events=2)
)
mtlr_test = train_mtlr(X_train, y_train, mtlr_test, time_bins, num_epochs=1000, 
                  lr=.0001, batch_size=128, verbose=True, device=device)

validate(mtlr_test, X_test, dff["time"], dff["event"], dff["cancer_death"])

[epoch    4/1000]:   0%|          | 2/1000 [00:00<00:50, 19.87it/s, loss = 448.1322]

torch.Size([1351, 20]) torch.Size([1351, 58])


[epoch  1000/1000]: 100%|██████████| 1000/1000 [01:04<00:00, 15.58it/s, loss = 1.8524]


(0.7412453854126634, 0.771536855482934, 0.7329867303787956, 0.7451064481543881)