# Multistate Model for AIDS SI Data: Real-Data Example

This notebook demonstrates the use of MultiStateNN with real AIDSSI dataset from pymsm. The dataset contains information about patients transitioning between HIV states and developing AIDS.

## Setup and Data Loading

First, let's import the necessary packages and load the data.

In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns

# Import from pymsm for the aidssi dataset
from pymsm.datasets import load_aidssi, prep_aidssi

# Import MultiStateNN package
from multistate_nn import (
    fit,
    plot_transition_heatmap, 
    plot_transition_graph, 
    simulate_patient_trajectory,
    simulate_cohort_trajectories,
    calculate_cif,
    plot_cif,
    compare_cifs
)

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

In [None]:
# Load and prep the AIDSSI dataset
data = load_aidssi()
competing_risk_dataset, covariate_cols, state_labels = prep_aidssi(data)

# Display the dataset
competing_risk_dataset.head()

## Understanding the Dataset

Let's explore the dataset structure and the state transition definitions.

In [None]:
# Print information about the dataset
print(f"Number of samples: {len(competing_risk_dataset)}")
print(f"Covariates: {covariate_cols}")
print(f"State labels: {state_labels}")

In [None]:
# Let's examine the distribution of transitions
transition_counts = competing_risk_dataset.groupby(['from_state', 'to_state']).size().reset_index(name='count')
transition_pivot = transition_counts.pivot(index='from_state', columns='to_state', values='count').fillna(0).astype(int)

# Display the transition count matrix
print("Transition Count Matrix:")
display(transition_pivot)

# Plot the transition counts as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(transition_pivot, annot=True, fmt='d', cmap='YlGnBu')
plt.title('Transition Counts')
plt.xlabel('To State')
plt.ylabel('From State')
plt.tight_layout()
plt.show()

In [None]:
# Calculate the transition probabilities
transition_probs = transition_pivot.div(transition_pivot.sum(axis=1), axis=0)

# Plot the transition probabilities as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(transition_probs, annot=True, fmt='.2f', cmap='YlGnBu')
plt.title('Empirical Transition Probabilities')
plt.xlabel('To State')
plt.ylabel('From State')
plt.tight_layout()
plt.show()

## Preparing Data for the Multistate Model

Based on the data exploration, we need to define the state transition structure for our model.

In [None]:
# Get unique states
unique_from_states = competing_risk_dataset['from_state'].unique()
unique_to_states = competing_risk_dataset['to_state'].unique()
all_states = np.unique(np.concatenate([unique_from_states, unique_to_states]))

# Define state transition structure based on observed transitions
state_transitions = {}
for state in all_states:
    transitions_from_state = competing_risk_dataset[competing_risk_dataset['from_state'] == state]['to_state'].unique()
    state_transitions[state] = list(transitions_from_state)

# Display the state transition structure
for from_state, to_states in state_transitions.items():
    if to_states.size > 0:
        print(f"From state {from_state} ({state_labels[from_state]}): Can transition to states {to_states} ({[state_labels[s] for s in to_states]})")
    else:
        print(f"State {from_state} ({state_labels[from_state]}): Absorbing state (no outgoing transitions)")

In [None]:
# Convert the state_transitions dictionary to the format expected by MultiStateNN
model_state_transitions = {}
for state, transitions in state_transitions.items():
    model_state_transitions[int(state)] = [int(s) for s in transitions]

print("Model state transitions:")
model_state_transitions

In [None]:
# Rename the columns to match MultiStateNN's expectations
df = competing_risk_dataset.copy()

# Ensure all covariates are numeric
for col in covariate_cols:
    if df[col].dtype == object:  # If column is categorical
        df[col] = df[col].astype('category').cat.codes

# Scale numerical covariates
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
df[covariate_cols] = scaler.fit_transform(df[covariate_cols])

df.head()

## Training the MultiState Neural Network Model

In [None]:
# Define model parameters
input_dim = len(covariate_cols)
hidden_dims = [64, 32]  # Two hidden layers
num_states = int(max(all_states)) + 1

# Fit the model
model = fit(
    df=df,
    covariates=covariate_cols,
    input_dim=input_dim,
    hidden_dims=hidden_dims,
    num_states=num_states,
    state_transitions=model_state_transitions,
    epochs=50,
    batch_size=64,
    learning_rate=0.001
)

## Making Predictions for Sample Patients

Now let's make predictions for a few sample patients and analyze the results.

