In [22]:
"""
Comprehensive Medical Trajectory Modeling with Interpretable Embeddings
This notebook demonstrates a complete implementation of:
1. Disease state embedding
2. Interpretable transitions
3. Trajectory analysis
4. Visualization
5. Synthetic data generation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
from torch.utils.data import Dataset, DataLoader
import umap
from sklearn.manifold import TSNE
import networkx as nx
from tqdm import tqdm
import random
from collections import defaultdict

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# 1. Data Structures and Medical Knowledge Base
@dataclass
class ClinicalConcept:
    """Represents a clinical concept with attributes"""
    name: str
    category: str
    related_concepts: List[str]
    temporal_profile: Dict[str, float]  # e.g., {"acute": 0.8, "chronic": 0.2}
    severity_profile: Dict[str, float]  # e.g., {"mild": 0.7, "severe": 0.3}

@dataclass
class Disease:
    """Represents a disease with its attributes"""
    name: str
    concepts: List[ClinicalConcept]
    common_symptoms: List[str]
    progression_profile: Dict[str, float]
    typical_duration: int  # in days

@dataclass
class PatientState:
    """Represents a patient's state at a point in time"""
    disease_states: List[str]  # Can have multiple concurrent conditions
    symptoms: Dict[str, float]
    vitals: Dict[str, float]
    lab_values: Dict[str, float]
    timestamp: int

# 2. Medical Knowledge Base Setup
def create_medical_knowledge_base():
    """Create a synthetic medical knowledge base"""
    # Clinical Concepts
    respiratory = ClinicalConcept(
        name="respiratory",
        category="system",
        related_concepts=["cough", "dyspnea"],
        temporal_profile={"acute": 0.6, "chronic": 0.4},
        severity_profile={"mild": 0.7, "moderate": 0.2, "severe": 0.1}
    )
    
    infectious = ClinicalConcept(
        name="infectious",
        category="etiology",
        related_concepts=["fever", "inflammation"],
        temporal_profile={"acute": 0.8, "chronic": 0.2},
        severity_profile={"mild": 0.6, "moderate": 0.3, "severe": 0.1}
    )
    
    # Diseases
    diseases = {
        "Viral Upper Respiratory Infection": Disease(
            name="Viral Upper Respiratory Infection",
            concepts=[respiratory, infectious],
            common_symptoms=["cough", "fever", "rhinorrhea"],
            progression_profile={"improvement": 0.7, "persistence": 0.2, "worsening": 0.1},
            typical_duration=10
        ),
        "Bacterial Pneumonia": Disease(
            name="Bacterial Pneumonia",
            concepts=[respiratory, infectious],
            common_symptoms=["cough", "fever", "dyspnea"],
            progression_profile={"improvement": 0.6, "persistence": 0.2, "worsening": 0.2},
            typical_duration=14
        ),
        # Add more diseases as needed
    }
    
    return diseases

