In [None]:
import numpy as np
import pandas as pd
import scipy.stats as stats
import scipy.linalg as linalg
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt


# Part 2

## Task 7

In [None]:
Q = np.array([
    [-0.0085,  0.005,  0.0025,  0    , 0.001],
    [ 0     , -0.014,  0.005 ,  0.004, 0.005],
    [ 0     ,  0    , -0.008 ,  0.003, 0.005],
    [ 0     ,  0    ,  0     , -0.009, 0.009],
    [ 0     ,  0    ,  0     ,  0    , 0    ]])

state_dict = {
    'no_tumor': 0,
    'local_recurrence': 1,
    'distant_metastatis': 2,
    'local_and_distant': 3,
    'death': 4}

num_states = len(state_dict)


In [None]:
# when the transition happens
out_rate = -np.diag(Q)

# what new state will be
transition_prob = Q.copy()
non_arbsorbing_mask = out_rate > 0
transition_prob[non_arbsorbing_mask] /= out_rate[non_arbsorbing_mask, None]
np.fill_diagonal(transition_prob, ~non_arbsorbing_mask)


In [None]:
num_trails = 1000

start_states = np.full(num_trails, state_dict['no_tumor'])
lifetimes = np.zeros(num_trails)

is_alive_mask = np.ones(num_trails, dtype=bool)
current_states = start_states.copy()

has_had_cancer_reapear_within_30_5_months = np.zeros(num_trails, dtype=bool)

while np.any(is_alive_mask):
    
    lifetimes[is_alive_mask] += np.random.exponential(1 / out_rate[current_states[is_alive_mask]])
    new_states = np.array([np.random.choice(num_states, p=transition_prob[state]) for state in current_states])

    has_had_cancer_reapear_within_30_5_months |= ((new_states == state_dict['distant_metastatis']) | \
                                                  (new_states == state_dict['local_and_distant'])) & \
                                                 (lifetimes <= 30.5)

    current_states = new_states
    is_alive_mask = current_states != state_dict['death']


In [None]:
print(f"Probability of having cancer reappear within 30.5 months: {has_had_cancer_reapear_within_30_5_months.mean()}")


In [None]:
plt.plot(
    np.sort(lifetimes),
    np.linspace(0, 1, num_trails, endpoint=False)[::-1])
plt.title('Lifetime distribution')
plt.xlabel('Months')
plt.ylabel('Survival probability')
plt.show()


# report the mean lifetime distribution and 95% confidence interval
mean_lifetime = np.mean(lifetimes)

std_lifetime = np.std(lifetimes)
lower, upper = stats.t.interval(0.95, num_trails - 1, loc=mean_lifetime, scale=stats.sem(lifetimes))
print(f'Mean lifetime: {mean_lifetime:.2f} months')

print(f'Standard deviation: {std_lifetime:.2f} months')

# confidence interval mean
print(f'95% confidence interval: ({lower:.2f}, {upper:.2f}) months')

# confidence interval standard deviation
lower, upper = stats.t.interval(0.95, num_trails - 1, loc=std_lifetime, scale=stats.sem(lifetimes))
print(f'95% confidence interval for standard deviation: ({lower:.2f}, {upper:.2f}) months')



## Task 8

In [None]:
p_0 = np.zeros((num_states-1, 1))
p_0[state_dict['no_tumor']] = 1
Q_s = Q[:-1, :-1]

ts = np.linspace(0, lifetimes.max(), 1000)
F_t = lambda t: 1 - (p_0.T @ linalg.expm(Q_s * t)).sum()

F_true = np.array([F_t(t) for t in ts])

plt.plot(ts, F_true, label='Theoretical')
plt.plot(np.sort(lifetimes), np.linspace(0, 1, num_trails, endpoint=False), label='Empirical')
plt.xlabel('Months')
plt.ylabel('Death probability')
plt.title('Theoretical vs empirical lifetime distribution')
plt.legend()
plt.show()

In [None]:
Q2 = np.array([
    [0, 0.0025, 0.00125, 0    , 0.001],
    [0, 0     , 0      , 0.002, 0.005],
    [0, 0     , 0      , 0.003, 0.005],
    [0, 0     , 0      , 0    , 0.009],
    [0, 0     , 0      , 0    , 0    ]])

idxs = np.arange(len(Q2), dtype=int)
Q2[idxs, idxs] = -Q2.sum(axis=1)

Q1 = Q.copy()

