## Reproductible fitting RMSE computation

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from jmstate import MultiStateJointModel
from jmstate.utils import *

torch.manual_seed(42)

<torch._C.Generator at 0x20bacb56f70>

In [2]:
def log_weibull(t1, t0, lambda_, rho_):
    t = t1 - t0
    lambda_ = torch.as_tensor(lambda_, dtype=torch.float32)
    rho_ = torch.as_tensor(rho_, dtype=torch.float32)
    eps = 1e-8
    t = t + eps
    return torch.log(rho_ / lambda_) + (rho_ - 1) * torch.log(t / lambda_)


def double_slope(t, x, psi):
    x0 = psi[:, [0]]
    a = psi[:, [1]]
    b1 = psi[:, [2]]
    b2 = psi[:, [3]]
    corr = torch.where(t > x0, (b2 - b1) * (t - x0), torch.zeros_like(t))
    return (a + b1 * t + corr).unsqueeze(-1)


def double_slope_grad(t, x, psi):
    x0 = psi[:, [0]]
    b1 = psi[:, [2]]
    b2 = psi[:, [3]]
    return torch.where(t <= x0, b1, b2).unsqueeze(-1)


def link(t, x, psi):
    return torch.cat([double_slope(t, x, psi), double_slope_grad(t, x, psi)], dim=-1)


f = lambda gamma, b: gamma + b

In [3]:
##############################################

lambda_T01 = 6.33
rho_T01 = 1.90
lambda_T02 = 4.24
rho_T02 = 3.16
lambda_T12 = 5.70
rho_T12 = 1.48

##############################################

gamma = torch.tensor([1.45, 2.33, -1.38, 0.17])
Q_inv = torch.tensor([2.25, 1.34, 0.51, 0.77])
R_inv = torch.tensor([1.19])
Q_sqrt = torch.matrix_exp(-torch.diag(Q_inv))
R_sqrt = torch.exp(-R_inv)
alphas= {
    (0, 1): torch.tensor([0.07, 5.16]),
    (0, 2): torch.tensor([-0.12, 4.84]),
    (1, 2): torch.tensor([-0.02, 0.49]),
}
betas = {
    (0, 1): torch.tensor([-1.34]),
    (0, 2): torch.tensor([-0.91]),
    (1, 2): torch.tensor([-0.54]),
}

##############################################

real_params = ModelParams(
    gamma,
    (Q_inv, "diag"),
    (R_inv, "ball"),
    alphas,
    betas
)

In [4]:
log_weibull_T01 = lambda t1, t0: log_weibull(t1, t0, lambda_T01, rho_T01)
log_weibull_T02 = lambda t1, t0: log_weibull(t1, t0, lambda_T02, rho_T02)
log_weibull_T12 = lambda t1, t0: log_weibull(t1, t0, lambda_T12, rho_T12)

surv = {
    (0, 1): (log_weibull_T01, link),
    (0, 2): (log_weibull_T02, link),
    (1, 2): (log_weibull_T12, link),
}

In [5]:
model_design = ModelDesign(
    f,
    double_slope,
    surv
)

real_model = MultiStateJointModel(model_design, real_params)

In [6]:
n, p = 200, 1

def get_data():
    t = torch.linspace(0, 15, 16)
    c = torch.rand(n) * 5 + 10
    x = torch.randn(n, 1) 
    b = torch.randn(n, real_params.gamma.shape[0]) @ Q_sqrt
    psi = f(real_params.gamma, b)

    trajectories_init = [[(0.0, 0)] for _ in range(n)]

    sample_data = SampleData(
        x,
        trajectories_init,
        psi
    )

    trajectories = real_model.sample_trajectories(sample_data, c)

    y = model_design.h(t, x, psi)
    y += torch.randn_like(y) * R_sqrt
    y[t.repeat(n, 1) > c.view(-1, 1)] = torch.nan

    data = ModelData(
        x,
        t,
        y,
        trajectories,
        c
    )

    return data

In [7]:
mse = []

loops = 50

for _ in range(loops):
    seed = int(torch.randint(low=0, high=2**32, size=(1,)))
    torch.manual_seed(seed)

    data = get_data()
    
    init_params = ModelParams(
        torch.tensor([2.0, 2.0, -1.0, 1.0]),
        (torch.zeros_like(Q_inv), "diag"),
        (torch.zeros_like(R_inv), "ball"),
        {k: torch.zeros_like(v) for k, v in alphas.items()},
        {k: torch.zeros_like(v) for k, v in betas.items()},
    )

    model = MultiStateJointModel(model_design, init_params)
    model.fit(data, n_iter=3000)

    mse.append([torch.sum((p - q.detach())**2) for p, q in zip(real_params.as_list, model.params_.as_list)])

Fitting joint model: 100%|██████████| 3000/3000 [01:17<00:00, 38.82it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:20<00:00, 37.32it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:20<00:00, 37.44it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:16<00:00, 39.08it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:12<00:00, 41.31it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:12<00:00, 41.23it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:11<00:00, 42.05it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:11<00:00, 41.74it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:12<00:00, 41.59it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:11<00:00, 42.22it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:12<00:00, 41.52it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:12<00:00, 41.63it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01:12<00:00, 41.49it/s]
Fitting joint model: 100%|██████████| 3000/3000 [01

In [8]:
tensor_mse = torch.vstack([torch.tensor(p).view(1, -1) for p in mse])
RMSE = torch.sqrt(torch.mean(tensor_mse, axis=0))
print(RMSE)

tensor([0.1142, 0.2266, 0.0231, 0.3593, 0.4059, 0.3868, 0.1142, 0.1011, 0.1283])
