# Libero

download del dataset Libero 
python3 ./LIBERO/benchmark_scripts/download_libero_datasets.py --datasets libero_goal --use-huggingface 
mv ./LIBERO/libero/datasets/* ./dataset

In [52]:
import sys
from unittest.mock import MagicMock

# The Patch
mock_mpl = MagicMock()
sys.modules["matplotlib"] = mock_mpl
sys.modules["matplotlib.pyplot"] = mock_mpl
sys.modules["matplotlib.cm"] = mock_mpl
sys.modules["matplotlib.colors"] = mock_mpl
sys.modules["matplotlib.transforms"] = mock_mpl
sys.modules["matplotlib.ticker"] = mock_mpl
sys.modules["matplotlib._path"] = mock_mpl

In [53]:
import h5py
import numpy as np
from PIL import Image
from IPython.display import display, HTML

# Path to your demo file
file_path = "dataset/libero_goal/pick_up_the_black_bowl_from_table_center_and_place_it_on_the_plate_demo.hdf5"

print(f"Opening file: {file_path}")

try:
    with h5py.File(file_path, "r") as f:
        # Loop through all demos in the file
        for demo_name in f["data"]:
            print(f"\n=== Demo: {demo_name} ===")
            
            # Access the image data (AgentView RGB)
            # Note: Ensure this path exists in your specific HDF5 structure
            if "obs/agentview_rgb" in f[f"data/{demo_name}"]:
                dataset = f[f"data/{demo_name}/obs/agentview_rgb"]
                num_images = dataset.shape[0]
                print(f"Total frames: {num_images}")
                
                # Pick indices: every 15th frame + the last one
                indices = list(range(0, num_images, 15))
                if num_images - 1 not in indices:
                    indices.append(num_images - 1)
                
                # Display images horizontally using HTML/PIL (No Matplotlib)
                images_html = []
                for idx in indices:
                    img_array = dataset[idx]
                    
                    # Convert numpy array to PIL Image
                    # (Robosuite images are usually already correct, but sometimes flipped)
                    img = Image.fromarray(img_array)
                    
                    # Resize for smaller display if needed
                    img_small = img.resize((128, 128)) 
                    
                    # Hack to display inline in loop
                    print(f"Frame {idx}:")
                    #display(img)
            else:
                print(f"Skipping {demo_name}: 'obs/agentview_rgb' not found.")
                
except Exception as e:
    print(f"An error occurred: {e}")

Opening file: dataset/libero_goal/pick_up_the_black_bowl_from_table_center_and_place_it_on_the_plate_demo.hdf5
An error occurred: [Errno 2] Unable to synchronously open file (unable to open file: name = 'dataset/libero_goal/pick_up_the_black_bowl_from_table_center_and_place_it_on_the_plate_demo.hdf5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)


# Libraries

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# Check Einops (we'll use it a lot in the recursive model)
try:
    from einops import rearrange
    print("‚úÖ Einops available")
except ImportError:
    print("‚ùå Please install einops: pip install einops")

# Check timm for vision encoder
try:
    import timm
    print("‚úÖ timm available")
except ImportError:
    print("‚ùå Please install timm: pip install timm")

# Check transformers for text encoder
try:
    import transformers
    print("‚úÖ transformers available")
except ImportError:
    print("‚ùå Please install transformers: pip install transformers")

print("‚úÖ Environment ready for the TinyRecursive model.")

PyTorch Version: 2.9.1+cu128
‚úÖ GPU Disponibile: NVIDIA GeForce RTX 3090 Ti
Einops installato correttamente.
‚úÖ Ambiente pronto per il modello TinyRecursive.


In [None]:
# Base configuration for reproducibility (important for thesis)
torch.manual_seed(12345)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(12345)
random.seed(12345)

if torch.cuda.is_available():
    torch.cuda.manual_seed(12345)
    torch.cuda.manual_seed_all(12345)  # If using multiple GPUs

print("‚úÖ Environment ready for the TinyRecursive model.")

# =========================

@dataclass
class TrainingConfig:
    """Structured container for training/model hyperparameters."""

    lr: float = 3e-4
    hidden_dim: int = 256
    num_recursions: int = 8
    epochs: int = 20
    batch_size: int = 64
    weight_decay: float = 1e-4
    grad_clip: Optional[float] = 1.0
    sched_T0: Optional[int] = None
    sched_T_mult: int = 1
    lr_min: float = 1e-6
    warmup_epochs: int = 3
    early_stop_patience: Optional[int] = None
    save_path: str = 'best_model.pt'
    freeze_backbone: bool = True
    augmentation: bool = False
    attention_heads: int = 8
    attention_dropout: float = 0.1
    text_encoder: str = 'bert-base-uncased'
    max_text_length: int = 256
    scheduler: str = 'warmup_cosine'
    dropout: float = 0.2

    # -- PREVIOUS ACTION SUPPORT --
    use_previous_actions: bool = False
    previous_action_steps: int = 4
    action_fusion_strategy: str = 'weighted_sum'  
    weight_learning_rate: float = 1e-3
    weight_decay_fusion: float = 1e-3
    freeze_weights_epochs: int = 5

    # Seed for reproducibility
    seed: int = 12345
    
    # Benchmark
    benchmark: str = 'libero_goal'
    task_filter: Optional[List[str]] = None
    dataset_path: str = 'dataset/libero_goal'
    demo_split_ratio: float = 0.8
    max_demos_per_task: Optional[int] = None
    
    # Loss function
    loss_type: str = 'mse'
    reconstruction_weight: float = 0.0
    
    # Model specification
    model_type: str = 'text_encoder_plus'
    
    # Validation
    val_frequency: int = 5
    video_frequency: int = 50
    
    # Optimization
    optimizer: str = 'AdamW'
    beta1: float = 0.9
    beta2: float = 0.999
    eps: float = 1e-8
    
    # Performance
    device: Optional[str] = None
    mixed_precision: bool = False
    num_workers: int = 4
    pin_memory: bool = True
    cache_data: bool = False

@dataclass

Using device: cuda


In [56]:
import numpy as np
import h5py
import sys
from pathlib import Path
from PIL import Image
from IPython.display import display

# --- HELPER FUNCTIONS (Kept largely the same) ---

def load_images_robust(dataset):
    """
    Carica immagini da dataset HDF5 usando metodo robusto.
    """
    shape = dataset.shape
    
    # METODO 1: Lettura diretta uint8
    try:
        buffer = np.empty(shape, dtype=np.uint8)
        dataset.read_direct(buffer)
        return buffer
    except Exception:
        pass
    
    # METODO 2: Float32 -> Uint8
    try:
        buffer = np.empty(shape, dtype=np.float32)
        dataset.read_direct(buffer)
        if buffer.max() <= 1.0:
            buffer = (buffer * 255).astype(np.uint8)
        else:
            buffer = np.clip(buffer, 0, 255).astype(np.uint8)
        return buffer
    except Exception:
        pass
    
    # METODO 3: Float64 -> Uint8
    try:
        buffer = np.empty(shape, dtype=np.float64)
        dataset.read_direct(buffer)
        if buffer.max() <= 1.0:
            buffer = (buffer * 255).astype(np.uint8)
        else:
            buffer = np.clip(buffer, 0, 255).astype(np.uint8)
        return buffer
    except Exception:
        pass

    # METODO 4: Fallback bytes
    try:
        buffer = np.empty(shape, dtype=np.uint8)
        dataset.id.read(h5py.h5s.ALL, h5py.h5s.ALL, buffer)
        return buffer
    except Exception as e:
        raise RuntimeError(f"Impossibile leggere il dataset: {e}")

def load_actions_robust(dataset):
    """
    Carica azioni da dataset HDF5.
    """
    shape = dataset.shape
    try:
        buffer = np.empty(shape, dtype=np.float32)
        dataset.read_direct(buffer)
        return buffer
    except Exception:
        pass
    
    try:
        buffer = np.empty(shape, dtype=np.float64)
        dataset.read_direct(buffer)
        return buffer.astype(np.float32)
    except Exception as e:
        raise RuntimeError(f"Impossibile leggere le azioni: {e}")

# --- MODIFIED EXPLORATION FUNCTION (No Matplotlib) ---

def explore_libero_dataset(data_path: Path):
    
    # Trova file
    hdf5_files = list(data_path.glob('**/*.hdf5'))
    
    if not hdf5_files:
        print(f"‚ö†Ô∏è Nessun file HDF5 trovato in {data_path}")
        return []
    
    print(f"‚úÖ Trovati {len(hdf5_files)} file HDF5")
    
    # Analizza il primo file
    demo_file = hdf5_files[0]
    print(f"\nüìÑ Analizzando: {demo_file.name}")
    
    try:
        with h5py.File(demo_file, 'r') as f:
            if 'data' not in f:
                print("‚ö†Ô∏è Chiave 'data' non trovata")
                return hdf5_files
            
            data_group = f['data']
            demo_keys = list(data_group.keys())
            first_demo_key = demo_keys[0]
            demo_0 = data_group[first_demo_key]
            
            imgs = None
            
            # 1. Caricamento Immagini
            if 'obs' in demo_0:
                obs_group = demo_0['obs']
                
                # Strategia di ricerca chiave immagine
                image_keys = ['agentview_rgb', 'agentview_image', 'rgb', 'image', 'robot0_eye_in_hand_image']
                img_key = next((k for k in image_keys if k in obs_group), None)
                
                # Fallback ricerca generica
                if img_key is None:
                    img_key = next((k for k in obs_group.keys() if 'rgb' in k.lower() or 'image' in k.lower()), None)
                
                if img_key:
                    print(f"\nüñºÔ∏è Usando chiave immagini: '{img_key}'")
                    try:
                        imgs = load_images_robust(obs_group[img_key])
                        print(f"  ‚úÖ Immagini caricate: {imgs.shape}")
                    except Exception as e:
                        print(f"  ‚ùå Errore immagini: {e}")
            
            # 2. Caricamento Azioni
            if 'actions' in demo_0:
                try:
                    actions = load_actions_robust(demo_0['actions'])
                    print(f"\nüéÆ Azioni caricate: {actions.shape}")
                    print(f"  Range: [{actions.min():.3f}, {actions.max():.3f}]")
                except Exception as e:
                    print(f"  ‚ùå Errore azioni: {e}")

            # 3. VISUALIZZAZIONE (Senza Matplotlib)
            if imgs is not None and len(imgs) > 0:
                print("\nüé¨ Visualizzazione frame esempio (PIL/IPython):")
                
                num_frames = min(4, len(imgs))
                indices = np.linspace(0, len(imgs) - 1, num_frames, dtype=int)
                
                for idx in indices:
                    img_array = imgs[idx]
                    
                    # Se l'immagine √® float [0,1], converti a uint8
                    if img_array.dtype != np.uint8:
                         img_array = (np.clip(img_array, 0, 1) * 255).astype(np.uint8)
                    
                    # Crea immagine PIL
                    pil_img = Image.fromarray(img_array)
                    
                    # (Opzionale) Resize per non occupare troppo spazio
                    # pil_img = pil_img.resize((128, 128))
                    
                    print(f"--- Frame {idx} ---")
                    display(pil_img)
            else:
                print("\n‚ö†Ô∏è Nessuna immagine valida da visualizzare")

    except Exception as e:
        print(f"Errore critico durante l'apertura del file: {e}")
    
    return hdf5_files

# Esegui
hdf5_files = explore_libero_dataset(Path('dataset/libero_spatia'))

‚ö†Ô∏è Nessun file HDF5 trovato in dataset/libero_spatia


In [None]:
# --- HELPER FUNCTIONS (Kept largely the same) ---

def _load_images_robust(obs_group):
    """
    Load images from HDF5 dataset using robust method.
    """
    
    # Try multiple possible image keys
    possible_image_keys = [
        'agentview_rgb', 'rgb', 'agentview_image', 
        'camera0_rgb', 'camera_0_rgb', 'frontview_rgb',
        'image', 'obs_rgb', 'robot0_eye_in_hand_image'
    ]
    
    for img_key in possible_image_keys:
        if img_key in obs_group:
            try:
                return np.array(obs_group[img_key])
            except:
                continue
    
    # If no image key is found, raise error
    available_keys = list(obs_group.keys())
    raise ValueError(f"No valid image key found. Available keys: {available_keys}")

def _load_actions_robust(actions_group):
    """
    Load actions from HDF5 dataset.
    """
    return np.array(actions_group).astype(np.float32)

# --- IMPROVED DEMO LOADER ---

def load_demonstrations_from_hdf5(hdf5_path, max_demos=None):
    """
    Loads demonstrations from a single HDF5 file.
    Returns: list of dicts {'obs': images_array, 'actions': actions_array}
    """
    
    demonstrations = []
    
    try:
        with h5py.File(hdf5_path, 'r') as f:
            print(f"üîì Opening {Path(hdf5_path).name}")
            
            # Check 'data' key exists
            if 'data' not in f:
                print("‚ö†Ô∏è 'data' key not found")
                return demonstrations
            
            data_group = f['data']
            demo_keys = [k for k in data_group.keys() if 'demo' in k.lower()]
            
            if max_demos:
                demo_keys = demo_keys[:max_demos]
            
            # 1. Image Loading
            for demo_key in demo_keys:
                try:
                    
                    # Image key search strategy
                    demo_group = data_group[demo_key]
                    obs_group = demo_group['obs']
                    
                    imgs = None
                    for img_key in ['agentview_rgb', 'rgb', 'frontview_rgb']:
                        if img_key in obs_group:
                            print(f"\nüñºÔ∏è Using image key: '{img_key}'")
                            try:
                                imgs = _load_images_robust(obs_group)
                                print(f"  ‚úÖ Images loaded: {imgs.shape}")
                                break
                            except Exception as e:
                                print(f"  ‚ùå Image error: {e}")
                    
                    # 2. Action Loading
                    actions = None
                    try:
                        actions = _load_actions_robust(demo_group['actions'])
                        print(f"\nüéÆ Actions loaded: {actions.shape}")
                    except Exception as e:
                        print(f"  ‚ùå Error loading actions: {e}")
                        continue
                    
                    if imgs is not None and actions is not None:
                        demonstrations.append({
                            'obs': imgs,
                            'actions': actions
                        })
                        print(f"‚úÖ Demo {demo_key} loaded successfully")
                    else:
                        print(f"‚ùå Demo {demo_key} incomplete")
                    
                except Exception as e:
                    print(f"‚ùå Error in demo {demo_key}: {e}")
                    continue
            
            print(f"üéØ Loaded {len(demonstrations)} demonstrations from {Path(hdf5_path).name}")
            
    except Exception as e:
        print(f"üí• Critical error opening file: {e}")
    
    return demonstrations

# Load and test with one HDF5 file

if True:  # Set to False to skip this test
    test_file = "dataset/libero_goal/put_the_bowl_on_the_plate_demo.hdf5"
    
    if Path(test_file).exists():
        print(f"\n{'='*60}")
        print("üß™ TESTING DEMO LOADING")
        print(f"{'='*60}")
        
        demos = load_demonstrations_from_hdf5(test_file, max_demos=2)
        
        if demos:
            print(f"\nüìä SUMMARY:")
            print(f"   Loaded {len(demos)} demonstrations")
            for i, demo in enumerate(demos):
                print(f"   Demo {i}: {demo['obs'].shape} obs, {demo['actions'].shape} actions")
        else:
            print("‚ùå No demonstrations loaded!")
    else:
        print(f"‚ö†Ô∏è Test file not found: {test_file}")


# =========================

