In [None]:
#!pip uninstall -y ukko

In [None]:
pip install -e .

In [None]:
import ukko 
import importlib

importlib.reload(ukko.core)
importlib.reload(ukko.tests_core)

ukko.tests_core.test_ClassificationHead_new()


# torchsurv

https://github.com/Novartis/torchsurv

```sh
pip install torchsurv
```


In [None]:
#pip install torchsurv

https://github.com/Novartis/torchsurv/blob/main/docs/notebooks/helpers_introduction.py

In [None]:
class Custom_dataset(Dataset):
    """ "Custom dataset for the GSBG2 brain cancer dataset"""

    # defining values in the constructor
    def __init__(self, df: pd.DataFrame):
        self.df = df

    # Getting data size/length
    def __len__(self):
        return len(self.df)

    # Getting the data samples
    def __getitem__(self, idx):
        sample = self.df.iloc[idx]
        # Targets
        event = torch.tensor(sample["cens"]).bool()
        time = torch.tensor(sample["time"]).float()
        # Predictors
        x = torch.tensor(sample.drop(["cens", "time"]).values).float()
        return x, (event, time)

In [None]:
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex

# Dummy data, random:
import torch
_ = torch.manual_seed(52)
n = 64
x = torch.randn((n, 16))
event = torch.randint(low=0, high=2, size=(n,)).bool()
time = torch.randint(low=1, high=100, size=(n,)).float()

In [None]:
print(f"Number of samples: {x.shape}")

In [None]:
#import lifelines
import matplotlib.pyplot as plt
from torchsurv.stats.kaplan_meier import KaplanMeierEstimator

# Create a Kaplan-Meier estimator
km = KaplanMeierEstimator()

# Compute the estimator
km(event, time)

#  Accept *kwarks from matplotlib.pyplot.plot 
km.plot_km()
#plt.autoscale(True)
plt.xlabel("Time (rnd)")

In [None]:
# Cox proportional hazards model
from torch import nn
model_cox = nn.Sequential(nn.Linear(16, 1))
log_hz = model_cox(x)
print(log_hz.shape)
#torch.Size([64, 1])

from torchsurv.loss.cox import neg_partial_log_likelihood
loss = neg_partial_log_likelihood(log_hz, event, time)
print(loss)
#tensor(4.1723, grad_fn=<DivBackward0>)

from torchsurv.metrics.cindex import ConcordanceIndex
with torch.no_grad(): log_hz = model_cox(x)
cindex = ConcordanceIndex()
print(cindex(log_hz, event, time))
#tensor(0.4872)

from torchsurv.metrics.auc import Auc
new_time = torch.tensor(50.)
auc = Auc()
print(auc(log_hz, event, time, new_time=50))
#tensor([0.4737])

## AFT - Weibull

In [None]:
from torch import nn
model_weibull = nn.Sequential(nn.Linear(16, 2))
log_params = model_weibull(x)
print(log_params.shape)
#torch.Size([64, 2])

from torchsurv.loss.weibull import neg_log_likelihood
loss = neg_log_likelihood(log_params, event, time)
print(loss)
#tensor(82931.5078, grad_fn=<DivBackward0>)

# Log hazard and survivla functions:
from torchsurv.loss.weibull import log_hazard
from torchsurv.loss.weibull import survival_function

with torch.no_grad(): log_params = model_weibull(x)
log_hz = log_hazard(log_params, time)
print(log_hz.shape)
#torch.Size([64, 64])

surv = survival_function(log_params, time)
print(surv.shape)
#torch.Size([64, 64])
display(surv)

from torchsurv.metrics.cindex import ConcordanceIndex
cindex = ConcordanceIndex()
print(cindex(log_hz, event, time))
#tensor(0.4062)

from torchsurv.metrics.auc import Auc
new_time = torch.tensor(50.)
log_hz_t = log_hazard(log_params, time=new_time)
auc = Auc()
print(auc(log_hz_t, event, time, new_time=new_time))
#tensor([0.3509])

from torchsurv.metrics.brier_score import BrierScore
brier_score = BrierScore()
bs = brier_score(surv, event, time)
print(brier_score.integral())
#tensor(0.4447)