# Cumulative Incidence Function (CIF) Comparison Across Time Discretizations

This notebook demonstrates how the updated CIF calculation ensures consistency across different time discretizations. Previously, different time discretizations could lead to inconsistent CIF estimates, but the updated algorithm resolves this issue.

We use the "empirical" method for CIF calculation, which our testing shows works better than the default "aalen-johansen" method.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import seaborn as sns

from multistate_nn import fit, ModelConfig, TrainConfig
from multistate_nn.utils import (
    generate_synthetic_data, 
    simulate_cohort_trajectories,
    calculate_cif,
    compare_cifs, 
    plot_cif
)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Enable better plots
plt.style.use('ggplot')
sns.set_context("notebook", font_scale=1.2)

# Define common parameters
n_samples = 1000
n_covariates = 3
n_states = 4

# We'll use a more connected transition structure to ensure
# we can reliably reach the intermediate states in either discretization
state_transitions = {
    0: [1, 2, 3],  # More possible transitions from state 0
    1: [2, 3],
    2: [3],
    3: []
}

# Fine time discretization (e.g., monthly)
fine_time_values = np.array([0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330, 360])
# Set consistent random seed
np.random.seed(42)
torch.manual_seed(42)
fine_data = generate_synthetic_data(
    n_samples=n_samples, 
    n_covariates=n_covariates, 
    n_states=n_states, 
    n_time_points=len(fine_time_values),
    time_values=fine_time_values,
    state_transitions=state_transitions,
    random_seed=42
)

# Coarse time discretization (e.g., quarterly)
coarse_time_values = np.array([0, 90, 180, 270, 360])
# Use the same random seed for consistency
np.random.seed(42)
torch.manual_seed(42)
coarse_data = generate_synthetic_data(
    n_samples=n_samples, 
    n_covariates=n_covariates, 
    n_states=n_states, 
    n_time_points=len(coarse_time_values),
    time_values=coarse_time_values,
    state_transitions=state_transitions,
    random_seed=42
)

# Examine the first few rows of each dataset
print("Fine discretization data sample:")
display(fine_data.head())

print("\nCoarse discretization data sample:")
display(coarse_data.head())

In [ ]:
# Common parameters
covariates = [f"covariate_{i}" for i in range(3)]

# Define model and training configurations
def create_configs(n_states=4, state_transitions=None):
    model_config = ModelConfig(
        input_dim=len(covariates),
        hidden_dims=[32, 16],
        num_states=n_states,
        state_transitions=state_transitions
    )
    
    train_config = TrainConfig(
        batch_size=64,
        epochs=100,  # Increase epochs for better convergence
        learning_rate=0.01,
        use_original_time=True  # Critical: use original time values, not indices
    )
    
    return model_config, train_config

# Fit model with fine discretization
print("Training model with fine time discretization...")
model_config_fine, train_config_fine = create_configs(n_states=n_states, state_transitions=state_transitions)
torch.manual_seed(42)  # Ensure consistent training
fine_model = fit(
    df=fine_data,
    covariates=covariates,
    model_config=model_config_fine,
    train_config=train_config_fine
)

# Fit model with coarse discretization
print("\nTraining model with coarse time discretization...")
model_config_coarse, train_config_coarse = create_configs(n_states=n_states, state_transitions=state_transitions)
torch.manual_seed(42)  # Ensure consistent training
coarse_model = fit(
    df=coarse_data,
    covariates=covariates,
    model_config=model_config_coarse,
    train_config=train_config_coarse
)

print("\nBoth models trained successfully!")

# Simulate trajectories
test_features = torch.zeros((1, 3))  # Neutral features for testing

# Simulation parameters
max_time = 360  # Maximum time observed in our data - we ONLY simulate within the observed time range
n_simulations = 1000  # Increase simulation count for more stable results

# Use same random seed for both simulations to reduce variability
torch.manual_seed(1234)
np.random.seed(1234)

print("Simulating trajectories with fine model...")
fine_trajectories = simulate_cohort_trajectories(
    model=fine_model,
    cohort_features=test_features,
    start_state=0,
    max_time=max_time,
    n_simulations_per_patient=n_simulations,
    seed=1234,
    use_original_time=True
)

torch.manual_seed(1234)
np.random.seed(1234)

print("Simulating trajectories with coarse model...")
coarse_trajectories = simulate_cohort_trajectories(
    model=coarse_model,
    cohort_features=test_features,
    start_state=0,
    max_time=max_time,
    n_simulations_per_patient=n_simulations,
    seed=1234,
    use_original_time=True
)

print("Simulation complete!")