def simulate_lifetimes(Q, num_trials):
    out_rate = -np.diag(Q)
    transition_prob = Q.copy()
    non_absorbing_mask = out_rate > 0
    transition_prob[non_absorbing_mask] /= out_rate[non_absorbing_mask, None]
    np.fill_diagonal(transition_prob, ~non_absorbing_mask)

    start_states = np.full(num_trials, state_dict['no_tumor'])
    lifetimes = np.zeros(num_trials)
    is_alive_mask = np.ones(num_trials, dtype=bool)
    current_states = start_states.copy()

    while np.any(is_alive_mask):
        lifetimes[is_alive_mask] += np.random.exponential(1 / out_rate[current_states[is_alive_mask]])
        new_states = np.array([np.random.choice(num_states, p=transition_prob[state]) for state in current_states])
        current_states = new_states
        is_alive_mask = current_states != state_dict['death']

    return lifetimes, current_states

# Calculate the Kaplan-Meier estimate
def kaplan_meier_estimate(lifetimes, events):
    sorted_indices = np.argsort(lifetimes)
    lifetimes = lifetimes[sorted_indices]
    events = events[sorted_indices]

    unique_times, death_counts = np.unique(lifetimes, return_counts=True)
    at_risk_counts = np.array([np.sum(lifetimes >= time) for time in unique_times])
    survival_prob = np.cumprod(1 - death_counts / at_risk_counts)

    return np.concatenate(([0], unique_times)), np.concatenate(([1], survival_prob))

# Simulate lifetimes for both Q1 and Q2
num_trials = 1000
lifetimes1, current_states1 = simulate_lifetimes(Q1, num_trials)
lifetimes2, current_states2 = simulate_lifetimes(Q2, num_trials)

# Calculate the Kaplan-Meier estimates
events1 = (current_states1 == state_dict['death']).astype(int)
events2 = (current_states2 == state_dict['death']).astype(int)
times1, survival_prob1 = kaplan_meier_estimate(lifetimes1, events1)
times2, survival_prob2 = kaplan_meier_estimate(lifetimes2, events2)

# Plot the Kaplan-Meier survival estimates
plt.figure(figsize=(10, 6))
plt.step(times1, survival_prob1, where='post', label='Kaplan-Meier Estimate (Q1)')
plt.step(times2, survival_prob2, where='post', label='Kaplan-Meier Estimate (Q2)', linestyle='--')
plt.xlabel('Time')
plt.ylabel('Survival Probability')
plt.title('Kaplan-Meier Survival Curves')
plt.grid(True)
plt.legend()
plt.show()


# Part 3

## Task 12

In [None]:
def simulate_lifetimes(
    Q: np.ndarray,
    num_trails: int,
    *,
    months_between_doctor_visits: int=48,
    state_dict: dict[str, int]=state_dict,
    return_Q_estimators: bool=False
) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray]:
    
    num_states = len(state_dict)
    unique_states = np.arange(num_states, dtype=int)

    # when the transition happens
    out_rate = -np.diag(Q)

    # what new state will be
    transition_prob = Q.copy()
    non_arbsorbing_mask = out_rate > 0
    transition_prob[non_arbsorbing_mask] /= out_rate[non_arbsorbing_mask, None]
    np.fill_diagonal(transition_prob, ~non_arbsorbing_mask)

    # returns
    state_time_series = np.zeros((1, num_trails), dtype=int)
    sojourn_times = np.zeros(num_states)
    num_jumps = np.zeros((num_states, num_states), dtype=int)

    # init
    start_states = np.full(num_trails, state_dict['no_tumor'])
    lifetimes = np.zeros(num_trails)

    is_alive_mask = np.ones(num_trails, dtype=bool)
    current_states = start_states.copy()

    # simulate
    while np.any(is_alive_mask):

        # get how long we spend in the current state
        sojourn_time = np.random.exponential(1 / out_rate[current_states[is_alive_mask]])

        # update the sojourn
        for state in unique_states:
            indxs = np.where(current_states[is_alive_mask] == state)
            sojourn_times[state] += sojourn_time[indxs].sum()

        # update the life total and the state
        lifetimes[is_alive_mask] += sojourn_time
        new_states = np.array([np.random.choice(num_states, p=transition_prob[state]) for state in current_states[is_alive_mask]])

        # update the number of jumps
        for i in unique_states:
            for j in unique_states:
                num_jumps[i, j] += np.sum((current_states[is_alive_mask] == i) & (new_states == j))

        # update the doctor visit
        visit_idx = np.ceil(lifetimes[is_alive_mask] / months_between_doctor_visits).astype(int)
        while visit_idx.max() >= state_time_series.shape[0]:
            state_time_series = np.concatenate([state_time_series, np.zeros_like(state_time_series)], axis=0)
        state_time_series[visit_idx, is_alive_mask] = new_states

        # ready for the next iteration
        current_states[is_alive_mask] = new_states
        is_alive_mask = current_states != state_dict['death']

    # fill in the gaps
    state_time_series = pd.DataFrame(state_time_series).cummax().to_numpy()

    if return_Q_estimators:
        return state_time_series, sojourn_times, num_jumps

    return state_time_series


