In [4]:
import torch
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import List, Dict, Optional
from enum import Enum
import random

action_costs = {}

# 1. Data Structures
class Disease(Enum):
    HEALTHY = 0
    COUGH = 1
    PERSISTENT_COUGH = 2
    NSCLC = 3
    POST_INFECTIOUS = 4
    ACE_COUGH = 5

class Action(Enum):
    WAIT = 0
    XRAY = 1
    STOP_ACE = 2
    CT_SCAN = 3
    BIOPSY = 4

@dataclass
class PatientState:
    disease: Disease
    features: Dict[str, float]
    time: int
    history: List[Disease]
    action_history: List[Action]

# 2. Disease Progression Model (Simplified HSMM)
class DiseaseProgressionModel:
    def __init__(self):
        # Simplified transition probabilities
        self.transition_matrix = {
            Disease.HEALTHY: {Disease.COUGH: 0.1},
            Disease.COUGH: {
                Disease.PERSISTENT_COUGH: 0.3,
                Disease.POST_INFECTIOUS: 0.5,
                Disease.ACE_COUGH: 0.2
            },
            Disease.PERSISTENT_COUGH: {Disease.NSCLC: 0.1},
            Disease.NSCLC: {Disease.NSCLC: 1.0},  # Absorbing state
            Disease.POST_INFECTIOUS: {Disease.HEALTHY: 0.3},
            Disease.ACE_COUGH: {Disease.ACE_COUGH: 0.9, Disease.HEALTHY: 0.1}
        }

    def generate_features(self, disease: Disease) -> Dict[str, float]:
        """Generate synthetic features for each disease state"""
        base_features = {
            'cough_severity': random.gauss(0, 1),
            'duration_weeks': random.gauss(0, 1),
            'age': random.gauss(60, 10),
            'smoker': random.choice([0, 1])
        }
        
        # Modify features based on disease
        if disease == Disease.NSCLC:
            base_features['cough_severity'] += 2
            base_features['duration_weeks'] += 4
        elif disease == Disease.PERSISTENT_COUGH:
            base_features['cough_severity'] += 1
            base_features['duration_weeks'] += 2
            
        return base_features

    def next_state(self, current_state: PatientState) -> PatientState:
        possible_next = self.transition_matrix.get(current_state.disease, {})
        if not possible_next:
            return current_state
            
        next_disease = random.choices(
            list(possible_next.keys()),
            list(possible_next.values())
        )[0]
        
        return PatientState(
            disease=next_disease,
            features=self.generate_features(next_disease),
            time=current_state.time + 1,
            history=current_state.history + [current_state.disease],
            action_history=current_state.action_history
        )

# 3. Simple Decision Maker (replacing LLM for now)
class SimpleDecisionMaker:
    def decide_action(self, state: PatientState) -> Action:
        """Simple rule-based decision making"""
        if state.disease == Disease.COUGH:
            return Action.WAIT
        
        if state.disease == Disease.PERSISTENT_COUGH:
            if state.features['duration_weeks'] > 3:
                return Action.XRAY
            return Action.WAIT
            
        if state.disease == Disease.ACE_COUGH:
            return Action.STOP_ACE
            
        if state.disease == Disease.NSCLC:
            if Action.XRAY not in state.action_history:
                return Action.XRAY
            elif Action.CT_SCAN not in state.action_history:
                return Action.CT_SCAN
            elif Action.BIOPSY not in state.action_history:
                return Action.BIOPSY
                
        return Action.WAIT