# 3. Synthetic Data Generation
class SyntheticPatientGenerator:
    def __init__(self, medical_kb: Dict[str, Disease]):
        self.medical_kb = medical_kb
        self.vital_ranges = {
            "temperature": (36.5, 38.5),
            "heart_rate": (60, 100),
            "respiratory_rate": (12, 20),
            "blood_pressure_systolic": (90, 140),
            "blood_pressure_diastolic": (60, 90),
            "oxygen_saturation": (95, 100)
        }
        
    def generate_vitals(self, disease: Disease, severity: float) -> Dict[str, float]:
        """Generate synthetic vital signs based on disease and severity"""
        vitals = {}
        for vital, (low, high) in self.vital_ranges.items():
            # Add disease-specific modifications
            if disease.name == "bacterial_pneumonia" and vital == "temperature":
                low += severity * 2  # Higher fever with more severe pneumonia
            
            vitals[vital] = np.random.uniform(low, high)
        return vitals
    
    def generate_trajectory(self, 
                          initial_disease: str, 
                          duration: int,
                          complication_prob: float = 0.1) -> List[PatientState]:
        """Generate a synthetic patient trajectory"""
        trajectory = []
        current_disease = self.medical_kb[initial_disease]
        severity = np.random.beta(2, 5)  # Initial severity
        
        for t in range(duration):
            # Generate vitals and symptoms
            vitals = self.generate_vitals(current_disease, severity)
            symptoms = {symptom: np.random.uniform(0.5, 1.0) 
                      for symptom in current_disease.common_symptoms}
            
            # Create patient state
            state = PatientState(
                disease_states=[current_disease.name],
                symptoms=symptoms,
                vitals=vitals,
                lab_values={},  # Add if needed
                timestamp=t
            )
            trajectory.append(state)
            
            # Update severity based on progression profile
            prog = np.random.choice(
                list(current_disease.progression_profile.keys()),
                p=list(current_disease.progression_profile.values())
            )
            
            if prog == "improvement":
                severity *= 0.8
            elif prog == "worsening":
                severity *= 1.2
                
            # Possible complication
            if np.random.random() < complication_prob:
                # Add complication logic here
                pass
                
        return trajectory


In [23]:
def create_vocabulary(trajectories: List[List[PatientState]], 
                     medical_kb: Dict[str, Disease]) -> Dict[str, Dict[str, int]]:
    """
    Create vocabulary mappings for diseases, symptoms, and clinical concepts
    Returns:
        Dict with mappings for diseases, symptoms, and concepts to indices
    """
    vocab = {
        'diseases': {'<PAD>': 0},
        'symptoms': {'<PAD>': 0},
        'concepts': {'<PAD>': 0}
    }
    
    # Collect all unique values
    diseases = set()
    symptoms = set()
    concepts = set()
    
    # From trajectories
    for trajectory in trajectories:
        for state in trajectory:
            diseases.update(state.disease_states)
            symptoms.update(state.symptoms.keys())
    
    # From medical knowledge base
    for disease in medical_kb.values():
        diseases.add(disease.name)
        symptoms.update(disease.common_symptoms)
        for concept in disease.concepts:
            concepts.add(concept.name)
            concepts.update(concept.related_concepts)
    
    # Create mappings
    for i, disease in enumerate(sorted(diseases), start=1):
        vocab['diseases'][disease] = i
    
    for i, symptom in enumerate(sorted(symptoms), start=1):
        vocab['symptoms'][symptom] = i
        
    for i, concept in enumerate(sorted(concepts), start=1):
        vocab['concepts'][concept] = i
    
    return vocab

# Additional utility functions

def save_vocabulary(vocab: Dict[str, Dict[str, int]], path: str):
    """Save vocabulary to file"""
    with open(path, 'w') as f:
        json.dump(vocab, f, indent=2)

def load_vocabulary(path: str) -> Dict[str, Dict[str, int]]:
    """Load vocabulary from file"""
    with open(path, 'r') as f:
        return json.load(f)

def get_vocabulary_sizes(vocab: Dict[str, Dict[str, int]]) -> Dict[str, int]:
    """Get sizes of each vocabulary"""
    return {k: len(v) for k, v in vocab.items()}

class VocabularyHandler:
    """Class to handle vocabulary operations"""
    def __init__(self, vocab: Dict[str, Dict[str, int]]):
        self.vocab = vocab
        self.inverse_vocab = {
            category: {idx: token for token, idx in mappings.items()}
            for category, mappings in vocab.items()
        }
    
    def encode(self, category: str, items: List[str]) -> torch.Tensor:
        """Encode items to indices"""
        return torch.tensor([
            self.vocab[category].get(item, self.vocab[category]['<PAD>'])
            for item in items
        ])
    
    def decode(self, category: str, indices: torch.Tensor) -> List[str]:
        """Decode indices to items"""
        return [
            self.inverse_vocab[category].get(idx.item(), '<PAD>')
            for idx in indices
        ]
    
    def get_size(self, category: str) -> int:
        """Get vocabulary size for category"""
        return len(self.vocab[category])