In [None]:
# Create test cases with different profiles
# Let's select a few patients with different covariate profiles
# First, let's get the mean and std of each covariate to generate realistic test cases
covariate_stats = competing_risk_dataset[covariate_cols].describe()
display(covariate_stats)

In [None]:
# Get a random sample of real patients for testing
sample_patients = df.sample(5)
sample_patients_covs = sample_patients[covariate_cols].values

# Convert to tensor
x_test = torch.tensor(sample_patients_covs, dtype=torch.float32)

# Display the sample patients
display(sample_patients)

In [None]:
# Function to print predicted probabilities with labels
def print_transition_probs(x, model, time_idx, from_state):
    probs = model.predict_proba(x, time_idx, from_state).detach().cpu().numpy()
    next_states = model.state_transitions[from_state]
    
    # Create DataFrame for display
    result_df = pd.DataFrame(probs, columns=[state_labels[s] for s in next_states])
    print(f"Transition probabilities from {state_labels[from_state]} state (t={time_idx}):")
    return result_df

# Let's make predictions from state 1 (HIV) for our sample patients
try:
    from_state = 1  # HIV
    probs_hiv = print_transition_probs(x_test, model, time_idx=0, from_state=from_state)
    display(probs_hiv)
except Exception as e:
    print(f"Error predicting from state {from_state}: {e}")
    # Try another state if the first one fails
    available_states = [s for s in model_state_transitions.keys() if model_state_transitions[s]]
    if available_states:
        from_state = available_states[0]
        print(f"Trying with state {from_state} instead...")
        probs = print_transition_probs(x_test, model, time_idx=0, from_state=from_state)
        display(probs)

## Visualizing Transitions with Heatmap

Let's use the built-in visualization tools to explore the predicted transitions.

In [None]:
# Visualize transition heatmap for sample patients
plt.figure(figsize=(12, 8))
try:
    ax = plot_transition_heatmap(model, x_test, time_idx=0, from_state=from_state)
    
    # Update x-axis labels with state names
    next_states = model.state_transitions[from_state]
    ax.set_xticklabels([state_labels[s] for s in next_states])
    
    # Update y-axis labels with patient identifiers
    ax.set_yticklabels([f"Patient {i+1}" for i in range(len(x_test))])
    
    plt.title(f"Transition Probabilities from {state_labels[from_state]} State")
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error plotting transition heatmap: {e}")

## Transition Network Visualization

Let's visualize the transition network for a specific patient.

In [None]:
# Select first patient for visualization
patient1 = x_test[0:1]

try:
    fig, ax = plot_transition_graph(model, patient1, time_idx=0, threshold=0.01)
    plt.title(f"AIDS Patient Transition Network (Patient 1)")
    plt.show()
except Exception as e:
    print(f"Error plotting transition graph: {e}")

## Simulating Patient Trajectories

Now, let's simulate patient trajectories to better understand the disease progression over time.

In [None]:
# Simulate trajectories for the first patient
try:
    # Assume patients start in state 1 (HIV)
    start_state = 1
    max_time = 10
    n_simulations = 20
    
    trajectories = simulate_patient_trajectory(
        model=model, 
        x=patient1, 
        start_state=start_state, 
        max_time=max_time, 
        n_simulations=n_simulations,
        seed=42
    )
    
    # Combine all trajectories
    all_trajectories = pd.concat(trajectories, ignore_index=True)
    
    # Convert states to labels
    all_trajectories['state_label'] = all_trajectories['state'].apply(lambda s: state_labels[s])
    
    # Plot a few sample trajectories
    plt.figure(figsize=(14, 8))
    
    # Select 5 random simulations to plot
    sample_sims = np.random.choice(n_simulations, min(5, n_simulations), replace=False)
    
    for sim_idx in sample_sims:
        sim_data = all_trajectories[all_trajectories['simulation'] == sim_idx]
        plt.step(sim_data['time'], sim_data['state'], where='post', label=f'Simulation {sim_idx+1}')
    
    plt.xlabel('Time')
    plt.ylabel('State')
    plt.yticks(list(state_labels.keys()), [state_labels[s] for s in state_labels.keys()])
    plt.title('Sample Simulated Patient Trajectories')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error simulating patient trajectories: {e}")

## Calculating Cumulative Incidence Functions (CIFs)

Let's calculate the cumulative incidence functions for reaching various states.

