# Advanced Velocity Feature Importance Analysis

This notebook extends the velocity-based feature importance analysis with:
- 3D visualizations of feature importance across time and flow stages
- Patient-specific importance patterns
- Attention weight analysis for transformer-based models
- Uncertainty-aware feature importance
- Interactive visualizations using Plotly

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from scipy import stats
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("viridis")

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

Using device: cuda
PyTorch version: 2.7.0+cu128


In [2]:
# Import project modules
import sys
sys.path.append('/home/yl2428/Time-LLM')

from data_provider_pretrain.data_factory import data_provider
from models.time_series_flow_matching_model import TimeSeriesFlowMatchingModel
from models.model9_NS_transformer.ns_models.ns_Transformer import Model as NSTransformer
import argparse
from tqdm import tqdm

In [3]:
# Configuration and utility functions from feature_importance_analysis.ipynb
class DotDict(dict):
    """A dictionary that supports both dot notation and dictionary access."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__dict__[key] = value

    def __delattr__(self, item):
        self.__dict__.pop(item, None)

# Flow matching configuration (exact from feature_importance_analysis.ipynb)
flow_matching_config = DotDict({
    "num_nodes": 1,
    "task_name": "long_term_forecast",
    "is_training": 1,
    "model_id": "ETTh1_ETTh2_512_192",
    "model": "ns_Transformer",
    "precision": "32",
    "generative_model": "flow_matching",
    "data_pretrain": "Glucose",
    "root_path": "/home/yl2428/Time-LLM/dataset/glucose",
    "data_path": "output_Junt_16_3.csv",
    "data_path_pretrain": "output_Junt_16_3.csv",
    "features": "MS",
    "target": "OT",
    "freq": "t",
    "checkpoints": "/home/yl2428/checkpoints",
    "log_dir": "/home/yl2428/logs",
    "seq_len": 72,
    "label_len": 32,
    "pred_len": 48,
    "seasonal_patterns": "Monthly",
    "stride": 1,
    "enc_in": 9,
    "dec_in": 9,
    "c_out": 9,
    "d_model": 32,
    "n_heads": 8,
    "e_layers": 2,
    "d_layers": 1,
    "d_ff": 256,
    "moving_avg": 25,
    "factor": 3,
    "dropout": 0.1,
    "embed": "timeF",
    "activation": "gelu",
    "output_attention": False,
    "patch_len": 16,
    "prompt_domain": 0,
    "llm_model": "LLAMA",
    "llm_dim": 4096,
    "vae_hidden_dim": 16,
    "num_workers": 10,
    "itr": 1,
    "train_epochs": 100,
    "align_epochs": 10,
    "ema_decay": 0.995,
    "batch_size": 512,
    "eval_batch_size": 512,
    "patience": 40,
    "learning_rate": 0.0001,
    "des": "Exp",
    "loss": "MSE",
    "lradj": "COS",
    "pct_start": 0.2,
    "use_amp": False,
    "llm_layers": 32,
    "percent": 100,
    "num_individuals": -1,
    "enable_covariates": 1,
    "cov_type": "tensor",
    "gradient_accumulation_steps": 1,
    "use_deep_speed": 1,
    "wandb": 1,
    "wandb_group": None,
    "use_moe": 1,
    "num_experts": 8,
    "latent_len": 36,
    "top_k_experts": 4,
    "moe_layer_indices": [0, 1],
    "moe_loss_weight": 0.01,
    "log_routing_stats": 1,
    "num_universal_experts": 1,
    "universal_expert_weight": 0.3,
    "head_dropout": 0.1,
    "channel_independence": 0,
    "decomp_method": "moving_avg",
    "use_norm": 1,
    "down_sampling_layers": 2,
    "down_sampling_window": 1,
    "down_sampling_method": "avg",
    "use_future_temporal_feature": 0,
    "k_z": 1e-2,
    "k_cond": 0.001,
    "d_z": 8,
    "p_hidden_dims": [64, 64],
    "p_hidden_layers": 2,
    "diffusion_config_dir": "/home/yl2428/Time-LLM/models/model9_NS_transformer/configs/toy_8gauss.yml",  # Added missing config
    "cond_pred_model_pertrain_dir": None,
    "CART_input_x_embed_dim": 32,
    "mse_timestep": 0,
    "MLP_diffusion_net": False,
    "timesteps": 50,
    "ode_solver": "dopri5",
    "ode_rtol": 1e-5,
    "ode_atol": 1e-5,
    "interpolation_type": "linear",
    "expert_layers": 2,
    "loader": "modal",
    "model_comment": "none",
    "enable_context_aware": 1,
    "glucose_dropout_rate": 0.4,
    "use_contrastive_learning": 1,
    "contrastive_loss_weight": 0.1,
    "contrastive_temperature": 0.1,
    "use_momentum_encoder": 1,
    "momentum_factor": 0.999,
    "n_flow_stages": 5,  # For velocity analysis
    "col_stats": {'SEX': {'COUNT': (['F', 'M'], [367, 135])}, 'RACE': {'COUNT': (['WHITE', 'NOT REPORTED', 'ASIAN', 'BLACK/AFRICAN AMERICAN', 'MULTIPLE', 'UNKNOWN', 'AMERICAN INDIAN/ALASKAN NATIVE'], [459, 11, 10, 10, 8, 2, 2])}, 'ETHNIC': {'COUNT': (['Not Hispanic or Latino', 'Hispanic or Latino', 'Do not wish to answer', "Don't know"], [472, 15, 13, 2])}, 'ARMCD': {'COUNT': (['RESISTANCE', 'INTERVAL', 'AEROBIC'], [172, 167, 163])}, 'insulin modality': {'COUNT': (['CLOSED LOOP INSULIN PUMP', 'INSULIN PUMP', 'MULTIPLE DAILY INJECTIONS'], [225, 189, 88])}, 'AGE': {'MEAN': 36.655378486055774, 'STD': 13.941209833786187, 'QUANTILES': [18.0, 25.0, 33.0, 45.75, 70.0]}, 'WEIGHT': {'MEAN': 161.39940239043824, 'STD': 30.624877585598654, 'QUANTILES': [103.0, 140.0, 155.0, 179.0, 280.0]}, 'HEIGHT': {'MEAN': 66.72509960159363, 'STD': 3.505847063905933, 'QUANTILES': [58.0, 64.0, 66.0, 69.0, 77.0]}, 'HbA1c': {'MEAN': 6.642828685258964, 'STD': 0.7633658734231158, 'QUANTILES': [4.8, 6.1, 6.6, 7.1, 10.0]}, 'DIABETES_ONSET': {'MEAN': 18.72725737051793, 'STD': 11.889102915798386, 'QUANTILES': [0.0833, 11.0, 16.0, 24.0, 66.0]}},
    "col_names_dict": {'categorical': ['ARMCD', 'ETHNIC', 'RACE', 'SEX', 'insulin modality'], 'numerical': ['AGE', 'DIABETES_ONSET', 'HEIGHT', 'HbA1c', 'WEIGHT']}
})

# Use the config as args
args = flow_matching_config
print(f"Configuration loaded: {args.model} with d_model={args.d_model}, batch_size={args.batch_size}")
print(f"Diffusion config path: {args.diffusion_config_dir}")

Configuration loaded: ns_Transformer with d_model=32, batch_size=512
Diffusion config path: /home/yl2428/Time-LLM/models/model9_NS_transformer/configs/toy_8gauss.yml


In [4]:
# Checkpoint loading functions (exact from feature_importance_analysis.ipynb)
import os
import glob
import re

def find_best_checkpoint(base_path="/home/yl2428/logs/ns_Transformer/flow_matching/comfy-dust-243", metric="val_loss"):
    """Find the best checkpoint based on validation loss."""
    print(f"Searching for checkpoints in: {base_path}")
    
    checkpoint_pattern = os.path.join(base_path, "checkpoints/epoch=*-step=*-val_loss=*.ckpt/checkpoint")
    print(checkpoint_pattern)
    checkpoint_dirs = glob.glob(checkpoint_pattern)
    
    if not checkpoint_dirs:
        print("No checkpoints found!")
        return None, None, None
    
    best_checkpoint = None
    best_metric = float('inf')
    best_run = None
    
    print(f"Found {len(checkpoint_dirs)} checkpoints:")
    
    for checkpoint_dir in checkpoint_dirs:
        pattern = r'epoch=(\d+)-step=(\d+)-val_loss=([\d.]+)\.ckpt'
        match = re.search(pattern, checkpoint_dir)
        
        if match:
            epoch, step, val_loss = match.groups()
            val_loss = float(val_loss)
            run_name = checkpoint_dir.split('/')[-4]
            
            print(f"  - {run_name}: epoch={epoch}, step={step}, val_loss={val_loss:.4f}")
            
            if val_loss < best_metric:
                best_metric = val_loss
                best_checkpoint = checkpoint_dir
                best_run = run_name
    
    if best_checkpoint:
        print(f"\nBest checkpoint: {best_run}")
        print(f"  - Path: {best_checkpoint}")
        print(f"  - Val Loss: {best_metric:.4f}")
    
    return best_checkpoint, best_metric, best_run


def load_deepspeed_checkpoint(model, checkpoint_path):
    """Load DeepSpeed checkpoint into the model."""
    print(f"Loading DeepSpeed checkpoint from: {checkpoint_path}")
    
    model_states_path = os.path.join(checkpoint_path, "mp_rank_00_model_states.pt")
    
    if not os.path.exists(model_states_path):
        raise FileNotFoundError(f"Model states file not found: {model_states_path}")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    checkpoint = torch.load(model_states_path, map_location=device)
    
    if 'module' in checkpoint:
        state_dict = checkpoint['module']
    elif 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    
    cleaned_state_dict = {}
    for key, value in state_dict.items():
        clean_key = key
        if key.startswith('_forward_module.'):
            clean_key = key.replace('_forward_module.', '')
        elif key.startswith('module.'):
            clean_key = key.replace('module.', '')
        
        if isinstance(value, torch.Tensor):
            value = value.to(device)
        
        cleaned_state_dict[clean_key] = value
    
    try:
        model = model.to(device)
        missing_keys, unexpected_keys = model.load_state_dict(cleaned_state_dict, strict=False)
        
        if missing_keys:
            print(f"Missing keys: {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}")
        if unexpected_keys:
            print(f"Unexpected keys: {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}")
            
        print("✓ Model weights loaded successfully!")
        
    except Exception as e:
        print(f"Warning: Some keys couldn't be loaded: {e}")
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in cleaned_state_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        
        model = model.to(device)
        model.load_state_dict(model_dict)
        print(f"✓ Loaded {len(pretrained_dict)}/{len(cleaned_state_dict)} parameters")
    
    model = model.to(device)
    print(f"✓ All model components moved to {device}")
    
    return model

print("Checkpoint loading functions defined!")

Checkpoint loading functions defined!


In [5]:
# Load model and data (exact from feature_importance_analysis.ipynb)
print("Loading flow matching model and data...")

# Load data
flow_args = flow_matching_config
train_data_fm, train_loader_fm, flow_args = data_provider(
    flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, True, 'train'
)
vali_data_fm, vali_loader_fm, flow_args = data_provider(
    flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, True, 'val'
)
test_data_fm, test_loader_fm, flow_args = data_provider(
    flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, False, 'test'
)

# Initialize model
flow_matching_model = TimeSeriesFlowMatchingModel(flow_args, train_loader_fm, vali_loader_fm, test_loader_fm)

# Find and load best checkpoint
checkpoint_path, best_metric, run_name = find_best_checkpoint()

if checkpoint_path:
    flow_matching_model = load_deepspeed_checkpoint(flow_matching_model, checkpoint_path)
    flow_matching_model.eval()
    print(f"✓ Model loaded from {run_name} with val_loss: {best_metric:.4f}")
else:
    print("No checkpoint found - using untrained model")
    flow_matching_model.eval()

# Set model reference
model = flow_matching_model
test_loader = test_loader_fm

print(f"Model on device: {next(flow_matching_model.parameters()).device}")
print(f"Model type: {flow_args.model}")
print(f"Covariates enabled: {flow_args.enable_covariates}")
print(f"Batch size: {flow_args.batch_size}")
print(f"Number of experts: {model.num_experts}")
print("✓ Data and model loaded successfully!")

Loading flow matching model and data...
Mean: [  0.84085476   0.80213351  79.40491332  32.5294183    0.24537674
   5.3128258    3.82292609   5.69710292 144.81223086]
Std: [ 1.61659338  0.60898286 19.36734307 86.98838521  0.43031035 14.4286974
 11.59240125 11.53828031 55.07849221]
Loading data into memory...


100%|██████████| 488/488 [00:44<00:00, 11.09it/s]


Mean: [  0.84085476   0.80213351  79.40491332  32.5294183    0.24537674
   5.3128258    3.82292609   5.69710292 144.81223086]
Std: [ 1.61659338  0.60898286 19.36734307 86.98838521  0.43031035 14.4286974
 11.59240125 11.53828031 55.07849221]
Loading data into memory...


100%|██████████| 488/488 [00:43<00:00, 11.16it/s]


Mean: [  0.84085476   0.80213351  79.40491332  32.5294183    0.24537674
   5.3128258    3.82292609   5.69710292 144.81223086]
Std: [ 1.61659338  0.60898286 19.36734307 86.98838521  0.43031035 14.4286974
 11.59240125 11.53828031 55.07849221]
Loading data into memory...


100%|██████████| 488/488 [00:44<00:00, 11.09it/s]


Searching for checkpoints in: /home/yl2428/logs/ns_Transformer/flow_matching/comfy-dust-243
/home/yl2428/logs/ns_Transformer/flow_matching/comfy-dust-243/checkpoints/epoch=*-step=*-val_loss=*.ckpt/checkpoint
Found 1 checkpoints:
  - comfy-dust-243: epoch=9, step=111310, val_loss=1.0574

Best checkpoint: comfy-dust-243
  - Path: /home/yl2428/logs/ns_Transformer/flow_matching/comfy-dust-243/checkpoints/epoch=9-step=111310-val_loss=1.0574.ckpt/checkpoint
  - Val Loss: 1.0574
Loading DeepSpeed checkpoint from: /home/yl2428/logs/ns_Transformer/flow_matching/comfy-dust-243/checkpoints/epoch=9-step=111310-val_loss=1.0574.ckpt/checkpoint
Using device: cuda
✓ Model weights loaded successfully!
✓ All model components moved to cuda
✓ Model loaded from comfy-dust-243 with val_loss: 1.0574
Model on device: cuda:0
Model type: ns_Transformer
Covariates enabled: 1
Batch size: 512


AttributeError: 'TimeSeriesFlowMatchingModel' object has no attribute 'num_experts'

In [None]:
# Fixed model configuration - access num_experts from args
print(f"Model on device: {next(flow_matching_model.parameters()).device}")
print(f"Model type: {flow_args.model}")
print(f"Covariates enabled: {flow_args.enable_covariates}")
print(f"Batch size: {flow_args.batch_size}")
print(f"Number of experts (from config): {flow_args.num_experts}")
print("✓ Data and model loaded successfully!")

## 3. 3D Visualization of Feature Importance Across Stages

In [7]:
def create_3d_importance_surface(stage_importance, seq_len):
    """
    Create 3D surface plot showing importance across time steps and flow stages.
    """
    n_stages = len(stage_importance)
    
    # Prepare data for 3D surface
    Z = np.zeros((n_stages, seq_len))
    for stage_idx in range(n_stages):
        importance = stage_importance[stage_idx]['x']
        if len(importance.shape) > 1:
            importance = importance.mean(axis=0)  # Average across batch
        Z[stage_idx, :] = importance[:seq_len]
    
    # Create meshgrid
    time_steps = np.arange(seq_len)
    flow_stages = np.arange(n_stages)
    X, Y = np.meshgrid(time_steps, flow_stages)
    
    # Create 3D surface plot
    fig = go.Figure(data=[go.Surface(
        x=X,
        y=Y,
        z=Z,
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(title="Feature Importance")
    )])
    
    fig.update_layout(
        title='3D Feature Importance Surface: Time Steps × Flow Stages',
        scene=dict(
            xaxis_title='Historical Time Step',
            yaxis_title='Flow Stage',
            zaxis_title='Feature Importance',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.3)
            )
        ),
        width=900,
        height=700
    )
    
    return fig

# Example usage (with dummy data for demonstration)
dummy_stage_importance = {
    i: {'x': np.random.exponential(scale=0.5, size=(72,)) * np.exp(-np.arange(72) * 0.02)}
    for i in range(5)
}

fig_3d = create_3d_importance_surface(dummy_stage_importance, 72)
fig_3d.show()

In [8]:
!pip install nbformat

Collecting nbformat
  Using cached nbformat-5.10.4-py3-none-any.whl.metadata (3.6 kB)
Collecting fastjsonschema>=2.15 (from nbformat)
  Using cached fastjsonschema-2.21.1-py3-none-any.whl.metadata (2.2 kB)
Collecting jsonschema>=2.6 (from nbformat)
  Using cached jsonschema-4.25.0-py3-none-any.whl.metadata (7.7 kB)
Collecting jsonschema-specifications>=2023.03.6 (from jsonschema>=2.6->nbformat)
  Using cached jsonschema_specifications-2025.4.1-py3-none-any.whl.metadata (2.9 kB)
Collecting referencing>=0.28.4 (from jsonschema>=2.6->nbformat)
  Using cached referencing-0.36.2-py3-none-any.whl.metadata (2.8 kB)
Collecting rpds-py>=0.7.1 (from jsonschema>=2.6->nbformat)
  Downloading rpds_py-0.27.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Using cached nbformat-5.10.4-py3-none-any.whl (78 kB)
Using cached fastjsonschema-2.21.1-py3-none-any.whl (23 kB)
Using cached jsonschema-4.25.0-py3-none-any.whl (89 kB)
Using cached jsonschema_specifications-2025.4.1-py3

## 4. Interactive Heatmap with Plotly

## 7. Cross‑Modality Feature Importance Across Stages

This section computes per‑modality (channel) importance across flow stages by aggregating absolute importance over time (and batch if present). It supports inputs where stage importance for `x` is shaped as:
- [seq_len] (falls back to a single “time” modality), or
- [seq_len, channels], or
- [batch, seq_len, channels] / [batch, channels, seq_len].

We then visualize a Stage × Modality heatmap.


In [8]:
def create_interactive_heatmap(stage_importance, feature_names=None):
    """
    Create interactive heatmap showing importance evolution across stages.
    """
    n_stages = len(stage_importance)
    seq_len = len(stage_importance[0]['x'])
    
    if feature_names is None:
        feature_names = [f'T-{seq_len-i}' for i in range(seq_len)]
    
    # Prepare heatmap data
    heatmap_data = []
    for stage_idx in range(n_stages):
        importance = stage_importance[stage_idx]['x']
        if len(importance.shape) > 1:
            importance = importance.mean(axis=0)
        heatmap_data.append(importance)
    
    heatmap_data = np.array(heatmap_data)
    
    # Create interactive heatmap
    fig = go.Figure(data=go.Heatmap(
        z=heatmap_data,
        x=feature_names,
        y=[f'Stage {i+1}' for i in range(n_stages)],
        colorscale='RdYlBu_r',
        hovertemplate='Feature: %{x}<br>Stage: %{y}<br>Importance: %{z:.4f}<extra></extra>',
        colorbar=dict(title="Importance")
    ))
    
    fig.update_layout(
        title='Interactive Feature Importance Heatmap Across Flow Stages',
        xaxis_title='Time Steps',
        yaxis_title='Flow Stages',
        width=1200,
        height=500,
        xaxis=dict(tickangle=45)
    )
    
    # Add annotations for stage descriptions
    stage_descriptions = [
        "Initial Transport",
        "Coarse Features",
        "Mid Refinement",
        "Fine Details",
        "Final Approach"
    ]
    
    for i, desc in enumerate(stage_descriptions[:n_stages]):
        fig.add_annotation(
            x=-0.1,
            y=i,
            text=desc,
            showarrow=False,
            xref="paper",
            yref="y",
            font=dict(size=10),
            xanchor="right"
        )
    
    return fig

# Create interactive heatmap
fig_heatmap = create_interactive_heatmap(dummy_stage_importance)
fig_heatmap.show()

## 5. Patient-Specific Importance Analysis

In [9]:
class PatientSpecificAnalyzer:
    """
    Analyze importance patterns for individual patients.
    """
    
    def __init__(self, analyzer, data_loader):
        self.analyzer = analyzer
        self.data_loader = data_loader
        
    def compute_patient_importance(self, patient_ids, stage_idx=0):
        """
        Compute importance for specific patients at a given stage.
        """
        patient_importance = {}
        
        for batch in self.data_loader:
            batch_x, batch_y, batch_x_mark, batch_y_mark = batch
            
            # Assume patient IDs are in batch_y_mark or need to be extracted
            # This is a template - adjust based on your data structure
            
            batch_x = batch_x.float().to(device)
            covariates = batch_x_mark.float().to(device) if batch_x_mark is not None else None
            
            t = self.analyzer.stage_centers[stage_idx]
            x_importance, cov_importance = self.analyzer.integrated_gradients(
                batch_x, covariates, t
            )
            
            # Store per-patient importance
            for i in range(batch_x.shape[0]):
                # Extract patient ID (placeholder - adjust based on your data)
                patient_id = f"patient_{i}"
                
                if patient_id not in patient_importance:
                    patient_importance[patient_id] = []
                
                patient_importance[patient_id].append({
                    'x_importance': x_importance[i].cpu().numpy(),
                    'cov_importance': cov_importance[i].cpu().numpy() if cov_importance is not None else None
                })
        
        return patient_importance
    
    def cluster_patients_by_importance(self, patient_importance, n_clusters=3):
        """
        Cluster patients based on their importance patterns.
        """
        from sklearn.cluster import KMeans
        
        # Flatten importance vectors for clustering
        importance_vectors = []
        patient_ids = []
        
        for patient_id, importance_list in patient_importance.items():
            # Average importance across samples for this patient
            avg_importance = np.mean([imp['x_importance'] for imp in importance_list], axis=0)
            importance_vectors.append(avg_importance.flatten())
            patient_ids.append(patient_id)
        
        importance_vectors = np.array(importance_vectors)
        
        # Apply PCA for dimensionality reduction
        pca = PCA(n_components=min(50, importance_vectors.shape[1]))
        importance_pca = pca.fit_transform(importance_vectors)
        
        # Cluster patients
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        clusters = kmeans.fit_predict(importance_pca)
        
        # Visualize clusters
        tsne = TSNE(n_components=2, random_state=42)
        importance_tsne = tsne.fit_transform(importance_pca)
        
        # Create scatter plot
        fig = px.scatter(
            x=importance_tsne[:, 0],
            y=importance_tsne[:, 1],
            color=clusters,
            hover_data={'Patient': patient_ids},
            title='Patient Clustering Based on Feature Importance Patterns',
            labels={'x': 't-SNE 1', 'y': 't-SNE 2', 'color': 'Cluster'},
            color_continuous_scale='Viridis'
        )
        
        return clusters, fig

print("Patient-specific analyzer defined")

Patient-specific analyzer defined


## 6. Attention Weight Analysis for Transformer Models

In [10]:
def analyze_attention_patterns(model, data_loader, stage_idx=2):
    """
    Extract and visualize attention patterns from transformer layers.
    """
    attention_weights = []
    
    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(data_loader):
            if batch_idx > 5:  # Limit to a few batches for visualization
                break
            
            batch_x, _, _, _ = batch
            batch_x = batch_x.float().to(device)
            
            # Hook to extract attention weights
            # This is model-specific - adjust based on your transformer implementation
            if hasattr(model, 'encoder'):
                # Example: Extract attention from encoder layers
                # You'll need to implement hooks or modify your model to expose attention
                pass
    
    # Create attention visualization
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Layer 1 Attention', 'Layer 2 Attention', 
                       'Layer 3 Attention', 'Average Attention'),
        specs=[[{'type': 'heatmap'}, {'type': 'heatmap'}],
               [{'type': 'heatmap'}, {'type': 'heatmap'}]]
    )
    
    # Dummy attention data for demonstration
    for i in range(4):
        attention_matrix = np.random.random((72, 72))
        attention_matrix = (attention_matrix + attention_matrix.T) / 2  # Make symmetric
        
        row = i // 2 + 1
        col = i % 2 + 1
        
        fig.add_trace(
            go.Heatmap(z=attention_matrix, colorscale='Blues', showscale=(i==3)),
            row=row, col=col
        )
    
    fig.update_layout(
        title='Attention Weight Patterns Across Transformer Layers',
        height=800,
        width=1000
    )
    
    return fig

# Example visualization
# fig_attention = analyze_attention_patterns(model, data_loader)
# fig_attention.show()
print("Attention analysis function defined")

Attention analysis function defined