class VisionEncoder(nn.Module):
    """
    Load visual observations and actions from HDF5 files.
    
    This dataset handles training/validation split in a more sophisticated way:
    instead of splitting entire files, it splits the demos WITHIN each file
    (e.g., 80% train, 20% val per file).
    """
    
    def __init__(
        self,
        pretrained_name: str = 'resnet18',
        output_dim: int = 256,
        freeze: bool = True
    ):
        super().__init__()
        
        # Load pretrained vision model
        if 'resnet' in pretrained_name:
            self.backbone = getattr(torchvision.models, pretrained_name)(pretrained=True)
            # Remove final classification layer
            backbone_dim = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            raise ValueError(f"Unsupported pretrained_name: {pretrained_name}")
        
        # Freeze if requested
        if freeze:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        # Adapter layer
        self.adapter = nn.Sequential(
            nn.Linear(backbone_dim, output_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Normalization parameters for ImageNet
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.float()

        if x.shape[-1] != 224 or x.shape[-2] != 224:
            x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        x = (x - self.mean) / self.std
        features = self.backbone(x).flatten(start_dim=1) # [B, 512]
        output = self.adapter(features) # try passing features directly instead of the adapter
        return output

In [None]:
class PretrainedVisualEncoder(nn.Module):
    """Visual encoder based on ResNet18 with adaptive head."""

    def __init__(self, hidden_dim: int = 256, freeze_backbone: bool = True, dropout: float = 0.1, double_visual_features: bool = False):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.freeze_backbone = freeze_backbone
        self.double_visual_features = double_visual_features

        # Load pretrained ResNet18 from torchvision
        resnet18 = torchvision.models.resnet18(weights='IMAGENET1K_V1')
        
        # Remove final classification layer
        self.backbone = nn.Sequential(*list(resnet18.children())[:-1])  # Remove last fc layer
        
        # Freeze backbone if requested
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
            print("üîí Visual backbone frozen")
        
        # Adaptive head
        resnet_output_dim = 512  # ResNet18 fc input features
        
        if double_visual_features:
            self.adaptive_head = nn.Sequential(
                nn.Linear(resnet_output_dim, hidden_dim * 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim * 2, hidden_dim)
            )
        else:
            self.adaptive_head = nn.Sequential(
                nn.Linear(resnet_output_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            )

        # ImageNet normalization
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

        self._init_weights()

    def _init_weights(self):
        """Initialize adaptive head weights."""
        for m in self.adaptive_head:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Input tensor [B, C, H, W] (assumes 0-1 normalized)
            
        Returns:
            Encoded visual features [B, hidden_dim]
        """
        
        if x.shape[-1] != 224 or x.shape[-2] != 224:
            x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        # ImageNet normalization
        x = (x - self.mean) / self.std

        # Pass through backbone
        with torch.no_grad() if self.freeze_backbone else torch.enable_grad():
            features = self.backbone(x)  # [B, 512, 1, 1] for ResNet18
        
        features = features.flatten(start_dim=1)  # [B, 512]
        
        # Apply adaptive head
        visual_features = self.adaptive_head(features)  # [B, hidden_dim]
        
        return visual_features

# =========================

class PromptEncoder(nn.Module):
    """Encodes natural-language task prompts via CLIP ViT-L/14 text tower."""

    def __init__(self, hidden_dim: int = 256, pretrained_name: str = "bert-base-uncased", max_length: int = 256, dropout: float = 0.1):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.max_length = max_length
        self.pretrained_name = pretrained_name

        # BERT tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_name)
        self.bert = AutoModel.from_pretrained(pretrained_name)
        
        # Freeze BERT weights (optional)
        for param in self.bert.parameters():
            param.requires_grad = False
        print(f"üîí BERT model '{pretrained_name}' loaded and frozen")

        # Map BERT output to desired hidden dimension
        bert_hidden_dim = self.bert.config.hidden_size  # 768 for base BERT
        self.projection = nn.Sequential(
            nn.Linear(bert_hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize projection weights."""
        for m in self.projection:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, prompts: List[str]) -> torch.Tensor:
        """
        Encode a batch of text prompts.
        
        Args:
            prompts: List of text strings
            
        Returns:
            Text embeddings [B, hidden_dim]
        """
        device = next(self.parameters()).device
        
        # Tokenize prompts
        encoded = self.tokenizer(
            prompts,
            max_length=self.max_length,
            padding=True,
            truncation=True,
            return_tensors='pt'
        ).to(device)
        
        # Pass through BERT
        with torch.no_grad():  # BERT is frozen
            outputs = self.bert(**encoded)
            # Use [CLS] token representation
            text_features = outputs.last_hidden_state[:, 0]  # [B, 768]
        
        # Project to desired dimension
        text_features = self.projection(text_features)  # [B, hidden_dim]
        
        return text_features

# =========================

class TinyRecursiveModel(nn.Module):
    """
    TinyRecursiveModel with integrated text prompts, visual observations, and action prediction.
    
    Key modifications:
    - Added prompt encoding via BERT
    - Supports both visual-only and vision+language modes
    - Action prediction head with proper normalization
    - Recursive transformer with cross-attention
    """

    def __init__(self, config: TrainingConfig):
        super().__init__()
        
        self.config = config
        self.hidden_dim = config.hidden_dim
        self.num_recursions = config.num_recursions
        self.use_text_prompts = True  # Enable text prompts by default
        
        print(f"üèóÔ∏è Building TinyRecursiveModel:")
        print(f"   Hidden dim: {self.hidden_dim}")
        print(f"   Recursions: {self.num_recursions}")
        print(f"   Text prompts: {self.use_text_prompts}")

        # Vision encoder
        self.visual_encoder = PretrainedVisualEncoder(
            hidden_dim=self.hidden_dim,
            freeze_backbone=config.freeze_backbone,
            dropout=config.dropout
        )
        
        # Text encoder (conditional)
        if self.use_text_prompts:
            self.prompt_encoder = PromptEncoder(
                hidden_dim=self.hidden_dim,
                pretrained_name=config.text_encoder,
                max_length=config.max_text_length,
                dropout=config.dropout
            )

        # Recursive Transformer
        self.recursive_transformer = RecursiveTransformerWithAttention(
            hidden_dim=self.hidden_dim,
            num_heads=config.attention_heads,
            num_recursions=self.num_recursions,
            dropout=config.attention_dropout
        )

        # Action prediction head
        self.action_head = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(self.hidden_dim, 7),  # 7D action space for robotic manipulation
        )

        self._init_weights()
        
        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"üìä Model parameters: {total_params:,} total, {trainable_params:,} trainable")

    def _init_weights(self):
        """Initialize action head weights."""
        for m in self.action_head:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, images: torch.Tensor, prompts: Optional[List[str]] = None) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            images: Visual observations [B, C, H, W]
            prompts: Optional text prompts (list of strings)
            
        Returns:
            Predicted actions [B, action_dim]
        """
        batch_size = images.shape[0]
        
        # Encode visual observations
        visual_features = self.visual_encoder(images)  # [B, hidden_dim]
        
        # Encode text prompts (if provided)
        if self.use_text_prompts and prompts is not None:
            text_features = self.prompt_encoder(prompts)  # [B, hidden_dim]
            
            # Combine visual and text features
            # Simple concatenation + projection approach
            combined_features = visual_features + text_features  # Element-wise addition
        else:
            combined_features = visual_features

        # Apply recursive transformer
        refined_features = self.recursive_transformer(combined_features)  # [B, hidden_dim]

        # Predict actions
        actions = self.action_head(refined_features)  # [B, 7]

        return actions

ok


In [None]:
# =========================

class LiberoDataset(Dataset):
    """
    Dataset for LIBERO demonstrations
    
    Supports both training/validation split with stratified sampling (some demonstrations
    from each file go to training, others to validation), as well as split by ratio  
    within each file (e.g., 80% train, 20% val per file).
    """
    
    def __init__(
        self, 
        data_path: str, 
        split: str = 'train',
        demo_split_ratio: float = 0.8,
        max_demos_per_task: Optional[int] = None,
        action_stats: Optional[Dict[str, np.ndarray]] = None,
        normalize_actions: bool = True,
        augment_images: bool = False
    ):
        """
        Args:
            data_path: path to folder with HDF5 files
            split: 'train' or 'val' (ignored if demo_split_ratio is used)
            max_demos_per_task: maximum limit of demos per task (for debugging)
            demo_split_ratio: percentage of demos for training (default 0.8 = 80%)
            action_stats: pre-computed statistics (mean, std) for action normalization
            normalize_actions: whether to normalize actions using statistics
            augment_images: apply augmentations to images during training
        """
        
        self.data_path = Path(data_path)
        self.split = split
        self.demo_split_ratio = demo_split_ratio
        self.max_demos_per_task = max_demos_per_task
        self.normalize_actions = normalize_actions
        self.augment_images = augment_images and split == 'train'
        
        self.action_stats = action_stats.copy() if action_stats else {}
        
        # Load HDF5 files
        hdf5_files = list(self.data_path.glob("*.hdf5"))
        if not hdf5_files:
            raise FileNotFoundError(f"No HDF5 files found in {data_path}")
        
        print(f"üìÅ Found {len(hdf5_files)} HDF5 files")
        
        self.data = []
        all_actions = [] 
        
        for hdf5_file in hdf5_files:
            try:
                with h5py.File(hdf5_file, 'r') as f:
                    prompt = self._prompt_from_filename(hdf5_file)
                    
                    demo_keys = [k for k in f['data'].keys() if k.startswith('demo_')]
                    demo_keys.sort()
                    
                    if self.max_demos_per_task:
                        demo_keys = demo_keys[:self.max_demos_per_task]
                    
                    # Split demos for train/val
                    num_train_demos = int(len(demo_keys) * self.demo_split_ratio)
                    
                    if self.split == 'train':
                        selected_demos = demo_keys[:num_train_demos]
                    elif self.split == 'val':
                        selected_demos = demo_keys[num_train_demos:]
                    else:
                        selected_demos = demo_keys
                    
                    if not selected_demos:
                        print(f"‚ö†Ô∏è No demos for {self.split} split in {hdf5_file.name}")
                        continue
                    
                    task_actions = []
                    for demo_name in selected_demos:
                        try:
                            demo_path = f'data/{demo_name}'
                            obs_path = f'{demo_path}/obs'
                            
                            if 'agentview_rgb' not in f[obs_path]:
                                print(f"‚ö†Ô∏è No 'agentview_rgb' in {demo_name}")
                                continue
                                
                            images = f[f'{obs_path}/agentview_rgb'][:]
                            actions = f[f'{demo_path}/actions'][:]
                            
                            if len(images) != len(actions):
                                print(f"‚ö†Ô∏è Length mismatch: images={len(images)}, actions={len(actions)} for {demo_name}")
                                continue
                            
                            for i in range(len(images)):
                                self.data.append({
                                    'image': images[i],
                                    'action': actions[i].astype(np.float32),
                                    'prompt': prompt,
                                    'demo_id': demo_name,
                                    'timestep': i
                                })
                                
                            task_actions.append(actions)
                            
                        except Exception as e:
                            print(f"‚ö†Ô∏è Error loading {demo_name}: {e}")
                    
                    all_actions.extend(task_actions)
                    print(f"‚úÖ Loaded {len(selected_demos)} demos from {hdf5_file.name} for {self.split}")
                    
            except Exception as e:
                print(f"‚ùå Critical error opening file: {e}")
                continue
        
        split_name = f"{self.split} ({len(self.data)} samples)"
        
        if len(self.data) == 0:
            raise ValueError(f"No valid demonstrations loaded for {split_name}! Check your data files.")
        
        # Calculate action statistics for normalization (only for training set or if not provided)
        if self.normalize_actions and len(all_actions) > 0 and action_stats is None:
            all_actions_concat = np.concatenate(all_actions, axis=0)
        
            mean = all_actions_concat.mean(axis=0).astype(np.float32)
            std  = all_actions_concat.std(axis=0).astype(np.float32)
        
            # ‚ö†Ô∏è Safety floor: avoid too small std that explodes normalization
            std_clipped = np.clip(std, 0.1, None)
        
            # Detailed logging
            print(f"üìä Action statistics computed from {split_name} set:")
            print(f"   Mean:        {np.round(mean, 3)}")
            print(f"   Std (raw):   {np.round(std, 3)}")
            print(f"   Std (clipped to >=0.1): {np.round(std_clipped, 3)}")
        
            self.action_stats['mean'] = mean
            self.action_stats['std']  = std_clipped
        
        elif action_stats is not None:
            print(f"üìä Using provided action statistics")
            self.action_stats = {
                'mean': action_stats['mean'].astype(np.float32),
                'std':  np.clip(action_stats['std'], 0.1, None).astype(np.float32)
            }

        # Build transition index for O(1) access
        self.samples = self._build_sample_index()
        print(f"üì¶ Generated {len(self.samples)} transitions for {split_name}")

    @staticmethod
    def _prompt_from_filename(hdf5_file: Path) -> str:
        """Converts the HDF5 filename to a natural language prompt."""
        name = hdf5_file.stem
        if name.endswith('_demo'):
            name = name[:-5]
        name = name.replace('_', ' ').replace('-', ' ')
        return name.title()
        
    def _build_sample_index(self) -> List[int]:
        """Build a list of valid sample indices."""
        return list(range(len(self.data)))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample_idx = self.samples[idx]
        item = self.data[sample_idx]
        
        image = item['image'].copy()
        action = item['action'].copy()
        prompt = item['prompt']
        
        # Convert image from HWC to CHW and normalize to [0, 1]
        if len(image.shape) == 3 and image.shape[-1] == 3:
            image = image.transpose(2, 0, 1)
        image = image.astype(np.float32) / 255.0
        
        # Apply augmentation if enabled
        if self.augment_images:
            image = self._apply_augmentation(image)
        
        # Normalize actions if statistics are available
        if self.normalize_actions and 'mean' in self.action_stats:
            action = (action - self.action_stats['mean']) / self.action_stats['std']
        
        return {
            'image': torch.from_numpy(image),
            'action': torch.from_numpy(action),
            'prompt': prompt,
            'demo_id': item['demo_id'],
            'timestep': item['timestep']
        }
    
    def _apply_augmentation(self, image):
        """Apply random augmentations to image tensor (CHW format)."""
        # Simple random horizontal flip
        if np.random.random() > 0.5:
            image = image[:, :, ::-1].copy()  # Flip along width dimension
        return image

In [None]:
def train_tiny_recursive_model(
    base_config: TrainingConfig,
    train_dataset: LiberoDataset,
    val_dataset: LiberoDataset,
) -> TinyRecursiveModel:
    """Perform final training starting from the chosen configuration."""

    print(f"\nüöÇ FINAL TRAINING - TinyRecursiveModel")
    print(f"{'='*60}")
    print(f"   Hidden Dim: {base_config.hidden_dim}")
    print(f"   Recursions: {base_config.num_recursions}")
    print(f"   Learning Rate: {base_config.lr}")
    print(f"   Batch Size: {base_config.batch_size}")
    print(f"   Epochs: {base_config.epochs}")
    print(f"   Text Encoder: {base_config.text_encoder}")
    print(f"   Model: {base_config.model_type}")
    
    print(f"\nüìä Dataset Statistics:")
    print(f"   Training samples: {len(train_dataset)}")
    print(f"   Validation samples: {len(val_dataset)}")
    print(f"   Train Action Stats: mean={np.round(train_dataset.action_stats['mean'], 3)}")
    print(f"                       std={np.round(train_dataset.action_stats['std'], 3)}")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"   Device: {device}")

    # Create model
    model = TinyRecursiveModel(base_config).to(device)
    
    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=base_config.batch_size, 
        shuffle=True,
        num_workers=base_config.num_workers,
        pin_memory=base_config.pin_memory
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=base_config.batch_size, 
        shuffle=False,
        num_workers=base_config.num_workers,
        pin_memory=base_config.pin_memory
    )

    # Training setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=base_config.lr, weight_decay=base_config.weight_decay)
    
    # Scheduler
    if base_config.scheduler == 'warmup_cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=base_config.sched_T0 if base_config.sched_T0 else base_config.epochs,
            T_mult=base_config.sched_T_mult,
            eta_min=base_config.lr_min
        )
    elif base_config.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    else:
        scheduler = None

    criterion = nn.MSELoss()
    
    print(f"\nüéØ Starting Training Loop...")
    
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    
    for epoch in range(base_config.epochs):
        
        # === TRAINING PHASE ===
        model.train()
        epoch_train_loss = 0.0
        num_train_batches = 0
        
        for batch in train_loader:
            images = batch['image'].to(device)
            actions = batch['action'].to(device)
            prompts = batch['prompt']
            
            optimizer.zero_grad()
            
            predicted_actions = model(images, prompts)
            loss = criterion(predicted_actions, actions)
            
            loss.backward()
            
            if base_config.grad_clip:
                torch.nn.utils.clip_grad_norm_(model.parameters(), base_config.grad_clip)
            
            optimizer.step()
            
            epoch_train_loss += loss.item()
            num_train_batches += 1
        
        avg_train_loss = epoch_train_loss / num_train_batches if num_train_batches > 0 else 0
        train_losses.append(avg_train_loss)
        
        # === VALIDATION PHASE ===
        model.eval()
        epoch_val_loss = 0.0
        num_val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(device)
                actions = batch['action'].to(device)
                prompts = batch['prompt']
                
                predicted_actions = model(images, prompts)
                loss = criterion(predicted_actions, actions)
                
                epoch_val_loss += loss.item()
                num_val_batches += 1
        
        avg_val_loss = epoch_val_loss / num_val_batches if num_val_batches > 0 else float('inf')
        val_losses.append(avg_val_loss)
        
        # Learning rate scheduling
        if scheduler:
            scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        
        # Logging
        print(f"Epoch {epoch+1:2d}/{base_config.epochs}: "
              f"Train Loss = {avg_train_loss:.6f}, "
              f"Val Loss = {avg_val_loss:.6f}, "
              f"LR = {current_lr:.2e}")
        
        # Early stopping and best model saving
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            
            # Save best model
            torch.save({
                'model_state_dict': model.state_dict(),
                'config': base_config,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'epoch': epoch,
                'best_val_loss': best_val_loss,
                'action_stats': train_dataset.action_stats
            }, base_config.save_path)
            print(f"         ‚ú® New best model saved! Val Loss: {best_val_loss:.6f}")
        else:
            patience_counter += 1
        
        # Early stopping check
        if base_config.early_stop_patience and patience_counter >= base_config.early_stop_patience:
            print(f"üõë Early stopping triggered after {patience_counter} epochs without improvement")
            break
    
    print(f"\n‚úÖ Training completed!")
    print(f"   Best validation loss: {best_val_loss:.6f}")
    print(f"   Model saved to: {base_config.save_path}")
    
    # Load best model for return
    best_checkpoint = torch.load(base_config.save_path, map_location=device, weights_only=False)
    model.load_state_dict(best_checkpoint['model_state_dict'])
    
    return model

In [None]:
if __name__ == "__main__":
    
    # Quick test to create datasets and verify they work
    print("üß™ TESTING DATASET CREATION")
    print("="*60)
    
    # Parameters
    data_path = 'dataset/libero_goal'  # Use libero_goal for testing
    
    if not Path(data_path).exists():
        print(f"‚ùå Data path not found: {data_path}")
        print("   Please ensure you have the LIBERO dataset downloaded")
    else:
        print(f"‚úÖ Data path found: {data_path}")
        
        # Find HDF5 files
        hdf5_files = list(Path(data_path).glob("*.hdf5"))
        print(f"üìÅ Found {len(hdf5_files)} HDF5 files")
        
        if len(hdf5_files) == 0:
            print("‚ùå No HDF5 files found!")
        else:
            # Demo-level split: use ALL files for both datasets
            # but split the demos WITHIN each file (80% train, 20% val per task)
            demo_split_ratio = 0.8
            print(f"\nüìä Demo-level split: {demo_split_ratio:.0%} train / {1-demo_split_ratio:.0%} val per task")
            print(f"   All {len(hdf5_files)} tasks present in both train and val")
            
            # Create datasets with demo-level split
            print("\nCreating TRAIN dataset...")
            train_dataset = LiberoDataset(
                data_path,  # Use ALL files
                split='train',
                demo_split_ratio=demo_split_ratio,
                max_demos_per_task=50,  # Limit for speed
                normalize_actions=True,
                augment_images=False
            )
            
            # Use the same statistics from training set for validation
            train_action_stats = train_dataset.action_stats
            
            print("\nCreating VAL dataset...")
            val_dataset = LiberoDataset(
                data_path,  # Use ALL files
                split='val',
                demo_split_ratio=demo_split_ratio,
                action_stats=train_action_stats,  # Use training stats
                normalize_actions=True,
                augment_images=False
            )
            
            print(f"\n‚úÖ Dataset creation completed!")
            print(f"   Training samples: {len(train_dataset)}")
            print(f"   Validation samples: {len(val_dataset)}")
            print(f"   Action dimension: {len(train_action_stats['mean'])}")
            
            # Test a few samples
            print(f"\nüß™ Testing sample loading...")
            
            # Test training dataset
            for i in range(min(3, len(train_dataset))):
                sample = train_dataset[i]
                print(f"   Train sample {i}: image={sample['image'].shape}, action={sample['action'].shape}")
                print(f"                     prompt='{sample['prompt'][:50]}...'")
            
            # Test validation dataset
            for i in range(min(2, len(val_dataset))):
                sample = val_dataset[i]
                print(f"   Val sample {i}: image={sample['image'].shape}, action={sample['action'].shape}")
                
            print(f"\n‚úÖ Sample loading test passed!")
            
            # Quick training test with minimal configuration
            print(f"\nüöÇ QUICK TRAINING TEST")
            print("="*40)
            
            config = TrainingConfig(
                lr=1e-3,
                hidden_dim=128,  # Reduced for speed
                num_recursions=2,  # Reduced for speed
                epochs=3,  # Very few epochs for testing
                batch_size=8,   # Small batch for testing
                text_encoder='bert-base-uncased',
                save_path='test_model.pt',
                early_stop_patience=None  # Disable early stopping for test
            )
            
            try:
                model = train_tiny_recursive_model(config, train_dataset, val_dataset)
                print(f"‚úÖ Quick training test completed successfully!")
                
                # Test inference
                print(f"\nüîÆ Testing inference...")
                model.eval()
                with torch.no_grad():
                    sample = train_dataset[0]
                    test_image = sample['image'].unsqueeze(0)  # Add batch dimension
                    test_prompt = [sample['prompt']]
                    
                    device = next(model.parameters()).device
                    test_image = test_image.to(device)
                    
                    predicted_action = model(test_image, test_prompt)
                    print(f"   Input image shape: {test_image.shape}")
                    print(f"   Input prompt: '{test_prompt[0]}'")
                    print(f"   Predicted action: {predicted_action[0].cpu().numpy()}")
                    print(f"   Ground truth action: {sample['action'].numpy()}")
                    
                print(f"‚úÖ Inference test completed!")
                
                # Clean up test model
                if Path(config.save_path).exists():
                    Path(config.save_path).unlink()
                    print(f"üóëÔ∏è Cleaned up test model file")
                    
            except Exception as e:
                print(f"‚ùå Training test failed: {e}")
                import traceback
                traceback.print_exc()
        
    print(f"\nüéØ Dataset testing completed!")

In [62]:
main_pipeline(quick_search=False, train_final=False, evaluate=False, data_path='dataset/libero_spatial')


    ü§ñ TinyRecursiveModels per Controllo Robotico
    
    Obiettivi:
    1. Adattare architettura TRM per robotica
    2. Training con Behavior Cloning su LIBERO
    3. Valutazione in simulazione con metriche quantitative e qualitative
    
    
‚úì Using device: cuda


STEP 1: Caricamento Dataset
‚úì Trovati 10 file HDF5 (task)

üìä Demo-level split: 80% train / 20% val per ogni task
   Tutti i 10 task presenti in entrambi train e val

Creating TRAIN dataset...
Loading 10 HDF5 files for TRAIN (demo split: 80%)...


‚úÖ Loaded 400 demonstrations for TRAIN
üìä Action statistics computed from TRAIN set:
   Mean:        [ 0.15   0.134 -0.155 -0.005 -0.011 -0.02   0.093]
   Std (raw):   [0.414 0.347 0.508 0.038 0.072 0.058 0.996]
   Std (clipped to >=0.1): [0.414 0.347 0.508 0.1   0.1   0.1   0.996]
üì¶ Generated 49735 transitions for TRAIN

Creating VAL dataset...
Loading 10 HDF5 files for VAL (demo split: 80%)...
‚úÖ Loaded 100 demonstrations for VAL
üìä Using provided action statistics
üì¶ Generated 12515 transitions for VAL

‚úì Dataset creati con demo-level split
  Train samples: 49735
  Val samples: 12515

‚úÖ Pipeline completata!


# Evaluation

In [63]:
import sys
import numpy as np
from unittest.mock import MagicMock
from dataclasses import fields, dataclass
from typing import Any, Dict, List, Callable, Optional, Tuple
import torch
import json
import os
from pathlib import Path

In [None]:
# %%
import json
import numpy as np
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path

class TextExplainer:
    """
    Class to calculate gradient-based explainability for text prompts.
    
    Calculates which prompt tokens are most important for action prediction,
    using backpropagation through the text encoder.
    """
    
    def __init__(self, model: nn.Module, prompt_encoder: PromptEncoder, device: torch.device):
        """
        Args:
            model: TRMPolicy model
            prompt_encoder: PromptEncoder for tokenization and encoding
            device: device for computation
        """
        self.model = model
        self.prompt_encoder = prompt_encoder
        self.device = device
        
    def compute_token_saliency(
        self,
        obs: torch.Tensor,
        prompt: str,
        target_action: Optional[torch.Tensor] = None
    ) -> Dict[str, Any]:
        """
        Calculate the saliency of each prompt token with respect to action prediction.
        
        Args:
            obs: Visual observation [1, C, H, W]
            prompt: Text prompt string
            target_action: Optional target action for supervised saliency [1, action_dim]
        
        Returns:
            Dict containing token strings, saliency scores, and metadata
        """
        
        # Tokenize prompt
        tokenizer = self.prompt_encoder.tokenizer
        encoded = tokenizer(
            prompt,
            max_length=self.prompt_encoder.max_length,
            padding=True,
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        
        input_ids = encoded['input_ids']  # [1, seq_len]
        attention_mask = encoded['attention_mask']  # [1, seq_len]
        
        # Extract token strings for interpretation
        token_strings = tokenizer.convert_ids_to_tokens(input_ids[0])
        
        # Get text embeddings with gradient computation enabled
        with torch.no_grad():
            # Get BERT embeddings
            bert_outputs = self.prompt_encoder.bert(input_ids=input_ids, attention_mask=attention_mask)
            bert_embeddings = bert_outputs.last_hidden_state  # [1, seq_len, 768]
            
        # Extract [CLS] token and project
        cls_embedding = bert_embeddings[:, 0]  # [1, 768]
        cls_embedding.requires_grad_(True)
        
        # Project to model hidden dim
        text_features = self.prompt_encoder.projection(cls_embedding)  # [1, hidden_dim]
        
        # Get model prediction
        with torch.no_grad():
            visual_features = self.model.visual_encoder(obs)  # [1, hidden_dim]
        
        # Combine features and predict action
        combined_features = visual_features + text_features
        refined_features = self.model.recursive_transformer(combined_features)
        predicted_actions = self.model.action_head(refined_features)  # [1, action_dim]
        
        # Compute loss for gradient computation
        if target_action is not None:
            # Supervised: use provided target
            loss = nn.functional.mse_loss(predicted_actions, target_action)
        else:
            # Unsupervised: use norm of predicted actions as proxy
            loss = torch.norm(predicted_actions, dim=1).mean()
        
        # Backward pass to compute gradients
        loss.backward()
        
        # Get gradients w.r.t. CLS embedding
        cls_gradients = cls_embedding.grad  # [1, 768]
        
        # Compute importance scores
        # Use gradient x input as attribution measure
        attribution = torch.abs(cls_gradients * cls_embedding.detach())  # [1, 768]
        token_importance = attribution.sum(dim=-1).squeeze().cpu().numpy()  # [seq_len,] proxy
        
        # Since we only have CLS gradients, we'll distribute importance across all tokens
        # This is a simplified approach - more sophisticated methods exist
        num_tokens = attention_mask.sum().item()
        if num_tokens > 0:
            # Distribute CLS importance across all active tokens
            distributed_importance = np.ones(len(token_strings)) * (token_importance / num_tokens)
            # Mask out padding tokens
            mask = attention_mask[0].cpu().numpy()
            distributed_importance = distributed_importance * mask
        else:
            distributed_importance = np.zeros(len(token_strings))
        
        # Normalize scores
        if distributed_importance.max() > 0:
            normalized_scores = distributed_importance / distributed_importance.max()
        else:
            normalized_scores = distributed_importance
            
        return {
            'token_strings': token_strings,
            'raw_scores': distributed_importance,
            'normalized_scores': normalized_scores,
            'attention_mask': attention_mask[0].cpu().numpy(),
            'predicted_actions': predicted_actions.detach().cpu().numpy(),
            'loss': loss.item()
        }
    
    def get_top_k_tokens(
        self,
        saliency_result: Dict[str, Any],
        k: int = 10,
        filter_special: bool = True
    ) -> List[Tuple[str, float]]:
        """
        Extract top-k most important tokens from saliency result.
        
        Args:
            saliency_result: Output from compute_token_saliency
            k: Number of top tokens to return
            filter_special: Whether to filter out special tokens ([CLS], [SEP], [PAD])
        
        Returns:
            List of (token, score) tuples sorted by importance
        """
        tokens = saliency_result['token_strings']
        scores = saliency_result['normalized_scores']
        mask = saliency_result['attention_mask']
        
        # Filter out padding tokens and special tokens if requested
        filtered_tokens_scores = []
        special_tokens = ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']
        
        for i, (token, score) in enumerate(zip(tokens, scores)):
            # Skip padding tokens
            if mask[i] == 0:
                continue
                
            # Skip special tokens if requested
            if filter_special and token in special_tokens:
                continue
            
            filtered_tokens_scores.append((token, float(score)))
        
        # Sort by score (descending) and return top-k
        filtered_tokens_scores.sort(key=lambda x: x[1], reverse=True)
        return filtered_tokens_scores[:k]


def visualize_explainability_results(
    json_path: str,
    top_n_frames: int = 5,
    top_n_tokens: int = 10
):
    """
    Visualize explainability results from JSON file.
    
    Args:
        json_path: Path to explainability JSON file
        top_n_frames: Number of frames to display
        top_n_tokens: Number of tokens to display per frame
    """
    
    if not Path(json_path).exists():
        print(f"‚ùå File not found: {json_path}")
        return
    
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    metadata = data['metadata']
    frames = data['frames']
    
    print(f"üìä Explainability Results for Task {metadata['task_id']}")
    print(f"   Benchmark: {metadata['benchmark']}")
    print(f"   Task prompt: '{metadata['task_prompt']}'")
    print(f"   Success rate: {metadata['success_rate']:.2%}")
    print(f"   Frames analyzed: {metadata['num_frames_analyzed']}")
    print("=" * 80)
    
    # Show top N frames
    for i, frame in enumerate(frames[:top_n_frames]):
        print(f"\nüìΩÔ∏è Frame {i+1} (Step {frame['step']}):")
        print(f"   Top {top_n_tokens} important words:")
        
        for j, token_data in enumerate(frame['top_tokens'][:top_n_tokens], 1):
            token = token_data['token']
            score = token_data['score']
            print(f"     {j:2d}. '{token}': {score:.4f}")
        
        if i < len(frames) - 1:
            print("-" * 40)


def analyze_explainability_across_tasks(
    video_dir: str,
    task_ids: List[int]
) -> Tuple[Dict[str, Any], Dict[int, Dict[str, Any]]]:
    """
    Analyze explainability patterns across multiple tasks.
    
    Args:
        video_dir: Directory containing explainability JSON files
        task_ids: List of task IDs to analyze
    
    Returns:
        Tuple of (aggregated_token_stats, task_summaries)
    """
    
    all_tokens = {}
    task_summaries = {}
    
    for task_id in task_ids:
        json_path = Path(video_dir) / f"task_{task_id:02d}_explainability.json"
        
        if not json_path.exists():
            print(f"‚ö†Ô∏è Skipping task {task_id}: file not found")
            continue
        
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            # Extract task summary
            metadata = data['metadata']
            task_summaries[task_id] = {
                'prompt': metadata['task_prompt'],
                'success_rate': metadata['success_rate'],
                'frames_analyzed': metadata['num_frames_analyzed']
            }
            
            # Aggregate token statistics
            for frame in data['frames']:
                for token_data in frame['top_tokens'][:5]:  # Top 5 per frame
                    token = token_data['token']
                    score = token_data['score']
                    
                    if token not in all_tokens:
                        all_tokens[token] = {
                            'scores': [],
                            'task_appearances': set(),
                            'total_appearances': 0
                        }
                    
                    all_tokens[token]['scores'].append(score)
                    all_tokens[token]['task_appearances'].add(task_id)
                    all_tokens[token]['total_appearances'] += 1
            
            print(f"‚úÖ Processed task {task_id}: {len(data['frames'])} frames")
            
        except Exception as e:
            print(f"‚ùå Error processing task {task_id}: {e}")
    
    # Calculate aggregate statistics
    token_stats = {}
    for token, data in all_tokens.items():
        scores = data['scores']
        token_stats[token] = {
            'mean_score': np.mean(scores),
            'std_score': np.std(scores),
            'max_score': np.max(scores),
            'min_score': np.min(scores),
            'total_appearances': data['total_appearances'],
            'task_count': len(data['task_appearances']),
            'task_ids': sorted(list(data['task_appearances']))
        }
    
    # Sort by mean score
    sorted_tokens = sorted(
        token_stats.items(),
        key=lambda x: x[1]['mean_score'],
        reverse=True
    )
    
    print(f"\nüìä Cross-task Token Analysis:")
    print(f"   Total unique tokens: {len(token_stats)}")
    print(f"   Most important tokens across all tasks:")
    
    for i, (token, stats) in enumerate(sorted_tokens[:15], 1):
        print(f"     {i:2d}. '{token}': avg={stats['mean_score']:.4f}, "
              f"appeared {stats['total_appearances']} times across {stats['task_count']} tasks")
    
    return token_stats, task_summaries


def compare_explainability_success_correlation(
    video_dir: str,
    min_appearances: int = 5
):
    """
    Analyze correlation between token importance and task success rates.
    
    Args:
        video_dir: Directory containing explainability JSON files
        min_appearances: Minimum appearances required for token to be analyzed
    """
    
    # Find all explainability files
    json_files = list(Path(video_dir).glob("*_explainability.json"))
    
    if not json_files:
        print("‚ùå No explainability files found")
        return
    
    print(f"üìä Analyzing {len(json_files)} tasks for success correlation...")
    
    # Collect data
    task_data = []
    all_tokens = {}
    
    for json_path in json_files:
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            metadata = data['metadata']
            task_id = metadata['task_id']
            success_rate = metadata['success_rate']
            
            # Get average token scores for this task
            task_token_scores = {}
            for frame in data['frames']:
                for token_data in frame['top_tokens'][:5]:
                    token = token_data['token']
                    score = token_data['score']
                    
                    if token not in task_token_scores:
                        task_token_scores[token] = []
                    task_token_scores[token].append(score)
            
            # Average scores per token for this task
            task_avg_scores = {
                token: np.mean(scores)
                for token, scores in task_token_scores.items()
            }
            
            task_data.append({
                'task_id': task_id,
                'success_rate': success_rate,
                'token_scores': task_avg_scores,
                'prompt': metadata['task_prompt']
            })
            
            # Aggregate across all tasks
            for token, avg_score in task_avg_scores.items():
                if token not in all_tokens:
                    all_tokens[token] = {
                        'success_rates': [],
                        'importance_scores': []
                    }
                all_tokens[token]['success_rates'].append(success_rate)
                all_tokens[token]['importance_scores'].append(avg_score)
                
        except Exception as e:
            print(f"‚ö†Ô∏è Error processing {json_path.name}: {e}")
    
    # Analyze correlations
    print(f"\nüìà Token-Success Correlation Analysis:")
    print(f"   Minimum appearances threshold: {min_appearances}")
    
    correlations = []
    for token, data in all_tokens.items():
        if len(data['success_rates']) >= min_appearances:
            success_rates = np.array(data['success_rates'])
            importance_scores = np.array(data['importance_scores'])
            
            # Calculate Pearson correlation
            if len(success_rates) > 1 and np.std(success_rates) > 0 and np.std(importance_scores) > 0:
                correlation = np.corrcoef(success_rates, importance_scores)[0, 1]
                correlations.append((token, correlation, len(success_rates)))
    
    # Sort by absolute correlation
    correlations.sort(key=lambda x: abs(x[1]), reverse=True)
    
    print(f"\n   Top positively correlated tokens (higher importance ‚Üí higher success):")
    positive_corrs = [c for c in correlations if c[1] > 0][:10]
    for i, (token, corr, count) in enumerate(positive_corrs, 1):
        print(f"     {i:2d}. '{token}': r={corr:.3f} (n={count})")
    
    print(f"\n   Top negatively correlated tokens (higher importance ‚Üí lower success):")
    negative_corrs = [c for c in correlations if c[1] < 0][:10]
    for i, (token, corr, count) in enumerate(negative_corrs, 1):
        print(f"     {i:2d}. '{token}': r={corr:.3f} (n={count})")
    
    # Task performance summary
    print(f"\nüìä Task Performance Summary:")
    task_data.sort(key=lambda x: x['success_rate'], reverse=True)
    for task_info in task_data:
        print(f"   Task {task_info['task_id']}: {task_info['success_rate']:.2%} - {task_info['prompt']}")

print("\n‚úÖ Explainability analysis functions loaded!")

In [65]:
# ============================================================================
# VISUAL EXPLAINABILITY: Gradient-based Saliency Maps & Heatmap Videos
# ============================================================================

import cv2
from scipy.ndimage import gaussian_filter

class VisualExplainer:
    """
    Classe per calcolare l'explainability visuale basata sui gradienti.
    
    Calcola quali regioni dell'immagine sono pi√π importanti per la predizione
    dell'azione usando backpropagation attraverso il visual encoder.
    
    Metodi disponibili:
    - Vanilla Gradient: gradienti diretti sull'input
    - Integrated Gradients: integrazione lungo un path baseline->input
    - SmoothGrad: media di gradienti con rumore
    - GradCAM-like: gradienti pesati sulle feature maps
    """
    
    def __init__(self, model: nn.Module, device: torch.device):
        """
        Args:
            model: TRMPolicy model
            device: device per computazione
        """
        self.model = model
        self.device = device
        self.feature_maps = None
        self.gradients = None
        
        # Hook per GradCAM-like saliency
        self._register_hooks()
    
    def _register_hooks(self):
        """Registra hooks per catturare feature maps e gradienti dal backbone"""
        def forward_hook(module, input, output):
            self.feature_maps = output.detach()
        
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()
        
        # Trova l'ultimo layer convoluzionale nel backbone
        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'backbone'):
            backbone = self.model.encoder.backbone
            # ResNet backbone - ultimo layer prima del pooling
            for name, module in backbone.named_modules():
                if isinstance(module, nn.Conv2d):
                    self.last_conv = module
            
            if hasattr(self, 'last_conv'):
                self.last_conv.register_forward_hook(forward_hook)
                self.last_conv.register_full_backward_hook(backward_hook)
    
    def compute_visual_saliency(
        self,
        obs: torch.Tensor,
        prompt: str,
        method: str = 'gradcam',
        target_action: Optional[torch.Tensor] = None,
        smooth_samples: int = 20,
        noise_level: float = 0.1
    ) -> Dict[str, Any]:
        """
        Calcola la saliency map per un'immagine di input.
        
        Args:
            obs: Immagine di input (1, C, H, W)
            prompt: Prompt testuale per il task
            method: 'vanilla', 'smoothgrad', 'gradcam', 'integrated'
            target_action: Azione target opzionale
            smooth_samples: Numero di campioni per SmoothGrad
            noise_level: Livello di rumore per SmoothGrad
            
        Returns:
            Dict contenente:
                - saliency_map: Mappa di saliency normalizzata (H, W)
                - importance_score: Score totale dell'importanza visiva
                - raw_gradients: Gradienti grezzi
        """
        self.model.eval()
        
        if method == 'vanilla':
            saliency = self._vanilla_gradient(obs, prompt, target_action)
        elif method == 'smoothgrad':
            saliency = self._smoothgrad(obs, prompt, target_action, smooth_samples, noise_level)
        elif method == 'gradcam':
            saliency = self._gradcam(obs, prompt, target_action)
        elif method == 'integrated':
            saliency = self._integrated_gradients(obs, prompt, target_action)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        # Normalizza saliency map in [0, 1]
        saliency_norm = saliency - saliency.min()
        if saliency_norm.max() > 0:
            saliency_norm = saliency_norm / saliency_norm.max()
        
        # Calcola score totale (somma pesata dell'importanza)
        importance_score = float(saliency.sum())
        
        return {
            'saliency_map': saliency_norm,
            'importance_score': importance_score,
            'raw_saliency': saliency,
            'method': method
        }
    
    def _forward_with_grad(
        self,
        obs: torch.Tensor,
        prompt: str,
        target_action: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass con gradienti abilitati"""
        obs = obs.clone().detach().requires_grad_(True)
        
        # Forward
        if hasattr(self.model, 'use_text_prompts') and self.model.use_text_prompts:
            actions = self.model(obs, [prompt])
        else:
            actions = self.model(obs, None)
        
        # Loss
        if target_action is not None:
            loss = F.mse_loss(actions, target_action)
        else:
            loss = actions.norm()
        
        # Backward
        loss.backward()
        
        return obs.grad
    
    def _vanilla_gradient(
        self,
        obs: torch.Tensor,
        prompt: str,
        target_action: Optional[torch.Tensor] = None
    ) -> np.ndarray:
        """Calcola vanilla gradient saliency"""
        grad = self._forward_with_grad(obs, prompt, target_action)
        
        # Aggrega sui canali con valore assoluto
        saliency = grad.abs().sum(dim=1).squeeze().cpu().numpy()
        
        # Smooth per visualizzazione
        saliency = gaussian_filter(saliency, sigma=2)
        
        return saliency
    
    def _smoothgrad(
        self,
        obs: torch.Tensor,
        prompt: str,
        target_action: Optional[torch.Tensor] = None,
        n_samples: int = 20,
        noise_level: float = 0.1
    ) -> np.ndarray:
        """SmoothGrad: media di gradienti con input rumorosi"""
        saliency_sum = None
        
        for _ in range(n_samples):
            # Aggiungi rumore
            noise = torch.randn_like(obs) * noise_level
            noisy_obs = obs + noise
            
            grad = self._forward_with_grad(noisy_obs, prompt, target_action)
            saliency = grad.abs().sum(dim=1).squeeze().cpu().numpy()
            
            if saliency_sum is None:
                saliency_sum = saliency
            else:
                saliency_sum += saliency
        
        saliency = saliency_sum / n_samples
        saliency = gaussian_filter(saliency, sigma=2)
        
        return saliency
    
    def _gradcam(
        self,
        obs: torch.Tensor,
        prompt: str,
        target_action: Optional[torch.Tensor] = None
    ) -> np.ndarray:
        """GradCAM-like saliency usando feature maps del backbone"""
        self.feature_maps = None
        self.gradients = None
        
        obs_grad = obs.clone().detach().requires_grad_(True)
        
        # Forward
        if hasattr(self.model, 'use_text_prompts') and self.model.use_text_prompts:
            actions = self.model(obs_grad, [prompt])
        else:
            actions = self.model(obs_grad, None)
        
        # Loss
        if target_action is not None:
            loss = F.mse_loss(actions, target_action)
        else:
            loss = actions.norm()
        
        # Backward
        self.model.zero_grad()
        loss.backward()
        
        # Se abbiamo feature maps e gradienti dai hooks
        if self.feature_maps is not None and self.gradients is not None:
            # Global average pooling dei gradienti
            weights = self.gradients.mean(dim=[2, 3], keepdim=True)
            
            # Weighted sum delle feature maps
            cam = (weights * self.feature_maps).sum(dim=1).squeeze()
            cam = F.relu(cam)
            
            # Resize alla dimensione dell'input
            cam = cam.cpu().numpy()
            cam = cv2.resize(cam, (obs.shape[3], obs.shape[2]))
            
            cam = gaussian_filter(cam, sigma=3)
        else:
            # Fallback a vanilla gradient
            cam = self._vanilla_gradient(obs, prompt, target_action)
        
        return cam
    
    def _integrated_gradients(
        self,
        obs: torch.Tensor,
        prompt: str,
        target_action: Optional[torch.Tensor] = None,
        steps: int = 50
    ) -> np.ndarray:
        """Integrated Gradients con baseline nero"""
        baseline = torch.zeros_like(obs)
        scaled_inputs = [baseline + (float(i) / steps) * (obs - baseline) 
                         for i in range(1, steps + 1)]
        
        grads_sum = None
        
        for scaled_input in scaled_inputs:
            grad = self._forward_with_grad(scaled_input, prompt, target_action)
            grad_np = grad.abs().sum(dim=1).squeeze().cpu().numpy()
            
            if grads_sum is None:
                grads_sum = grad_np
            else:
                grads_sum += grad_np
        
        # Media dei gradienti * (input - baseline)
        avg_grads = grads_sum / steps
        integrated = avg_grads * (obs - baseline).abs().sum(dim=1).squeeze().cpu().numpy()
        
        integrated = gaussian_filter(integrated, sigma=2)
        
        return integrated
    
    def generate_heatmap_overlay(
        self,
        obs: torch.Tensor,
        saliency_map: np.ndarray,
        alpha: float = 0.5,
        colormap: int = cv2.COLORMAP_JET
    ) -> np.ndarray:
        """
        Genera un'immagine con heatmap sovrapposta.
        
        Args:
            obs: Immagine originale (1, C, H, W) o (C, H, W) in [0, 1]
            saliency_map: Mappa di saliency normalizzata (H, W)
            alpha: Trasparenza della heatmap
            colormap: Colormap OpenCV
            
        Returns:
            Immagine con heatmap sovrapposta (H, W, 3) in formato uint8
        """
        # Converti osservazione in numpy (H, W, 3)
        if obs.dim() == 4:
            obs = obs.squeeze(0)
        
        img = obs.permute(1, 2, 0).cpu().numpy()
        img = (img * 255).clip(0, 255).astype(np.uint8)
        
        # Assicurati che la saliency abbia la stessa dimensione
        if saliency_map.shape != (img.shape[0], img.shape[1]):
            saliency_map = cv2.resize(saliency_map, (img.shape[1], img.shape[0]))
        
        # Converti saliency in heatmap colorata
        heatmap = (saliency_map * 255).astype(np.uint8)
        heatmap = cv2.applyColorMap(heatmap, colormap)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Overlay
        overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0)
        
        return overlay
    
    def compute_visual_importance_score(
        self,
        obs: torch.Tensor,
        prompt: str,
        method: str = 'gradcam'
    ) -> float:
        """
        Calcola uno score scalare dell'importanza visiva totale.
        
        Utile per confrontare l'importanza visiva vs testuale.
        """
        result = self.compute_visual_saliency(obs, prompt, method=method)
        return result['importance_score']