# # Example usage in main():
# def main():
#     # Create medical knowledge base
#     medical_kb = create_medical_knowledge_base()
    
#     # Generate synthetic data
#     generator = SyntheticPatientGenerator(medical_kb)
#     trajectories = [
#         generator.generate_trajectory(
#             initial_disease='viral_uri',
#             duration=random.randint(5, 15)
#         )
#         for _ in range(100)
#     ]
    
#     # Create vocabulary
#     vocab = create_vocabulary(trajectories, medical_kb)
#     vocab_handler = VocabularyHandler(vocab)
    
#     print("Vocabulary sizes:")
#     for category, size in get_vocabulary_sizes(vocab).items():
#         print(f"{category}: {size} items")
    
#     # Example encoding/decoding
#     example_diseases = ['viral_uri', 'bacterial_pneumonia']
#     encoded = vocab_handler.encode('diseases', example_diseases)
#     decoded = vocab_handler.decode('diseases', encoded)
#     print("\nEncoding/Decoding example:")
#     print(f"Original: {example_diseases}")
#     print(f"Encoded: {encoded}")
#     print(f"Decoded: {decoded}")
    
#     # Create dataset with vocabulary handler
#     dataset = ClinicalTrajectoryDataset(trajectories, medical_kb, vocab)
    
#     # Continue with the rest of the main function...