# 4. Counterfactual Simulator
class CounterfactualSimulator:
    def __init__(self, disease_model: DiseaseProgressionModel, decision_maker: SimpleDecisionMaker):
        self.disease_model = disease_model
        self.decision_maker = decision_maker

    def simulate_trajectory(self, initial_state: PatientState, forced_actions: List[Action] = None) -> List[PatientState]:
        trajectory = [initial_state]
        current_state = initial_state
        
        for t in range(20):  # Max 20 time steps
            # Get next action (either forced or from decision maker)
            if forced_actions and t < len(forced_actions):
                action = forced_actions[t]
            else:
                action = self.decision_maker.decide_action(current_state)
                
            # Update state based on action and disease progression
            next_state = self.disease_model.next_state(current_state)
            next_state.action_history.append(action)
            
            trajectory.append(next_state)
            current_state = next_state
            
            # Stop if we reach a terminal state
            if action == Action.BIOPSY:
                break
                
        return trajectory

    def generate_counterfactuals(self, initial_state: PatientState) -> Dict[str, List[PatientState]]:
        counterfactuals = {
            "actual": self.simulate_trajectory(initial_state),
            "wait_only": self.simulate_trajectory(initial_state, [Action.WAIT] * 10),
            "aggressive": self.simulate_trajectory(initial_state, [Action.XRAY, Action.CT_SCAN, Action.BIOPSY])
        }
        return counterfactuals

# 5. Analysis Tools
def analyze_trajectory(trajectory: List[PatientState]) -> Dict:
    """Analyze a trajectory for key metrics"""
    return {
        "time_to_diagnosis": len(trajectory),
        "final_disease": trajectory[-1].disease,
        "actions_taken": len([a for a in trajectory[-1].action_history if a != Action.WAIT]),
        "cost": sum([action_costs.get(a, 0) for a in trajectory[-1].action_history])
    }

# 6. Usage Example
if __name__ == "__main__":
    # Initialize models
    disease_model = DiseaseProgressionModel()
    decision_maker = SimpleDecisionMaker()
    simulator = CounterfactualSimulator(disease_model, decision_maker)

    # Create initial patient state
    initial_state = PatientState(
        disease=Disease.COUGH,
        features={
            'cough_severity': 1.0,
            'duration_weeks': 2.0,
            'age': 62,
            'smoker': 1
        },
        time=0,
        history=[],
        action_history=[]
    )

    # Generate counterfactuals
    counterfactuals = simulator.generate_counterfactuals(initial_state)

    # Analyze results
    for scenario, trajectory in counterfactuals.items():
        print(f"\nScenario: {scenario}")
        analysis = analyze_trajectory(trajectory)
        print(f"Time to diagnosis: {analysis['time_to_diagnosis']}")
        print(f"Final disease: {analysis['final_disease']}")
        print(f"Actions taken: {analysis['actions_taken']}")
        
        # Print trajectory
        print("\nTrajectory:")
        for state in trajectory:
            print(f"Time {state.time}: Disease={state.disease}, " + 
                  f"Action={state.action_history[-1] if state.action_history else 'None'}")



Scenario: actual
Time to diagnosis: 9
Final disease: Disease.NSCLC
Actions taken: 9

Trajectory:
Time 0: Disease=Disease.COUGH, Action=Action.BIOPSY
Time 1: Disease=Disease.POST_INFECTIOUS, Action=Action.BIOPSY
Time 2: Disease=Disease.HEALTHY, Action=Action.BIOPSY
Time 3: Disease=Disease.COUGH, Action=Action.BIOPSY
Time 4: Disease=Disease.PERSISTENT_COUGH, Action=Action.BIOPSY
Time 5: Disease=Disease.NSCLC, Action=Action.BIOPSY
Time 6: Disease=Disease.NSCLC, Action=Action.BIOPSY
Time 7: Disease=Disease.NSCLC, Action=Action.BIOPSY
Time 8: Disease=Disease.NSCLC, Action=Action.BIOPSY

Scenario: wait_only
Time to diagnosis: 21
Final disease: Disease.COUGH
Actions taken: 9

Trajectory:
Time 0: Disease=Disease.COUGH, Action=Action.BIOPSY
Time 1: Disease=Disease.POST_INFECTIOUS, Action=Action.BIOPSY
Time 2: Disease=Disease.HEALTHY, Action=Action.BIOPSY
Time 3: Disease=Disease.COUGH, Action=Action.BIOPSY
Time 4: Disease=Disease.ACE_COUGH, Action=Action.BIOPSY
Time 5: Disease=Disease.ACE_COUGH