class MultimodalExplainer:
    """
    Combina TextExplainer e VisualExplainer per analisi multimodale.
    
    Calcola:
    - Importanza relativa text vs visual
    - Correlazione con success rate
    - Pattern temporali di attenzione multimodale
    """
    
    def __init__(
        self,
        model: nn.Module,
        prompt_encoder: 'PromptEncoder',
        device: torch.device
    ):
        self.model = model
        self.device = device
        self.text_explainer = TextExplainer(model, prompt_encoder, device)
        self.visual_explainer = VisualExplainer(model, device)
    
    def compute_multimodal_importance(
        self,
        obs: torch.Tensor,
        prompt: str,
        visual_method: str = 'gradcam'
    ) -> Dict[str, Any]:
        """
        Calcola importanza multimodale (text + visual) con normalizzazione comparabile.
        
        Utilizza Mean-Pooling normalization come in:
        - Chefer et al. (2021) "Generic Attention-model Explainability for Interpreting Bi-Modal 
          and Encoder-Decoder Transformers"
        - Sundararajan et al. (2017) "Axiomatic Attribution for Deep Networks"
        
        La normalizzazione per dimensione rende comparabili i contributi delle due modalit√†:
        - Text: media delle norme L2 sui token (invece di somma)
        - Visual: media dei valori di saliency sui pixel (invece di somma)
        
        Returns:
            Dict con:
                - text_importance: Score importanza testuale (media per token)
                - visual_importance: Score importanza visiva (media per pixel)
                - text_ratio: Proporzione del contributo testuale [0, 1]
                - visual_ratio: Proporzione del contributo visivo [0, 1]
                - text_visual_ratio: Rapporto text/visual
                - top_tokens: Token pi√π importanti
                - saliency_map: Mappa saliency visiva
        """
        # Text importance
        text_result = self.text_explainer.compute_token_saliency(obs, prompt)
        
        # Visual importance
        visual_result = self.visual_explainer.compute_visual_saliency(
            obs, prompt, method=visual_method
        )
        
        # =========================================================================
        # NORMALIZZAZIONE PER DIMENSIONE (Mean-Pooling)
        # Approccio standard in letteratura per confronto multimodale equo
        # =========================================================================
        
        # Numero di elementi per ciascuna modalit√†
        num_tokens = sum(text_result['attention_mask'])  # Solo token validi (no padding)
        saliency_map = visual_result['saliency_map']
        num_pixels = saliency_map.size  # H * W
        
        # Importanza MEDIA per elemento (invece di somma)
        # Text: media delle norme L2 sui token
        text_importance_mean = float(text_result['total_saliency']) / max(num_tokens, 1)
        
        # Visual: media dei valori di saliency sui pixel
        visual_importance_mean = float(saliency_map.sum()) / max(num_pixels, 1)
        
        # Calcola ratio normalizzati
        total = text_importance_mean + visual_importance_mean + 1e-8
        text_ratio = text_importance_mean / total
        visual_ratio = visual_importance_mean / total
        
        # Top tokens
        top_tokens = self.text_explainer.get_top_k_tokens(text_result, k=5, filter_special=True)
        
        return {
            'text_importance': text_importance_mean,
            'visual_importance': visual_importance_mean,
            'text_importance_raw': float(text_result['total_saliency']),  # Valore originale
            'visual_importance_raw': float(saliency_map.sum()),  # Valore originale
            'num_tokens': num_tokens,
            'num_pixels': num_pixels,
            'text_ratio': text_ratio,
            'visual_ratio': visual_ratio,
            'text_visual_ratio': text_importance_mean / (visual_importance_mean + 1e-8),
            'top_tokens': top_tokens,
            'saliency_map': visual_result['saliency_map'],
            'token_saliency': text_result
        }
    
    def compute_all_visual_methods(
        self,
        obs: torch.Tensor,
        prompt: str
    ) -> Dict[str, np.ndarray]:
        """
        Calcola saliency maps con tutti i metodi disponibili.
        
        Metodi implementati secondo standard in letteratura:
        - Vanilla Gradient (Simonyan et al., 2014)
        - SmoothGrad (Smilkov et al., 2017)
        - Grad-CAM (Selvaraju et al., 2017)
        - Integrated Gradients (Sundararajan et al., 2017)
        
        Returns:
            Dict con saliency map per ogni metodo
        """
        methods = ['vanilla', 'smoothgrad', 'gradcam', 'integrated']
        results = {}
        
        for method in methods:
            try:
                result = self.visual_explainer.compute_visual_saliency(
                    obs, prompt, method=method
                )
                results[method] = result['saliency_map']
            except Exception as e:
                print(f"‚ö†Ô∏è Error computing {method}: {e}")
                results[method] = np.zeros((obs.shape[2], obs.shape[3]), dtype=np.float32)
        
        return results


