# Libero dataset download

In [None]:
!git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git

In [None]:
!echo n | python LIBERO/benchmark_scripts/download_libero_datasets.py --datasets libero_spatial --use-huggingface > /dev/null 2>&1

In [None]:
!mkdir dataset | mv LIBERO/libero/datasets/* dataset

In [None]:
import sys
from unittest.mock import MagicMock

# The Patch
mock_mpl = MagicMock()
sys.modules["matplotlib"] = mock_mpl
sys.modules["matplotlib.pyplot"] = mock_mpl
sys.modules["matplotlib.cm"] = mock_mpl
sys.modules["matplotlib.colors"] = mock_mpl
sys.modules["matplotlib.transforms"] = mock_mpl
sys.modules["matplotlib.ticker"] = mock_mpl
sys.modules["matplotlib._path"] = mock_mpl

# Libraries

In [None]:
# Standard library
import os
import sys
import json
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict, replace
from typing import Dict, List, Tuple, Optional, Any, Callable
from unittest.mock import MagicMock

# Scientific computing
import numpy as np
import h5py
import cv2
from scipy.ndimage import gaussian_filter
from PIL import Image
from IPython.display import display

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Vision
import einops
from torchvision.models import resnet18, ResNet18_Weights

# Transformers
from transformers import CLIPTokenizer, CLIPTextModel

# Device setup
print(f"PyTorch Version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU Available: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    print(f"Using CPU")

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
@dataclass
class TrainingConfig:
    """Training and model hyperparameters container."""

    lr: float = 3e-4
    hidden_dim: int = 256
    num_recursions: int = 8
    epochs: int = 20
    batch_size: int = 64
    weight_decay: float = 1e-4
    grad_clip: Optional[float] = 1.0
    sched_T0: Optional[int] = None
    sched_T_mult: int = 1
    lr_min: float = 1e-6
    warmup_epochs: int = 3
    early_stop_patience: Optional[int] = None
    save_path: str = 'model.pt'
    freeze_backbone: bool = True
    augmentation: bool = False
    dropout: float = 0.1
    encoder_dropout: float = 0.1
    use_text_prompts: bool = True
    text_encoder_name: str = 'openai/clip-vit-large-patch14'
    train_text_encoder: bool = False
    text_dropout: float = 0.1
    double_visual_features: bool = False
    use_attention: bool = True
    attention_fusion: bool = False

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

    def label(self) -> str:
        return f"lr{self.lr}_h{self.hidden_dim}_rec{self.num_recursions}_bs{self.batch_size}"

In [None]:
def load_images(dataset):
    """Load images from HDF5 dataset with robust dtype handling."""
    shape = dataset.shape

    try:
        buffer = np.empty(shape, dtype=np.uint8)
        dataset.read_direct(buffer)
        return buffer
    except Exception:
        pass
    
    try:
        buffer = np.empty(shape, dtype=np.float32)
        dataset.read_direct(buffer)
        if buffer.max() <= 1.0:
            buffer = (buffer * 255).astype(np.uint8)
        else:
            buffer = np.clip(buffer, 0, 255).astype(np.uint8)
        return buffer
    except Exception:
        pass
    
    try:
        buffer = np.empty(shape, dtype=np.float64)
        dataset.read_direct(buffer)
        if buffer.max() <= 1.0:
            buffer = (buffer * 255).astype(np.uint8)
        else:
            buffer = np.clip(buffer, 0, 255).astype(np.uint8)
        return buffer
    except Exception:
        pass

    try:
        buffer = np.empty(shape, dtype=np.uint8)
        dataset.id.read(h5py.h5s.ALL, h5py.h5s.ALL, buffer)
        return buffer
    except Exception as e:
        raise RuntimeError(f"Cannot read dataset: {e}")

def load_actions(dataset):
    """Load actions from HDF5 dataset."""
    shape = dataset.shape
    try:
        buffer = np.empty(shape, dtype=np.float32)
        dataset.read_direct(buffer)
        return buffer
    except Exception:
        pass
    
    try:
        buffer = np.empty(shape, dtype=np.float64)
        dataset.read_direct(buffer)
        return buffer.astype(np.float32)
    except Exception as e:
        raise RuntimeError(f"Cannot read actions: {e}")


def explore_libero_dataset(data_path: Path):
    """Explore LIBERO dataset structure and display sample frames."""
    hdf5_files = list(data_path.glob('**/*.hdf5'))
    
    if not hdf5_files:
        print(f"‚ö†Ô∏è No HDF5 files found in {data_path}")
        return []
    
    print(f"‚úÖ Found {len(hdf5_files)} HDF5 files")
    
    demo_file = hdf5_files[0]
    print(f"\nüìÑ Analyzing: {demo_file.name}")
    
    try:
        with h5py.File(demo_file, 'r') as f:
            if 'data' not in f:
                print("‚ö†Ô∏è Key 'data' not found")
                return hdf5_files
            
            data_group = f['data']
            demo_keys = list(data_group.keys())
            first_demo_key = demo_keys[0]
            demo_0 = data_group[first_demo_key]
            
            imgs = None
            
            if 'obs' in demo_0:
                obs_group = demo_0['obs']
                
                image_keys = ['agentview_rgb', 'agentview_image', 'rgb', 'image', 'robot0_eye_in_hand_image']
                img_key = next((k for k in image_keys if k in obs_group), None)
                
                if img_key is None:
                    img_key = next((k for k in obs_group.keys() if 'rgb' in k.lower() or 'image' in k.lower()), None)
                
                if img_key:
                    print(f"\nüñºÔ∏è Using image key: '{img_key}'")
                    try:
                        imgs = load_images(obs_group[img_key])
                        print(f"  ‚úÖ Images loaded: {imgs.shape}")
                    except Exception as e:
                        print(f"  ‚ùå Image error: {e}")
            
            if 'actions' in demo_0:
                try:
                    actions = load_actions(demo_0['actions'])
                    print(f"\nüéÆ Actions loaded: {actions.shape}")
                    print(f"  Range: [{actions.min():.3f}, {actions.max():.3f}]")
                except Exception as e:
                    print(f"  ‚ùå Actions error: {e}")

            if imgs is not None and len(imgs) > 0:
                print("\nüé¨ Sample frames:")
                
                num_frames = min(4, len(imgs))
                indices = np.linspace(0, len(imgs) - 1, num_frames, dtype=int)
                
                for idx in indices:
                    img_array = imgs[idx]
                    
                    if img_array.dtype != np.uint8:
                         img_array = (np.clip(img_array, 0, 1) * 255).astype(np.uint8)
                    
                    pil_img = Image.fromarray(img_array)
                    
                    print(f"--- Frame {idx} ---")
                    display(pil_img)
            else:
                print("\n‚ö†Ô∏è No valid images to display")

    except Exception as e:
        print(f"Critical error opening file: {e}")
    
    return hdf5_files

hdf5_files = explore_libero_dataset(Path('dataset/libero_spatial'))