In [None]:
try:
    # Calculate CIF for reaching AIDS (state 2) from simulated trajectories
    target_state = 2  # AIDS
    
    # Ensure we have all_trajectories available
    if 'all_trajectories' not in locals():
        start_state = 1
        max_time = 10
        n_simulations = 50
        
        trajectories = simulate_patient_trajectory(
            model=model, 
            x=patient1, 
            start_state=start_state, 
            max_time=max_time, 
            n_simulations=n_simulations,
            seed=42
        )
        
        all_trajectories = pd.concat(trajectories, ignore_index=True)
    
    # Calculate CIF for reaching AIDS
    cif_aids = calculate_cif(
        trajectories=all_trajectories,
        target_state=target_state,
        max_time=max_time
    )
    
    # Plot CIF
    plt.figure(figsize=(10, 6))
    ax = plot_cif(
        cif_df=cif_aids,
        label=f"Progression to {state_labels[target_state]}"
    )
    plt.title(f"Cumulative Incidence of {state_labels[target_state]}")
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error calculating CIF: {e}")

## Comparing Different Patients

Let's compare CIFs for different patients.

In [None]:
try:
    # Compare CIFs for multiple patients
    n_patients = min(3, len(x_test))
    all_patient_trajectories = []
    
    for i in range(n_patients):
        patient_x = x_test[i:i+1]
        
        patient_trajectories = simulate_patient_trajectory(
            model=model, 
            x=patient_x, 
            start_state=start_state, 
            max_time=max_time, 
            n_simulations=n_simulations,
            seed=42 + i
        )
        
        # Combine all trajectories for this patient
        patient_all_trajectories = pd.concat(patient_trajectories, ignore_index=True)
        patient_all_trajectories['patient_id'] = i
        
        all_patient_trajectories.append(patient_all_trajectories)
    
    combined_trajectories = pd.concat(all_patient_trajectories, ignore_index=True)
    
    # Calculate CIFs for each patient
    patient_cifs = []
    
    for i in range(n_patients):
        patient_data = combined_trajectories[combined_trajectories['patient_id'] == i]
        
        patient_cif = calculate_cif(
            trajectories=patient_data,
            target_state=target_state,
            max_time=max_time
        )
        
        patient_cifs.append(patient_cif)
    
    # Compare CIFs
    fig, ax = compare_cifs(
        cif_list=patient_cifs,
        labels=[f"Patient {i+1}" for i in range(n_patients)],
        title=f"Comparison of Cumulative Incidence of {state_labels[target_state]}"
    )
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error comparing patient CIFs: {e}")

## Simulating Cohort Trajectories

Let's simulate trajectories for a cohort of patients.

In [None]:
try:
    # Simulate trajectories for the entire cohort
    cohort_trajectories = simulate_cohort_trajectories(
        model=model,
        cohort_features=x_test,
        start_state=start_state,
        max_time=max_time,
        n_simulations_per_patient=10,
        seed=42
    )
    
    # Calculate cohort-level CIF
    cohort_cif = calculate_cif(
        trajectories=cohort_trajectories,
        target_state=target_state,
        max_time=max_time
    )
    
    # Calculate individual patient CIFs
    individual_cifs = calculate_cif(
        trajectories=cohort_trajectories,
        target_state=target_state,
        max_time=max_time,
        by_patient=True
    )
    
    # Plot cohort CIF and individual patient CIFs
    plt.figure(figsize=(12, 8))
    
    # Plot individual patient CIFs
    patient_groups = individual_cifs.groupby('patient_id')
    for patient_id, data in patient_groups:
        plt.plot(data['time'], data['cif'], 'b-', alpha=0.3, linewidth=1)
    
    # Plot cohort average CIF with confidence intervals
    plt.plot(cohort_cif['time'], cohort_cif['cif'], 'r-', linewidth=2, label='Cohort Average')
    plt.fill_between(
        cohort_cif['time'],
        cohort_cif['lower_ci'],
        cohort_cif['upper_ci'],
        color='r',
        alpha=0.2
    )
    
    plt.xlabel('Time')
    plt.ylabel('Cumulative Incidence')
    plt.title(f"Cumulative Incidence of {state_labels[target_state]} - Individual and Cohort Average")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error simulating cohort trajectories: {e}")

## Conclusion

In this notebook, we've demonstrated how to use MultiStateNN with a real-world AIDSSI dataset to:

1. Prepare and explore real multistate data
2. Train a neural network model for predicting state transitions
3. Visualize transition probabilities and state networks
4. Simulate patient trajectories over time
5. Calculate and visualize cumulative incidence functions
6. Compare risks between different patients

The MultiStateNN framework provides a flexible approach for modeling complex multistate processes using neural networks, making it suitable for a wide range of applications in healthcare, finance, and other domains.