def estimate_Q_from_obervations(
    state_time_series: np.ndarray,
    *,
    months_between_doctor_visits: int=48
) -> np.ndarray:
    """
    Estimates the initial Q matrix from the observed state time series
    """
    num_states = state_time_series.max() + 1

    Q = np.zeros((num_states, num_states))
    for i in range(num_states):
        for j in range(num_states):
            if i == j: continue

            transitions = np.sum((state_time_series[:-1] == i) & (state_time_series[1:] == j))
            total_time = np.sum((state_time_series[:-1] == i)) * months_between_doctor_visits
            Q[i, j] = transitions / total_time if total_time > 0 else 0

    np.fill_diagonal(Q, -Q.sum(axis=1))
    return Q


def mcem(
    Q_est,
    state_time_series,
    *,
    months_between_doctor_visits: int=48
):
    """
    Monte Carlo Expectation Maximization

    Q_est: initial estimate of the Q matrix
        shape: (num_states, num_states)
    
    state_time_series: observed state time series
        shape: (num_visits, num_trails)
        num_visits is dynamic
    """

    out_rates = -np.diag(Q_est)
    num_states = Q_est.shape[0]
    sojourn_times = np.zeros(num_states)
    num_jumps = np.zeros((num_states, num_states))

    for trail_idx in range(state_time_series.shape[1]):
        trajectory = state_time_series[:, trail_idx]
        jump_idxs, = np.where(trajectory[:-1] != trajectory[1:])

        lifetime = 0
        for jump_idx in jump_idxs:
            min_time = jump_idx * months_between_doctor_visits
            max_time = (jump_idx + 1) * months_between_doctor_visits
            current_state = trajectory[jump_idx]
            next_state = trajectory[jump_idx + 1]

            out_rate = out_rates[current_state]
            while True:
                sojourn_time = np.random.exponential(1 / out_rate)

                if min_time <= lifetime + sojourn_time <= max_time:
                    lifetime += sojourn_time
                    sojourn_times[current_state] += sojourn_time
                    num_jumps[current_state, next_state] += 1
                    break
    
    Q_est = num_jumps.astype(float)
    Q_est[:-1] /= sojourn_times[:-1, None]
    np.fill_diagonal(Q_est, -Q_est.sum(axis=1))
    
    return Q_est


In [None]:
Q = np.array([
    [-0.0085,  0.005,  0.0025,  0    , 0.001],
    [ 0     , -0.014,  0.005 ,  0.004, 0.005],
    [ 0     ,  0    , -0.008 ,  0.003, 0.005],
    [ 0     ,  0    ,  0     , -0.009, 0.009],
    [ 0     ,  0    ,  0     ,  0    , 0    ]])

state_dict = {
    'no_tumor': 0,
    'local_recurrence': 1,
    'distant_metastatis': 2,
    'local_and_distant': 3,
    'death': 4}


In [None]:
observed_lifetimes = simulate_lifetimes(Q, 1000)
num_highlitghed = 5

plt.plot(observed_lifetimes[:, num_highlitghed:], alpha=0.005, color='k')
plt.plot(observed_lifetimes[:, :num_highlitghed], linewidth=3)
plt.xlabel('Doctor visit')
plt.ylabel('State')
plt.show()


In [None]:
Q_est = estimate_Q_from_obervations(observed_lifetimes)

pbar = tqdm.tqdm()
while True:
    Q_est_new = mcem(Q_est, observed_lifetimes)
    max_diff = np.abs(Q_est - Q_est_new).max()
    Q_est = Q_est_new
    pbar.update(1)
    if max_diff < 1e-4: # inf norm
        pbar.close()
        break

print(
    "Estimated Q matrix:",
    np.array2string(Q_est, precision=4, suppress_small=True),
    "True Q matrix:",
    np.array2string(Q, precision=4, suppress_small=True),
    sep='\n')