In [None]:
class LIBERODataset(Dataset):
    """
    Dataset for LIBERO demonstrations.
    
    Loads visual observations and actions from HDF5 files.
    Supports data augmentation and normalization.
    Uses demo-level split: splits demos within each file (e.g., 80% train, 20% val).
    """
    
    def __init__(
        self,
        hdf5_files: List[Path],
        sequence_length: int = 1,
        image_size: Tuple[int, int] = (128, 128),
        normalize_actions: bool = True,
        augmentation: bool = False,
        max_demos_per_task: Optional[int] = None,
        demo_split_ratio: float = 0.8,
        is_train: bool = True,
        action_stats: Optional[Dict] = None
    ):
        """
        Args:
            hdf5_files: list of paths to HDF5 files
            sequence_length: sequence length (1 = single-step prediction)
            image_size: image dimensions
            normalize_actions: if True, normalize actions with z-score
            augmentation: if True, apply data augmentation
            max_demos_per_task: max demos per task (for debugging)
            demo_split_ratio: percentage of demos for training (default 0.8 = 80%)
            is_train: if True, use first demo_split_ratio% demos; otherwise use rest
            action_stats: pre-computed action statistics (for validation set)
        """
        self.hdf5_files = hdf5_files
        self.sequence_length = sequence_length
        self.image_size = (int(image_size[0]), int(image_size[1]))
        self.augmentation = augmentation and is_train
        self.normalize_actions = normalize_actions
        self.demo_split_ratio = demo_split_ratio
        self.is_train = is_train
        
        self.data = []
        self.action_stats = action_stats if action_stats is not None else {'mean': None, 'std': None}
        self.samples: List[Tuple[int, int]] = []  
        
        split_name = "TRAIN" if is_train else "VAL"
        print(f"Loading {len(hdf5_files)} HDF5 files for {split_name} (demo split: {demo_split_ratio:.0%})...")
        all_actions = []
        
        for hdf5_file in hdf5_files:
            try:
                with h5py.File(hdf5_file, 'r') as f:
                    if 'data' not in f:
                        print(f"‚ö†Ô∏è 'data' key not found in {hdf5_file.name}, skipping...")
                        continue
                    
                    demo_keys = list(f['data'].keys())
                    
                    if max_demos_per_task is not None:
                        demo_keys = demo_keys[:max_demos_per_task]
                    
                    n_demos = len(demo_keys)
                    n_train_demos = int(n_demos * demo_split_ratio)
                    
                    if is_train:
                        selected_demo_keys = demo_keys[:n_train_demos]
                    else:
                        selected_demo_keys = demo_keys[n_train_demos:]
                    
                    if len(selected_demo_keys) == 0:
                        print(f"‚ö†Ô∏è No demos selected from {hdf5_file.name} for {split_name}, skipping...")
                        continue
                    
                    task_prompt = self._prompt_from_filename(hdf5_file)

                    for demo_key in selected_demo_keys:
                        try:
                            demo = f[f'data/{demo_key}']
                            
                            obs_group = demo['obs']
                            img_key = self._find_image_key(obs_group)
                            
                            if img_key is None:
                                print(f"‚ö†Ô∏è No image key found in {hdf5_file.name}/{demo_key}, skipping...")
                                continue
                            
                            obs = self._load_images(obs_group[img_key])
                            actions = self._load_actions(demo['actions'])
                            
                            min_len = min(len(obs), len(actions))
                            if min_len < self.sequence_length:
                                print(f"‚ö†Ô∏è Demo too short ({min_len} < {self.sequence_length}), skipping...")
                                continue
                            
                            obs = obs[:min_len]
                            actions = actions[:min_len]
                            
                            self.data.append({
                                'observations': obs,
                                'actions': actions,
                                'prompt': task_prompt
                            })
                            
                            all_actions.append(actions)
                            
                        except Exception as e:
                            print(f"‚ö†Ô∏è Error loading demo {demo_key} from {hdf5_file.name}: {e}")
                            continue
                            
            except Exception as e:
                print(f"‚ùå Error opening file {hdf5_file}: {e}")
                continue
        
        print(f"‚úÖ Loaded {len(self.data)} demonstrations for {split_name}")
        
        if len(self.data) == 0:
            raise ValueError(f"No valid demonstrations loaded for {split_name}!")
        
        if self.normalize_actions and len(all_actions) > 0 and action_stats is None:
            all_actions_concat = np.concatenate(all_actions, axis=0)
        
            mean = all_actions_concat.mean(axis=0).astype(np.float32)
            std  = all_actions_concat.std(axis=0).astype(np.float32)
            std_clipped = np.clip(std, 0.1, None)
        
            print(f"üìä Action statistics computed from {split_name} set:")
            print(f"   Mean: {np.round(mean, 3)}")
            print(f"   Std (clipped to >=0.1): {np.round(std_clipped, 3)}")
        
            self.action_stats['mean'] = mean
            self.action_stats['std']  = std_clipped
        
        elif action_stats is not None:
            print(f"üìä Using provided action statistics")
            self.action_stats = {
                'mean': action_stats['mean'].astype(np.float32),
                'std':  np.clip(action_stats['std'], 0.1, None).astype(np.float32)
            }

        self.samples = self._build_sample_index()
        print(f"üì¶ Generated {len(self.samples)} transitions for {split_name}")

    @staticmethod
    def _prompt_from_filename(hdf5_file: Path) -> str:
        """Convert HDF5 filename to natural language prompt."""
        name = hdf5_file.stem
        if name.endswith('_demo'):
            name = name[:-5]
        name = name.replace('_', ' ').replace('-', ' ')
        return ' '.join(name.split()).strip()

    
    def _find_image_key(self, obs_group) -> Optional[str]:
        """Find correct image key in observation group."""
        possible_keys = [
            'agentview_rgb',
            'agentview_image', 
            'rgb',
            'image',
            'robot0_eye_in_hand_image',
            'frontview_image',
            'sideview_image'
        ]
        
        obs_keys = list(obs_group.keys())
        
        for key in possible_keys:
            if key in obs_keys:
                return key
        
        for key in obs_keys:
            if 'rgb' in key.lower() or 'image' in key.lower():
                return key
        
        return None
    
    def _load_images(self, dataset) -> np.ndarray:
        """Load images from HDF5 dataset with robust dtype handling."""
        shape = dataset.shape
        
        try:
            buffer = np.empty(shape, dtype=np.uint8)
            dataset.read_direct(buffer)
            return buffer
        except Exception:
            pass
        
        try:
            buffer = np.empty(shape, dtype=np.float32)
            dataset.read_direct(buffer)
            if buffer.max() <= 1.0:
                buffer = (buffer * 255).astype(np.uint8)
            else:
                buffer = np.clip(buffer, 0, 255).astype(np.uint8)
            return buffer
        except Exception:
            pass
        
        try:
            buffer = np.empty(shape, dtype=np.float64)
            dataset.read_direct(buffer)
            if buffer.max() <= 1.0:
                buffer = (buffer * 255).astype(np.uint8)
            else:
                buffer = np.clip(buffer, 0, 255).astype(np.uint8)
            return buffer
        except Exception:
            pass
        
        try:
            buffer = np.empty(shape, dtype=np.uint8)
            dataset.id.read(h5py.h5s.ALL, h5py.h5s.ALL, buffer)
            return buffer
        except Exception as e:
            raise RuntimeError(f"Cannot read image dataset: {e}")
    
    def _load_actions(self, dataset) -> np.ndarray:
        """Load actions from HDF5 dataset."""
        shape = dataset.shape
        
        try:
            buffer = np.empty(shape, dtype=np.float32)
            dataset.read_direct(buffer)
            return buffer
        except Exception:
            pass
        
        try:
            buffer = np.empty(shape, dtype=np.float64)
            dataset.read_direct(buffer)
            return buffer.astype(np.float32)
        except Exception as e:
            raise RuntimeError(f"Cannot read actions dataset: {e}")
    
    def _build_sample_index(self) -> List[Tuple[int, int]]:
        """Pre-compute (demo_idx, start_idx) indices for each transition."""
        indices: List[Tuple[int, int]] = []
        for demo_idx, demo in enumerate(self.data):
            demo_transitions = len(demo['observations']) - self.sequence_length + 1
            if demo_transitions <= 0:
                continue
            indices.extend((demo_idx, start) for start in range(demo_transitions))
        if not indices:
            raise ValueError("Dataset index is empty after preprocessing")
        return indices

    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Return a transition (observation, action)."""
        demo_idx, start_idx = self.samples[idx]
        demo = self.data[demo_idx]
        end_idx = start_idx + self.sequence_length

        obs = demo['observations'][start_idx:end_idx].copy()
        actions = demo['actions'][start_idx:end_idx].copy()

        obs = self._preprocess_obs(obs)
        actions = self._preprocess_actions(actions)

        if self.sequence_length == 1:
            obs = obs[0]
            actions = actions[0]

        # Convert HWC -> CHW for PyTorch
        if obs.ndim == 3:
            obs = np.transpose(obs, (2, 0, 1))
        elif obs.ndim == 4:
            obs = np.transpose(obs, (0, 3, 1, 2))

        return {
            'observations': torch.from_numpy(obs).float(),
            'actions': torch.from_numpy(actions).float(),
            'prompt': demo.get('prompt', '')
        }
    
    def _preprocess_obs(self, obs: np.ndarray) -> np.ndarray:
        """Preprocess observations."""
        processed = []
        target_h, target_w = self.image_size
        for img in obs:
            if img.shape[0] != target_h or img.shape[1] != target_w:
                img = cv2.resize(img, (target_w, target_h), interpolation=cv2.INTER_AREA)
            processed.append(img)
        obs = np.stack(processed, axis=0)

        obs = obs.astype(np.float32) / 255.0

        if self.augmentation:
            obs = self._augment_obs(obs)
        
        return obs

    def _augment_obs(self, obs: np.ndarray) -> np.ndarray:
        """Apply data augmentation to observations."""
        if np.random.rand() < 0.5:
            brightness = np.random.uniform(0.8, 1.2)
            obs = np.clip(obs * brightness, 0, 1)
        
        if np.random.rand() < 0.3:
            contrast = np.random.uniform(0.8, 1.2)
            mean = obs.mean(axis=(1, 2), keepdims=True)
            obs = np.clip((obs - mean) * contrast + mean, 0, 1)
        
        if np.random.rand() < 0.3:
            crop_ratio = np.random.uniform(0.85, 0.95)
            crop_size_h = int(self.image_size[0] * crop_ratio)
            crop_size_w = int(self.image_size[1] * crop_ratio)
            
            start_y = np.random.randint(0, self.image_size[0] - crop_size_h + 1)
            start_x = np.random.randint(0, self.image_size[1] - crop_size_w + 1)
            
            cropped = []
            for img in obs:
                img_crop = img[start_y:start_y+crop_size_h, start_x:start_x+crop_size_w]
                img_resized = cv2.resize(img_crop, (self.image_size[1], self.image_size[0]))
                cropped.append(img_resized)
            obs = np.stack(cropped)
        
        return obs
    
    def _preprocess_actions(self, actions: np.ndarray) -> np.ndarray:
        """Preprocess actions with z-score normalization."""
        actions = actions.astype(np.float32)
        if self.action_stats['mean'] is not None:
            actions = (actions - self.action_stats['mean']) / self.action_stats['std']
        return actions
    
    def get_action_stats(self) -> Dict[str, np.ndarray]:
        """Return action statistics for denormalization."""
        return self.action_stats.copy()
    
    def denormalize_actions(self, actions: np.ndarray) -> np.ndarray:
        """Denormalize actions for simulator execution."""
        if self.action_stats['mean'] is not None:
            return actions * self.action_stats['std'] + self.action_stats['mean']
        return actions

In [None]:
class PretrainedVisualEncoder(nn.Module):
    """Visual encoder based on ResNet18 with adaptive head."""

    def __init__(self, hidden_dim: int = 256, freeze_backbone: bool = True, dropout: float = 0.1, double_visual_features: bool = False):
        super().__init__()

        self.hidden_dim = hidden_dim

        resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

        if freeze_backbone:
            for param in resnet.parameters():
                param.requires_grad = False

        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        self.adapter = nn.Sequential(
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(512, hidden_dim * 2) if double_visual_features else nn.Linear(512, hidden_dim)
        )
        self.ln = nn.LayerNorm(hidden_dim * (2 if double_visual_features else 1))

        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.float()

        if x.shape[-1] != 224 or x.shape[-2] != 224:
            x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        x = (x - self.mean) / self.std
        features = self.backbone(x).flatten(start_dim=1)
        output = self.adapter(features) 
        return output

class PromptEncoder(nn.Module):
    """Encodes natural-language task prompts via CLIP ViT-L/14 text tower."""

    def __init__(
        self,
        hidden_dim: int,
        model_name: str = 'openai/clip-vit-large-patch14',
        trainable: bool = False,
        dropout: float = 0.3,
        max_length: int = 77
    ):
        super().__init__()

        self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
        self.text_model = CLIPTextModel.from_pretrained(model_name)
        self.max_length = min(max_length, getattr(self.text_model.config, 'max_position_embeddings', max_length))

        if not trainable:
            self.text_model.eval()
            for param in self.text_model.parameters():
                param.requires_grad = False

        self.text_hidden = self.text_model.config.hidden_size
        self.adapter = nn.Sequential(
            nn.LayerNorm(self.text_hidden),
            nn.Linear(self.text_hidden, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        self._token_cache: Dict[str, Dict[str, torch.Tensor]] = {}

    def _tokenize(self, prompt: str) -> Dict[str, torch.Tensor]:
        if prompt not in self._token_cache:
            tokens = self.tokenizer(
                prompt,
                return_tensors='pt',
                padding='max_length',
                truncation=True,
                max_length=self.max_length
            )
            self._token_cache[prompt] = {k: v for k, v in tokens.items()}
        cached = self._token_cache[prompt]
        return {k: v.clone() for k, v in cached.items()}

    def forward(self, prompts: List[str], device: torch.device) -> torch.Tensor:
        if len(prompts) == 0:
            raise ValueError("PromptEncoder received an empty batch of prompts")

        token_batches = [self._tokenize(p) for p in prompts]
        batch = {
            key: torch.cat([tokens[key] for tokens in token_batches], dim=0).to(device)
            for key in token_batches[0]
        }

        outputs = self.text_model(**batch)
        pooled = outputs.pooler_output if outputs.pooler_output is not None else outputs.last_hidden_state[:, -1, :]
        return self.adapter(pooled)

class RecursiveBlock(nn.Module):
    """TRM recursive block with self-attention and MLP."""
    
    def __init__(self, hidden_dim=256, num_heads=4, dropout=0.1, use_attention: bool = True):
        super().__init__()
        self.use_attention = use_attention

        if self.use_attention:
            self.attention = nn.MultiheadAttention(
                hidden_dim,
                num_heads,
                dropout=dropout,
                batch_first=True
            )
            self.dropout = nn.Dropout(dropout)
        else:
            self.attention = None
            self.dropout = None
        self.norm = nn.LayerNorm(hidden_dim)
        
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, h, x_cond):
        """
        Args:
            h: (B, D) current hidden state
            x_cond: (B, D) conditioning from input
        Returns:
            (B, D) new hidden state
        """
        h_in = h.unsqueeze(1)

        if self.use_attention:
            cond = x_cond.unsqueeze(1)
            attn_out, _ = self.attention(h_in, cond, cond)
            attn_out = self.dropout(attn_out)
            x = self.norm(attn_out + self.mlp(attn_out))
        else:
            x = self.norm(h_in + self.mlp(h_in))
        return x.squeeze(1)

class CrossAttentionFusion(nn.Module): 
    """Cross-attention fusion for multimodal features."""
    
    def __init__(self, visual_dim: int, text_dim: int, hidden_dim: int, num_heads: int = 4, dropout: float = 0.1):
        super().__init__()
        self.query_proj = nn.Linear(visual_dim, hidden_dim)
        self.key_proj = nn.Linear(text_dim, hidden_dim)
        self.value_proj = nn.Linear(text_dim, hidden_dim)
        self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
        self.resid_norm = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Dropout(dropout)
        )
        self.out_norm = nn.LayerNorm(hidden_dim)

    def forward(self, visual_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
        q = self.query_proj(visual_feats).unsqueeze(1)
        k = self.key_proj(text_feats).unsqueeze(1)
        v = self.value_proj(text_feats).unsqueeze(1)
        attn_out, _ = self.attn(q, k, v)
        fused = self.resid_norm(q + attn_out)
        fused = self.out_norm(fused + self.ffn(fused))
        return fused.squeeze(1)

class TRMPolicy(nn.Module):
    """
    Policy based on Tiny Recursive Models for robotic control.
    
    Architecture:
    1. Visual Encoder: ResNet18 pre-trained for image feature extraction
    2. Text Prompt Encoder: CLIP text tower for prompt encoding
    3. Recursive Block: applied N times for iterative reasoning
    4. Action Head: predicts actions from final hidden state
    """
    
    def __init__(
        self,
        obs_shape=(3, 128, 128),
        action_dim=7,
        hidden_dim=256,
        num_heads=4,
        num_recursions=8,
        dropout=0.1,
        freeze_backbone=True,
        encoder_dropout=0.1,
        use_text_prompts=True,
        text_encoder_name='openai/clip-vit-large-patch14',
        train_text_encoder=False,
        text_dropout=0.1,
        double_visual_features=False,
        use_attention: bool = True,
        attention_fusion: bool = False
    ):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_recursions = num_recursions
        self.obs_shape = obs_shape
        self.use_text_prompts = use_text_prompts
        self.use_attention = use_attention
        self.attention_fusion = attention_fusion and use_text_prompts

        self.encoder = PretrainedVisualEncoder(
            hidden_dim=hidden_dim,
            freeze_backbone=freeze_backbone,
            dropout=encoder_dropout,
            double_visual_features=double_visual_features
        )
        
        visual_dim = hidden_dim * (2 if double_visual_features else 1)
        self.fusion_adapter: nn.Module
        if self.use_text_prompts:
            self.prompt_encoder = PromptEncoder(
                hidden_dim=hidden_dim,
                model_name=text_encoder_name,
                trainable=train_text_encoder,
                dropout=text_dropout
            )
            if self.attention_fusion:
                self.fusion_adapter = CrossAttentionFusion(
                    visual_dim=visual_dim,
                    text_dim=hidden_dim,
                    hidden_dim=hidden_dim,
                    num_heads=num_heads,
                    dropout=dropout
                )
            else:
                fusion_in = visual_dim + hidden_dim
                self.fusion_adapter = nn.Sequential(
                    nn.LayerNorm(fusion_in),
                    nn.Linear(fusion_in, hidden_dim),
                    nn.GELU(),
                    nn.Dropout(dropout)
                )
        else:
            self.prompt_encoder = None
            fusion_in = visual_dim 
            if fusion_in != hidden_dim:
                self.fusion_adapter = nn.Sequential(
                    nn.LayerNorm(fusion_in),
                    nn.Linear(fusion_in, hidden_dim),
                    nn.GELU(),
                    nn.Dropout(dropout)
                )
            else:
                self.fusion_adapter = nn.Identity()

        self.recursive_block = RecursiveBlock(hidden_dim, num_heads, dropout, use_attention=self.use_attention)
        
        self.action_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, action_dim)
        )
        
    def forward(self, obs, prompts: Optional[List[str]] = None, return_all_states=False):
        """
        Args:
            obs: (B, C, H, W) visual observations
            prompts: list of strings (B) with task descriptions
            return_all_states: if True, return all hidden states
        Returns:
            actions: (B, action_dim) predicted actions
            (optional) states: list of hidden states
        """
        B = obs.shape[0]
        
        visual_features = self.encoder(obs)

        if self.use_text_prompts:
            text_features = self.prompt_encoder(prompts, device=obs.device)
            if self.attention_fusion:
                x_cond = self.fusion_adapter(visual_features, text_features)
            else:
                x_cond = self.fusion_adapter(torch.cat([visual_features, text_features], dim=-1))
        else:
            x_cond = self.fusion_adapter(visual_features)
        h = x_cond.clone()
        
        states = [h] if return_all_states else None

        for t in range(self.num_recursions):
            h = self.recursive_block(h, x_cond)
            if return_all_states:
                states.append(h)
        
        actions = self.action_head(h)
        
        if return_all_states:
            return actions, states
        return actions
    
def build_policy_from_config(config: TrainingConfig, obs_shape: Tuple[int, int, int] = (3, 128, 128)) -> TRMPolicy:
    """Build a TRMPolicy from TrainingConfig."""

    return TRMPolicy(
        obs_shape=obs_shape,
        action_dim=7,
        hidden_dim=config.hidden_dim,
        num_recursions=config.num_recursions,
        dropout=config.dropout,
        freeze_backbone=config.freeze_backbone,
        encoder_dropout=config.encoder_dropout,
        use_text_prompts=config.use_text_prompts,
        text_encoder_name=config.text_encoder_name,
        train_text_encoder=config.train_text_encoder,
        text_dropout=config.text_dropout,
        double_visual_features=config.double_visual_features,
        use_attention=getattr(config, 'use_attention', True),
        attention_fusion=getattr(config, 'attention_fusion', False)
    )

In [None]:
class BehaviorCloningTrainer:
    """Trainer for Behavior Cloning."""

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: TrainingConfig,
        device: torch.device
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device
        self.steps_per_epoch = max(len(train_loader), 1)

        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.lr,
            weight_decay=config.weight_decay
        )

        self.scheduler = None
        
        self.use_amp = (self.device.type == 'cuda')
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
        self.grad_clip = config.grad_clip
        self.early_stop_patience = config.early_stop_patience
        self._epochs_no_improve = 0
        self.early_stop_patience = 5 if config.early_stop_patience is None else config.early_stop_patience
        self.best_val_loss = float('inf')
        self.best_model_path = config.save_path
    
    def train(self):
        """Complete training loop."""
        
        for epoch in range(self.config.epochs):
            train_metrics = self._train_epoch(epoch)
            val_metrics = self._validate_epoch(epoch)
            print(f"Epoch {epoch}, training loss: {train_metrics['loss']:.4f}, validation loss: {val_metrics['loss']:.4f}")
            
            if val_metrics['loss'] < self.best_val_loss:
                self.best_val_loss = val_metrics['loss']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_loss': val_metrics['loss'],
                    'config': self.config.to_dict()
                }, self.best_model_path)
                print(f"  ‚úì Saved best model (val_loss: {val_metrics['loss']:.4f})")
                self._epochs_no_improve = 0
            else:
                self._epochs_no_improve += 1

            if self.early_stop_patience and self._epochs_no_improve >= self.early_stop_patience:
                print("‚èπÔ∏è  Early stopping triggered")
                break
        
        print(f"\n‚úÖ Training completed! Best val loss: {self.best_val_loss:.4f}")
        return self.best_val_loss
    
    def _train_epoch(self, epoch):
        """Training for one epoch."""
        self.model.train()
        
        total_loss = 0
        action_mse = 0
        action_l1 = 0
        
        for step, batch in enumerate(self.train_loader):
            obs = batch['observations'].to(self.device, non_blocking=True)
            target_actions = batch['actions'].to(self.device, non_blocking=True)
            prompts = batch.get('prompt')

            self.optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=self.use_amp):
                pred_actions = self.model(obs, prompts=prompts)
                mse = F.mse_loss(pred_actions, target_actions)
                l1 = F.l1_loss(pred_actions, target_actions)
                loss = 0.7 * mse + 0.3 * l1

            if self.use_amp:
                self.scaler.scale(loss).backward()
                if self.grad_clip:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                if self.grad_clip:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()

            total_loss += loss.item()
            action_mse += mse.item()
            action_l1 += l1.item()

            if self.scheduler is not None:
                self.scheduler.step()

        n_batches = len(self.train_loader)
        return {
            'loss': total_loss / n_batches,
            'action_mse': action_mse / n_batches,
            'action_l1': action_l1 / n_batches
        }
    
    def _validate_epoch(self, epoch):
        """Validation for one epoch."""
        self.model.eval()
        
        total_loss = 0
        action_mse = 0
        action_l1 = 0
        
        with torch.no_grad():
            for batch in self.val_loader:
                obs = batch['observations'].to(self.device)
                target_actions = batch['actions'].to(self.device)
                prompts = batch.get('prompt')
                
                pred_actions = self.model(obs, prompts=prompts)
                
                mse = F.mse_loss(pred_actions, target_actions)
                l1 = F.l1_loss(pred_actions, target_actions)
                loss = 0.7 * mse + 0.3 * l1
                
                total_loss += loss.item()
                action_mse += mse.item()
                action_l1 += l1.item()
        
        n_batches = len(self.val_loader)
        return {
            'loss': total_loss / n_batches,
            'action_mse': action_mse / n_batches,
            'action_l1': action_l1 / n_batches
        }


def build_dataloaders(
    train_dataset: Dataset,
    val_dataset: Dataset,
    batch_size: int,
    loader_kwargs: Dict[str, Any]
) -> Tuple[DataLoader, DataLoader]:
    """Create data loaders from datasets."""

    kwargs = loader_kwargs.copy()
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        **kwargs
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        **kwargs
    )

    return train_loader, val_loader


def _set_dataset_augmentation(dataset, flag: bool):
    """Temporarily set augmentation and return restore function."""

    if not hasattr(dataset, 'augmentation'):
        return lambda: None

    original = dataset.augmentation
    dataset.augmentation = flag

    def restore():
        dataset.augmentation = original

    return restore

In [None]:
def train_model(
    config: TrainingConfig,
    train_dataset: Dataset,
    val_dataset: Dataset,
    loader_kwargs: Dict[str, Any],
    device
) -> Tuple[nn.Module, float]:
    """Execute model training with specified configuration."""

    print(f"\n{'='*60}")
    print("üéØ MODEL TRAINING")
    print(f"Config: {config.label()}")
    print(f"{'='*60}")

    train_loader, val_loader = build_dataloaders(
        train_dataset,
        val_dataset,
        config.batch_size,
        loader_kwargs
    )

    model = build_policy_from_config(config)
    trainer = BehaviorCloningTrainer(
        model,
        train_loader,
        val_loader,
        config,
        device
    )

    restore_aug = _set_dataset_augmentation(train_dataset, config.augmentation)
    try:
        final_val_loss = trainer.train()
    finally:
        restore_aug()

    if os.path.exists(config.save_path):
        checkpoint = torch.load(config.save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])

    print(f"\n‚úÖ Model trained! Val loss: {final_val_loss:.4f}")

    return model, final_val_loss

In [None]:
def main_pipeline(
    data_path: str = 'dataset/libero_spatial',
    train: bool = True
):
    """
    Main pipeline for TRM Robotics project.
    
    Args:
        data_path: path to LIBERO data
        train: if True, execute model training
    """
    
    print(f"""
    {'='*80}
    ü§ñ TinyRecursiveModels for Robotic Control
    {'='*80}
    """)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"‚úì Using device: {device}\n")
    
    # STEP 1: Load Dataset
    print(f"\n{'='*80}")
    print("STEP 1: Loading Dataset")
    print(f"{'='*80}")
    
    data_path = Path(data_path)
    hdf5_files = list(data_path.glob('**/*.hdf5'))
    
    if not hdf5_files:
        print(f"‚ùå No HDF5 files found in {data_path}")
        return
    
    print(f"‚úì Found {len(hdf5_files)} HDF5 files (tasks)")
    
    demo_split_ratio = 0.8
    print(f"\nüìä Demo-level split: {demo_split_ratio:.0%} train / {1-demo_split_ratio:.0%} val per task")
    
    print("\nCreating TRAIN dataset...")
    train_dataset = LIBERODataset(
        hdf5_files,
        sequence_length=1,
        image_size=(128, 128),
        augmentation=False,
        max_demos_per_task=50,
        demo_split_ratio=demo_split_ratio,
        is_train=True
    )
    
    train_action_stats = train_dataset.action_stats
    
    print("\nCreating VAL dataset...")
    val_dataset = LIBERODataset(
        hdf5_files,
        sequence_length=1,
        image_size=(128, 128),
        augmentation=False,
        max_demos_per_task=50,
        demo_split_ratio=demo_split_ratio,
        is_train=False,
        action_stats=train_action_stats
    )
    
    num_workers = min(4, os.cpu_count() or 1)
    use_cuda = torch.cuda.is_available()

    loader_common = {
        'num_workers': num_workers,
        'pin_memory': use_cuda,
        'persistent_workers': num_workers > 0
    }
    if num_workers > 0:
        loader_common['prefetch_factor'] = 2

    print(f"\n‚úì Datasets created with demo-level split")
    print(f"  Train samples: {len(train_dataset)}")
    print(f"  Val samples: {len(val_dataset)}")
    
    action_stats = train_dataset.action_stats
    with open('action_stats.json', 'w') as f:
        json.dump({
            'mean': action_stats['mean'].tolist(),
            'std': action_stats['std'].tolist()
        }, f)
    
    # STEP 2: Training
    trained_model = None
    
    if train:
        print(f"\n{'='*80}")
        print("STEP 2: Training")
        print(f"{'='*80}")
        
        config = TrainingConfig(
            lr=0.0001, 
            hidden_dim=256,
            num_recursions=16,
            epochs=30,
            batch_size=512,
            weight_decay=0.1,
            grad_clip=1.0,
            sched_T0=None,
            sched_T_mult=1,
            lr_min=1e-06,
            warmup_epochs=3,
            early_stop_patience=10,  
            freeze_backbone=False,
            augmentation=True,
            dropout=0.1,
            encoder_dropout=0.1,
            use_text_prompts=True,
            text_encoder_name="openai/clip-vit-large-patch14",
            train_text_encoder=False,
            text_dropout=0.1,
            double_visual_features=True,
            use_attention=False,
            attention_fusion=False,
            save_path='models/model.pt'
        )

        trained_model, final_val_loss = train_model(
            config,
            train_dataset,
            val_dataset,
            loader_common,
            device
        )
    
    print(f"\n{'='*80}")
    print("‚úÖ Pipeline completed!")
    print(f"{'='*80}")

In [None]:
main_pipeline(train=False)

# Evaluation

In [None]:
class VisualExplainer:
    """
    Computes saliency maps using gradient-based methods.
    
    Available methods:
    - Vanilla Gradient
    - SmoothGrad
    - GradCAM
    - Integrated Gradients
    """
    
    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.feature_maps = None
        self.gradients = None
        self._register_hooks()
    
    def _register_hooks(self):
        """Register hooks for GradCAM."""
        def forward_hook(module, input, output):
            self.feature_maps = output.detach()
        
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()
        
        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'backbone'):
            backbone = self.model.encoder.backbone
            for name, module in backbone.named_modules():
                if isinstance(module, nn.Conv2d):
                    self.last_conv = module
            
            if hasattr(self, 'last_conv'):
                self.last_conv.register_forward_hook(forward_hook)
                self.last_conv.register_full_backward_hook(backward_hook)
    
    def compute_saliency(
        self,
        obs: torch.Tensor,
        prompt: str,
        method: str = 'gradcam'
    ) -> np.ndarray:
        """Compute saliency map with specified method."""
        self.model.eval()
        
        if method == 'vanilla':
            saliency = self._vanilla_gradient(obs, prompt)
        elif method == 'smoothgrad':
            saliency = self._smoothgrad(obs, prompt)
        elif method == 'gradcam':
            saliency = self._gradcam(obs, prompt)
        elif method == 'integrated':
            saliency = self._integrated_gradients(obs, prompt)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        saliency = saliency - saliency.min()
        if saliency.max() > 0:
            saliency = saliency / saliency.max()
        
        return saliency
    
    def compute_all_methods(self, obs: torch.Tensor, prompt: str) -> Dict[str, np.ndarray]:
        """Compute saliency with all available methods."""
        methods = ['vanilla', 'smoothgrad', 'gradcam', 'integrated']
        results = {}
        
        for method in methods:
            try:
                results[method] = self.compute_saliency(obs, prompt, method=method)
            except Exception as e:
                print(f"‚ö†Ô∏è Error computing {method}: {e}")
                results[method] = np.zeros((obs.shape[2], obs.shape[3]), dtype=np.float32)
        
        return results
    
    def _forward_with_grad(self, obs: torch.Tensor, prompt: str) -> torch.Tensor:
        """Forward pass with gradients."""
        obs = obs.clone().detach().requires_grad_(True)
        
        if hasattr(self.model, 'use_text_prompts') and self.model.use_text_prompts:
            actions = self.model(obs, [prompt])
        else:
            actions = self.model(obs, None)
        
        loss = actions.norm()
        loss.backward()
        
        return obs.grad
    
    def _vanilla_gradient(self, obs: torch.Tensor, prompt: str) -> np.ndarray:
        grad = self._forward_with_grad(obs, prompt)
        saliency = grad.abs().sum(dim=1).squeeze().cpu().numpy()
        return gaussian_filter(saliency, sigma=2)
    
    def _smoothgrad(self, obs: torch.Tensor, prompt: str, n_samples: int = 20, noise: float = 0.1) -> np.ndarray:
        saliency_sum = None
        
        for _ in range(n_samples):
            noisy_obs = obs + torch.randn_like(obs) * noise
            grad = self._forward_with_grad(noisy_obs, prompt)
            saliency = grad.abs().sum(dim=1).squeeze().cpu().numpy()
            
            if saliency_sum is None:
                saliency_sum = saliency
            else:
                saliency_sum += saliency
        
        return gaussian_filter(saliency_sum / n_samples, sigma=2)
    
    def _gradcam(self, obs: torch.Tensor, prompt: str) -> np.ndarray:
        self.feature_maps = None
        self.gradients = None
        
        obs_grad = obs.clone().detach().requires_grad_(True)
        
        if hasattr(self.model, 'use_text_prompts') and self.model.use_text_prompts:
            actions = self.model(obs_grad, [prompt])
        else:
            actions = self.model(obs_grad, None)
        
        self.model.zero_grad()
        actions.norm().backward()
        
        if self.feature_maps is not None and self.gradients is not None:
            weights = self.gradients.mean(dim=[2, 3], keepdim=True)
            cam = (weights * self.feature_maps).sum(dim=1).squeeze()
            cam = F.relu(cam)
            cam = cam.cpu().numpy()
            cam = cv2.resize(cam, (obs.shape[3], obs.shape[2]))
            return gaussian_filter(cam, sigma=3)
        else:
            return self._vanilla_gradient(obs, prompt)
    
    def _integrated_gradients(self, obs: torch.Tensor, prompt: str, steps: int = 50) -> np.ndarray:
        baseline = torch.zeros_like(obs)
        grads_sum = None
        
        for i in range(1, steps + 1):
            scaled_input = baseline + (float(i) / steps) * (obs - baseline)
            grad = self._forward_with_grad(scaled_input, prompt)
            grad_np = grad.abs().sum(dim=1).squeeze().cpu().numpy()
            
            if grads_sum is None:
                grads_sum = grad_np
            else:
                grads_sum += grad_np
        
        avg_grads = grads_sum / steps
        integrated = avg_grads * (obs - baseline).abs().sum(dim=1).squeeze().cpu().numpy()
        return gaussian_filter(integrated, sigma=2)

In [None]:
def generate_heatmap_video(
    frames_data: List[Dict],
    original_frames: List[np.ndarray],
    output_path: str,
    fps: int = 5
):
    """
    Generate a video with 2x3 grid containing all saliency methods.
    
    Grid layout:
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ    Original     ‚îÇ Vanilla Gradient‚îÇ    SmoothGrad   ‚îÇ
    ‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
    ‚îÇ    Grad-CAM     ‚îÇ   Integrated    ‚îÇ   Step Info     ‚îÇ
    ‚îÇ                 ‚îÇ   Gradients     ‚îÇ                 ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    """
    if len(frames_data) == 0 or len(original_frames) == 0:
        print("‚ö†Ô∏è No frames to generate video")
        return
    
    n_frames = min(len(frames_data), len(original_frames))
    print(f"üìπ Generating heatmap video with {n_frames} frames...")
    
    title_height = 25
    h, w = original_frames[0].shape[:2]
    panel_h = h + title_height
    panel_w = w
    grid_h = panel_h * 2
    grid_w = panel_w * 3
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (grid_w, grid_h))
    
    if not out.isOpened():
        print(f"‚ùå Failed to open VideoWriter for {output_path}")
        return
    
    method_titles = {
        'original': 'Original Input',
        'vanilla': 'Vanilla Gradient',
        'smoothgrad': 'SmoothGrad',
        'gradcam': 'Grad-CAM',
        'integrated': 'Integrated Grad',
        'info': 'Step Info'
    }
    
    def add_title_bar(img: np.ndarray, title: str) -> np.ndarray:
        """Add title bar to image."""
        result = np.zeros((img.shape[0] + title_height, img.shape[1], 3), dtype=np.uint8)
        result[:title_height, :] = (40, 40, 40)
        result[title_height-3:title_height, :] = (100, 100, 255)
        result[title_height:, :] = img
        
        font = cv2.FONT_HERSHEY_SIMPLEX
        text_size = cv2.getTextSize(title, font, 0.4, 1)[0]
        text_x = (img.shape[1] - text_size[0]) // 2
        cv2.putText(result, title, (text_x, 17), font, 0.4, (255, 255, 255), 1, cv2.LINE_AA)
        return result
    
    def create_heatmap_overlay(frame: np.ndarray, saliency: np.ndarray, alpha: float = 0.4) -> np.ndarray:
        """Create heatmap overlay on image."""
        if saliency is None or saliency.size == 0:
            return frame.copy()
        
        saliency = np.array(saliency, dtype=np.float32)
        if saliency.shape[:2] != frame.shape[:2]:
            saliency = cv2.resize(saliency, (frame.shape[1], frame.shape[0]))
        
        if saliency.max() > saliency.min():
            saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min())
        
        heatmap = (saliency * 255).astype(np.uint8)
        heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
        
        return cv2.addWeighted(frame, 1 - alpha, heatmap_colored, alpha, 0)
    
    def create_info_panel(frame: np.ndarray, step: int) -> np.ndarray:
        """Create info panel."""
        panel = np.zeros_like(frame)
        panel[:] = (30, 30, 30)
        
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(panel, f"Step: {step}", (10, 40), font, 0.6, (255, 255, 255), 1)
        
        return panel
    
    frames_written = 0
    for i in range(n_frames):
        try:
            frame = original_frames[i].copy()
            data = frames_data[i]
            
            if frame.dtype != np.uint8:
                frame = ((frame * 255) if frame.max() <= 1.0 else np.clip(frame, 0, 255)).astype(np.uint8)
            
            if len(frame.shape) == 2:
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
            elif frame.shape[2] == 4:
                frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
            
            if frame.shape[0] != h or frame.shape[1] != w:
                frame = cv2.resize(frame, (w, h))
            
            panels = []
            
            # Row 1: Original, Vanilla, SmoothGrad
            panels.append(add_title_bar(frame.copy(), method_titles['original']))
            
            vanilla = data.get('vanilla', data.get('saliency_map', None))
            panels.append(add_title_bar(create_heatmap_overlay(frame, vanilla), method_titles['vanilla']))
            
            smoothgrad = data.get('smoothgrad', data.get('saliency_map', None))
            panels.append(add_title_bar(create_heatmap_overlay(frame, smoothgrad), method_titles['smoothgrad']))
            
            # Row 2: GradCAM, Integrated, Info
            gradcam = data.get('gradcam', data.get('saliency_map', None))
            panels.append(add_title_bar(create_heatmap_overlay(frame, gradcam), method_titles['gradcam']))
            
            integrated = data.get('integrated', data.get('saliency_map', None))
            panels.append(add_title_bar(create_heatmap_overlay(frame, integrated), method_titles['integrated']))
            
            step = data.get('step', i)
            panels.append(add_title_bar(create_info_panel(frame, step), method_titles['info']))
            
            for idx, p in enumerate(panels):
                if p.shape[0] != panel_h or p.shape[1] != panel_w:
                    panels[idx] = cv2.resize(p, (panel_w, panel_h))
            
            row1 = np.hstack([panels[0], panels[1], panels[2]])
            row2 = np.hstack([panels[3], panels[4], panels[5]])
            grid = np.vstack([row1, row2])
            
            grid_bgr = cv2.cvtColor(grid, cv2.COLOR_RGB2BGR)
            out.write(grid_bgr)
            frames_written += 1
            
        except Exception as e:
            print(f"‚ö†Ô∏è Error processing frame {i}: {e}")
            continue
    
    out.release()
    
    if frames_written > 0:
        print(f"‚úÖ Heatmap video saved: {output_path} ({frames_written} frames)")
    else:
        print(f"‚ùå No frames written!")

In [None]:
# Matplotlib mock to prevent kernel crash due to NumPy/ABI conflicts
mock_mpl = MagicMock()
sys.modules["matplotlib"] = mock_mpl
sys.modules["matplotlib.pyplot"] = mock_mpl
sys.modules["matplotlib.cm"] = mock_mpl
sys.modules["matplotlib.colors"] = mock_mpl
sys.modules["matplotlib.transforms"] = mock_mpl
sys.modules["matplotlib.ticker"] = mock_mpl
sys.modules["matplotlib._path"] = mock_mpl

# LIBERO imports
LIBERO_REPO_ROOT = Path('LIBERO')
if LIBERO_REPO_ROOT.exists() and str(LIBERO_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(LIBERO_REPO_ROOT))

try:
    from robosuite.utils.numba import jit_decorator
except Exception:
    pass

from libero.libero import get_libero_path
from libero.libero.benchmark import get_benchmark
from libero.libero.envs import OffScreenRenderEnv
from libero.libero.utils.time_utils import Timer
from libero.libero.utils.video_utils import VideoWriter

print("‚úì LIBERO imports successful")

In [None]:
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def _merge_training_config(stored: Dict[str, Any]):
    """Merge stored config into a simple object."""
    class ConfigObj:
        def __init__(self, **entries): 
            self.__dict__.update(entries)
    return ConfigObj(**stored)

def _stack_vector_obs(obs: Any) -> Dict[str, np.ndarray]:
    """Stack observations from vectorized environment."""
    if isinstance(obs, list):
        keys = obs[0].keys()
        return {k: np.stack([o[k] for o in obs], axis=0) for k in keys}
    return obs

def _select_camera_key(obs_batch: Dict[str, np.ndarray]) -> str:
    """Select appropriate camera key from observation."""
    for key in ('agentview_rgb', 'agentview_image', 'robot0_agentview_image'):
        if key in obs_batch:
            return key
    return list(obs_batch.keys())[0]

def _prepare_policy_input(images: np.ndarray, device: torch.device) -> torch.Tensor:
    """Prepare images for policy input."""
    imgs = torch.from_numpy(images).to(device=device, dtype=torch.float32) / 255.0
    return imgs.permute(0, 3, 1, 2).contiguous()


class SequentialVectorEnv:
    """Simple sequential vectorized environment wrapper."""
    
    def __init__(self, env_fns: List[Callable]):
        self.envs = [fn() for fn in env_fns]
    
    def step(self, actions):
        results = [env.step(a) for env, a in zip(self.envs, actions)]
        obs_list, rews_list, dones_list, infos_list = zip(*results)
        return list(obs_list), np.array(rews_list), np.array(dones_list), list(infos_list)
    
    def reset(self):
        return [env.reset() for env in self.envs]
    
    def seed(self, seed):
        for i, env in enumerate(self.envs):
            if hasattr(env, 'seed'):
                env.seed(seed + i)
    
    def set_init_state(self, states):
        return [env.set_init_state(s) for env, s in zip(self.envs, states)]
    
    def close(self):
        for env in self.envs:
            env.close()

In [None]:
def evaluate_model(
    checkpoint_path: str = 'models/model.pt',
    action_stats_path: str = 'action_stats.json',
    benchmark: str = 'libero_spatial',
    task_id: int = 0,
    env_num: int = 10,
    max_steps: int = 600,
    seed: int = 42,
    save_videos: bool = True,
    video_dir: str = 'evaluation_videos',
    camera_height: int = 128,
    camera_width: int = 128,
    video_skip: int = 1,
    generate_heatmaps: bool = False,
    heatmap_interval: int = 10
) -> Dict[str, Any]:
    """
    Evaluate model on a LIBERO task.
    
    Args:
        checkpoint_path: path to model checkpoint
        action_stats_path: path to action statistics
        benchmark: benchmark name ('libero_spatial', 'libero_goal', etc.)
        task_id: task ID to evaluate
        env_num: number of parallel environments
        max_steps: maximum steps per episode
        seed: random seed
        save_videos: whether to save execution videos
        video_dir: video output directory
        camera_height, camera_width: camera dimensions
        video_skip: frame skip for videos
        generate_heatmaps: whether to generate saliency heatmap videos
        heatmap_interval: sampling interval for heatmaps
    
    Returns:
        Dict with success_rate and other metrics
    """
    print(f"üîç Evaluating Task {task_id} on {benchmark}...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    cfg = _merge_training_config(ckpt.get('config', {}))
    policy = build_policy_from_config(cfg, obs_shape=(3, camera_height, camera_width)).to(device)
    policy.load_state_dict(ckpt['model_state_dict'])
    policy.eval()
    
    stats = json.load(open(action_stats_path))
    action_mean = torch.tensor(stats['mean'], device=device).unsqueeze(0)
    action_std = torch.tensor(stats['std'], device=device).unsqueeze(0)
    action_dim = int(action_mean.shape[-1])
    
    benchmark_map = {
        'libero_10': 'LIBERO_10',
        'libero_spatial': 'LIBERO_SPATIAL',
        'libero_goal': 'LIBERO_GOAL'
    }
    suite = get_benchmark(benchmark_map.get(benchmark, benchmark))(0)
    task = suite.get_task(task_id)
    task_prompt = task.language
    use_prompts = getattr(policy, 'use_text_prompts', False)
    
    if use_prompts:
        print(f"   Prompt: '{task_prompt}'")
    
    visual_explainer = None
    heatmap_frames = []
    original_frames = []
    
    if generate_heatmaps:
        visual_explainer = VisualExplainer(policy, device)
        print(f"   Heatmap generation enabled (interval: {heatmap_interval})")
    
    env_args = {
        'bddl_file_name': str(Path(get_libero_path('bddl_files')) / task.problem_folder / task.bddl_file),
        'camera_heights': camera_height,
        'camera_widths': camera_width
    }
    env = SequentialVectorEnv([lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)])
    
    try:
        init_states = torch.load(
            str(Path(get_libero_path('init_states')) / task.problem_folder / task.init_states_file),
            map_location='cpu', weights_only=False
        )
        obs = env.reset()
        env.seed(seed)
        env.set_init_state(init_states[0:env_num])
        
        dones = [False] * env_num
        successes = np.zeros(env_num, dtype=bool)
        video_path = Path(video_dir) / f"task_{task_id:02d}"
        
        with VideoWriter(str(video_path), save_videos) as video_writer:
            for step in range(max_steps):
                obs_batch = _stack_vector_obs(obs)
                cam_key = _select_camera_key(obs_batch)
                
                alive = [i for i, d in enumerate(dones) if not d]
                if not alive:
                    break
                
                vis_batch = obs_batch[cam_key][alive]
                p_in = _prepare_policy_input(vis_batch, device)
                
                with torch.no_grad():
                    prompt_batch = [task_prompt for _ in alive] if use_prompts else None
                    actions_alive = policy(p_in, prompt_batch)
                    actions_alive = actions_alive * action_std + action_mean
                    full_actions = np.zeros((env_num, action_dim), dtype=np.float32)
                    full_actions[alive] = actions_alive.detach().cpu().numpy()
                
                if visual_explainer is not None and step % heatmap_interval == 0 and len(alive) > 0:
                    try:
                        single_obs = p_in[0:1]
                        saliency_maps = visual_explainer.compute_all_methods(single_obs, task_prompt)
                        
                        heatmap_frames.append({
                            'step': step,
                            'vanilla': saliency_maps['vanilla'],
                            'smoothgrad': saliency_maps['smoothgrad'],
                            'gradcam': saliency_maps['gradcam'],
                            'integrated': saliency_maps['integrated']
                        })
                        
                        orig_frame = vis_batch[0].copy()
                        if orig_frame.dtype != np.uint8:
                            orig_frame = ((orig_frame * 255) if orig_frame.max() <= 1.0 else orig_frame).astype(np.uint8)
                        original_frames.append(orig_frame)
                        
                    except Exception as e:
                        print(f"‚ö†Ô∏è Heatmap error at step {step}: {e}")
                
                obs, reward, done_batch, info = env.step(full_actions)
                
                for i in alive:
                    if reward[i] != 0.0:
                        successes[i] = True
                    dones[i] = dones[i] or bool(done_batch[i])
                
                if save_videos and step % video_skip == 0:
                    video_writer.append_vector_obs(obs, dones, camera_name=cam_key)
        
        success_rate = float(successes.mean())
        print(f"   Result: {successes.sum()}/{env_num} successes ({success_rate:.0%})")
        
        results = {
            'success_rate': success_rate,
            'episodes': int(env_num),
            'max_steps': int(max_steps),
            'task_id': task_id,
            'task_prompt': task_prompt
        }
        
    finally:
        env.close()
    
    if save_videos:
        src = Path(video_dir) / f"task_{task_id:02d}" / "video.mp4"
        dst = Path(video_dir) / f"task_{task_id:02d}.mp4"
        if src.exists():
            dst.parent.mkdir(parents=True, exist_ok=True)
            src.rename(dst)
            try:
                src.parent.rmdir()
            except OSError:
                pass
    
    if generate_heatmaps and heatmap_frames and original_frames:
        heatmap_video_path = str(Path(video_dir) / f"task_{task_id:02d}_heatmaps.mp4")
        generate_heatmap_video(heatmap_frames, original_frames, heatmap_video_path, fps=5)
    
    return results

In [None]:
final_results = []
for task_id in range(10):
    result = evaluate_model(
        task_id=task_id,
        env_num=5,
        max_steps=600,
        save_videos=True,
        generate_heatmaps=True,
        heatmap_interval=10
    )
    final_results.append(result['success_rate'])

mean_success = np.mean(final_results)
print(f"\n{'='*60}")
print(f"EVALUATION SUMMARY")
print(f"{'='*60}")
print(f"Mean success rate: {mean_success:.2%}")
print(f"Success rates: {[f'{r:.0%}' for r in final_results]}")

top_3 = sorted(range(len(final_results)), key=lambda i: final_results[i], reverse=True)[:3]
print(f"Best tasks: {top_3} ({[f'{final_results[i]:.0%}' for i in top_3]})")

worst_3 = sorted(range(len(final_results)), key=lambda i: final_results[i])[:3]
print(f"Worst tasks: {worst_3} ({[f'{final_results[i]:.0%}' for i in worst_3]})")