In [24]:
# 4. Embedding Models
class ClinicalConceptEmbedder(nn.Module):
    """Embeds clinical concepts into a learned space"""
    def __init__(self, n_concepts: int, embedding_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(n_concepts, embedding_dim)
        self.concept_projection = nn.Linear(embedding_dim, embedding_dim)
        
    def forward(self, concept_ids: torch.Tensor) -> torch.Tensor:
        concept_embeddings = self.embedding(concept_ids)
        return self.concept_projection(concept_embeddings)

class DiseaseStateEmbedder(nn.Module):
    """Combines multiple elements to create a disease state embedding"""
    def __init__(self, 
                 n_diseases: int,
                 n_symptoms: int,
                 n_concepts: int,
                 embedding_dim: int):
        super().__init__()
        self.embedding_dim = embedding_dim
        
        # Component embedders
        self.disease_embedding = nn.Embedding(n_diseases, embedding_dim)
        self.symptom_embedding = nn.Embedding(n_symptoms, embedding_dim)
        self.concept_embedder = ClinicalConceptEmbedder(n_concepts, embedding_dim)
        
        # Attention mechanism for combining embeddings
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads=4)
        
        # Final projection
        self.final_projection = nn.Sequential(
            nn.Linear(embedding_dim * 3, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.ReLU()
        )
        
    def forward(self, 
                disease_ids: torch.Tensor,
                symptom_ids: torch.Tensor,
                concept_ids: torch.Tensor) -> torch.Tensor:
        # Get component embeddings
        disease_emb = self.disease_embedding(disease_ids)
        symptom_emb = self.symptom_embedding(symptom_ids)
        concept_emb = self.concept_embedder(concept_ids)
        
        # Combine using attention
        attended_disease, _ = self.attention(
            disease_emb.unsqueeze(0),
            torch.cat([symptom_emb, concept_emb], dim=0),
            torch.cat([symptom_emb, concept_emb], dim=0)
        )
        
        # Concatenate and project
        combined = torch.cat([
            attended_disease.squeeze(0),
            symptom_emb.mean(0).unsqueeze(0),
            concept_emb.mean(0).unsqueeze(0)
        ], dim=-1)
        
        return self.final_projection(combined)

class TransitionModel(nn.Module):
    """Models transitions between disease states"""
    def __init__(self, embedding_dim: int):
        super().__init__()
        self.embedding_dim = embedding_dim
        
        # GRU for temporal modeling
        self.gru = nn.GRU(
            input_size=embedding_dim,
            hidden_size=embedding_dim,
            num_layers=2,
            dropout=0.1,
            batch_first=True
        )
        
        # Transition scoring
        self.transition_scorer = nn.Sequential(
            nn.Linear(embedding_dim * 2, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, 1)
        )
        
    def forward(self, 
                current_embedding: torch.Tensor,
                history_embeddings: Optional[torch.Tensor] = None) -> torch.Tensor:
        if history_embeddings is not None:
            # Include history in transition computation
            _, hidden = self.gru(history_embeddings)
            hidden = hidden[-1]  # Take last layer's hidden state
        else:
            hidden = current_embedding
            
        # Predict next embedding
        next_embedding = self.gru(current_embedding.unsqueeze(1), hidden.unsqueeze(0))
        return next_embedding[0]

class InterpretableClinicModel(nn.Module):
    """Complete model with interpretability layers"""
    def __init__(self,
                 n_diseases: int,
                 n_symptoms: int,
                 n_concepts: int,
                 embedding_dim: int,
                 n_clinical_patterns: int):
        super().__init__()
        
        # Core components
        self.state_embedder = DiseaseStateEmbedder(
            n_diseases, n_symptoms, n_concepts, embedding_dim)
        self.transition_model = TransitionModel(embedding_dim)
        
        # Interpretability components
        self.pattern_prototypes = nn.Parameter(
            torch.randn(n_clinical_patterns, embedding_dim))
        self.pattern_descriptions = [f"Pattern_{i}" for i in range(n_clinical_patterns)]
        
        # Decoders for interpretation
        self.disease_decoder = nn.Linear(embedding_dim, n_diseases)
        self.symptom_decoder = nn.Linear(embedding_dim, n_symptoms)
        
        # Clinical reasoning module
        self.reasoning_templates = {
            "progression": "Disease progression suggests {}, based on {}",
            "complication": "Possible complication of {} due to {}",
            "improvement": "Improvement noted in {} with evidence of {}"
        }
        
    def forward(self, 
                current_state: Dict[str, torch.Tensor],
                history: Optional[List[Dict[str, torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
        # Embed current state
        current_embedding = self.state_embedder(
            current_state['disease_ids'],
            current_state['symptom_ids'],
            current_state['concept_ids']
        )
        
        # Get history embeddings if available
        history_embeddings = None
        if history:
            history_embeddings = torch.stack([
                self.state_embedder(h['disease_ids'], 
                                  h['symptom_ids'],
                                  h['concept_ids'])
                for h in history
            ])
        
        # Predict next state
        next_embedding = self.transition_model(current_embedding, history_embeddings)
        
        # Generate interpretations
        interpretations = self.interpret_transition(current_embedding, next_embedding)
        
        return {
            'next_embedding': next_embedding,
            'interpretations': interpretations
        }
        
    def interpret_transition(self, 
                           current_embedding: torch.Tensor,
                           next_embedding: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Compare with clinical patterns
        pattern_similarities = F.cosine_similarity(
            next_embedding.unsqueeze(1),
            self.pattern_prototypes.unsqueeze(0),
            dim=-1
        )
        
        # Decode probable diseases and symptoms
        disease_probs = F.softmax(self.disease_decoder(next_embedding), dim=-1)
        symptom_probs = torch.sigmoid(self.symptom_decoder(next_embedding))
        
        # Measure change
        state_change = next_embedding - current_embedding
        change_magnitude = torch.norm(state_change, dim=-1)
        
        return {
            'pattern_similarities': pattern_similarities,
            'disease_probabilities': disease_probs,
            'symptom_probabilities': symptom_probs,
            'change_magnitude': change_magnitude
        }
        
    def generate_explanation(self, interpretations: Dict[str, torch.Tensor]) -> str:
        """Generate natural language explanation of transition"""
        top_pattern_idx = interpretations['pattern_similarities'].argmax()
        top_pattern = self.pattern_descriptions[top_pattern_idx]
        
        top_diseases = torch.topk(
            interpretations['disease_probabilities'],
            k=3
        )
        
        explanation = (
            f"Clinical pattern most closely matches {top_pattern}. "
            f"Suggesting possible progression to: {top_diseases}. "
            f"Magnitude of change: {interpretations['change_magnitude']:.2f}"
        )
        
        return explanation


In [40]:
# 5. Dataset and Training Infrastructure
class ClinicalTrajectoryDataset(Dataset):
    """Dataset for clinical trajectories"""
    def __init__(self, 
                 trajectories: List[List[PatientState]],
                 medical_kb: Dict[str, Disease],
                 vocab: Dict[str, Dict[str, int]]):
        self.trajectories = trajectories
        self.medical_kb = medical_kb
        self.vocab_handler = VocabularyHandler(vocab)
        
    def __len__(self):
        return len(self.trajectories)
    
    def __getitem__(self, idx):
        trajectory = self.trajectories[idx]
        
        # Convert trajectory to tensors
        sequence = []
        for state in trajectory:
            # Convert state to indices using vocab handler
            disease_ids = self.vocab_handler.encode('diseases', state.disease_states)
            symptom_ids = self.vocab_handler.encode('symptoms', list(state.symptoms.keys()))
            concept_ids = self.vocab_handler.encode(
                'concepts',
                [c.name for d in state.disease_states
                 for c in self.medical_kb[d].concepts]
            )
            
            sequence.append({
                'disease_ids': disease_ids,
                'symptom_ids': symptom_ids,
                'concept_ids': concept_ids,
                'vitals': torch.tensor(list(state.vitals.values())),
                'timestamp': state.timestamp
            })
            
        return sequence


class TrainingEngine:
    """Handles model training and validation"""
    def __init__(self,
                 model: InterpretableClinicModel,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        
        self.optimizer = torch.optim.Adam(model.parameters())
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=5)
        
        # Loss functions
        self.transition_loss = nn.MSELoss()
        self.classification_loss = nn.CrossEntropyLoss()
        
        # Metrics tracking
        self.train_metrics = []
        self.val_metrics = []
        
    def train_epoch(self):
        self.model.train()
        epoch_metrics = defaultdict(list)
        
        for batch in tqdm(self.train_loader):
            self.optimizer.zero_grad()
            total_loss = 0
            
            # Process each trajectory in batch
            for trajectory in batch:
                trajectory = [
                    {k: torch.Tensor(v).to(self.device) for k, v in state.items()}
                    for state in trajectory
                ]
                
                # Forward pass through sequence
                for t in range(len(trajectory) - 1):
                    current_state = trajectory[t]
                    next_state = trajectory[t + 1]
                    history = trajectory[:t] if t > 0 else None
                    
                    print(current_state)
                    output = self.model(current_state, history)
                    
                    # Compute losses
                    next_embedding = self.model.state_embedder(
                        next_state['disease_ids'],
                        next_state['symptom_ids'],
                        next_state['concept_ids']
                    )
                    
                    transition_loss = self.transition_loss(
                        output['next_embedding'],
                        next_embedding
                    )
                    
                    disease_loss = self.classification_loss(
                        output['interpretations']['disease_probabilities'],
                        next_state['disease_ids']
                    )
                    
                    total_loss += transition_loss + disease_loss
                    
                    # Track metrics
                    epoch_metrics['transition_loss'].append(transition_loss.item())
                    epoch_metrics['disease_loss'].append(disease_loss.item())
            
            # Backward pass
            total_loss.backward()
            self.optimizer.step()
        
        # Compute epoch metrics
        return {k: np.mean(v) for k, v in epoch_metrics.items()}
    
    def validate(self):
        self.model.eval()
        val_metrics = defaultdict(list)
        
        with torch.no_grad():
            for batch in self.val_loader:
                # Similar to train_epoch but without backward pass
                pass
                
        return {k: np.mean(v) for k, v in val_metrics.items()}
    
    def train(self, n_epochs: int):
        for epoch in range(n_epochs):
            train_metrics = self.train_epoch()
            val_metrics = self.validate()
            
            self.train_metrics.append(train_metrics)
            self.val_metrics.append(val_metrics)
            
            self.scheduler.step(val_metrics['total_loss'])
            
            print(f"Epoch {epoch + 1}/{n_epochs}")
            print(f"Train metrics: {train_metrics}")
            print(f"Val metrics: {val_metrics}")

# 6. Visualization Tools
class TrajectoryVisualizer:
    """Tools for visualizing clinical trajectories"""
    def __init__(self, model: InterpretableClinicModel):
        self.model = model
        self.umap_transform = umap.UMAP(n_components=2)
        self.tsne_transform = TSNE(n_components=2)
        
    def plot_embedding_space(self, 
                           embeddings: torch.Tensor,
                           labels: List[str],
                           method: str = 'umap'):
        """Plot embeddings in 2D space"""
        # Transform embeddings to 2D
        transform = self.umap_transform if method == 'umap' else self.tsne_transform
        embeddings_2d = transform.fit_transform(embeddings.detach().cpu().numpy())
        
        # Create plot
        plt.figure(figsize=(12, 8))
        scatter = plt.scatter(
            embeddings_2d[:, 0],
            embeddings_2d[:, 1],
            c=np.arange(len(embeddings_2d)),
            cmap='viridis'
        )
        
        # Add labels
        for i, label in enumerate(labels):
            plt.annotate(
                label,
                (embeddings_2d[i, 0], embeddings_2d[i, 1]),
                xytext=(5, 5),
                textcoords='offset points'
            )
            
        plt.colorbar(scatter)
        plt.title(f'Disease State Embeddings ({method.upper()})')
        plt.show()
        
    def plot_trajectory(self, 
                       trajectory: List[PatientState],
                       with_interpretations: bool = True):
        """Plot patient trajectory with interpretations"""
        # Create trajectory graph
        G = nx.DiGraph()
        
        # Add nodes and edges
        for i in range(len(trajectory) - 1):
            current_state = trajectory[i]
            next_state = trajectory[i + 1]
            
            # Add nodes
            G.add_node(i, state=current_state)
            G.add_node(i + 1, state=next_state)
            
            # Add edge
            if with_interpretations:
                # Get model interpretation
                current_tensor = self._state_to_tensor(current_state)
                interpretation = self.model.interpret_transition(
                    current_tensor,
                    self._state_to_tensor(next_state)
                )
                G.add_edge(i, i + 1, interpretation=interpretation)
            else:
                G.add_edge(i, i + 1)
        
        # Plot
        plt.figure(figsize=(15, 10))
        pos = nx.spring_layout(G)
        nx.draw(
            G, pos,
            with_labels=True,
            node_color='lightblue',
            node_size=1000,
            arrowsize=20
        )
        
        if with_interpretations:
            edge_labels = {
                (i, i + 1): self.model.generate_explanation(G.edges[i, i + 1]['interpretation'])
                for i in range(len(trajectory) - 1)
            }
            nx.draw_networkx_edge_labels(G, pos, edge_labels)
            
        plt.title('Patient Trajectory Graph')
        plt.show()
        
    def _state_to_tensor(self, state: PatientState) -> torch.Tensor:
        """Convert PatientState to tensor for model input"""
        # Implementation depends on specific state representation
        pass

# 7. Analysis Tools
class TrajectoryAnalyzer:
    """Tools for analyzing clinical trajectories"""
    def __init__(self, model: InterpretableClinicModel, medical_kb: Dict[str, Disease]):
        self.model = model
        self.medical_kb = medical_kb
        
    def analyze_transition_patterns(self, 
                                  trajectories: List[List[PatientState]]) -> pd.DataFrame:
        """Analyze common transition patterns in trajectories"""
        patterns = []
        
        for trajectory in trajectories:
            for i in range(len(trajectory) - 1):
                current_state = trajectory[i]
                next_state = trajectory[i + 1]
                
                # Get model interpretation
                output = self.model(
                    self._state_to_tensor(current_state),
                    history=[self._state_to_tensor(s) for s in trajectory[:i]]
                )
                
                patterns.append({
                    'from_state': current_state.disease_states,
                    'to_state': next_state.disease_states,
                    'change_magnitude': output['interpretations']['change_magnitude'].item(),
                    'top_pattern': self.model.pattern_descriptions[
                        output['interpretations']['pattern_similarities'].argmax().item()
                    ]
                })
        
        return pd.DataFrame(patterns)
    
    def identify_critical_transitions(self,
                                   trajectory: List[PatientState],
                                   threshold: float = 0.8) -> List[int]:
        """Identify critical transition points in trajectory"""
        critical_points = []
        
        for i in range(len(trajectory) - 1):
            current_state = trajectory[i]
            next_state = trajectory[i + 1]
            
            output = self.model(self._state_to_tensor(current_state))
            
            # Check if transition is critical based on change magnitude
            if output['interpretations']['change_magnitude'].item() > threshold:
                critical_points.append(i)
                
        return critical_points
    
    def generate_counterfactuals(self,
                               trajectory: List[PatientState],
                               n_alternatives: int = 3) -> List[List[PatientState]]:
        """Generate counterfactual trajectories"""
        counterfactuals = []
        
        # Implementation would depend on specific counterfactual generation strategy
        return counterfactuals

In [41]:

medical_kb = create_medical_knowledge_base()

# Generate synthetic data
generator = SyntheticPatientGenerator(medical_kb)
trajectories = [
    generator.generate_trajectory(
        initial_disease='Viral Upper Respiratory Infection',
        duration=random.randint(5, 15)
    )
    for _ in range(100)
]

# Create vocabulary
vocab = create_vocabulary(trajectories, medical_kb)

# Create dataset
dataset = ClinicalTrajectoryDataset(trajectories, medical_kb, vocab)
train_loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=lambda x: x  # Custom collate function for variable length sequences
    )
    
# Initialize model
model = InterpretableClinicModel(
    n_diseases=len(vocab['diseases']),
    n_symptoms=len(vocab['symptoms']),
    n_concepts=len(vocab['concepts']),
    embedding_dim=64,
    n_clinical_patterns=10
)

# Train model
trainer = TrainingEngine(model, train_loader, None)
trainer.train(n_epochs=10)

# Visualize results
visualizer = TrajectoryVisualizer(model)
analyzer = TrajectoryAnalyzer(model, medical_kb)

# Example trajectory analysis
trajectory = trajectories[0]
visualizer.plot_trajectory(trajectory, with_interpretations=True)

# Analyze patterns
patterns_df = analyzer.analyze_transition_patterns(trajectories)
print("\nCommon Transition Patterns:")
print(patterns_df.head())

# Identify critical points
critical_points = analyzer.identify_critical_transitions(trajectory)
print("\nCritical Transition Points:", critical_points)

# Generate and visualize counterfactuals
counterfactuals = analyzer.generate_counterfactuals(trajectory)
for i, cf_trajectory in enumerate(counterfactuals):
    print(f"\nCounterfactual Trajectory {i + 1}:")
    visualizer.plot_trajectory(cf_trajectory)


  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

{'disease_ids': tensor([2]), 'symptom_ids': tensor([1, 3, 4]), 'concept_ids': tensor([6, 4]), 'vitals': tensor([ 37.8337,  61.2573,  14.8813, 102.8987,  66.7386,  95.8988]), 'timestamp': tensor([])}





AssertionError: For batched (3-D) `query`, expected `key` and `value` to be 3-D but found 2-D and 2-D tensors respectively