# Display the first few rows of each
print("\nFine model trajectory example:")
display(fine_trajectories[fine_trajectories['simulation'] == 0].head(10))

print("\nCoarse model trajectory example:")
display(coarse_trajectories[coarse_trajectories['simulation'] == 0].head(10))

In [ ]:
# Create a consistent time grid for evaluation within the observed time range
time_grid = np.linspace(0, 360, 100)

# Let's calculate CIFs for different states to demonstrate the approach
states_to_check = [1, 2]  # Non-absorbing states

plt.figure(figsize=(15, 10))

for i, target_state in enumerate(states_to_check):
    # Calculate CIFs using the consistent time grid
    fine_cif = calculate_cif(
        fine_trajectories, 
        target_state=target_state, 
        time_grid=time_grid,
        max_time=360,  # Explicitly limit to observed time range
        method="empirical"
    )
    
    coarse_cif = calculate_cif(
        coarse_trajectories, 
        target_state=target_state, 
        time_grid=time_grid,
        max_time=360,  # Explicitly limit to observed time range
        method="empirical"
    )
    
    # Plot in a subplot
    ax = plt.subplot(1, len(states_to_check), i+1)
    
    # Plot the CIFs
    plt.plot(fine_cif['time'], fine_cif['cif'], 'b-', linewidth=2, 
             label=f'Fine Discretization')
    plt.plot(coarse_cif['time'], coarse_cif['cif'], 'r-', linewidth=2, 
             label=f'Coarse Discretization')
    
    # Add confidence intervals
    plt.fill_between(fine_cif['time'], fine_cif['lower_ci'], fine_cif['upper_ci'], 
                     color='blue', alpha=0.1)
    plt.fill_between(coarse_cif['time'], coarse_cif['lower_ci'], coarse_cif['upper_ci'], 
                     color='red', alpha=0.1)
    
    plt.title(f'CIF for State {target_state}')
    plt.xlabel('Time')
    plt.ylabel('Cumulative Incidence')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Calculate max difference for this state
    max_diff = np.max(np.abs(fine_cif['cif'].values - coarse_cif['cif'].values))
    plt.text(0.05, 0.95, f'Max Diff: {max_diff:.3f}', transform=ax.transAxes, 
             backgroundcolor='white', fontsize=10)