def generate_grid_explainability_video(
    frames_data: List[Dict],
    original_frames: List[np.ndarray],
    output_path: str,
    fps: int = 5,
    include_text_overlay: bool = True
):
    """
    Genera un video con griglia 2x3 contenente tutte le metodologie di explainability.
    
    Layout griglia:
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ    Original     ‚îÇ Vanilla Gradient‚îÇ    SmoothGrad   ‚îÇ
    ‚îÇ                 ‚îÇ (Simonyan 2014) ‚îÇ (Smilkov 2017)  ‚îÇ
    ‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
    ‚îÇ    Grad-CAM     ‚îÇ   Integrated    ‚îÇ  Text Saliency  ‚îÇ
    ‚îÇ (Selvaraju 2017)‚îÇ   Gradients     ‚îÇ   (Top Tokens)  ‚îÇ
    ‚îÇ                 ‚îÇ(Sundararajan17) ‚îÇ                 ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    
    Args:
        frames_data: Lista di dict con saliency maps per ogni metodo
        original_frames: Lista di frame originali (H, W, 3)
        output_path: Path per il video output
        fps: Frame per secondo
        include_text_overlay: Se mostrare i top token
    """
    if len(frames_data) == 0 or len(original_frames) == 0:
        print("‚ö†Ô∏è No frames to generate video")
        return
    
    n_frames = min(len(frames_data), len(original_frames))
    print(f"üìπ Generating grid video with {n_frames} frames...")
    
    # Colori per i titoli
    title_bg_color = (40, 40, 40)
    title_text_color = (255, 255, 255)
    title_height = 25
    
    # Dimensioni griglia: 2 righe x 3 colonne
    # IMPORTANTE: Considerare altezza barra titolo per ogni pannello
    h, w = original_frames[0].shape[:2]
    panel_h = h + title_height  # Altezza pannello CON barra titolo
    panel_w = w
    grid_h = panel_h * 2  # 2 righe
    grid_w = panel_w * 3  # 3 colonne
    
    print(f"   Frame size: {w}x{h}, Panel size: {panel_w}x{panel_h}, Grid size: {grid_w}x{grid_h}")
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (grid_w, grid_h))
    
    if not out.isOpened():
        print(f"‚ùå Failed to open VideoWriter for {output_path}")
        return
    
    method_info = {
        'original': ('Original Input', (200, 200, 200)),
        'vanilla': ('Vanilla Gradient', (255, 100, 100)),
        'smoothgrad': ('SmoothGrad', (100, 255, 100)),
        'gradcam': ('Grad-CAM', (100, 100, 255)),
        'integrated': ('Integrated Grad', (255, 255, 100)),
        'text': ('Token Importance', (255, 150, 255))
    }
    
    def add_title_bar(img: np.ndarray, title: str, color: Tuple[int, int, int]) -> np.ndarray:
        """Aggiunge una barra titolo all'immagine"""
        result = np.zeros((img.shape[0] + title_height, img.shape[1], 3), dtype=np.uint8)
        # Barra titolo
        result[:title_height, :] = title_bg_color
        # Linea colorata sotto il titolo
        result[title_height-3:title_height, :] = color
        # Immagine
        result[title_height:, :] = img
        # Testo
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.4
        text_size = cv2.getTextSize(title, font, font_scale, 1)[0]
        text_x = (img.shape[1] - text_size[0]) // 2
        cv2.putText(result, title, (text_x, 17), font, font_scale, title_text_color, 1, cv2.LINE_AA)
        return result
    
    def create_heatmap_overlay(frame: np.ndarray, saliency: np.ndarray, alpha: float = 0.4) -> np.ndarray:
        """Crea overlay heatmap su immagine"""
        if saliency is None or saliency.size == 0:
            return frame.copy()
        
        # Normalizza saliency
        saliency = np.array(saliency, dtype=np.float32)
        if saliency.shape[:2] != frame.shape[:2]:
            saliency = cv2.resize(saliency, (frame.shape[1], frame.shape[0]))
        
        if saliency.max() > saliency.min():
            saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min())
        
        heatmap = (saliency * 255).astype(np.uint8)
        heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
        
        overlay = cv2.addWeighted(frame, 1 - alpha, heatmap_colored, alpha, 0)
        return overlay
    
    def create_text_importance_panel(frame: np.ndarray, top_tokens: List[Tuple[str, float]]) -> np.ndarray:
        """Crea pannello con importanza dei token testuali"""
        panel = frame.copy()
        
        if not top_tokens:
            return panel
        
        # Background semi-trasparente
        overlay = panel.copy()
        cv2.rectangle(overlay, (5, 5), (w-5, min(len(top_tokens)*22 + 10, h-5)), (0, 0, 0), -1)
        panel = cv2.addWeighted(overlay, 0.7, panel, 0.3, 0)
        
        # Disegna barre di importanza per ogni token
        max_score = max(score for _, score in top_tokens) if top_tokens else 1.0
        y_offset = 20
        bar_max_width = w - 80
        
        for token, score in top_tokens[:8]:  # Max 8 token
            # Normalizza per larghezza barra
            bar_width = int((score / max_score) * bar_max_width)
            
            # Colore barra (gradiente rosso-giallo-verde basato su score)
            hue = int((score / max_score) * 60)  # Da rosso a giallo
            bar_color = tuple(int(c) for c in cv2.cvtColor(
                np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0][0])
            
            # Disegna barra
            cv2.rectangle(panel, (10, y_offset-12), (10 + bar_width, y_offset+2), bar_color, -1)
            
            # Testo token
            text = f"{token[:10]}: {score:.3f}"
            cv2.putText(panel, text, (15, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 
                       0.4, (255, 255, 255), 1, cv2.LINE_AA)
            y_offset += 22
        
        return panel
    
    frames_written = 0
    for i in range(n_frames):
        try:
            frame = original_frames[i].copy()
            data = frames_data[i]
            
            # Assicurati formato corretto
            if frame.dtype != np.uint8:
                if frame.max() <= 1.0:
                    frame = (frame * 255).astype(np.uint8)
                else:
                    frame = np.clip(frame, 0, 255).astype(np.uint8)
            
            if len(frame.shape) == 2:
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
            elif frame.shape[2] == 4:
                frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
            
            if frame.shape[0] != h or frame.shape[1] != w:
                frame = cv2.resize(frame, (w, h))
            
            # Crea i 6 pannelli
            panels = []
            
            # 1. Original
            original_panel = add_title_bar(frame.copy(), *method_info['original'][:2])
            panels.append(original_panel)
            
            # 2. Vanilla Gradient
            vanilla_saliency = data.get('vanilla_saliency', data.get('saliency_map', None))
            vanilla_overlay = create_heatmap_overlay(frame, vanilla_saliency)
            vanilla_panel = add_title_bar(vanilla_overlay, *method_info['vanilla'][:2])
            panels.append(vanilla_panel)
            
            # 3. SmoothGrad
            smoothgrad_saliency = data.get('smoothgrad_saliency', data.get('saliency_map', None))
            smoothgrad_overlay = create_heatmap_overlay(frame, smoothgrad_saliency)
            smoothgrad_panel = add_title_bar(smoothgrad_overlay, *method_info['smoothgrad'][:2])
            panels.append(smoothgrad_panel)
            
            # 4. Grad-CAM
            gradcam_saliency = data.get('gradcam_saliency', data.get('saliency_map', None))
            gradcam_overlay = create_heatmap_overlay(frame, gradcam_saliency)
            gradcam_panel = add_title_bar(gradcam_overlay, *method_info['gradcam'][:2])
            panels.append(gradcam_panel)
            
            # 5. Integrated Gradients
            integrated_saliency = data.get('integrated_saliency', data.get('saliency_map', None))
            integrated_overlay = create_heatmap_overlay(frame, integrated_saliency)
            integrated_panel = add_title_bar(integrated_overlay, *method_info['integrated'][:2])
            panels.append(integrated_panel)
            
            # 6. Text Token Importance
            top_tokens = data.get('top_tokens', [])
            text_panel = create_text_importance_panel(frame, top_tokens)
            text_panel = add_title_bar(text_panel, *method_info['text'][:2])
            panels.append(text_panel)
            
            # Verifica dimensioni pannelli prima di costruire griglia
            for idx, p in enumerate(panels):
                if p.shape[0] != panel_h or p.shape[1] != panel_w:
                    panels[idx] = cv2.resize(p, (panel_w, panel_h))
            
            # Costruisci griglia 2x3
            row1 = np.hstack([panels[0], panels[1], panels[2]])
            row2 = np.hstack([panels[3], panels[4], panels[5]])
            grid = np.vstack([row1, row2])
            
            # Verifica dimensioni griglia finale
            if grid.shape[0] != grid_h or grid.shape[1] != grid_w:
                grid = cv2.resize(grid, (grid_w, grid_h))
            
            # Aggiungi step counter
            step = data.get('step', i)
            cv2.putText(grid, f"Step: {step}", (10, grid.shape[0] - 10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
            
            # Converti per OpenCV (BGR)
            grid_bgr = cv2.cvtColor(grid, cv2.COLOR_RGB2BGR)
            out.write(grid_bgr)
            frames_written += 1
            
        except Exception as e:
            print(f"‚ö†Ô∏è Error processing frame {i}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    out.release()
    
    if frames_written > 0:
        print(f"‚úÖ Grid explainability video saved to {output_path} ({frames_written} frames)")
    else:
        print(f"‚ùå No frames written to video!")


In [66]:
# --- MONKEY PATCH (CRITICAL FIX) ---
# FIX: Mock di Matplotlib per prevenire il kernel crash dovuto a conflitti di NumPy/ABI.
mock_mpl = MagicMock()
sys.modules["matplotlib"] = mock_mpl
sys.modules["matplotlib.pyplot"] = mock_mpl
sys.modules["matplotlib.cm"] = mock_mpl
sys.modules["matplotlib.colors"] = mock_mpl
sys.modules["matplotlib.transforms"] = mock_mpl
sys.modules["matplotlib.ticker"] = mock_mpl
sys.modules["matplotlib._path"] = mock_mpl
# --------------------------------------

# --- SETUP PATHS ---
LIBERO_REPO_ROOT = Path('LIBERO')
if LIBERO_REPO_ROOT.exists() and str(LIBERO_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(LIBERO_REPO_ROOT))

# Fix for Numba
try:
    from robosuite.utils.numba import jit_decorator
except Exception:
    pass

# Now these imports will work because matplotlib is mocked
from libero.libero import get_libero_path
from libero.libero.benchmark import get_benchmark
from libero.libero.envs import OffScreenRenderEnv
from libero.libero.utils.time_utils import Timer
from libero.libero.utils.video_utils import VideoWriter

print("‚úì LIBERO imports successful")

‚úì LIBERO imports successful


In [67]:
# Helper functions per evaluation
def _merge_training_config(stored: Dict[str, Any]) -> TrainingConfig:
    class ConfigObj:
        def __init__(self, **entries): self.__dict__.update(entries)
    return ConfigObj(**stored)

def _stack_vector_obs(obs: Any) -> Dict[str, np.ndarray]:
    if isinstance(obs, list):
        keys = obs[0].keys()
        return {k: np.stack([o[k] for o in obs], axis=0) for k in keys}
    return obs

def _select_camera_key(obs_batch: Dict[str, np.ndarray]) -> str:
    for key in ('agentview_rgb', 'agentview_image', 'robot0_agentview_image'):
        if key in obs_batch: return key
    return list(obs_batch.keys())[0]

def _prepare_policy_input(images: np.ndarray, device: torch.device) -> torch.Tensor:
    imgs = torch.from_numpy(images).to(device=device, dtype=torch.float32) / 255.0
    return imgs.permute(0, 3, 1, 2).contiguous()

class SequentialVectorEnv:
    def __init__(self, env_fns: List[Callable]):
        self.envs = [fn() for fn in env_fns]
    def step(self, actions):
        results = [env.step(a) for env, a in zip(self.envs, actions)]
        obs_list, rews_list, dones_list, infos_list = zip(*results)
        return list(obs_list), np.array(rews_list), np.array(dones_list), list(infos_list)
    def reset(self):
        return [env.reset() for env in self.envs]
    def seed(self, seed):
        for i, env in enumerate(self.envs):
            if hasattr(env, 'seed'):
                env.seed(seed + i)
    def set_init_state(self, states):
        return [env.set_init_state(s) for env, s in zip(self.envs, states)]
    def close(self):
        for env in self.envs:
            env.close()

In [None]:
# --- EVALUATION FUNCTION (FIXED) ---
def evaluate_model(
    checkpoint_path: str = 'models/back.pt',
    action_stats_path: str = 'action_stats.json',
    benchmark: str = 'libero_spatial',
    task_id: int = 6,
    env_num: int = 10,
    max_steps: int = 800,
    seed: int = 42,
    save_videos: bool = True,
    video_dir: str = 'evaluation_videos',
    camera_height: int = 128,
    camera_width: int = 128,
    video_skip: int = 1,
    enable_explainability: bool = True,
    explainability_interval: int = 10,
    top_k_tokens: int = 10
) -> Dict[str, Any]:
    print(f"Starting evaluation on {benchmark} Task {task_id} (Envs: {env_num}, Max Steps: {max_steps})...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    cfg = _merge_training_config(ckpt.get('config', {}))
    policy = build_policy_from_config(cfg, obs_shape=(3, camera_height, camera_width)).to(device)
    policy.load_state_dict(ckpt['model_state_dict'])
    policy.eval()

    stats = json.load(open(action_stats_path))
    action_mean = torch.tensor(stats['mean'], device=device).unsqueeze(0)
    action_std = torch.tensor(stats['std'], device=device).unsqueeze(0)
    action_dim = int(action_mean.shape[-1])

    benchmark_map = {'libero_10': 'LIBERO_10', 'libero_spatial': 'LIBERO_SPATIAL', 'libero_goal': 'LIBERO_GOAL'}
    suite = get_benchmark(benchmark_map.get(benchmark, benchmark))(0)
    task = suite.get_task(task_id)
    task_prompt = task.language
    #task_prompt = "pick up the black bowl next to the ramekin and place it on the plate"
    use_prompts = getattr(policy, 'use_text_prompts', False)
    if use_prompts:
        print(f"Using language prompt: '{task_prompt}'")
    
    # Initialize TextExplainer if requested
    text_explainer = None
    explainability_data = []
    if enable_explainability and use_prompts and hasattr(policy, 'prompt_encoder'):
        text_explainer = TextExplainer(policy, policy.prompt_encoder, device)
        print(f"‚úÖ Explainability enabled (interval: every {explainability_interval} steps, top-{top_k_tokens} tokens)")

    env_args = {
        'bddl_file_name': str(Path(get_libero_path('bddl_files')) / task.problem_folder / task.bddl_file),
        'camera_heights': camera_height,
        'camera_widths': camera_width
    }
    env = SequentialVectorEnv([lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)])

    try:
        init_states= torch.load(str(Path(get_libero_path('init_states')) / task.problem_folder / task.init_states_file), map_location='cpu', weights_only=False)
        obs = env.reset()
        env.seed(seed)
        env.set_init_state(init_states[0:env_num])

        dones = [False] * env_num
        successes = np.zeros(env_num, dtype=bool)
        video_path = Path(video_dir) / f"task_{task_id:02d}"
        
        
        with VideoWriter(str(video_path), save_videos) as video_writer:
            for step in range(max_steps):
                obs_batch = _stack_vector_obs(obs)
                cam_key = _select_camera_key(obs_batch)

                alive = [i for i, d in enumerate(dones) if not d]
                if not alive:
                    break

                vis_batch = obs_batch[cam_key][alive]
                p_in = _prepare_policy_input(vis_batch, device)

                with torch.no_grad():
                    prompt_batch = [task_prompt for _ in alive] if use_prompts else None
                    actions_alive = policy(p_in, prompt_batch)
                    actions_alive = actions_alive * action_std + action_mean
                    full_actions = np.zeros((env_num, action_dim), dtype=np.float32)
                    full_actions[alive] = actions_alive.detach().cpu().numpy().astype(np.float32)
                
                # Explainability: calculate token saliency
                if text_explainer is not None and step % explainability_interval == 0 and len(alive) > 0:
                    try:
                        # Use the first alive environment for analysis
                        first_alive = alive[0]
                        single_obs = p_in[0:1]  # First batch element
                        
                        # Calculate saliency
                        saliency_result = text_explainer.compute_token_saliency(
                            single_obs,
                            task_prompt,
                            target_action=None  # Unsupervised saliency
                        )
                        
                        # Extract top-k tokens
                        top_tokens = text_explainer.get_top_k_tokens(
                            saliency_result,
                            k=top_k_tokens,
                            filter_special=True
                        )
                        
                        # Save data for this frame
                        frame_data = {
                            'step': step,
                            'prompt': task_prompt,
                            'top_tokens': [
                                {'token': token, 'score': float(score)}
                                for token, score in top_tokens
                            ],
                            'all_tokens': {
                                'tokens': saliency_result['token_strings'],
                                'scores': saliency_result['normalized_scores'],
                                'attention_mask': saliency_result['attention_mask']
                            }
                        }
                        explainability_data.append(frame_data)
                        
                        # Debug logging
                        if step % (explainability_interval * 5) == 0:
                            print(f"\n  [Step {step}] Top important words:")
                            for token, score in top_tokens[:5]:
                                print(f"    '{token}': {score:.4f}")
                    
                    except Exception as e:
                        print(f"‚ö†Ô∏è Explainability error at step {step}: {e}")

                obs, reward, done_batch, info = env.step(full_actions)
                
                for i in alive:
                    if reward[i] != 0.0:
                        successes[i] = True
                        #print(f"‚úì Success detected for env {i} at step {step}, reward = {reward[i]}")
                    dones[i] = dones[i] or bool(done_batch[i])
                
                if save_videos and step % video_skip == 0:
                    video_writer.append_vector_obs(obs, dones, camera_name=cam_key)

        success_rate = float(successes.mean())
        print(f"\nüìä Final Results on task {task_id}: {successes.sum()}/{env_num} successes")
        results = {
            'success_rate': success_rate,
            'episodes': int(env_num),
            'max_steps': int(max_steps)
        }
    finally:
        env.close()
    
    # Save explainability data
    if explainability_data:
        explainability_json_path = Path(video_dir) / f"task_{task_id:02d}_explainability.json"
        explainability_json_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Add metadata
        explainability_output = {
            'metadata': {
                'task_id': task_id,
                'benchmark': benchmark,
                'task_prompt': task_prompt,
                'num_frames_analyzed': len(explainability_data),
                'success_rate': success_rate,
                'explainability_interval': explainability_interval,
                'top_k_tokens': top_k_tokens
            },
            'frames': explainability_data
        }
        
        with open(explainability_json_path, 'w') as f:
            json.dump(explainability_output, f, indent=2)
        
        print(f"‚úì Explainability data saved to {explainability_json_path}")
        print(f"  Analyzed {len(explainability_data)} frames")
        
        # Print summary of most common words among top tokens
        if explainability_data:
            all_top_tokens = {}
            for frame in explainability_data:
                for token_data in frame['top_tokens'][:5]:  # Only top 5 per frame
                    token = token_data['token']
                    score = token_data['score']
                    if token in all_top_tokens:
                        all_top_tokens[token].append(score)
                    else:
                        all_top_tokens[token] = [score]
            
            # Calculate average per token
            token_avg_scores = {
                token: np.mean(scores)
                for token, scores in all_top_tokens.items()
            }
            
            # Sort by average score
            sorted_tokens = sorted(token_avg_scores.items(), key=lambda x: x[1], reverse=True)
            
            print(f"\n  üìä Most consistently important words across all frames:")
            for i, (token, avg_score) in enumerate(sorted_tokens[:10], 1):
                freq = len(all_top_tokens[token])
                print(f"    {i}. '{token}': avg={avg_score:.4f}, appeared in top-5 {freq} times")
    
    if save_videos:
        src = Path(video_dir) / f"task_{task_id:02d}" / "video.mp4"
        dst = Path(video_dir) / f"task_{task_id:02d}.mp4"
        if src.exists():
            dst.parent.mkdir(parents=True, exist_ok=True)
            src.rename(dst)
            print(f"‚úì Videos saved to {dst}")   
            try:
                src.parent.rmdir()
            except OSError:
                pass 

    return results

In [69]:
def visualize_explainability_results(
    json_path: str,
    top_n_frames: int = 5,
    top_n_tokens: int = 10
):
    """
    Visualizza i risultati dell'explainability da un file JSON.
    
    Args:
        json_path: Path al file JSON con i risultati
        top_n_frames: Numero di frame da mostrare
        top_n_tokens: Numero di token da mostrare per frame
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    metadata = data['metadata']
    frames = data['frames']
    
    print(f"\n{'='*80}")
    print(f"EXPLAINABILITY RESULTS: Task {metadata['task_id']}")
    print(f"{'='*80}")
    print(f"Task Prompt: '{metadata['task_prompt']}'")
    print(f"Success Rate: {metadata['success_rate']:.2%}")
    print(f"Frames Analyzed: {metadata['num_frames_analyzed']}")
    print(f"{'='*80}\n")
    
    # Mostra i frame pi√π significativi
    print(f"Top {top_n_frames} analyzed frames:\n")
    for i, frame in enumerate(frames[:top_n_frames], 1):
        print(f"\nFrame #{i} (Step {frame['step']}):")
        print(f"  Most Important Words:")
        for j, token_data in enumerate(frame['top_tokens'][:top_n_tokens], 1):
            token = token_data['token']
            score = token_data['score']
            bar = '‚ñà' * int(score * 50)
            print(f"    {j:2d}. '{token:15s}' {score:6.4f} {bar}")
    
    # Calcola statistiche aggregate
    print(f"\n{'='*80}")
    print("AGGREGATE STATISTICS")
    print(f"{'='*80}")
    
    # Conta frequenza dei token nei top-k
    token_frequencies = {}
    token_total_scores = {}
    
    for frame in frames:
        for token_data in frame['top_tokens'][:5]:  # Top 5 per frame
            token = token_data['token']
            score = token_data['score']
            
            if token in token_frequencies:
                token_frequencies[token] += 1
                token_total_scores[token] += score
            else:
                token_frequencies[token] = 1
                token_total_scores[token] = score
    
    # Calcola medie e ordina
    token_avg_scores = {
        token: token_total_scores[token] / token_frequencies[token]
        for token in token_frequencies
    }
    
    sorted_by_freq = sorted(token_frequencies.items(), key=lambda x: x[1], reverse=True)
    sorted_by_score = sorted(token_avg_scores.items(), key=lambda x: x[1], reverse=True)
    
    print(f"\nMost Frequent Important Words:")
    for i, (token, freq) in enumerate(sorted_by_freq[:10], 1):
        avg_score = token_avg_scores[token]
        pct = (freq / len(frames)) * 100
        print(f"  {i:2d}. '{token:15s}': appeared {freq:3d} times ({pct:5.1f}%), avg score={avg_score:.4f}")
    
    print(f"\nHighest Average Importance Scores:")
    for i, (token, avg_score) in enumerate(sorted_by_score[:10], 1):
        freq = token_frequencies[token]
        print(f"  {i:2d}. '{token:15s}': avg score={avg_score:.4f}, appeared {freq:3d} times")


def analyze_explainability_across_tasks(
    video_dir: str = 'evaluation_videos',
    task_ids: Optional[List[int]] = None
):
    """
    Analizza i risultati dell'explainability attraverso pi√π task.
    
    Args:
        video_dir: Directory contenente i file JSON
        task_ids: Lista di task ID da analizzare (None = tutti)
    """
    video_path = Path(video_dir)
    
    # Trova tutti i file JSON di explainability
    if task_ids is None:
        json_files = list(video_path.glob('task_*_explainability.json'))
    else:
        json_files = [video_path / f'task_{tid:02d}_explainability.json' for tid in task_ids]
        json_files = [f for f in json_files if f.exists()]
    
    if not json_files:
        print(f"‚ö†Ô∏è No explainability JSON files found in {video_dir}")
        return
    
    print(f"\n{'='*80}")
    print(f"CROSS-TASK EXPLAINABILITY ANALYSIS")
    print(f"{'='*80}")
    print(f"Analyzing {len(json_files)} tasks\n")
    
    # Aggregazione dati da tutti i task
    all_token_scores = {}
    all_token_tasks = {}
    task_summaries = []
    
    for json_file in json_files:
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        task_id = data['metadata']['task_id']
        task_prompt = data['metadata']['task_prompt']
        success_rate = data['metadata']['success_rate']
        
        task_summaries.append({
            'task_id': task_id,
            'prompt': task_prompt,
            'success_rate': success_rate
        })
        
        # Aggrega token da questo task
        for frame in data['frames']:
            for token_data in frame['top_tokens'][:5]:
                token = token_data['token']
                score = token_data['score']
                
                if token in all_token_scores:
                    all_token_scores[token].append(score)
                    all_token_tasks[token].add(task_id)
                else:
                    all_token_scores[token] = [score]
                    all_token_tasks[token] = {task_id}
    
    # Calcola statistiche globali
    token_stats = []
    for token, scores in all_token_scores.items():
        token_stats.append({
            'token': token,
            'avg_score': np.mean(scores),
            'std_score': np.std(scores),
            'total_appearances': len(scores),
            'num_tasks': len(all_token_tasks[token])
        })
    
    # Ordina per importanza media
    token_stats.sort(key=lambda x: x['avg_score'], reverse=True)
    
    print(f"Most Important Words Across All Tasks:")
    print(f"{'Rank':<6} {'Token':<20} {'Avg Score':<12} {'Std':<10} {'Appearances':<13} {'Tasks':<8}")
    print(f"{'-'*80}")
    for i, stat in enumerate(token_stats[:20], 1):
        print(f"{i:<6} '{stat['token']:<18}' {stat['avg_score']:<12.4f} {stat['std_score']:<10.4f} "
              f"{stat['total_appearances']:<13} {stat['num_tasks']:<8}")
    
    # Analizza per success rate
    print(f"\n{'='*80}")
    print("Task Success Rates:")
    print(f"{'='*80}")
    task_summaries.sort(key=lambda x: x['success_rate'], reverse=True)
    for i, task in enumerate(task_summaries, 1):
        print(f"{i:2d}. Task {task['task_id']:2d} ({task['success_rate']:6.1%}): {task['prompt']}")
    
    return token_stats, task_summaries


def compare_explainability_success_correlation(
    video_dir: str = 'evaluation_videos',
    min_appearances: int = 5
):
    """
    Analizza la correlazione tra token importanti e success rate dei task.
    
    Args:
        video_dir: Directory con i file JSON
        min_appearances: Minimo numero di apparizioni per considerare un token
    """
    video_path = Path(video_dir)
    json_files = list(video_path.glob('task_*_explainability.json'))
    
    if len(json_files) < 2:
        print("‚ö†Ô∏è Need at least 2 tasks for correlation analysis")
        return
    
    print(f"\n{'='*80}")
    print("EXPLAINABILITY vs SUCCESS RATE CORRELATION")
    print(f"{'='*80}\n")
    
    # Raccogli dati
    task_data = []
    all_tokens = set()
    
    for json_file in json_files:
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        task_id = data['metadata']['task_id']
        success_rate = data['metadata']['success_rate']
        
        # Calcola token importance per questo task
        token_importance = {}
        for frame in data['frames']:
            for token_data in frame['top_tokens'][:5]:
                token = token_data['token']
                score = token_data['score']
                
                if token in token_importance:
                    token_importance[token].append(score)
                else:
                    token_importance[token] = [score]
                
                all_tokens.add(token)
        
        # Media per token
        token_avg = {token: np.mean(scores) for token, scores in token_importance.items()}
        
        task_data.append({
            'task_id': task_id,
            'success_rate': success_rate,
            'token_importance': token_avg
        })
    
    # Filtra token con poche apparizioni
    token_task_count = {token: 0 for token in all_tokens}
    for task in task_data:
        for token in task['token_importance']:
            token_task_count[token] += 1
    
    frequent_tokens = [token for token, count in token_task_count.items() 
                       if count >= min_appearances]
    
    print(f"Analyzing {len(frequent_tokens)} tokens that appear in >={min_appearances} tasks")
    
    # Per ogni token, calcola correlazione con success rate
    correlations = []
    for token in frequent_tokens:
        success_rates = []
        importance_scores = []
        
        for task in task_data:
            if token in task['token_importance']:
                success_rates.append(task['success_rate'])
                importance_scores.append(task['token_importance'][token])
        
        if len(success_rates) >= min_appearances:
            # Calcola correlazione di Pearson
            corr = np.corrcoef(success_rates, importance_scores)[0, 1]
            correlations.append({
                'token': token,
                'correlation': corr,
                'num_tasks': len(success_rates),
                'avg_importance': np.mean(importance_scores)
            })
    
    # Ordina per correlazione assoluta
    correlations.sort(key=lambda x: abs(x['correlation']), reverse=True)
    
    print(f"\nTokens Most Correlated with Success (positive = important for success):")
    print(f"{'Rank':<6} {'Token':<20} {'Correlation':<13} {'Avg Importance':<16} {'Tasks':<8}")
    print(f"{'-'*80}")
    for i, stat in enumerate(correlations[:15], 1):
        corr_str = f"{stat['correlation']:+.4f}"
        print(f"{i:<6} '{stat['token']:<18}' {corr_str:<13} {stat['avg_importance']:<16.4f} {stat['num_tasks']:<8}")

In [70]:
final_results = []
for id in range(10):
    #id = 8
    final_results.append(
        evaluate_model(
            task_id=id, 
            env_num=5, 
            max_steps=600, 
            save_videos=True,
            enable_explainability=True,  # Abilita explainability
            explainability_interval=5,   # Analizza ogni 5 step
            top_k_tokens=5,             # Salva top 5 token
        )["success_rate"]
    )

mean_success = np.mean(final_results)
print(f"Mean success rate over all tasks: {mean_success:.2%}")

print("All success rates:", final_results)

top_3 = sorted(range(len(final_results)), key=lambda i: final_results[i], reverse=True)[:3]
print("Top 3 tasks by success rate:", top_3)

worst_3 = sorted(range(len(final_results)), key=lambda i: final_results[i])[:3]
print("Worst 3 tasks by success rate:", worst_3)

Starting evaluation on libero_spatial Task 0 (Envs: 5, Max Steps: 600)...
[info] using task orders [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Using language prompt: 'pick up the black bowl between the plate and the ramekin and place it on the plate'
‚úÖ Explainability enabled (interval: every 5 steps, top-5 tokens)

  [Step 0] Top important words:
    'plate': 0.1070
    'ram': 0.0831
    'between': 0.0802
    'it': 0.0584
    'the': 0.0526

  [Step 25] Top important words:
    'ram': 0.1644
    'kin': 0.0891
    'plate': 0.0876
    'between': 0.0785
    'e': 0.0743

  [Step 50] Top important words:
    'ram': 0.1419
    'plate': 0.0912
    'between': 0.0777
    'kin': 0.0662
    'e': 0.0617

  [Step 75] Top important words:
    'ram': 0.1194
    'plate': 0.0852
    'the': 0.0724
    'between': 0.0706
    'the': 0.0547

  [Step 100] Top important words:
    'plate': 0.1369
    'between': 0.1366
    'and': 0.0906
    'and': 0.0855
    'ram': 0.0795

  [Step 125] Top important words:
    'between': 

KeyboardInterrupt: 

# Visualizzazione Explainability Results

In [71]:
# Esempio: Visualizza risultati per un singolo task
# Prima esegui l'evaluation con enable_explainability=True per generare i file JSON

explainability_file = 'evaluation_videos/task_00_explainability.json'
if Path(explainability_file).exists():
    visualize_explainability_results(
        explainability_file,
        top_n_frames=5,
        top_n_tokens=10
    )
else:
    print(f"‚ö†Ô∏è File {explainability_file} non trovato.")
    print("Esegui prima evaluate_model con enable_explainability=True per generare i dati.")


EXPLAINABILITY RESULTS: Task 0
Task Prompt: 'pick up the black bowl between the plate and the ramekin and place it on the plate'
Success Rate: 40.00%
Frames Analyzed: 120

Top 5 analyzed frames:


Frame #1 (Step 0):
  Most Important Words:
     1. 'plate          ' 0.1070 ‚ñà‚ñà‚ñà‚ñà‚ñà
     2. 'ram            ' 0.0831 ‚ñà‚ñà‚ñà‚ñà
     3. 'between        ' 0.0802 ‚ñà‚ñà‚ñà‚ñà
     4. 'it             ' 0.0584 ‚ñà‚ñà
     5. 'the            ' 0.0526 ‚ñà‚ñà

Frame #2 (Step 5):
  Most Important Words:
     1. 'ram            ' 0.1135 ‚ñà‚ñà‚ñà‚ñà‚ñà
     2. 'plate          ' 0.1055 ‚ñà‚ñà‚ñà‚ñà‚ñà
     3. 'between        ' 0.0894 ‚ñà‚ñà‚ñà‚ñà
     4. 'e              ' 0.0684 ‚ñà‚ñà‚ñà
     5. 'and            ' 0.0566 ‚ñà‚ñà

Frame #3 (Step 10):
  Most Important Words:
     1. 'between        ' 0.0983 ‚ñà‚ñà‚ñà‚ñà
     2. 'ram            ' 0.0969 ‚ñà‚ñà‚ñà‚ñà
     3. 'plate          ' 0.0798 ‚ñà‚ñà‚ñà
     4. 'e              ' 0.0743 ‚ñà‚ñà‚ñà
     5. 'and            ' 0.0634 ‚ñà‚ñà‚ñà



In [72]:
# Analisi cross-task
# Verifica che esistano file di explainability
if list(Path('evaluation_videos').glob('task_*_explainability.json')):
    token_stats, task_summaries = analyze_explainability_across_tasks(
        video_dir='evaluation_videos'
    )
else:
    print("‚ö†Ô∏è Nessun file di explainability trovato in evaluation_videos/")
    print("Esegui prima evaluate_model con enable_explainability=True")


CROSS-TASK EXPLAINABILITY ANALYSIS
Analyzing 10 tasks

Most Important Words Across All Tasks:
Rank   Token                Avg Score    Std        Appearances   Tasks   
--------------------------------------------------------------------------------
1      'plate             ' 0.1006       0.0218     1111          10      
2      'ram               ' 0.0953       0.0284     343           3       
3      'center            ' 0.0883       0.0150     103           1       
4      'drawer            ' 0.0877       0.0150     119           1       
5      'between           ' 0.0820       0.0162     117           1       
6      'cabinet           ' 0.0790       0.0156     231           2       
7      'stove             ' 0.0756       0.0060     115           1       
8      'pick              ' 0.0732       0.0104     432           8       
9      'wooden            ' 0.0707       0.0168     149           2       
10     'the               ' 0.0706       0.0113     975           10      

In [73]:
# Studio approfondito: correlazione tra importanza delle parole e successo/fallimento

def deep_explainability_analysis(video_dir: str = 'evaluation_videos'):
    """
    Analisi approfondita della correlazione tra importanza dei token e successo nei task.
    
    Analizza:
    1. Differenze di importanza tra task riusciti vs falliti
    2. Evoluzione temporale dell'attenzione (early vs late frames)
    3. Token discriminativi (pi√π importanti per successo vs fallimento)
    4. Pattern di stabilit√† dell'attenzione
    """
    video_path = Path(video_dir)
    json_files = sorted(video_path.glob('task_*_explainability.json'))
    
    if len(json_files) < 2:
        print("‚ö†Ô∏è Serve almeno 2 task per l'analisi di correlazione")
        return
    
    print(f"\n{'='*100}")
    print(f"STUDIO APPROFONDITO: CORRELAZIONE TRA IMPORTANZA DELLE PAROLE E SUCCESSO/FALLIMENTO")
    print(f"{'='*100}\n")
    
    # Raccogli dati da tutti i task
    task_data = []
    all_tokens = set()
    
    for json_file in json_files:
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        task_id = data['metadata']['task_id']
        task_prompt = data['metadata']['task_prompt']
        success_rate = data['metadata']['success_rate']
        frames = data['frames']
        
        # Analisi temporale: dividi in fasi
        n_frames = len(frames)
        early_frames = frames[:n_frames//3]
        mid_frames = frames[n_frames//3:2*n_frames//3]
        late_frames = frames[2*n_frames//3:]
        
        def get_token_importance(frame_list):
            token_scores = {}
            for frame in frame_list:
                for token_data in frame['top_tokens'][:5]:
                    token = token_data['token']
                    score = token_data['score']
                    all_tokens.add(token)
                    if token in token_scores:
                        token_scores[token].append(score)
                    else:
                        token_scores[token] = [score]
            # Calcola media e varianza
            return {
                token: {
                    'mean': np.mean(scores),
                    'std': np.std(scores),
                    'count': len(scores)
                }
                for token, scores in token_scores.items()
            }
        
        task_data.append({
            'task_id': task_id,
            'prompt': task_prompt,
            'success_rate': success_rate,
            'early_importance': get_token_importance(early_frames),
            'mid_importance': get_token_importance(mid_frames),
            'late_importance': get_token_importance(late_frames),
            'overall_importance': get_token_importance(frames),
            'num_frames': n_frames
        })
    
    # Separa task in successi (>50%) e fallimenti (<=50%)
    successful_tasks = [t for t in task_data if t['success_rate'] > 0.5]
    failed_tasks = [t for t in task_data if t['success_rate'] <= 0.5]
    
    print(f"üìä Dataset Overview:")
    print(f"   ‚Ä¢ Task totali: {len(task_data)}")
    print(f"   ‚Ä¢ Task riusciti (>50%): {len(successful_tasks)}")
    print(f"   ‚Ä¢ Task falliti (‚â§50%): {len(failed_tasks)}")
    print(f"   ‚Ä¢ Success rate medio: {np.mean([t['success_rate'] for t in task_data]):.1%}\n")
    
    # ANALISI 1: Token discriminativi tra successo e fallimento
    print(f"\n{'='*100}")
    print("ANALISI 1: TOKEN DISCRIMINATIVI (Successo vs Fallimento)")
    print(f"{'='*100}\n")
    
    def aggregate_token_importance(task_list):
        """Aggrega importanza dei token da una lista di task"""
        agg = {}
        for task in task_list:
            for token, stats in task['overall_importance'].items():
                if token in agg:
                    agg[token]['scores'].append(stats['mean'])
                    agg[token]['counts'].append(stats['count'])
                else:
                    agg[token] = {
                        'scores': [stats['mean']],
                        'counts': [stats['count']]
                    }
        return {
            token: {
                'mean': np.mean(data['scores']),
                'std': np.std(data['scores']),
                'total_count': sum(data['counts']),
                'num_tasks': len(data['scores'])
            }
            for token, data in agg.items()
        }
    
    if successful_tasks and failed_tasks:
        success_tokens = aggregate_token_importance(successful_tasks)
        fail_tokens = aggregate_token_importance(failed_tasks)
        
        # Trova token che appaiono in entrambi i gruppi
        common_tokens = set(success_tokens.keys()) & set(fail_tokens.keys())
        
        # Calcola differenza di importanza (successo - fallimento)
        token_differences = []
        for token in common_tokens:
            diff = success_tokens[token]['mean'] - fail_tokens[token]['mean']
            token_differences.append({
                'token': token,
                'success_importance': success_tokens[token]['mean'],
                'fail_importance': fail_tokens[token]['mean'],
                'difference': diff,
                'success_tasks': success_tokens[token]['num_tasks'],
                'fail_tasks': fail_tokens[token]['num_tasks'],
                'total_appearances': success_tokens[token]['total_count'] + fail_tokens[token]['total_count']
            })
        
        # Ordina per differenza assoluta (token pi√π discriminativi)
        token_differences.sort(key=lambda x: abs(x['difference']), reverse=True)
        
        print(f"Top 15 token pi√π discriminativi (differenza di importanza tra successo e fallimento):\n")
        print(f"{'Rank':<5} {'Token':<20} {'Success':<12} {'Failure':<12} {'Diff':<12} {'Verdict':<15}")
        print(f"{'-'*100}")
        
        for i, td in enumerate(token_differences[:15], 1):
            verdict = "‚úì Pro-Success" if td['difference'] > 0 else "‚úó Pro-Failure"
            print(f"{i:<5} '{td['token']:<18}' {td['success_importance']:<12.4f} "
                  f"{td['fail_importance']:<12.4f} {td['difference']:+.4f}      {verdict:<15}")
        
        print(f"\nInterpretazione:")
        print(f"  ‚Ä¢ Token con differenza positiva ‚Üí pi√π importanti nei task RIUSCITI")
        print(f"  ‚Ä¢ Token con differenza negativa ‚Üí pi√π importanti nei task FALLITI")
    
    # ANALISI 2: Evoluzione temporale dell'attenzione
    print(f"\n\n{'='*100}")
    print("ANALISI 2: EVOLUZIONE TEMPORALE DELL'ATTENZIONE")
    print(f"{'='*100}\n")
    
    # Per ogni token frequente, analizza come cambia l'importanza nel tempo
    token_temporal_patterns = {}
    for task in task_data:
        for phase, phase_name in [('early_importance', 'early'), 
                                   ('mid_importance', 'mid'), 
                                   ('late_importance', 'late')]:
            for token, stats in task[phase].items():
                if token not in token_temporal_patterns:
                    token_temporal_patterns[token] = {
                        'early': [], 'mid': [], 'late': [],
                        'tasks_early': 0, 'tasks_mid': 0, 'tasks_late': 0
                    }
                token_temporal_patterns[token][phase_name].append(stats['mean'])
                token_temporal_patterns[token][f'tasks_{phase_name}'] += 1
    
    # Calcola pattern temporale
    temporal_analysis = []
    for token, patterns in token_temporal_patterns.items():
        if len(patterns['early']) >= 3:  # Minimo 3 task
            early_mean = np.mean(patterns['early']) if patterns['early'] else 0
            mid_mean = np.mean(patterns['mid']) if patterns['mid'] else 0
            late_mean = np.mean(patterns['late']) if patterns['late'] else 0
            
            # Calcola trend (crescente, decrescente, stabile)
            early_late_diff = late_mean - early_mean
            
            temporal_analysis.append({
                'token': token,
                'early': early_mean,
                'mid': mid_mean,
                'late': late_mean,
                'trend': early_late_diff,
                'pattern': 'Crescente' if early_late_diff > 0.01 else ('Decrescente' if early_late_diff < -0.01 else 'Stabile')
            })
    
    temporal_analysis.sort(key=lambda x: abs(x['trend']), reverse=True)
    
    print(f"Token con pattern temporale pi√π marcato:\n")
    print(f"{'Rank':<5} {'Token':<20} {'Early':<10} {'Mid':<10} {'Late':<10} {'Trend':<12} {'Pattern':<15}")
    print(f"{'-'*100}")
    
    for i, ta in enumerate(temporal_analysis[:15], 1):
        print(f"{i:<5} '{ta['token']:<18}' {ta['early']:<10.4f} {ta['mid']:<10.4f} "
              f"{ta['late']:<10.4f} {ta['trend']:+.4f}     {ta['pattern']:<15}")
    
    print(f"\nInterpretazione:")
    print(f"  ‚Ä¢ Pattern Crescente ‚Üí attenzione aumenta nel tempo (esecuzione finale)")
    print(f"  ‚Ä¢ Pattern Decrescente ‚Üí attenzione diminuisce (pianificazione iniziale)")
    print(f"  ‚Ä¢ Pattern Stabile ‚Üí importanza costante durante tutto il task")
    
    # ANALISI 3: Correlazione tra importanza e success rate
    print(f"\n\n{'='*100}")
    print("ANALISI 3: CORRELAZIONE QUANTITATIVA (Token Importance vs Success Rate)")
    print(f"{'='*100}\n")
    
    # Per ogni token, calcola correlazione con success rate
    correlations = []
    for token in all_tokens:
        success_rates = []
        importance_scores = []
        
        for task in task_data:
            if token in task['overall_importance']:
                success_rates.append(task['success_rate'])
                importance_scores.append(task['overall_importance'][token]['mean'])
        
        if len(success_rates) >= 3:  # Minimo 3 task
            # Correlazione di Pearson
            if np.std(success_rates) > 0 and np.std(importance_scores) > 0:
                corr = np.corrcoef(success_rates, importance_scores)[0, 1]
                
                correlations.append({
                    'token': token,
                    'correlation': corr,
                    'num_tasks': len(success_rates),
                    'avg_importance': np.mean(importance_scores),
                    'importance_std': np.std(importance_scores)
                })
    
    # Ordina per correlazione assoluta
    correlations.sort(key=lambda x: abs(x['correlation']), reverse=True)
    
    print(f"Token con correlazione pi√π forte (positiva = predice successo, negativa = predice fallimento):\n")
    print(f"{'Rank':<5} {'Token':<20} {'Correlation':<13} {'Avg Importance':<16} {'Tasks':<8} {'Verdict':<20}")
    print(f"{'-'*100}")
    
    for i, c in enumerate(correlations[:20], 1):
        corr_str = f"{c['correlation']:+.4f}"
        if c['correlation'] > 0.3:
            verdict = "‚úì‚úì Forte predictor di SUCCESSO"
        elif c['correlation'] < -0.3:
            verdict = "‚úó‚úó Forte predictor di FALLIMENTO"
        elif c['correlation'] > 0:
            verdict = "‚úì Lieve pro-success"
        else:
            verdict = "‚úó Lieve pro-failure"
        
        print(f"{i:<5} '{c['token']:<18}' {corr_str:<13} {c['avg_importance']:<16.4f} "
              f"{c['num_tasks']:<8} {verdict:<20}")
    
    # ANALISI 4: Stabilit√† dell'attenzione
    print(f"\n\n{'='*100}")
    print("ANALISI 4: STABILIT√Ä DELL'ATTENZIONE (Varianza)")
    print(f"{'='*100}\n")
    
    # Alta varianza = attenzione instabile, bassa varianza = attenzione stabile
    stability_analysis = []
    for task in task_data:
        for token, stats in task['overall_importance'].items():
            if stats['count'] >= 5:  # Minimo 5 apparizioni
                stability_analysis.append({
                    'token': token,
                    'task_id': task['task_id'],
                    'success_rate': task['success_rate'],
                    'mean_importance': stats['mean'],
                    'std_importance': stats['std'],
                    'stability': stats['std'] / stats['mean'] if stats['mean'] > 0 else 0  # Coefficient of variation
                })
    
    # Confronta stabilit√† tra successi e fallimenti
    if successful_tasks and failed_tasks:
        success_stability = [s['stability'] for s in stability_analysis if s['success_rate'] > 0.5]
        fail_stability = [s['stability'] for s in stability_analysis if s['success_rate'] <= 0.5]
        
        print(f"Stabilit√† media dell'attenzione (coefficiente di variazione):")
        print(f"  ‚Ä¢ Task riusciti:  {np.mean(success_stability):.4f} ¬± {np.std(success_stability):.4f}")
        print(f"  ‚Ä¢ Task falliti:   {np.mean(fail_stability):.4f} ¬± {np.std(fail_stability):.4f}")
        print(f"\n  ‚Üí {'Task riusciti hanno attenzione PI√ô STABILE' if np.mean(success_stability) < np.mean(fail_stability) else 'Task falliti hanno attenzione PI√ô STABILE'}")
    
    # RIEPILOGO FINALE
    print(f"\n\n{'='*100}")
    print("RIEPILOGO CONCLUSIVO")
    print(f"{'='*100}\n")
    
    # Task-specific summary
    print("Performance per task:\n")
    task_data.sort(key=lambda x: x['success_rate'], reverse=True)
    for task in task_data:
        status = "‚úì SUCCESSO" if task['success_rate'] > 0.5 else "‚úó FALLIMENTO"
        print(f"  Task {task['task_id']:2d} ({task['success_rate']:6.1%}) {status}: {task['prompt'][:70]}...")
        
        # Top 3 token per questo task
        top_tokens = sorted(task['overall_importance'].items(), 
                          key=lambda x: x[1]['mean'], reverse=True)[:3]
        token_str = ", ".join([f"'{t[0]}' ({t[1]['mean']:.3f})" for t in top_tokens])
        print(f"           Top token: {token_str}\n")
    
    return {
        'task_data': task_data,
        'token_differences': token_differences if successful_tasks and failed_tasks else None,
        'temporal_analysis': temporal_analysis,
        'correlations': correlations,
        'summary': {
            'total_tasks': len(task_data),
            'successful_tasks': len(successful_tasks),
            'failed_tasks': len(failed_tasks),
            'mean_success_rate': np.mean([t['success_rate'] for t in task_data])
        }
    }


# Esegui l'analisi approfondita
if list(Path('evaluation_videos').glob('task_*_explainability.json')):
    results = deep_explainability_analysis(video_dir='evaluation_videos')
else:
    print("‚ö†Ô∏è Nessun file di explainability trovato in evaluation_videos/")
    print("Esegui prima evaluate_model con enable_explainability=True")



STUDIO APPROFONDITO: CORRELAZIONE TRA IMPORTANZA DELLE PAROLE E SUCCESSO/FALLIMENTO

üìä Dataset Overview:
   ‚Ä¢ Task totali: 10
   ‚Ä¢ Task riusciti (>50%): 6
   ‚Ä¢ Task falliti (‚â§50%): 4
   ‚Ä¢ Success rate medio: 56.0%


ANALISI 1: TOKEN DISCRIMINATIVI (Successo vs Fallimento)

Top 15 token pi√π discriminativi (differenza di importanza tra successo e fallimento):

Rank  Token                Success      Failure      Diff         Verdict        
----------------------------------------------------------------------------------------------------
1     'ram               ' 0.1091       0.0881       +0.0210      ‚úì Pro-Success  
2     'plate             ' 0.0928       0.1097       -0.0170      ‚úó Pro-Failure  
3     'wooden            ' 0.0640       0.0737       -0.0097      ‚úó Pro-Failure  
4     'e                 ' 0.0578       0.0625       -0.0047      ‚úó Pro-Failure  
5     'on                ' 0.0681       0.0636       +0.0045      ‚úì Pro-Success  
6     'it            

In [74]:
# ============================================================================
# STUDIO CORRELAZIONE: Text vs Visual Importance e Success Rate
# ============================================================================

def analyze_multimodal_correlation(
    checkpoint_path: str = 'models/back.pt',
    action_stats_path: str = 'action_stats.json',
    benchmark: str = 'libero_spatial',
    task_ids: Optional[List[int]] = None,
    env_num: int = 5,
    max_steps: int = 600,
    analysis_interval: int = 25,
    visual_method: str = 'gradcam',
    save_videos: bool = True,
    video_dir: str = 'evaluation_videos'
) -> Dict[str, Any]:
    """
    Analizza la correlazione tra importanza text/visual e success rate.
    
    Per ogni task:
    1. Esegue valutazione raccogliendo explainability multimodale
    2. Calcola rapporto text/visual importance per ogni step
    3. Analizza pattern temporali
    4. Correla con outcome (success/failure)
    
    Returns:
        Dict completo con analisi per task e correlazioni
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"\n{'='*100}")
    print("ANALISI MULTIMODALE: CORRELAZIONE TEXT vs VISUAL IMPORTANCE")
    print(f"{'='*100}\n")
    
    # Carica modello
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    cfg = _merge_training_config(ckpt.get('config', {}))
    policy = build_policy_from_config(cfg, obs_shape=(3, 128, 128)).to(device)
    policy.load_state_dict(ckpt['model_state_dict'])
    policy.eval()
    
    # Action stats
    stats = json.load(open(action_stats_path))
    action_mean = torch.tensor(stats['mean'], device=device).unsqueeze(0)
    action_std = torch.tensor(stats['std'], device=device).unsqueeze(0)
    
    # Benchmark
    benchmark_map = {'libero_10': 'LIBERO_10', 'libero_spatial': 'LIBERO_SPATIAL', 'libero_goal': 'LIBERO_GOAL'}
    suite = get_benchmark(benchmark_map.get(benchmark, benchmark))(0)
    
    if task_ids is None:
        task_ids = list(range(10))
    
    # Inizializza explainer multimodale
    multimodal_explainer = None
    if hasattr(policy, 'prompt_encoder'):
        multimodal_explainer = MultimodalExplainer(policy, policy.prompt_encoder, device)
        print("‚úÖ MultimodalExplainer initialized")
    
    # Raccogli dati per tutti i task
    all_task_data = []
    
    for task_id in task_ids:
        task = suite.get_task(task_id)
        task_prompt = task.language
        
        print(f"\n{'='*60}")
        print(f"Task {task_id}: {task_prompt[:60]}...")
        print(f"{'='*60}")
        
        # Setup environment
        env_args = {
            'bddl_file_name': str(Path(get_libero_path('bddl_files')) / task.problem_folder / task.bddl_file),
            'camera_heights': 128,
            'camera_widths': 128
        }
        env = SequentialVectorEnv([lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)])
        
        try:
            init_states = torch.load(
                str(Path(get_libero_path('init_states')) / task.problem_folder / task.init_states_file),
                map_location='cpu', weights_only=False
            )
            obs = env.reset()
            env.set_init_state(init_states[0:env_num])
            
            dones = [False] * env_num
            successes = np.zeros(env_num, dtype=bool)
            
            # Dati per questo task
            task_analysis = {
                'task_id': task_id,
                'prompt': task_prompt,
                'frames': [],
                'original_frames': []
            }
            
            for step in range(max_steps):
                obs_batch = _stack_vector_obs(obs)
                cam_key = _select_camera_key(obs_batch)
                
                alive = [i for i, d in enumerate(dones) if not d]
                if not alive:
                    break
                
                vis_batch = obs_batch[cam_key][alive]
                p_in = _prepare_policy_input(vis_batch, device)
                
                # Action prediction
                with torch.no_grad():
                    prompt_batch = [task_prompt for _ in alive]
                    actions_alive = policy(p_in, prompt_batch)
                    actions_alive = actions_alive * action_std + action_mean
                    full_actions = np.zeros((env_num, 7), dtype=np.float32)
                    full_actions[alive] = actions_alive.detach().cpu().numpy()
                
                # Multimodal explainability
                if multimodal_explainer is not None and step % analysis_interval == 0 and len(alive) > 0:
                    try:
                        single_obs = p_in[0:1]
                        
                        # Compute multimodal importance
                        # Calcola importanza multimodale base
                        mm_result = multimodal_explainer.compute_multimodal_importance(
                            single_obs, task_prompt, visual_method='vanilla'
                        )
                        
                        # Calcola TUTTE le saliency map per la griglia video
                        all_saliency = multimodal_explainer.compute_all_visual_methods(
                            single_obs, task_prompt
                        )
                        
                        frame_data = {
                            'step': step,
                            'text_importance': mm_result['text_importance'],
                            'visual_importance': mm_result['visual_importance'],
                            'text_ratio': mm_result['text_ratio'],
                            'visual_ratio': mm_result['visual_ratio'],
                            'text_visual_ratio': mm_result['text_visual_ratio'],
                            'top_tokens': mm_result['top_tokens'],
                            # Tutte le saliency map per la griglia
                            'vanilla_saliency': all_saliency['vanilla'],
                            'smoothgrad_saliency': all_saliency['smoothgrad'],
                            'gradcam_saliency': all_saliency['gradcam'],
                            'integrated_saliency': all_saliency['integrated']
                        }
                        task_analysis['frames'].append(frame_data)
                        
                        # Salva frame originale per video
                        # vis_batch √® in formato (N, H, W, C) con valori uint8 [0, 255]
                        orig_frame = vis_batch[0].copy()
                        if orig_frame.dtype != np.uint8:
                            if orig_frame.max() <= 1.0:
                                orig_frame = (orig_frame * 255).astype(np.uint8)
                            else:
                                orig_frame = np.clip(orig_frame, 0, 255).astype(np.uint8)
                        task_analysis['original_frames'].append(orig_frame)
                        
                        if step % (analysis_interval * 4) == 0:
                            print(f"  Step {step:3d} | Text: {mm_result['text_ratio']:.2%} | "
                                  f"Visual: {mm_result['visual_ratio']:.2%} | "
                                  f"Top: '{mm_result['top_tokens'][0][0] if mm_result['top_tokens'] else 'N/A'}'")
                    
                    except Exception as e:
                        print(f"‚ö†Ô∏è Error at step {step}: {str(e)[:50]}")
                
                # Step environment
                obs, reward, done_batch, info = env.step(full_actions)
                
                for i in alive:
                    if reward[i] != 0.0:
                        successes[i] = True
                    dones[i] = dones[i] or bool(done_batch[i])
            
            success_rate = float(successes.mean())
            task_analysis['success_rate'] = success_rate
            
            # Calcola statistiche aggregate per il task
            if task_analysis['frames']:
                text_ratios = [f['text_ratio'] for f in task_analysis['frames']]
                visual_ratios = [f['visual_ratio'] for f in task_analysis['frames']]
                tv_ratios = [f['text_visual_ratio'] for f in task_analysis['frames']]
                
                task_analysis['stats'] = {
                    'mean_text_ratio': np.mean(text_ratios),
                    'std_text_ratio': np.std(text_ratios),
                    'mean_visual_ratio': np.mean(visual_ratios),
                    'std_visual_ratio': np.std(visual_ratios),
                    'mean_tv_ratio': np.mean(tv_ratios),
                    'std_tv_ratio': np.std(tv_ratios)
                }
            
            all_task_data.append(task_analysis)
            
            print(f"\nüìä Task {task_id} Results:")
            print(f"   Success Rate: {success_rate:.1%}")
            if 'stats' in task_analysis:
                print(f"   Mean Text Ratio: {task_analysis['stats']['mean_text_ratio']:.2%}")
                print(f"   Mean Visual Ratio: {task_analysis['stats']['mean_visual_ratio']:.2%}")
            
            # Genera video explainability con GRIGLIA di tutte le metodologie
            if save_videos and task_analysis['frames'] and task_analysis['original_frames']:
                video_path = str(Path(video_dir) / f"task_{task_id:02d}_grid_explainability.mp4")
                generate_grid_explainability_video(
                    task_analysis['frames'],
                    task_analysis['original_frames'],
                    video_path,
                    fps=5
                )
        
        except Exception as e:
            print(f"‚ùå Error on task {task_id}: {e}")
            import traceback
            traceback.print_exc()
        
        finally:
            env.close()
    
    # =========================================================================
    # ANALISI DI CORRELAZIONE
    # =========================================================================
    
    print(f"\n\n{'='*100}")
    print("RISULTATI ANALISI DI CORRELAZIONE")
    print(f"{'='*100}\n")
    
    # Prepara dati per correlazione
    valid_tasks = [t for t in all_task_data if 'stats' in t]
    
    if len(valid_tasks) < 2:
        print("‚ö†Ô∏è Insufficient data for correlation analysis")
        return {'task_data': all_task_data}
    
    success_rates = [t['success_rate'] for t in valid_tasks]
    text_ratios = [t['stats']['mean_text_ratio'] for t in valid_tasks]
    visual_ratios = [t['stats']['mean_visual_ratio'] for t in valid_tasks]
    tv_ratios = [t['stats']['mean_tv_ratio'] for t in valid_tasks]
    
    # Correlazioni di Pearson
    corr_text_success = np.corrcoef(success_rates, text_ratios)[0, 1]
    corr_visual_success = np.corrcoef(success_rates, visual_ratios)[0, 1]
    corr_tv_success = np.corrcoef(success_rates, tv_ratios)[0, 1]
    
    print("üìä CORRELAZIONE CON SUCCESS RATE:")
    print(f"{'='*60}")
    print(f"   Text Importance vs Success:   r = {corr_text_success:+.4f}")
    print(f"   Visual Importance vs Success: r = {corr_visual_success:+.4f}")
    print(f"   Text/Visual Ratio vs Success: r = {corr_tv_success:+.4f}")
    
    # Interpretazione
    print(f"\nüìà INTERPRETAZIONE:")
    if corr_text_success > 0.3:
        print(f"   ‚úÖ Maggiore attenzione al TESTO ‚Üí predice SUCCESSO (r={corr_text_success:.3f})")
    elif corr_text_success < -0.3:
        print(f"   ‚ö†Ô∏è Maggiore attenzione al TESTO ‚Üí predice FALLIMENTO (r={corr_text_success:.3f})")
    
    if corr_visual_success > 0.3:
        print(f"   ‚úÖ Maggiore attenzione VISIVA ‚Üí predice SUCCESSO (r={corr_visual_success:.3f})")
    elif corr_visual_success < -0.3:
        print(f"   ‚ö†Ô∏è Maggiore attenzione VISIVA ‚Üí predice FALLIMENTO (r={corr_visual_success:.3f})")
    
    # Analisi per gruppo (success vs failure)
    print(f"\n{'='*60}")
    print("üìä CONFRONTO SUCCESSI vs FALLIMENTI:")
    print(f"{'='*60}")
    
    success_tasks = [t for t in valid_tasks if t['success_rate'] > 0.5]
    failure_tasks = [t for t in valid_tasks if t['success_rate'] <= 0.5]
    
    if success_tasks:
        success_text = np.mean([t['stats']['mean_text_ratio'] for t in success_tasks])
        success_visual = np.mean([t['stats']['mean_visual_ratio'] for t in success_tasks])
        print(f"\n   Task RIUSCITI ({len(success_tasks)} tasks):")
        print(f"      Text Ratio medio:   {success_text:.2%}")
        print(f"      Visual Ratio medio: {success_visual:.2%}")
    
    if failure_tasks:
        failure_text = np.mean([t['stats']['mean_text_ratio'] for t in failure_tasks])
        failure_visual = np.mean([t['stats']['mean_visual_ratio'] for t in failure_tasks])
        print(f"\n   Task FALLITI ({len(failure_tasks)} tasks):")
        print(f"      Text Ratio medio:   {failure_text:.2%}")
        print(f"      Visual Ratio medio: {failure_visual:.2%}")
    
    if success_tasks and failure_tasks:
        diff_text = success_text - failure_text
        diff_visual = success_visual - failure_visual
        print(f"\n   DIFFERENZA (Success - Failure):")
        print(f"      Text:   {diff_text:+.2%}")
        print(f"      Visual: {diff_visual:+.2%}")
        
        if abs(diff_text) > abs(diff_visual):
            modality = "TESTUALE"
            direction = "maggiore" if diff_text > 0 else "minore"
        else:
            modality = "VISIVA"
            direction = "maggiore" if diff_visual > 0 else "minore"
        
        print(f"\n   ‚Üí L'attenzione {modality} √® {direction} nei task riusciti")
    
    # Analisi temporale multimodale
    print(f"\n{'='*60}")
    print("üìä EVOLUZIONE TEMPORALE TEXT vs VISUAL:")
    print(f"{'='*60}")
    
    # Dividi frames in fasi
    for task in valid_tasks[:3]:  # Top 3 per esempio
        frames = task['frames']
        if len(frames) < 6:
            continue
        
        n = len(frames)
        early = frames[:n//3]
        mid = frames[n//3:2*n//3]
        late = frames[2*n//3:]
        
        early_text = np.mean([f['text_ratio'] for f in early])
        mid_text = np.mean([f['text_ratio'] for f in mid])
        late_text = np.mean([f['text_ratio'] for f in late])
        
        status = "‚úì" if task['success_rate'] > 0.5 else "‚úó"
        print(f"\n   Task {task['task_id']} ({task['success_rate']:.0%}) {status}:")
        print(f"      Early Text: {early_text:.1%} ‚Üí Mid: {mid_text:.1%} ‚Üí Late: {late_text:.1%}")
        
        trend = "‚Üó Crescente" if late_text > early_text + 0.02 else (
            "‚Üò Decrescente" if late_text < early_text - 0.02 else "‚Üí Stabile"
        )
        print(f"      Pattern: {trend}")
    
    # Salva risultati
    results = {
        'task_data': all_task_data,
        'correlations': {
            'text_success': corr_text_success,
            'visual_success': corr_visual_success,
            'tv_ratio_success': corr_tv_success
        },
        'summary': {
            'total_tasks': len(valid_tasks),
            'mean_success_rate': np.mean(success_rates),
            'mean_text_ratio': np.mean(text_ratios),
            'mean_visual_ratio': np.mean(visual_ratios)
        }
    }
    
    # Salva JSON
    output_path = Path(video_dir) / 'multimodal_analysis.json'
    with open(output_path, 'w') as f:
        # Converti numpy per JSON
        def convert(obj):
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, (np.float32, np.float64)):
                return float(obj)
            elif isinstance(obj, (np.int32, np.int64)):
                return int(obj)
            elif isinstance(obj, dict):
                return {k: convert(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [convert(v) for v in obj]
            return obj
        
        # Rimuovi saliency maps per JSON (troppo grandi)
        json_results = convert({
            'correlations': results['correlations'],
            'summary': results['summary'],
            'per_task': [
                {
                    'task_id': t['task_id'],
                    'prompt': t['prompt'],
                    'success_rate': t['success_rate'],
                    'stats': t.get('stats', {})
                }
                for t in all_task_data
            ]
        })
        json.dump(json_results, f, indent=2)
    
    print(f"\n‚úì Results saved to {output_path}")
    
    return results


In [81]:
# ============================================================================
# ESEGUI ANALISI MULTIMODALE COMPLETA
# ============================================================================

# Esegui l'analisi di correlazione text vs visual
print("üîç Avvio analisi multimodale Text vs Visual Importance...")
print("   Questo generer√†:")
print("   1. Video con heatmap di saliency per ogni task")
print("   2. Analisi correlazione text/visual con success rate")
print("   3. Pattern temporali di attenzione")
print()

multimodal_results = analyze_multimodal_correlation(
    checkpoint_path='models/back.pt',
    action_stats_path='action_stats.json',
    benchmark='libero_spatial',
    task_ids= [8], #list(range(10)),  # Tutti i 10 task
    env_num=1,
    max_steps=500,
    analysis_interval=5,  # Analizza ogni 5 step
    visual_method='gradcam',  # Usa GradCAM per heatmap
    save_videos=True,
    video_dir='evaluation_videos'
)


üîç Avvio analisi multimodale Text vs Visual Importance...
   Questo generer√†:
   1. Video con heatmap di saliency per ogni task
   2. Analisi correlazione text/visual con success rate
   3. Pattern temporali di attenzione


ANALISI MULTIMODALE: CORRELAZIONE TEXT vs VISUAL IMPORTANCE

[info] using task orders [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
‚úÖ MultimodalExplainer initialized

Task 8: pick up the black bowl next to the plate and place it on the...
  Step   0 | Text: 29.63% | Visual: 70.37% | Top: 'plate'
  Step  20 | Text: 60.45% | Visual: 39.55% | Top: 'plate'
  Step  40 | Text: 25.20% | Visual: 74.80% | Top: 'plate'
  Step  60 | Text: 58.09% | Visual: 41.91% | Top: 'plate'
  Step  80 | Text: 62.56% | Visual: 37.44% | Top: 'plate'
  Step 100 | Text: 66.41% | Visual: 33.59% | Top: 'plate'
  Step 120 | Text: 70.75% | Visual: 29.25% | Top: 'plate'
  Step 140 | Text: 68.57% | Visual: 31.43% | Top: 'plate'
  Step 160 | Text: 48.68% | Visual: 51.32% | Top: 'plate'
  Step 180 | Text: 38.57%

In [76]:
# ============================================================================
# VISUALIZZAZIONE RISULTATI E CONCLUSIONI
# ============================================================================

def visualize_multimodal_results(results: Dict[str, Any]):
    """
    Visualizza i risultati dell'analisi multimodale con grafici ASCII.
    """
    if not results or 'task_data' not in results:
        print("‚ö†Ô∏è No results to visualize")
        return
    
    valid_tasks = [t for t in results['task_data'] if 'stats' in t]
    
    if not valid_tasks:
        print("‚ö†Ô∏è No valid task data")
        return
    
    print(f"\n{'='*100}")
    print("VISUALIZZAZIONE RISULTATI ANALISI MULTIMODALE")
    print(f"{'='*100}\n")
    
    # Ordina per success rate
    valid_tasks.sort(key=lambda x: x['success_rate'], reverse=True)
    
    # Grafico a barre ASCII: Text vs Visual per task
    print("üìä IMPORTANZA RELATIVA TEXT vs VISUAL PER TASK:")
    print(f"{'='*80}")
    print(f"{'Task':<8} {'Success':>8} {'Text':<30} {'Visual':<30}")
    print(f"{'-'*80}")
    
    for task in valid_tasks:
        text_ratio = task['stats']['mean_text_ratio']
        visual_ratio = task['stats']['mean_visual_ratio']
        
        text_bar = '‚ñà' * int(text_ratio * 50)
        visual_bar = '‚ñà' * int(visual_ratio * 50)
        
        status = "‚úì" if task['success_rate'] > 0.5 else "‚úó"
        print(f"Task {task['task_id']:2d} {status} {task['success_rate']:>6.0%}  "
              f"{text_bar:<25} {text_ratio:>4.1%} | "
              f"{visual_bar:<25} {visual_ratio:>4.1%}")
    
    # Scatter plot ASCII: Success Rate vs Text Importance
    print(f"\n\n{'='*100}")
    print("SCATTER: SUCCESS RATE vs TEXT IMPORTANCE RATIO")
    print(f"{'='*100}")
    print("Success Rate %")
    print("100 |", end="")
    
    # Crea griglia 20x60
    grid = [[' ' for _ in range(60)] for _ in range(10)]
    
    for task in valid_tasks:
        sr = task['success_rate']
        tr = task['stats']['mean_text_ratio']
        
        row = 9 - int(sr * 9)  # 0-9, invertito
        col = int(tr * 100) - 20  # Assumendo text ratio tra 20% e 80%
        col = max(0, min(59, col))
        
        if task['success_rate'] > 0.5:
            grid[row][col] = '‚óè'
        else:
            grid[row][col] = '‚óã'
    
    for i, row in enumerate(grid):
        sr_label = f"{100 - i * 10:3d}" if i % 2 == 0 else "   "
        print(f"\n{sr_label} |{''.join(row)}", end="")
    
    print(f"\n  0 +{'-'*60}")
    print(f"      20%{' '*15}40%{' '*15}60%{' '*15}80%")
    print(f"                    Text Importance Ratio")
    print(f"      ‚óè = Success (>50%)    ‚óã = Failure (‚â§50%)")
    
    # Conclusioni
    print(f"\n\n{'='*100}")
    print("CONCLUSIONI DELL'ANALISI")
    print(f"{'='*100}\n")
    
    corr = results.get('correlations', {})
    
    print("üî¨ CORRELAZIONI CHIAVE:")
    print(f"   ‚Ä¢ Correlazione Text ‚Üî Success: r = {corr.get('text_success', 0):+.4f}")
    print(f"   ‚Ä¢ Correlazione Visual ‚Üî Success: r = {corr.get('visual_success', 0):+.4f}")
    print(f"   ‚Ä¢ Correlazione Text/Visual Ratio ‚Üî Success: r = {corr.get('tv_ratio_success', 0):+.4f}")
    
    # Interpretazione automatica
    text_corr = corr.get('text_success', 0)
    visual_corr = corr.get('visual_success', 0)
    
    print(f"\nüìà INTERPRETAZIONE:")
    
    if text_corr > 0.3:
        print("   ‚úÖ Il modello ha MIGLIORI performance quando presta MAGGIORE attenzione al TESTO")
        print("      ‚Üí Il language grounding √® cruciale per il successo")
    elif text_corr < -0.3:
        print("   ‚ö†Ô∏è Il modello ha PEGGIORI performance quando presta MAGGIORE attenzione al TESTO")
        print("      ‚Üí Potrebbe indicare over-reliance sul linguaggio vs percezione visiva")
    
    if visual_corr > 0.3:
        print("   ‚úÖ Il modello ha MIGLIORI performance quando presta MAGGIORE attenzione all'INPUT VISIVO")
        print("      ‚Üí La percezione visiva √® fondamentale per l'esecuzione corretta")
    elif visual_corr < -0.3:
        print("   ‚ö†Ô∏è Il modello ha PEGGIORI performance quando presta MAGGIORE attenzione all'INPUT VISIVO")
        print("      ‚Üí Potrebbe indicare difficolt√† nel processing visivo")
    
    # Raccomandazioni
    print(f"\nüí° RACCOMANDAZIONI:")
    
    if abs(text_corr) > abs(visual_corr):
        print("   ‚Üí L'ATTENZIONE TESTUALE √® pi√π correlata al successo")
        print("   ‚Üí Migliorare il language grounding potrebbe essere la priorit√†")
    else:
        print("   ‚Üí L'ATTENZIONE VISIVA √® pi√π correlata al successo")
        print("   ‚Üí Migliorare il visual encoder potrebbe essere la priorit√†")
    
    # Statistiche finali
    summary = results.get('summary', {})
    print(f"\nüìä STATISTICHE FINALI:")
    print(f"   ‚Ä¢ Task analizzati: {summary.get('total_tasks', len(valid_tasks))}")
    print(f"   ‚Ä¢ Success rate medio: {summary.get('mean_success_rate', 0):.1%}")
    print(f"   ‚Ä¢ Text ratio medio: {summary.get('mean_text_ratio', 0):.1%}")
    print(f"   ‚Ä¢ Visual ratio medio: {summary.get('mean_visual_ratio', 0):.1%}")
    
    print(f"\nüìÅ OUTPUT GENERATI:")
    print(f"   ‚Ä¢ Video heatmap: evaluation_videos/task_XX_heatmap.mp4")
    print(f"   ‚Ä¢ Analisi JSON: evaluation_videos/multimodal_analysis.json")


# Visualizza i risultati se disponibili
if 'multimodal_results' in dir() and multimodal_results:
    visualize_multimodal_results(multimodal_results)
else:
    print("‚ö†Ô∏è Esegui prima la cella di analisi multimodale")



VISUALIZZAZIONE RISULTATI ANALISI MULTIMODALE

üìä IMPORTANZA RELATIVA TEXT vs VISUAL PER TASK:
Task      Success Text                           Visual                        
--------------------------------------------------------------------------------
Task  2 ‚úó     0%  ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà 68.4% | ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà           31.6%
Task  8 ‚úó     0%  ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  49.5% | ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà 50.5%


SCATTER: SUCCESS RATE vs TEXT IMPORTANCE RATIO
Success Rate %
100 |
100 |                                                            
    |                                                            
 80 |                                                            
    |                                                            
 60 |                         