In [None]:
import argparse
import torch
from data_provider_pretrain.data_factory import data_provider
from models.time_series_diffusion_model import TimeSeriesDiffusionModel
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from utils.callbacks import EMA
from lightning.pytorch.loggers import WandbLogger
import time
import random
import numpy as np
import os
import wandb
from datetime import timedelta
from utils.clean_args import clean_args
os.environ['CURL_CA_BUNDLE'] = ''
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

parser = argparse.ArgumentParser(description='Time-LLM')

fix_seed = 2021
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)
torch.cuda.manual_seed(fix_seed)
torch.cuda.manual_seed_all(fix_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class DotDict(dict):
    """
    A dictionary that supports both dot notation and dictionary access.
    This allows both `args.att` and `args['att']` to work.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__dict__[key] = value

    def __delattr__(self, item):
        self.__dict__.pop(item, None)

default_config = DotDict({
    # Basic config
    "num_nodes": 1,
    "task_name": "long_term_forecast",
    "is_training": 1,
    "model_id": "ETTh1_ETTh2_512_192",
    "model": "ns_Transformer",
    "precision": "32",
    
    # Data loader
    "data_pretrain": "Glucose",
    "root_path": "/home/yl2428/Time-LLM/dataset/glucose",
    "data_path": "combined_data_Jun_28.csv",
    "data_path_pretrain": "combined_data_Jun_28.csv",
    "features": "M",
    "target": "OT",
    "freq": "t",
    "checkpoints": "/gpfs/gibbs/pi/gerstein/yl2428/checkpoints/",
    "log_dir": "/gpfs/gibbs/pi/gerstein/yl2428/logs",
    
    # Forecasting task
    "seq_len": 128,
    "label_len": 12,
    "pred_len": 32,
    "seasonal_patterns": "Monthly",
    "stride": 8,
    
    # Model define
    "enc_in": 4,
    "dec_in": 4,
    "c_out": 4,
    "d_model": 32,
    "n_heads": 8,
    "e_layers": 2,
    "d_layers": 1,
    "d_ff": 128,
    "moving_avg": 25,
    "factor": 3,
    "dropout": 0.1,
    "embed": "timeF",
    "activation": "gelu",
    "output_attention": False,
    "patch_len": 16,
    "prompt_domain": 0,
    "llm_model": "LLAMA",
    "llm_dim": 4096,
    
    # Optimization
    "num_workers": 10,
    "itr": 1,
    "train_epochs": 100,
    "align_epochs": 10,
    "ema_decay": 0.97,
    "batch_size": 64,
    "eval_batch_size": 2,
    "patience": 10,
    "learning_rate": 0.0004,
    "des": "Exp",
    "loss": "MSE",
    "lradj": "COS",
    "pct_start": 0.2,
    "use_amp": False,
    "llm_layers": 32,
    "percent": 100,
    "num_individuals": -1,
    "enable_covariates": 1,
    "cov_type": "tensor",
    "gradient_accumulation_steps": 1,
    "use_deep_speed": 1,
    
    # Wandb
    "wandb": 1,
    "wandb_group": None,
    "wandb_api_key": "6f1080f993d5d7ad6103e69ef57dd9291f1bf366",
    "num_experts": 8,
    "head_dropout": 0.1,
    
    # TimeMixer-specific parameters
    "channel_independence": 0,
    "decomp_method": "moving_avg",
    "use_norm": 1,
    "down_sampling_layers": 2,
    "down_sampling_window": 1,
    "down_sampling_method": "avg",
    "use_future_temporal_feature": 0,
    
    # Diffusion specific parameters
    "k_z": 1e-2,
    "k_cond": 1,
    "d_z": 8,
    
    # De-stationary projector params
    "p_hidden_dims": [64, 64],
    "p_hidden_layers": 2,
    
    # CART related args
    "diffusion_config_dir": "/home/yl2428/Time-LLM/models/model9_NS_transformer/configs/toy_8gauss.yml",
    "cond_pred_model_pertrain_dir": None,
    "CART_input_x_embed_dim": 32,
    "mse_timestep": 0,
    "MLP_diffusion_net": False,
    
    # Ax args
    "timesteps": 1000,
    
    # Additional parameters
    "master_port": 8889,
    "comment": "TimeLLM-ECL"
})


args = default_config

for ii in range(args.itr):
    train_data, train_loader, args = data_provider(args, args.data_pretrain, args.data_path_pretrain, True, 'train')
    vali_data, vali_loader, args = data_provider(args, args.data_pretrain, args.data_path_pretrain, True, 'val')
    test_data, test_loader, args = data_provider(args, args.data_pretrain, args.data_path_pretrain, False, 'test')
    model = TimeSeriesDiffusionModel(args, train_loader, vali_loader, test_loader)

In [2]:
import argparse
import torch
from data_provider_pretrain.data_factory import data_provider
from models.time_series_flow_matching_model import TimeSeriesFlowMatchingModel
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from utils.callbacks import EMA
from lightning.pytorch.loggers import WandbLogger
import time
import random
import numpy as np
import os
import wandb
from datetime import timedelta
from utils.clean_args import clean_args
import glob
import re
os.environ['CURL_CA_BUNDLE'] = ''
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

fix_seed = 2021
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)
torch.cuda.manual_seed(fix_seed)
torch.cuda.manual_seed_all(fix_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class DotDict(dict):
    """
    A dictionary that supports both dot notation and dictionary access.
    This allows both `args.att` and `args['att']` to work.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__dict__[key] = value

    def __delattr__(self, item):
        self.__dict__.pop(item, None)

# Flow matching configuration based on train_glucose_diffusion_slurm.sh
flow_matching_config = DotDict({
    # Basic config
    "num_nodes": 1,
    "task_name": "long_term_forecast",
    "is_training": 1,
    "model_id": "ETTh1_ETTh2_512_192",
    "model": "ns_DLinear",  # From shell script
    "precision": "32",
    "generative_model": "flow_matching",  # Key difference from diffusion
    
    # Data loader (from shell script)
    "data_pretrain": "Glucose",
    "root_path": "/home/yl2428/Time-LLM/dataset/glucose",
    "data_path": "combined_data_Jun_28.csv",
    "data_path_pretrain": "combined_data_Jun_28.csv",
    "features": "MS",  # From shell script
    "target": "OT",
    "freq": "t",
    "checkpoints": "/home/yl2428/checkpoints/",
    "log_dir": "/home/yl2428/logs",
    
    # Forecasting task (from shell script)
    "seq_len": 48,
    "label_len": 32,
    "pred_len": 36,
    "seasonal_patterns": "Monthly",
    "stride": 1,  # From shell script
    
    # Model define (from shell script)
    "enc_in": 4,
    "dec_in": 4,
    "c_out": 4,
    "d_model": 32,  # From shell script
    "n_heads": 8,
    "e_layers": 2,
    "d_layers": 1,
    "d_ff": 256,  # From shell script
    "moving_avg": 25,
    "factor": 3,  # From shell script
    "dropout": 0.1,
    "embed": "timeF",
    "activation": "gelu",
    "output_attention": False,
    "patch_len": 16,
    "prompt_domain": 0,
    "llm_model": "LLAMA",
    "llm_dim": 4096,
    
    # VAE-specific parameters for ns_DLinear
    "latent_len": 24,  # Half of seq_len by default
    "vae_hidden_dim": 16,
    
    # Required for Trompt encoder - these will be populated by data_provider
    "col_stats": None,
    "col_names_dict": None,
    
    # Optimization (from shell script)
    "num_workers": 10,
    "itr": 1,
    "train_epochs": 100,  # From shell script
    "align_epochs": 10,
    "ema_decay": 0.995,
    "batch_size": 64,  # From shell script
    "eval_batch_size": 8,
    "patience": 10,
    "learning_rate": 0.0001,  # From shell script
    "des": "Exp",
    "loss": "MSE",
    "lradj": "COS",
    "pct_start": 0.2,
    "use_amp": False,
    "llm_layers": 32,  # From shell script (llama_layers)
    "percent": 100,
    "num_individuals": 100,  # From shell script
    "enable_covariates": 1,  # From shell script
    "cov_type": "tensor",
    "gradient_accumulation_steps": 1,
    "use_deep_speed": 1,  # From shell script
    
    # Wandb
    "wandb": 1,
    "wandb_group": None,
    "wandb_api_key": "6f1080f993d5d7ad6103e69ef57dd9291f1bf366",
    
    # MoE parameters (from shell script)
    "use_moe": 1,
    "num_experts": 8,
    "top_k_experts": 4,
    "moe_layer_indices": [0, 1],
    "moe_loss_weight": 0.01,
    "log_routing_stats": 1,
    "num_universal_experts": 1,
    "universal_expert_weight": 0.3,
    "head_dropout": 0.1,
    
    # TimeMixer-specific parameters
    "channel_independence": 0,
    "decomp_method": "moving_avg",
    "use_norm": 1,
    "down_sampling_layers": 2,
    "down_sampling_window": 1,
    "down_sampling_method": "avg",
    "use_future_temporal_feature": 0,
    
    # Flow matching specific parameters
    "k_z": 1e-2,
    "k_cond": 1,
    "d_z": 8,
    
    # De-stationary projector params
    "p_hidden_dims": [64, 64],
    "p_hidden_layers": 2,
    
    # Flow matching config
    "diffusion_config_dir": "/home/yl2428/Time-LLM/models/model9_NS_transformer/configs/toy_8gauss.yml",
    "cond_pred_model_pertrain_dir": None,
    "CART_input_x_embed_dim": 32,
    "mse_timestep": 0,
    "MLP_diffusion_net": False,
    
    # Flow matching specific timesteps (reduced from 1000 for efficiency)
    "timesteps": 50,
    
    # Flow matching ODE solver parameters
    "ode_solver": "dopri5",
    "ode_rtol": 1e-5,
    "ode_atol": 1e-5,
    "interpolation_type": "linear",
})

def find_best_checkpoint(base_path="/home/yl2428/logs/ns_DLinear/flow_matching", metric="val_loss"):
    """
    Find the best checkpoint based on validation loss.
    
    Args:
        base_path: Base directory to search for checkpoints
        metric: Metric to optimize (default: val_loss, lower is better)
    
    Returns:
        tuple: (best_checkpoint_path, best_metric_value, run_name)
    """
    print(f"Searching for checkpoints in: {base_path}")
    
    # Find all checkpoint directories
    checkpoint_pattern = os.path.join(base_path, "*/checkpoints/epoch=*-step=*-val_loss=*.ckpt/checkpoint")
    checkpoint_dirs = glob.glob(checkpoint_pattern)
    
    if not checkpoint_dirs:
        print("No checkpoints found!")
        return None, None, None
    
    best_checkpoint = None
    best_metric = float('inf')  # Assuming lower is better for val_loss
    best_run = None
    
    print(f"Found {len(checkpoint_dirs)} checkpoints:")
    
    for checkpoint_dir in checkpoint_dirs:
        # Extract metric value from path
        # Pattern: epoch=X-step=Y-val_loss=Z.ckpt
        pattern = r'epoch=(\d+)-step=(\d+)-val_loss=([\d.]+)\.ckpt'
        match = re.search(pattern, checkpoint_dir)
        
        if match:
            epoch, step, val_loss = match.groups()
            val_loss = float(val_loss)
            
            # Extract run name
            run_name = checkpoint_dir.split('/')[-4]  # Get run directory name
            
            print(f"  - {run_name}: epoch={epoch}, step={step}, val_loss={val_loss:.4f}")
            
            if val_loss < best_metric:
                best_metric = val_loss
                best_checkpoint = checkpoint_dir
                best_run = run_name
    
    if best_checkpoint:
        print(f"\nBest checkpoint: {best_run}")
        print(f"  - Path: {best_checkpoint}")
        print(f"  - Val Loss: {best_metric:.4f}")
    
    return best_checkpoint, best_metric, best_run

def load_deepspeed_checkpoint(model, checkpoint_path):
    """
    Load DeepSpeed checkpoint into the model.
    
    Args:
        model: PyTorch Lightning model
        checkpoint_path: Path to the DeepSpeed checkpoint directory
    
    Returns:
        model: Model with loaded weights
    """
    print(f"Loading DeepSpeed checkpoint from: {checkpoint_path}")
    
    # DeepSpeed saves model states in mp_rank_00_model_states.pt
    model_states_path = os.path.join(checkpoint_path, "mp_rank_00_model_states.pt")
    
    if not os.path.exists(model_states_path):
        raise FileNotFoundError(f"Model states file not found: {model_states_path}")
    
    print(f"Loading model states from: {model_states_path}")
    
    # Determine the device to use
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load the checkpoint
    checkpoint = torch.load(model_states_path, map_location=device)
    
    # Extract the model state dict
    if 'module' in checkpoint:
        state_dict = checkpoint['module']
    elif 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        # Sometimes the checkpoint is the state dict directly
        state_dict = checkpoint
    
    # Remove any DeepSpeed prefixes if they exist
    cleaned_state_dict = {}
    for key, value in state_dict.items():
        # Remove common prefixes that DeepSpeed might add
        clean_key = key
        if key.startswith('_forward_module.'):
            clean_key = key.replace('_forward_module.', '')
        elif key.startswith('module.'):
            clean_key = key.replace('module.', '')
        
        # Ensure the tensor is on the correct device
        if isinstance(value, torch.Tensor):
            value = value.to(device)
        
        cleaned_state_dict[clean_key] = value
    
    # Load the state dict into the model
    try:
        # First move the model to the device
        model = model.to(device)
        
        # Load the state dict
        missing_keys, unexpected_keys = model.load_state_dict(cleaned_state_dict, strict=False)
        
        if missing_keys:
            print(f"Missing keys: {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}")
        if unexpected_keys:
            print(f"Unexpected keys: {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}")
            
        print("✓ Model weights loaded successfully!")
        
    except Exception as e:
        print(f"Warning: Some keys couldn't be loaded: {e}")
        # Try to load what we can
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in cleaned_state_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        
        # Move model to device first
        model = model.to(device)
        model.load_state_dict(model_dict)
        print(f"✓ Loaded {len(pretrained_dict)}/{len(cleaned_state_dict)} parameters")
    
    # Ensure all submodules are on the correct device
    model = model.to(device)
    
    # Special handling for torch_frame components that might have device issues
    def move_torch_frame_components_to_device(module, device):
        """Recursively move torch_frame components to device"""
        for name, child in module.named_children():
            if hasattr(child, 'fill_values') and isinstance(child.fill_values, torch.Tensor):
                child.fill_values = child.fill_values.to(device)
            if hasattr(child, 'embedding_table') and isinstance(child.embedding_table, torch.Tensor):
                child.embedding_table = child.embedding_table.to(device)
            # Recursively apply to children
            move_torch_frame_components_to_device(child, device)
    
    # Apply device fix to the model
    move_torch_frame_components_to_device(model, device)
    
    print(f"✓ All model components moved to {device}")
    
    return model

def move_batch_to_device(batch, device):
    """
    Move a batch of data to the specified device.
    
    Args:
        batch: Batch data (can be tuple, list, tensor, or TensorFrame)
        device: Target device
    
    Returns:
        batch: Batch moved to device
    """
    if isinstance(batch, (list, tuple)):
        return type(batch)(move_batch_to_device(item, device) for item in batch)
    elif isinstance(batch, torch.Tensor):
        return batch.to(device)
    elif hasattr(batch, 'to'):  # For TensorFrame and similar objects
        return batch.to(device)
    else:
        return batch

def load_flow_matching_model_with_weights(checkpoint_path=None, auto_find_best=True):
    """
    Load and initialize the flow matching model with the specified configuration and weights.
    
    Args:
        checkpoint_path: Specific path to checkpoint directory (optional)
        auto_find_best: If True, automatically find the best checkpoint (default: True)
    
    Returns:
        tuple: (model, args, loaders, checkpoint_info)
    """
    
    # Load data with flow matching config
    flow_args = flow_matching_config
    print("Loading data for Flow Matching model...")
    
    train_data_fm, train_loader_fm, flow_args = data_provider(
        flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, True, 'train'
    )
    vali_data_fm, vali_loader_fm, flow_args = data_provider(
        flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, True, 'val'
    )
    test_data_fm, test_loader_fm, flow_args = data_provider(
        flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, False, 'test'
    )
    
    # Initialize Flow Matching model
    print("Initializing Time Series Flow Matching Model...")
    flow_matching_model = TimeSeriesFlowMatchingModel(flow_args, train_loader_fm, vali_loader_fm, test_loader_fm)
    
    checkpoint_info = {}
    
    # Load weights if specified
    if checkpoint_path or auto_find_best:
        if auto_find_best and not checkpoint_path:
            print("\nFinding best checkpoint...")
            checkpoint_path, best_metric, run_name = find_best_checkpoint()
            checkpoint_info = {
                'path': checkpoint_path,
                'val_loss': best_metric,
                'run_name': run_name
            }
        
        if checkpoint_path:
            print(f"\nLoading weights from checkpoint...")
            flow_matching_model = load_deepspeed_checkpoint(flow_matching_model, checkpoint_path)
            if not checkpoint_info:
                checkpoint_info = {'path': checkpoint_path}
        else:
            print("No checkpoint found to load.")
    
    print("✓ Flow Matching model loaded successfully!")
    print(f"  - Model type: {flow_args.model}")
    print(f"  - Generative model: {flow_args.generative_model}")
    print(f"  - ODE Solver: {flow_args.ode_solver}")
    print(f"  - Timesteps: {flow_args.timesteps}")
    print(f"  - Batch size: {flow_args.batch_size}")
    print(f"  - Learning rate: {flow_args.learning_rate}")
    print(f"  - MoE enabled: {flow_args.use_moe}")
    print(f"  - Covariates enabled: {flow_args.enable_covariates}")
    print(f"  - Model dimensions: d_model={flow_args.d_model}, d_ff={flow_args.d_ff}")
    print(f"  - Sequence lengths: seq_len={flow_args.seq_len}, pred_len={flow_args.pred_len}")
    
    if checkpoint_info:
        print(f"\nCheckpoint info:")
        if 'run_name' in checkpoint_info:
            print(f"  - Run: {checkpoint_info['run_name']}")
        if 'val_loss' in checkpoint_info:
            print(f"  - Validation Loss: {checkpoint_info['val_loss']:.4f}")
        print(f"  - Path: {checkpoint_info['path']}")
    
    print(f"\n📝 Usage Tips:")
    print(f"  - Use model.eval() before inference")
    print(f"  - Move data to device: batch = move_batch_to_device(batch, model.device)")
    print(f"  - For sampling: model.sample_step(batch, batch_idx)")
    
    return flow_matching_model, flow_args, (train_loader_fm, vali_loader_fm, test_loader_fm), checkpoint_info

def load_flow_matching_model():
    """Load and initialize the flow matching model with the specified configuration (without weights)."""
    
    # Load data with flow matching config
    flow_args = flow_matching_config
    print("Loading data for Flow Matching model...")
    
    train_data_fm, train_loader_fm, flow_args = data_provider(
        flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, True, 'train'
    )
    vali_data_fm, vali_loader_fm, flow_args = data_provider(
        flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, True, 'val'
    )
    test_data_fm, test_loader_fm, flow_args = data_provider(
        flow_args, flow_args.data_pretrain, flow_args.data_path_pretrain, False, 'test'
    )
    
    # Initialize Flow Matching model
    print("Initializing Time Series Flow Matching Model...")
    flow_matching_model = TimeSeriesFlowMatchingModel(flow_args, train_loader_fm, vali_loader_fm, test_loader_fm)
    
    print("✓ Flow Matching model loaded successfully!")
    print(f"  - Model type: {flow_args.model}")
    print(f"  - Generative model: {flow_args.generative_model}")
    print(f"  - ODE Solver: {flow_args.ode_solver}")
    print(f"  - Timesteps: {flow_args.timesteps}")
    print(f"  - Batch size: {flow_args.batch_size}")
    print(f"  - Learning rate: {flow_args.learning_rate}")
    print(f"  - MoE enabled: {flow_args.use_moe}")
    print(f"  - Covariates enabled: {flow_args.enable_covariates}")
    print(f"  - Model dimensions: d_model={flow_args.d_model}, d_ff={flow_args.d_ff}")
    print(f"  - Sequence lengths: seq_len={flow_args.seq_len}, pred_len={flow_args.pred_len}")
    
    return flow_matching_model, flow_args, (train_loader_fm, vali_loader_fm, test_loader_fm)



In [None]:
model, args, loaders, checkpoint_info = load_flow_matching_model_with_weights(checkpoint_path="/home/yl2428/logs/ns_DLinear/flow_matching/fine-sponge-387/checkpoints/epoch=9-step=23070-val_loss=1.8386.ckpt/checkpoint")
train_loader, val_loader, test_loader = loaders

print("\nModel summary:")
print(f"Flow matching model has {sum(p.numel() for p in model.parameters())} parameters")
print(f"Condition prediction model has {sum(p.numel() for p in model.cond_pred_model.parameters())} parameters")

In [None]:

trainer = pl.Trainer(
    accelerator='cuda',
    devices=1, precision='64')
model.cuda()

In [5]:
state_dict = torch.load('/gpfs/gibbs/pi/gerstein/yl2428/logs/ns_Transformer/desert-sweep-6/checkpoints/checkpoints_1.pt')
# turn into double
for key in state_dict.keys():
    state_dict[key] = state_dict[key].double()

In [None]:
model.state_dict

In [None]:
model.load_state_dict(state_dict, strict=False)

In [None]:
next(iter(train_loader))[0][0]

In [None]:
train_loader

In [None]:
wandb.init(project="ns_Transformer", name="test")
trainer.test(model, train_loader)

In [None]:
model.sample_outputs[1].keys()

In [None]:
model.sample_outputs[0]['batch_x'].shape

In [8]:
torch.save(model.sample_outputs, '/gpfs/gibbs/pi/gerstein/yl2428/logs/ns_Transformer/desert-sweep-6/checkpoints/sample_outputs_May11.pt')

In [9]:
model.sample_outputs = torch.load('/gpfs/gibbs/pi/gerstein/yl2428/logs/ns_Transformer/desert-sweep-6/checkpoints/sample_outputs.pt')

In [11]:
from torch_frame import stype

In [None]:
model.cuda()

In [None]:
batch = model.sample_outputs[0]['batch']
batch_x, batch_y, batch_x_mark, batch_y_mark = batch[0]
batch_cov = batch[1]
with torch.no_grad():
    new_batch_x = batch_x.clone()
    new_batch_x[53, :, 1] = batch_x[53, :, 1].min()
    model.eval()
    new_batch = [None, None]
    new_batch[0] = new_batch_x, batch_y, batch_x_mark, batch_y_mark
    new_batch[1] = batch_cov
    model.sample_step(new_batch, 0)

In [None]:
len(model.sample_outputs)

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from scipy.ndimage import uniform_filter1d
%matplotlib inline
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

def plot_time_series_with_ci(groundtruth, sampled_output, cov, batch_x=None, num_series=5):
    fig, axes = plt.subplots(num_series, 1, figsize=(12, 6*num_series), sharex=True)
    if num_series == 1:
        axes = [axes]
    idx_list = [53, 11, 19]  # Adjust or randomize this list as needed
    for i in range(num_series):
        # Randomly select a time series from the batch
        idx = idx_list[i]
        hba1c = cov.feat_dict[stype.numerical][idx, 3]
        diabetes_onset = cov.feat_dict[stype.numerical][idx, 1]
        weight = cov.feat_dict[stype.numerical][idx, 4]
        
        if batch_x is not None:
            # Extract previous glucose values
            previous_glucose = batch_x[idx, :, -1].cpu().numpy()
            hr = batch_x[idx, :, 0] * 20.41707644 + 7.93461185e+01
            steps = batch_x[idx, :, 1] * 20.84327263 +  6.53019535e+00
            print(steps)
            hr_mean = np.mean(hr.cpu().numpy())
            steps_mean = np.sum(steps.cpu().numpy())
            
            # Concatenate previous glucose with groundtruth and mean predictions
            full_groundtruth = np.concatenate([previous_glucose, groundtruth[idx, :, -1]])
        
        else:
            full_groundtruth = groundtruth[idx, :, -1].cpu().numpy()

        # Plot ground truth (concatenated)
        axes[i].plot(full_groundtruth, color='#1f77b4', label='Ground Truth (with previous)', lw=2)
        
        # Add textual information
        axes[i].text(0, 2.8, f'idx: {idx}, hba1c: {hba1c.cpu().numpy():.2f}, diabetes_onset: {diabetes_onset.cpu().numpy():.2f}, weight: {weight.cpu().numpy():.2f}, steps: {steps_mean:.2f}', 
                     fontsize=12, color='black', bbox=dict(facecolor='white', alpha=0.5))
        
        # Calculate mean and confidence interval for predicted values
        mean = np.mean(sampled_output[idx, :, :, 0], axis=0)
        ci_lower = np.percentile(sampled_output[idx, :, :, 0], 80, axis=0)
        ci_upper = np.percentile(sampled_output[idx, :, :, 0], 20, axis=0)
        
        # Smooth the CI with a moving average
        ci_lower_smooth = uniform_filter1d(ci_lower, size=5)
        ci_upper_smooth = uniform_filter1d(ci_upper, size=5)
        
        # Concatenate previous glucose with predicted mean and CI
        full_mean = np.concatenate([previous_glucose, mean])
        full_ci_lower = np.concatenate([previous_glucose, ci_lower_smooth])
        full_ci_upper = np.concatenate([previous_glucose, ci_upper_smooth])
        
        # Plot mean prediction (concatenated)
        axes[i].plot(full_mean, color='#ff7f0e', label='Mean Prediction (with previous)', lw=2)
        
        # Plot confidence interval (concatenated)
        axes[i].fill_between(range(full_mean.shape[0]), full_ci_lower, full_ci_upper, color='#ff7f0e', alpha=0.3, label='95% CI')
        
        axes[i].set_title(f'Time Series {i+1}', fontsize=14)
        axes[i].set_xlabel('Time Step', fontsize=12)
        axes[i].set_ylabel('Value', fontsize=12)
        
        # Set y limit to be the same for all plots
        axes[i].set_ylim([-3, 3])
        axes[i].legend(loc='upper right', fontsize=12)
    
    plt.tight_layout(pad=3.0)
    plt.savefig('time_series_with_ci.pdf')
    plt.show()

# Sample invocation of the function with your data
# plot_time_series_with_ci(groundtruth, sampled_output, cov, num_series=5)

j = 0
groundtruth_to_plot = model.sample_outputs[j]['true']
sampled_output_to_plot = model.sample_outputs[j]['pred']
cov_to_plot = model.sample_outputs[j]['batch_cov']
batch_x_to_plot = model.sample_outputs[j]['batch_x']
# Call the function to plot 2 random time series
plot_time_series_with_ci(groundtruth_to_plot, sampled_output_to_plot, cov_to_plot, batch_x_to_plot, num_series=3)

In [None]:
# Clone the specific slice you intend to modify
original_tensor_slice = batch_cov_orig.feat_dict[stype.numerical][:, 3]
modified_slice = original_tensor_slice.clone() * 1.1
batch_cov_orig.feat_dict[stype.numerical][:, 3] = modified_slice

In [None]:
# Clone the specific slice you intend to modify
original_tensor_slice = batch_cov_orig.feat_dict[stype.numerical][:, 3]
modified_slice = original_tensor_slice.clone() * 1.1
batch_cov_orig.feat_dict[stype.numerical][:, 3] = modified_slice

In [None]:
# Clone the specific slice you intend to modify
original_tensor_slice = batch_cov_orig.feat_dict[stype.numerical][:, 3]
modified_slice = original_tensor_slice.clone() * 1.1
batch_cov_orig.feat_dict[stype.numerical][:, 3] = modified_slice

In [None]:
# Clone the specific slice you intend to modify
original_tensor_slice = batch_cov_orig.feat_dict[stype.numerical][:, 3]
modified_slice = original_tensor_slice.clone() * 1.1
batch_cov_orig.feat_dict[stype.numerical][:, 3] = modified_slice

In [None]:
batch_x

In [7]:
import torch
import numpy as np
import random
from torch_frame import stype # Ensure this is consistent with how stype is used/imported earlier

# Ensure 'model' is loaded and model.sample_outputs is populated from previous cells.
# For example, if needed:
# model.sample_outputs = torch.load('/gpfs/gibbs/pi/gerstein/yl2428/logs/ns_Transformer/desert-sweep-6/checkpoints/sample_outputs.pt')

def perturb_hba1c_covariates(batch_cov, individual_indices, percentage_increase):
    """
    Perturbs HBA1c for specified individuals in the batch_cov.
    HBA1c is assumed to be at index 3 of the numerical features based on notebook analysis.
    Input batch_cov is expected to be a torch_frame.MaterializedFrame object or similar.
    Returns a new batch_cov object with perturbations; does not modify the input object.
    """
    if stype.numerical not in batch_cov.feat_dict:
        print(f"Warning: stype.numerical ('{stype.numerical}') not found in batch_cov.feat_dict. Returning original batch_cov.")
        return batch_cov

    # Clone the numerical features tensor to ensure modifications do not affect the original batch_cov
    original_numerical_tensor = batch_cov.feat_dict[stype.numerical]
    perturbed_numerical_tensor = original_numerical_tensor.clone()
    
    for idx in individual_indices:
        if 0 <= idx < perturbed_numerical_tensor.shape[0]:
            # Modify the cloned tensor
            current_hba1c_val = perturbed_numerical_tensor[idx, 3]
            perturbed_numerical_tensor[idx, 3] = current_hba1c_val * (1 + percentage_increase / 100.0)
            # print(f"Individual {idx}: HbA1c changed from {current_hba1c_val.item():.2f} to {perturbed_numerical_tensor[idx, 3].item():.2f}")
        else:
            print(f"Warning: Index {idx} is out of bounds for numerical_feats (shape: {perturbed_numerical_tensor.shape}). Skipping perturbation for this index.")
            
    from copy import deepcopy
    new_feat_dict = deepcopy(batch_cov) 
    # 
    new_feat_dict.feat_dict[stype.numerical] = perturbed_numerical_tensor
    
    # Prepare arguments for constructing the new frame object.
    # It's important to pass all necessary attributes from the original batch_cov
    # that are required by its constructor (e.g., col_names_dict, col_stats).
    
    return new_feat_dict



In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# --- Perturbation Analysis ---
# Select a batch for analysis (e.g., the last one processed or a specific one)
# If model.sample_outputs is a list of outputs from trainer.test:
# Each element in model.sample_outputs would typically be a dictionary
# from a single batch processed by test_step.
# We'll use the last batch's data as an example.
# You might need to adjust which batch or how data is selected based on your exact structure.

if not model.sample_outputs:
    print("Error: model.sample_outputs is empty. Please ensure the model has processed data and populated this list.")
else:
    # Assuming the structure seen in the notebook: model.sample_outputs[i]['batch']
    # and model.sample_outputs[i]['pred'] for predictions.
    # We need an original batch to get 'batch_x', 'batch_y', 'batch_x_mark', 'batch_y_mark', and 'batch_cov'.
    
    # Let's use the data from the last entry in sample_outputs for perturbation
    # This corresponds to the data used for the last plot in the notebook (j = -1)
    # Or you can select a specific batch index, e.g., batch_index_to_perturb = 0
    batch_index_to_perturb = -1 # Use the last batch by default
    
    original_batch_data_dict = model.sample_outputs[batch_index_to_perturb]
    original_batch_tuple = original_batch_data_dict['batch'] # This is [ (batch_x, batch_y, batch_x_mark, batch_y_mark), batch_cov ]
    
    batch_x_orig, batch_y_orig, batch_x_mark_orig, batch_y_mark_orig = original_batch_tuple[0]
    batch_cov_orig = original_batch_tuple[1]

    # Get original predictions (samples before perturbation)
    # These are the 'pred' values from the model's output for this original batch
    # Assuming 'pred' stores the 50 samples: [batch_size, num_samples, pred_len, num_features]
    # And we are interested in the glucose feature, which is the last one (index -1 or 3 for c_out=4)
    sampled_output_before_perturb = original_batch_data_dict['pred'][..., -1] # Taking only glucose

    num_individuals_in_batch = batch_x_orig.shape[0]
    num_to_perturb = min(10, num_individuals_in_batch) # Perturb up to 3 individuals, or fewer if batch is small

    # Randomly select individuals to perturb
    # Ensure reproducibility if desired, by setting random.seed elsewhere or here for this specific selection
    # random.seed(42) # for reproducibility of selection
    individuals_to_perturb_indices = random.sample(range(num_individuals_in_batch), num_to_perturb)
    print(f"Original batch size: {num_individuals_in_batch}")
    print(f"Randomly selected individuals to perturb (indices): {individuals_to_perturb_indices}")

    # Perturb HBA1c for the selected individuals
    percentage_increase_hba1c = 200.0
    batch_cov_perturbed = perturb_hba1c_covariates(batch_cov_orig, individuals_to_perturb_indices, percentage_increase_hba1c)

    # Prepare the new batch for the model's sample_step or equivalent generation function
    # The model.sample_step(batch, batch_idx) was used in the notebook
    # We need to simulate how samples are generated or find the appropriate generation function.
    # If model.sample_step appends to model.sample_outputs, we need to handle that.
    # For now, let's assume we need to call a generation function.
    # The `sample_step` in TimeSeriesDiffusionModel takes `batch` and `batch_idx`
    # and seems to append to `self.sample_outputs`.
    # To get samples for the perturbed data without altering `model.sample_outputs` from original runs,
    # we might need to call a more direct sampling/prediction method of the model if available,
    # or temporarily store and then restore `model.sample_outputs`.

    # Let's try to get new samples.
    # The model's `predict_step` or a similar generation function is needed.
    # In TimeSeriesDiffusionModel, `sample_step` is used during `test_step` and it appends to `self.sample_outputs`.
    # A more direct way to get samples would be to call `model.model.sample()` (for the inner diffusion model)
    # or `model.cond_pred_model.predict()` if it's about conditional prediction.
    # Given the existing notebook structure, `model.sample_step` is what was used to generate `model.sample_outputs`.

    # To avoid confusion with previously stored sample_outputs, we will call a direct sampling method
    # of the underlying diffusion model if possible.
    # The TimeSeriesDiffusionModel has a `sample` method.
    # Signature: sample(self, batch_x, batch_x_mark, batch_y_mark, N=50, cond_scale=0.)
    
    # We need to get cond from batch_cov_perturbed
    # The model has `self.cond_pred_model.encode_cond(batch_cov)`
    # And then uses this `cond` in its own `sample` method, which calls `self.model.sample`.
    
    model.eval() # Ensure model is in eval mode
    with torch.no_grad():
        # 1. Encode covariates to get the condition
        # The `encode_cond` method might need the batch_cov on the correct device
        device = batch_x_orig.device # Assuming batch_x_orig is already on the correct device
        
        # The covariates in batch_cov_perturbed need to be on the same device as the model
        # Typically, the data loader handles this. Here we do it manually if needed.
        # Assuming batch_cov_perturbed.feat_dict[stype.numerical] is a tensor.
        
        # Create a new batch structure for the perturbed data
        perturbed_batch_for_sampling = [
            (batch_x_orig.to(device), batch_y_orig.to(device), batch_x_mark_orig.to(device), batch_y_mark_orig.to(device)), # Original x, y, x_mark, y_mark
            batch_cov_perturbed # Perturbed covariates
        ]

        model.sample_step(perturbed_batch_for_sampling , 0)
    
    sampled_output_after_perturb = model.sample_outputs[-1]['pred']

    print(f"Shape of original sampled output (glucose only): {sampled_output_before_perturb.shape}")
    print(f"Shape of perturbed sampled output (glucose only): {sampled_output_after_perturb.shape}")

    # Store HBA1c values for individuals of interest for plotting/stats
    original_hba1c_values = {}
    perturbed_hba1c_values = {}

    for i_idx in individuals_to_perturb_indices:
        original_hba1c_values[i_idx] = batch_cov_orig.feat_dict[stype.numerical][i_idx, 3].item()
        perturbed_hba1c_values[i_idx] = batch_cov_perturbed.feat_dict[stype.numerical][i_idx, 3].item()
        print(f"Individual {i_idx}: Original HbA1c: {original_hba1c_values[i_idx]:.2f}, Perturbed HbA1c: {perturbed_hba1c_values[i_idx]:.2f}")

# Ground truth for plotting (from the original selected batch)
groundtruth_for_plot = original_batch_data_dict['true'][..., -1]
batch_x_for_plot = original_batch_data_dict['batch_x'][..., -1]
sampled_output_after_perturb = model.sample_outputs[-1]['pred']

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if not model.sample_outputs or 'individuals_to_perturb_indices' not in locals():
    print("Error: Ensure the perturbation analysis cell has been run and required variables are available.")
else:
    # Assuming c_out was 1 or f_dim correctly selected the single glucose feature for pred_len output
    # sampled_output_before_perturb shape: (batch_size, num_samples, pred_len)
    # sampled_output_after_perturb shape: (batch_size, num_samples, pred_len, 1) from notebook output
    # groundtruth_for_plot shape: (batch_size, pred_len)
    # batch_x_for_plot shape: (batch_size, seq_len)
    
    seq_len = batch_x_for_plot.shape[1]
    pred_len = groundtruth_for_plot.shape[1]
    time_history = np.arange(seq_len)
    time_pred = np.arange(seq_len, seq_len + pred_len)
    
    for idx in individuals_to_perturb_indices:
        history_data = batch_x_for_plot[idx].cpu().numpy()
        true_future_data = groundtruth_for_plot[idx]
        
        # Predictions before perturbation
        preds_before_raw = sampled_output_before_perturb[idx] # (num_samples, pred_len)
        mean_preds_before = np.mean(preds_before_raw, axis=0)
        std_preds_before = np.std(preds_before_raw, axis=0)
        
        # Predictions after perturbation
        # sampled_output_after_perturb has shape (batch_size, num_samples, pred_len, 1)
        preds_after_raw = sampled_output_after_perturb[idx, ..., 0] # (num_samples, pred_len)
        mean_preds_after = np.mean(preds_after_raw, axis=0)
        std_preds_after = np.std(preds_after_raw, axis=0)
        
        plt.figure(figsize=(15, 7))
        
        # Plot history
        plt.plot(time_history, history_data, label='Input History (Glucose)', color='black', linewidth=1.5)
        
        # Plot true future
        plt.plot(time_pred, true_future_data, label='Ground Truth Future (Glucose)', color='green', linestyle='--', linewidth=2)
        
        # Plot predictions before perturbation
        plt.plot(time_pred, mean_preds_before, 
                 label=f'Mean Pred (Before Perturb, Orig HbA1c: {original_hba1c_values[idx]:.2f})', 
                 color='blue', linewidth=1.5)
        plt.fill_between(time_pred, mean_preds_before - std_preds_before, mean_preds_before + std_preds_before, 
                         color='blue', alpha=0.2, label='Std Dev (Before)')
        
        # Plot predictions after perturbation
        plt.plot(time_pred, mean_preds_after, 
                 label=f'Mean Pred (After Perturb, New HbA1c: {perturbed_hba1c_values[idx]:.2f})', 
                 color='red', linewidth=1.5)
        plt.fill_between(time_pred, mean_preds_after - std_preds_after, mean_preds_after + std_preds_after, 
                         color='red', alpha=0.2, label='Std Dev (After)')
        
        plt.title(f'Glucose Prediction Perturbation Analysis for Individual {idx}', fontsize=16)
        plt.xlabel('Time Steps', fontsize=14)
        plt.ylabel('Glucose Value', fontsize=14)
        plt.legend(fontsize=10)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.show()

In [None]:
import numpy as np

if 'individuals_to_perturb_indices' not in locals() or \
   'sampled_output_before_perturb' not in locals() or \
   'sampled_output_after_perturb' not in locals():
    print("Error: Ensure the perturbation analysis and plotting cells have been run, and variables are available.")
else:
    print("\n--- Comparison of Average Standard Deviations (Time-Averaged) ---\n")
    avg_std_devs_before_list = []
    avg_std_devs_after_list = []

    for idx in individuals_to_perturb_indices:
        # Predictions before perturbation: shape (num_samples, pred_len)"
        preds_before_raw = sampled_output_before_perturb[idx]
        # Std dev across samples for each time step: shape (pred_len,)"
        std_dev_over_samples_before = np.std(preds_before_raw, axis=0)
        # Average this std dev over the prediction length"
        avg_std_before = np.mean(std_dev_over_samples_before)
        avg_std_devs_before_list.append(avg_std_before)

        # Predictions after perturbation: shape (num_samples, pred_len) after [..., 0] slicing"
        preds_after_raw = sampled_output_after_perturb[idx, ..., 0]
        # Std dev across samples for each time step: shape (pred_len,)"
        std_dev_over_samples_after = np.std(preds_after_raw, axis=0)
        # Average this std dev over the prediction length"
        avg_std_after = np.mean(std_dev_over_samples_after)
        avg_std_devs_after_list.append(avg_std_after)

        print(f"Individual {idx}:")
        print(f"  Avg. Std. Dev (Before Perturbation): {avg_std_before:.4f}")
        print(f"  Avg. Std. Dev (After Perturbation):  {avg_std_after:.4f}")
        if avg_std_after > avg_std_before:
            print(f"  Comparison: Uncertainty (std dev) INCREASED by {avg_std_after - avg_std_before:.4f} after perturbation.")
        elif avg_std_after < avg_std_before:
            print(f"  Comparison: Uncertainty (std dev) DECREASED by {avg_std_before - avg_std_after:.4f} after perturbation.")
        else:
            print(f"  Comparison: Uncertainty (std dev) remained the same after perturbation.")
        print("-----")

    # Overall average if desired
    if avg_std_devs_before_list and avg_std_devs_after_list:
        overall_avg_std_before = np.mean(avg_std_devs_before_list)
        overall_avg_std_after = np.mean(avg_std_devs_after_list)
        print("\nOverall Average Across Perturbed Individuals:")
        print(f"  Overall Avg. Std. Dev (Before): {overall_avg_std_before:.4f}")
        print(f"  Overall Avg. Std. Dev (After):  {overall_avg_std_after:.4f}")
        if overall_avg_std_after > overall_avg_std_before:
            print(f"  Overall: Uncertainty INCREASED by {overall_avg_std_after - overall_avg_std_before:.4f}")
        elif overall_avg_std_after < overall_avg_std_before:
            print(f"  Overall: Uncertainty DECREASED by {overall_avg_std_before - overall_avg_std_after:.4f}")
        else:
            print(f"  Overall: Uncertainty remained the same.")

# Perturbation on Steps

In [58]:
def perturb_steps_batch_x(original_data_tuple, individual_indices, percentage_increase):
    """
    Perturbs 'steps' for specified individuals in batch_x.
    'steps' are derived from batch_x[:, :, STEPS_FEATURE_INDEX_IN_BATCH_X] using a specific formula.
    This function assumes batch_x is the first element of original_data_tuple:
    original_data_tuple = (batch_x, batch_y, batch_x_mark, batch_y_mark)
    Returns a new data_tuple with batch_x perturbed; does not modify the input tuple or its tensors.
    """
    # Constants for step calculation, as provided
    STEPS_SCALE = 20.84327263
    STEPS_OFFSET = 6.53019535e+00
    STEPS_FEATURE_INDEX_IN_BATCH_X = 1 # 0-indexed


    original_batch_x = original_data_tuple[0]
    perturbed_batch_x = original_batch_x.clone() # Ensure we don't modify the original tensor

    for idx in individual_indices:
        # Extract the scaled feature series for steps for the specific individual
        scaled_steps_series = perturbed_batch_x[idx, :, STEPS_FEATURE_INDEX_IN_BATCH_X]

        # Calculate current "true" step values (element-wise for the series)
        current_true_steps = scaled_steps_series * STEPS_SCALE + STEPS_OFFSET

        # Perturb the "true" step values
        perturbed_true_steps = current_true_steps * (1 + percentage_increase / 100.0)
        # Convert perturbed "true" steps back to scaled values for storage in batch_x
        new_scaled_steps_series = (perturbed_true_steps - STEPS_OFFSET) / STEPS_SCALE

        # Update the cloned batch_x with the new scaled step series
        perturbed_batch_x[idx, :, STEPS_FEATURE_INDEX_IN_BATCH_X] = new_scaled_steps_series


    # Reconstruct the data tuple with the perturbed batch_x
    new_data_list = list(original_data_tuple)
    new_data_list[0] = perturbed_batch_x
    new_data_tuple = tuple(new_data_list)
    
    return new_data_tuple

def perturb_steps_and_hr_batch_x(original_data_tuple, individual_indices, percentage_increase):
    """
    Perturbs both steps AND heart rate together to maintain physiological correlation
    """
    # Constants
    STEPS_SCALE = 20.84327263
    STEPS_OFFSET = 6.53019535e+00
    STEPS_FEATURE_INDEX = 1
    
    HR_SCALE = 79.3461185
    HR_OFFSET = 20.41707644  
    HR_FEATURE_INDEX = 0
    
    original_batch_x = original_data_tuple[0]
    perturbed_batch_x = original_batch_x.clone()

    for idx in individual_indices:
        # Perturb steps
        scaled_steps = perturbed_batch_x[idx, :, STEPS_FEATURE_INDEX]
        true_steps = scaled_steps * STEPS_SCALE + STEPS_OFFSET
        perturbed_true_steps = true_steps * (1 + percentage_increase / 100.0)
        new_scaled_steps = (perturbed_true_steps - STEPS_OFFSET) / STEPS_SCALE
        perturbed_batch_x[idx, :, STEPS_FEATURE_INDEX] = new_scaled_steps
        
        # Perturb heart rate proportionally (maybe smaller increase, e.g., 50% of steps increase)
        scaled_hr = perturbed_batch_x[idx, :, HR_FEATURE_INDEX]
        true_hr = scaled_hr * HR_SCALE + HR_OFFSET
        hr_percentage_increase = percentage_increase * 0.5  # Adjust this ratio as needed
        perturbed_true_hr = true_hr * (1 + hr_percentage_increase / 100.0)
        new_scaled_hr = (perturbed_true_hr - HR_OFFSET) / HR_SCALE
        perturbed_batch_x[idx, :, HR_FEATURE_INDEX] = new_scaled_hr

    # Reconstruct tuple
    new_data_list = list(original_data_tuple)
    new_data_list[0] = perturbed_batch_x
    return tuple(new_data_list)

In [66]:
# Steps Perturbation Analysis
# Using the last batch from model.sample_outputs for perturbation analysis

if not model.sample_outputs:
    print("Error: model.sample_outputs is empty. Please ensure the model has processed data.")
else:
    # Use the last batch for perturbation
    batch_index_to_perturb = 1
    original_batch_data_dict = model.sample_outputs[batch_index_to_perturb]
    original_batch_tuple = original_batch_data_dict['batch']
    
    batch_x_orig, batch_y_orig, batch_x_mark_orig, batch_y_mark_orig = original_batch_tuple[0]
    batch_cov_orig = original_batch_tuple[1]
    
    # Get original predictions before perturbation
    sampled_output_before_perturb = original_batch_data_dict['pred'][..., -1]  # Glucose channel only
    
    num_individuals_in_batch = batch_x_orig.shape[0]
    num_to_perturb = min(64, num_individuals_in_batch)
    
    # Randomly select individuals to perturb
    individuals_to_perturb_indices = random.sample(range(num_individuals_in_batch), num_to_perturb)
    print(f"Original batch size: {num_individuals_in_batch}")
    print(f"Randomly selected individuals to perturb (indices): {individuals_to_perturb_indices}")
    
    # Apply steps perturbation with 50% increase
    percentage_increase_steps = 50.0
    perturbed_batch_tuple = perturb_steps_and_hr_batch_x(original_batch_tuple[0], individuals_to_perturb_indices, percentage_increase_steps)
    
    # Calculate original and perturbed steps values for comparison
    STEPS_SCALE = 20.84327263
    STEPS_OFFSET = 6.53019535e+00
    STEPS_FEATURE_INDEX_IN_BATCH_X = 1
    
    original_steps_values = {}
    perturbed_steps_values = {}
    
    for i_idx in individuals_to_perturb_indices:
        # Calculate original steps (average across time series)
        orig_scaled = batch_x_orig[i_idx, :, STEPS_FEATURE_INDEX_IN_BATCH_X].mean().item()
        orig_true_steps = orig_scaled * STEPS_SCALE + STEPS_OFFSET
        
        # Calculate perturbed steps
        pert_scaled = perturbed_batch_tuple[0][i_idx, :, STEPS_FEATURE_INDEX_IN_BATCH_X].mean().item()
        pert_true_steps = pert_scaled * STEPS_SCALE + STEPS_OFFSET
        
        original_steps_values[i_idx] = orig_true_steps
        perturbed_steps_values[i_idx] = pert_true_steps
        
        print(f"Individual {i_idx}: Original Steps: {orig_true_steps:.2f}, Perturbed Steps: {pert_true_steps:.2f}")
    
    # Create perturbed batch for model inference
    perturbed_batch_for_sampling = [
        perturbed_batch_tuple,
        batch_cov_orig
    ]
    
    # Run model inference with perturbed data
    model.eval()
    with torch.no_grad():
        model.sample_step(perturbed_batch_for_sampling, 1)
    
    # Get perturbed predictions
    sampled_output_after_perturb = model.sample_outputs[-1]['pred'][..., -1]
    
    print(f"\nShape of original sampled output (glucose only): {sampled_output_before_perturb.shape}")
    print(f"Shape of perturbed sampled output (glucose only): {sampled_output_after_perturb.shape}")

Original batch size: 64
Randomly selected individuals to perturb (indices): [15, 62, 17, 58, 52, 48, 27, 46, 0, 16, 18, 23, 13, 49, 33, 37, 12, 5, 38, 43, 32, 29, 6, 31, 45, 25, 41, 63, 20, 30, 2, 35, 3, 44, 61, 24, 26, 28, 11, 51, 34, 19, 42, 54, 22, 57, 8, 10, 50, 60, 36, 53, 7, 59, 4, 14, 21, 39, 40, 1, 56, 55, 9, 47]
Individual 15: Original Steps: 3.80, Perturbed Steps: 5.70
Individual 62: Original Steps: 0.93, Perturbed Steps: 1.40
Individual 17: Original Steps: 2.97, Perturbed Steps: 4.45
Individual 58: Original Steps: 2.26, Perturbed Steps: 3.39
Individual 52: Original Steps: -0.00, Perturbed Steps: -0.00
Individual 48: Original Steps: 6.48, Perturbed Steps: 9.72
Individual 27: Original Steps: -0.00, Perturbed Steps: -0.00
Individual 46: Original Steps: 2.83, Perturbed Steps: 4.25
Individual 0: Original Steps: 1.62, Perturbed Steps: 2.43
Individual 16: Original Steps: 0.26, Perturbed Steps: 0.39
Individual 18: Original Steps: 3.86, Perturbed Steps: 5.79
Individual 23: Original S

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [None]:
# Perturbation on Steps

In [67]:
# Plotting and Statistical Analysis of Steps Perturbation Results

# Calculate prediction statistics
pred_mean_before = sampled_output_before_perturb.mean(axis=1)  # Mean across samples
pred_std_before = sampled_output_before_perturb.std(axis=1)    # Std across samples
pred_mean_after = sampled_output_after_perturb.mean(axis=1)   # Mean across samples  
pred_std_after = sampled_output_after_perturb.std(axis=1)     # Std across samples

# Print individual statistics for perturbed individuals
print("\n=== Individual Statistics ===")
for i_idx in individuals_to_perturb_indices:
    print(f"\nIndividual {i_idx}:")
    print(f"  Steps: {original_steps_values[i_idx]:.2f} → {perturbed_steps_values[i_idx]:.2f} ({percentage_increase_steps}% increase)")
    print(f"  Pred Mean: {pred_mean_before[i_idx].mean():.3f} → {pred_mean_after[i_idx].mean():.3f}")
    print(f"  Pred Std:  {pred_std_before[i_idx].mean():.3f} → {pred_std_after[i_idx].mean():.3f}")

# Overall statistics
print(f"\n=== Overall Statistics ===")
print(f"Overall prediction mean change: {pred_mean_before.mean():.3f} → {pred_mean_after.mean():.3f}")
print(f"Overall prediction std change:  {pred_std_before.mean():.3f} → {pred_std_after.mean():.3f}")

# Create plots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Steps Perturbation Analysis Results', fontsize=16)

# Plot 1: Individual glucose predictions before/after for selected individuals
ax1 = axes[0, 0]
n_individuals_to_plot = min(3, len(individuals_to_perturb_indices))
for plot_idx, i_idx in enumerate(individuals_to_perturb_indices[:n_individuals_to_plot]):
    time_steps = range(pred_mean_before.shape[1])
    ax1.plot(time_steps, pred_mean_before[i_idx], 'b-', alpha=0.7, label=f'Before (Ind {i_idx})' if plot_idx == 0 else "")
    ax1.plot(time_steps, pred_mean_after[i_idx], 'r--', alpha=0.7, label=f'After (Ind {i_idx})' if plot_idx == 0 else "")

ax1.set_title('Individual Glucose Predictions')
ax1.set_xlabel('Time Steps')
ax1.set_ylabel('Glucose Level')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Average prediction change across all perturbed individuals
ax2 = axes[0, 1]
perturbed_before = pred_mean_before[individuals_to_perturb_indices]
perturbed_after = pred_mean_after[individuals_to_perturb_indices]
time_steps = range(pred_mean_before.shape[1])
ax2.plot(time_steps, perturbed_before.mean(axis=0), 'b-', linewidth=2, label='Before Perturbation')
ax2.plot(time_steps, perturbed_after.mean(axis=0), 'r-', linewidth=2, label='After Perturbation')
ax2.fill_between(time_steps, 
                 perturbed_before.mean(axis=0) - perturbed_before.std(axis=0),
                 perturbed_before.mean(axis=0) + perturbed_before.std(axis=0),
                 alpha=0.2, color='blue')
ax2.fill_between(time_steps,
                 perturbed_after.mean(axis=0) - perturbed_after.std(axis=0), 
                 perturbed_after.mean(axis=0) + perturbed_after.std(axis=0),
                 alpha=0.2, color='red')
ax2.set_title('Average Glucose Predictions (Perturbed Individuals)')
ax2.set_xlabel('Time Steps')
ax2.set_ylabel('Glucose Level')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Steps values comparison
ax3 = axes[1, 0]
individuals_plot = list(individuals_to_perturb_indices[:5])  # Show first 5
orig_steps = [original_steps_values[i] for i in individuals_plot]
pert_steps = [perturbed_steps_values[i] for i in individuals_plot]
x_pos = range(len(individuals_plot))
width = 0.35
ax3.bar([x - width/2 for x in x_pos], orig_steps, width, label='Original Steps', alpha=0.7)
ax3.bar([x + width/2 for x in x_pos], pert_steps, width, label='Perturbed Steps', alpha=0.7)
ax3.set_title('Steps Values: Before vs After Perturbation')
ax3.set_xlabel('Individual Index')
ax3.set_ylabel('Steps Value')
ax3.set_xticks(x_pos)
ax3.set_xticklabels([f'Ind {i}' for i in individuals_plot])
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Prediction difference distribution
ax4 = axes[1, 1]
pred_diff = pred_mean_after - pred_mean_before
perturbed_diff = pred_diff[individuals_to_perturb_indices].flatten()
control_diff = np.delete(pred_diff, individuals_to_perturb_indices, axis=0).flatten()
ax4.hist(perturbed_diff, bins=30, alpha=0.7, label='Perturbed Individuals', color='red')
ax4.hist(control_diff, bins=30, alpha=0.7, label='Control Individuals', color='blue')
ax4.set_title('Distribution of Prediction Changes')
ax4.set_xlabel('Glucose Prediction Change')
ax4.set_ylabel('Frequency')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary analysis
print(f"\n=== Summary Analysis ===")
print(f"Average steps increase applied: {percentage_increase_steps}%")
print(f"Number of individuals perturbed: {len(individuals_to_perturb_indices)}")
mean_pred_change_perturbed = pred_diff[individuals_to_perturb_indices].mean()
mean_pred_change_control = np.delete(pred_diff, individuals_to_perturb_indices, axis=0).mean()
print(f"Average glucose prediction change (perturbed): {mean_pred_change_perturbed:.4f}")
print(f"Average glucose prediction change (control): {mean_pred_change_control:.4f}")
print(f"Differential effect: {mean_pred_change_perturbed - mean_pred_change_control:.4f}")


=== Individual Statistics ===

Individual 15:
  Steps: 3.80 → 5.70 (50.0% increase)
  Pred Mean: -0.663 → -0.602
  Pred Std:  0.377 → 0.425

Individual 62:
  Steps: 0.93 → 1.40 (50.0% increase)
  Pred Mean: -0.030 → 0.016
  Pred Std:  0.474 → 0.507

Individual 17:
  Steps: 2.97 → 4.45 (50.0% increase)
  Pred Mean: -0.215 → -0.177
  Pred Std:  0.386 → 0.409

Individual 58:
  Steps: 2.26 → 3.39 (50.0% increase)
  Pred Mean: -0.211 → -0.207
  Pred Std:  0.329 → 0.348

Individual 52:
  Steps: -0.00 → -0.00 (50.0% increase)
  Pred Mean: 0.040 → 0.050
  Pred Std:  0.364 → 0.372

Individual 48:
  Steps: 6.48 → 9.72 (50.0% increase)
  Pred Mean: -0.011 → -0.035
  Pred Std:  0.286 → 0.310

Individual 27:
  Steps: -0.00 → -0.00 (50.0% increase)
  Pred Mean: -0.330 → -0.337
  Pred Std:  0.299 → 0.303

Individual 46:
  Steps: 2.83 → 4.25 (50.0% increase)
  Pred Mean: 0.447 → 0.440
  Pred Std:  0.661 → 0.662

Individual 0:
  Steps: 1.62 → 2.43 (50.0% increase)
  Pred Mean: -0.605 → -0.577
  Pred S

In [70]:
import matplotlib.pyplot as plt
import numpy as np

# Individual time series plots for each perturbed individual (following HbA1c style)
if 'individuals_to_perturb_indices' not in locals() or \
   'sampled_output_before_perturb' not in locals() or \
   'sampled_output_after_perturb' not in locals():
    print("Error: Ensure the perturbation analysis has been run and variables are available.")
else:
    # Get the time series data for plotting
    seq_len = batch_x_orig.shape[1]
    # Defensive: pred_len should be the last dimension of the predictions
    # For before perturb: shape (num_individuals, num_samples, pred_len)
    # For after perturb: shape (num_individuals, num_samples, pred_len, num_channels) or (num_individuals, num_samples, pred_len) if squeezed
    # We'll handle both cases
    if sampled_output_before_perturb.ndim == 3:
        pred_len = sampled_output_before_perturb.shape[2]
    elif sampled_output_before_perturb.ndim == 2:
        pred_len = sampled_output_before_perturb.shape[1]
    else:
        raise ValueError("Unexpected shape for sampled_output_before_perturb")
    time_history = np.arange(seq_len)
    time_pred = np.arange(seq_len, seq_len + pred_len)

    # Plot individual time series for each perturbed individual
    for idx in individuals_to_perturb_indices:
        # Get input history (glucose channel)
        history_data = batch_x_orig[idx, :, 0].cpu().numpy()  # Assuming glucose is channel 0

        # Get ground truth future
        true_future_data = batch_y_orig[idx, -pred_len:, 0].cpu().numpy()  # Glucose channel

        # Get predictions before perturbation
        preds_before_raw = sampled_output_before_perturb[idx]  # (num_samples, pred_len)
        if preds_before_raw.ndim == 2:
            mean_preds_before = np.mean(preds_before_raw, axis=0)
            std_preds_before = np.std(preds_before_raw, axis=0)
        elif preds_before_raw.ndim == 1:
            mean_preds_before = preds_before_raw
            std_preds_before = np.zeros_like(mean_preds_before)
        else:
            raise ValueError("Unexpected shape for preds_before_raw")

        # Get predictions after perturbation
        preds_after_raw = sampled_output_after_perturb[idx]
        # Handle possible extra channel dimension
        if preds_after_raw.ndim == 3:
            # (num_samples, pred_len, num_channels)
            preds_after_raw = preds_after_raw[..., 0]  # Glucose channel
        if preds_after_raw.ndim == 2:
            mean_preds_after = np.mean(preds_after_raw, axis=0)
            std_preds_after = np.std(preds_after_raw, axis=0)
        elif preds_after_raw.ndim == 1:
            mean_preds_after = preds_after_raw
            std_preds_after = np.zeros_like(mean_preds_after)
        else:
            raise ValueError("Unexpected shape for preds_after_raw")

        # Ensure all arrays are 1D and of length pred_len
        mean_preds_before = np.asarray(mean_preds_before).flatten()
        std_preds_before = np.asarray(std_preds_before).flatten()
        mean_preds_after = np.asarray(mean_preds_after).flatten()
        std_preds_after = np.asarray(std_preds_after).flatten()
        true_future_data = np.asarray(true_future_data).flatten()
        # Defensive: truncate or pad to pred_len if needed
        if mean_preds_before.shape[0] != pred_len:
            mean_preds_before = mean_preds_before[:pred_len]
            std_preds_before = std_preds_before[:pred_len]
        if mean_preds_after.shape[0] != pred_len:
            mean_preds_after = mean_preds_after[:pred_len]
            std_preds_after = std_preds_after[:pred_len]
        if true_future_data.shape[0] != pred_len:
            true_future_data = true_future_data[:pred_len]

        plt.figure(figsize=(15, 7))

        # Plot input history
        plt.plot(time_history, history_data, label='Input History (Glucose)',
                 color='black', linewidth=1.5)

        # Plot ground truth future
        plt.plot(time_pred, true_future_data, label='Ground Truth Future (Glucose)',
                 color='green', linestyle='--', linewidth=2)

        # Plot predictions before perturbation
        plt.plot(time_pred, mean_preds_before,
                 label=f'Mean Pred (Before Perturb, Orig Steps: {original_steps_values[idx]:.1f})',
                 color='blue', linewidth=1.5)
        plt.fill_between(time_pred, mean_preds_before - std_preds_before,
                         mean_preds_before + std_preds_before,
                         color='blue', alpha=0.2, label='Std Dev (Before)')

        # Plot predictions after perturbation
        plt.plot(time_pred, mean_preds_after,
                 label=f'Mean Pred (After Perturb, New Steps: {perturbed_steps_values[idx]:.1f})',
                 color='red', linewidth=1.5)
        plt.fill_between(time_pred, mean_preds_after - std_preds_after,
                         mean_preds_after + std_preds_after,
                         color='red', alpha=0.2, label='Std Dev (After)')

        plt.title(f'Glucose Prediction Steps Perturbation Analysis for Individual {idx}', fontsize=16)
        plt.xlabel('Time Steps', fontsize=14)
        plt.ylabel('Glucose Value', fontsize=14)
        plt.legend(fontsize=10)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.show()

# Summary statistics analysis
print(f"\n=== Steps Perturbation Summary Analysis ===")
print(f"Perturbation applied: {percentage_increase_steps}% change in steps")
print(f"Number of individuals perturbed: {len(individuals_to_perturb_indices)}")

# Calculate prediction differences
# Defensive: mean over axis=1 (samples), keep pred_len dimension
pred_mean_before = sampled_output_before_perturb.mean(axis=1)  # (num_individuals, pred_len)
if sampled_output_after_perturb.ndim == 4:
    pred_mean_after = sampled_output_after_perturb[..., 0].mean(axis=1)  # (num_individuals, pred_len)
else:
    pred_mean_after = sampled_output_after_perturb.mean(axis=1)   # (num_individuals, pred_len)
pred_diff = pred_mean_after - pred_mean_before  # (num_individuals, pred_len)

# Calculate effect sizes
perturbed_diff = pred_diff[individuals_to_perturb_indices].mean()
control_indices = [i for i in range(pred_diff.shape[0]) if i not in individuals_to_perturb_indices]
control_diff = pred_diff[control_indices].mean() if control_indices else 0

print(f"Average glucose prediction change (perturbed individuals): {perturbed_diff:.4f}")
print(f"Average glucose prediction change (control individuals): {control_diff:.4f}")
print(f"Differential effect of steps perturbation: {perturbed_diff - control_diff:.4f}")

# Individual statistics
print(f"\n=== Individual Results ===")
for i_idx in individuals_to_perturb_indices:
    steps_change = perturbed_steps_values[i_idx] - original_steps_values[i_idx]
    # Defensive: mean over pred_len dimension
    pred_change = pred_mean_after[i_idx, -1].mean() - pred_mean_before[i_idx, -1].mean()
    print(f"Individual {i_idx}:")
    print(f"  Steps change: {original_steps_values[i_idx]:.1f} → {perturbed_steps_values[i_idx]:.1f} ({steps_change:+.1f})")
    print(f"  Avg glucose prediction change: {pred_change:+.4f}")


=== Steps Perturbation Summary Analysis ===
Perturbation applied: 50.0% change in steps
Number of individuals perturbed: 64
Average glucose prediction change (perturbed individuals): 0.0087
Average glucose prediction change (control individuals): 0.0000
Differential effect of steps perturbation: 0.0087

=== Individual Results ===
Individual 15:
  Steps change: 3.8 → 5.7 (+1.9)
  Avg glucose prediction change: -0.0905
Individual 62:
  Steps change: 0.9 → 1.4 (+0.5)
  Avg glucose prediction change: +0.1444
Individual 17:
  Steps change: 3.0 → 4.4 (+1.5)
  Avg glucose prediction change: +0.2248
Individual 58:
  Steps change: 2.3 → 3.4 (+1.1)
  Avg glucose prediction change: +0.0946
Individual 52:
  Steps change: -0.0 → -0.0 (-0.0)
  Avg glucose prediction change: +0.0348
Individual 48:
  Steps change: 6.5 → 9.7 (+3.2)
  Avg glucose prediction change: -0.0475
Individual 27:
  Steps change: -0.0 → -0.0 (-0.0)
  Avg glucose prediction change: +0.0376
Individual 46:
  Steps change: 2.8 → 4.2

In [None]:
# Perturbation on Steps