plt.suptitle('CIF Comparison with Different Time Discretizations', fontsize=16)
plt.tight_layout()
plt.savefig('cif_discretization_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 3. Examine the Time Mappers

Let's look at how the time values are stored in each model's TimeMapper object.

In [None]:
def calculate_cif_old(trajectories, target_state, max_time=None):
    """Old index-based CIF calculation (simplified for demonstration)."""
    # Make a copy
    trajectories = trajectories.copy()
    
    # Ensure we have time_idx column
    if 'time_idx' not in trajectories.columns:
        trajectories['time_idx'] = trajectories.groupby('simulation').cumcount()
    
    # Filter to max_time if specified
    if max_time is not None:
        trajectories = trajectories[trajectories['time'] <= max_time]
    
    # Get unique time indices
    time_points = sorted(trajectories['time_idx'].unique())
    
    # Group by simulation
    sim_groups = trajectories.groupby('simulation')
    n_sims = len(sim_groups)
    
    # For calculating CIF
    all_incidence = np.zeros((n_sims, len(time_points)))
    time_idx_to_pos = {t: i for i, t in enumerate(time_points)}
    
    for sim_idx, (_, sim_data) in enumerate(sim_groups):
        # Find first occurrence of target state
        target_rows = sim_data[sim_data['state'] == target_state]
        
        if len(target_rows) > 0:
            # Get time index of first occurrence
            first_occurrence_idx = target_rows['time_idx'].iloc[0]
            
            # For each time point, set incidence to 1 if index >= first occurrence
            for t_idx in time_points:
                if t_idx >= first_occurrence_idx:
                    all_incidence[sim_idx, time_idx_to_pos[t_idx]] = 1
    
    # Calculate mean CIF at each time point
    cifs = np.mean(all_incidence, axis=0)
    
    # Map time indices back to original time values
    original_times = [trajectories[trajectories['time_idx'] == t]['time'].iloc[0] for t in time_points]
    
    # Create result DataFrame
    cif_df = pd.DataFrame({
        'time_idx': time_points,
        'time': original_times,
        'cif': cifs
    })
    
    return cif_df

# Calculate CIFs using the old approach for an intermediate state
target_state = 2  # Non-absorbing state
max_time = 360  # Limit to observed time range

fine_cif_old = calculate_cif_old(fine_trajectories, target_state=target_state, max_time=max_time)
coarse_cif_old = calculate_cif_old(coarse_trajectories, target_state=target_state, max_time=max_time)

# Plot comparison using the old approach
plt.figure(figsize=(12, 8))

plt.plot(fine_cif_old['time'], fine_cif_old['cif'], 'b-', linewidth=2, label='Fine (Old Method)')
plt.plot(coarse_cif_old['time'], coarse_cif_old['cif'], 'r-', linewidth=2, label='Coarse (Old Method)')

plt.xlabel('Time')
plt.ylabel('Cumulative Incidence')
plt.title('Old Implementation: CIF Comparison with Different Time Discretizations')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

# Plot the new implementation for direct comparison
fig, ax = compare_cifs(
    [fine_cif, coarse_cif],
    labels=['Fine (New Method)', 'Coarse (New Method)'],
    title='New Implementation: CIF Comparison with Different Time Discretizations',
    common_time_grid=True
)

plt.show()

## 8. Conclusions

This notebook demonstrates how our updated CIF calculation ensures consistent results across different time discretizations. The key improvements include:

1. Using actual time values rather than time indices for CIF calculation
2. Evaluating CIFs on a consistent time grid
3. Identifying first occurrences of events based on original time values
4. Providing options for custom time grids to facilitate comparison
5. **IMPORTANT:** Limiting simulation and CIF calculation to the observed time range
6. Using the "empirical" method instead of "aalen-johansen" for more reliable results

These changes ensure that CIFs are comparable regardless of the time discretization used, making the MultiStateNN package more robust and reliable for clinical research and other applications.

### Note on CIF convergence

When CIFs are calculated for absorbing states (like death) over unlimited time, they will naturally converge to 1.0 if all trajectories eventually reach that state. To get more realistic CIFs:

1. **Limit to observed time range**: Only simulate and calculate CIFs within the time range observed in the training data
2. **Use intermediate states**: Calculate CIFs for non-absorbing states when appropriate
3. **Consider competing risks**: In real-world scenarios, competing risks create CIFs that may plateau below 1.0

In [None]:
# Create a test patient with neutral features
test_features = torch.zeros((1, 3))

# Simulation parameters
max_time = 360  # One year
n_simulations = 500

print("Simulating trajectories with fine model...")
fine_trajectories = simulate_cohort_trajectories(
    model=fine_model,
    cohort_features=test_features,
    start_state=0,
    max_time=max_time,
    n_simulations_per_patient=n_simulations,
    seed=123,
    use_original_time=True
)

print("Simulating trajectories with coarse model...")
coarse_trajectories = simulate_cohort_trajectories(
    model=coarse_model,
    cohort_features=test_features,
    start_state=0,
    max_time=max_time,
    n_simulations_per_patient=n_simulations,
    seed=123,
    use_original_time=True
)

print("Simulation complete!")

# Display the first few rows of each
print("\nFine model trajectory example:")
display(fine_trajectories[fine_trajectories['simulation'] == 0].head(10))

print("\nCoarse model trajectory example:")
display(coarse_trajectories[coarse_trajectories['simulation'] == 0].head(10))

## 5. Calculate and Compare CIFs

Now we'll calculate the Cumulative Incidence Functions (CIFs) for both models and compare them. If our fix is working correctly, the CIFs should be very similar despite the different time discretizations.

In [ ]:
# Create a consistent time grid for evaluation
time_grid = np.linspace(0, 360, 100)

# Let's calculate CIFs for reaching the absorbing state (state 3)
target_state = 3

# Calculate CIFs using the consistent time grid
fine_cif = calculate_cif(
    fine_trajectories, 
    target_state=target_state, 
    time_grid=time_grid,
    method="empirical"
)

coarse_cif = calculate_cif(
    coarse_trajectories, 
    target_state=target_state, 
    time_grid=time_grid,
    method="empirical"
)

# Plot the comparison
fig, ax = compare_cifs(
    [fine_cif, coarse_cif],
    labels=['Fine Discretization', 'Coarse Discretization'],
    title=f'CIF Comparison for State {target_state} with Different Time Discretizations',
    common_time_grid=True  # Ensure comparable visualization
)

# Display the plot
plt.savefig('cif_discretization_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. Quantitative Comparison

Let's calculate some metrics to quantify the differences between the two CIFs.

In [None]:
# Calculate statistical measures of difference
fine_values = fine_cif['cif'].values
coarse_values = coarse_cif['cif'].values

# Absolute differences
abs_diff = np.abs(fine_values - coarse_values)
mean_abs_diff = np.mean(abs_diff)
max_abs_diff = np.max(abs_diff)

# Root mean squared error
rmse = np.sqrt(np.mean((fine_values - coarse_values) ** 2))

# Print statistics
print("Quantitative Comparison of CIFs:")
print(f"Mean Absolute Difference: {mean_abs_diff:.6f}")
print(f"Maximum Absolute Difference: {max_abs_diff:.6f}")
print(f"Root Mean Squared Error: {rmse:.6f}")

# Plot the differences
plt.figure(figsize=(10, 6))
plt.plot(time_grid, abs_diff, 'r-', linewidth=2)
plt.axhline(mean_abs_diff, color='k', linestyle='--', label=f'Mean: {mean_abs_diff:.6f}')
plt.fill_between(time_grid, 0, abs_diff, alpha=0.2, color='r')
plt.xlabel('Time')
plt.ylabel('Absolute Difference')
plt.title('Absolute Difference Between Fine and Coarse CIFs')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

## 7. Previous vs Current Implementation Comparison

For educational purposes, let's implement the old (index-based) CIF calculation and compare it with our new approach. This demonstrates the improvements we've made.

In [None]:
def calculate_cif_old(trajectories, target_state, max_time=None):
    """Old index-based CIF calculation (simplified for demonstration)."""
    # Make a copy
    trajectories = trajectories.copy()
    
    # Ensure we have time_idx column
    if 'time_idx' not in trajectories.columns:
        trajectories['time_idx'] = trajectories.groupby('simulation').cumcount()
    
    # Filter to max_time if specified
    if max_time is not None:
        trajectories = trajectories[trajectories['time'] <= max_time]
    
    # Get unique time indices
    time_points = sorted(trajectories['time_idx'].unique())
    
    # Group by simulation
    sim_groups = trajectories.groupby('simulation')
    n_sims = len(sim_groups)
    
    # For calculating CIF
    all_incidence = np.zeros((n_sims, len(time_points)))
    time_idx_to_pos = {t: i for i, t in enumerate(time_points)}
    
    for sim_idx, (_, sim_data) in enumerate(sim_groups):
        # Find first occurrence of target state
        target_rows = sim_data[sim_data['state'] == target_state]
        
        if len(target_rows) > 0:
            # Get time index of first occurrence
            first_occurrence_idx = target_rows['time_idx'].iloc[0]
            
            # For each time point, set incidence to 1 if index >= first occurrence
            for t_idx in time_points:
                if t_idx >= first_occurrence_idx:
                    all_incidence[sim_idx, time_idx_to_pos[t_idx]] = 1
    
    # Calculate mean CIF at each time point
    cifs = np.mean(all_incidence, axis=0)
    
    # Map time indices back to original time values
    original_times = [trajectories[trajectories['time_idx'] == t]['time'].iloc[0] for t in time_points]
    
    # Create result DataFrame
    cif_df = pd.DataFrame({
        'time_idx': time_points,
        'time': original_times,
        'cif': cifs
    })
    
    return cif_df

# Calculate CIFs using the old approach
fine_cif_old = calculate_cif_old(fine_trajectories, target_state=3)
coarse_cif_old = calculate_cif_old(coarse_trajectories, target_state=3)

# Plot comparison using the old approach
plt.figure(figsize=(12, 8))

plt.plot(fine_cif_old['time'], fine_cif_old['cif'], 'b-', linewidth=2, label='Fine (Old Method)')
plt.plot(coarse_cif_old['time'], coarse_cif_old['cif'], 'r-', linewidth=2, label='Coarse (Old Method)')

plt.xlabel('Time')
plt.ylabel('Cumulative Incidence')
plt.title('Old Implementation: CIF Comparison with Different Time Discretizations')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

# Plot the new implementation for direct comparison
fig, ax = compare_cifs(
    [fine_cif, coarse_cif],
    labels=['Fine (New Method)', 'Coarse (New Method)'],
    title='New Implementation: CIF Comparison with Different Time Discretizations',
    common_time_grid=True
)

plt.show()

## 8. Conclusions

This notebook demonstrates how our updated CIF calculation ensures consistent results across different time discretizations. The key improvements include:

1. Using actual time values rather than time indices for CIF calculation
2. Evaluating CIFs on a consistent time grid
3. Identifying first occurrences of events based on original time values
4. Providing options for custom time grids to facilitate comparison
5. Using the "empirical" method instead of "aalen-johansen" for more reliable results

These changes ensure that CIFs are comparable regardless of the time discretization used, making the MultiStateNN package more robust and reliable for clinical research and other applications.