# RoboPianist Video Generation

This notebook is designed to generate videos from pre-trained RoboPianist models. It's particularly useful for creating videos locally after training on HPC clusters where video recording was disabled.

In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import random
import time
from pathlib import Path
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass, asdict
from tqdm import tqdm
from IPython.display import HTML, clear_output, display
from base64 import b64encode
import matplotlib.pyplot as plt
import glob
import json

# Setup paths
base_dir = os.getcwd()  
robopianist_dir = os.path.join(base_dir, "robopianist")
os.chdir(robopianist_dir)
if robopianist_dir not in sys.path:
    sys.path.insert(0, robopianist_dir)

# RoboPianist imports
from robopianist import suite, music
import dm_env_wrappers as wrappers
import robopianist.wrappers as robopianist_wrappers
import dm_env

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
print(f"Working directory: {os.getcwd()}")


In [None]:
class VideoGenerationArgs:
    # Environment settings (should match training)
    environment_name: str = "RoboPianist-debug-TwinkleTwinkleRousseau-v0"
    seed: int = 42
    control_timestep: float = 0.05
    n_steps_lookahead: int = 10
    trim_silence: bool = True
    gravity_compensation: bool = True
    reduced_action_space: bool = True
    primitive_fingertip_collisions: bool = True
    action_reward_observation: bool = True
    
    # Video generation settings
    hpc_mode: bool = False  # MUST be False for video generation
    eval_episodes: int = 3  # Number of episodes to record per model
    deterministic: bool = True  # Use deterministic policy for reproducible videos
    
    # Paths
    models_dir: str = "/tmp/robopianist"
    output_dir: str = "results/videos" 
    
    # Video settings
    camera_id: str = "piano/back" 
    video_height: int = 480
    video_width: int = 640
    record_every: int = 1 
    
    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

# Create configuration
video_args = VideoGenerationArgs()

# Ensure output directory exists
output_path = Path(video_args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"Videos will be saved to: {output_path.absolute()}")

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        dims = (input_dim,) + hidden_dims + (output_dim,)
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(nn.GELU())
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

class Policy(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        self.backbone = MLP(obs_dim, 2 * action_dim, hidden_dims)
        self.action_dim = action_dim
        
    def forward(self, obs):
        outputs = self.backbone(obs)
        means, log_stds = torch.chunk(outputs, 2, dim=-1)
        log_stds = torch.clamp(log_stds, -20, 2)
        return means, log_stds
    
    def sample(self, obs, deterministic=False):
        means, log_stds = self.forward(obs)
        stds = torch.exp(log_stds)
        
        if deterministic:
            actions = torch.tanh(means)
            log_probs = None
        else:
            dist = Normal(means, stds)
            samples = dist.rsample()
            actions = torch.tanh(samples)
            
            log_probs = dist.log_prob(samples)
            log_probs -= torch.log(1 - actions.pow(2) + 1e-6)
            log_probs = log_probs.sum(dim=-1, keepdim=True)
            
        return actions, log_probs

class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        self.network = MLP(obs_dim + action_dim, 1, hidden_dims)
    
    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        return self.network(x)

class TwinQNetwork(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden_dims: Tuple[int, ...]):
        super().__init__()
        self.q1 = QNetwork(obs_dim, action_dim, hidden_dims)
        self.q2 = QNetwork(obs_dim, action_dim, hidden_dims)
    
    def forward(self, obs, action):
        return self.q1(obs, action), self.q2(obs, action)

class SAC:
    def __init__(self, obs_dim: int, action_dim: int, args):
        self.device = torch.device(args.device)
        self.discount = getattr(args, 'discount', 0.99)
        self.tau = getattr(args, 'tau', 0.005)
        self.target_entropy = -action_dim
        
        # Networks (use default hidden dims if not specified)
        hidden_dims = getattr(args, 'hidden_dims', (256, 256, 256))
        self.actor = Policy(obs_dim, action_dim, hidden_dims).to(self.device)
        self.critic = TwinQNetwork(obs_dim, action_dim, hidden_dims).to(self.device)
        self.target_critic = TwinQNetwork(obs_dim, action_dim, hidden_dims).to(self.device)
        
        self.actor = self.actor.float()
        self.critic = self.critic.float()
        self.target_critic = self.target_critic.float()
        
        self.target_critic.load_state_dict(self.critic.state_dict())
        
        # Initialize log_alpha
        init_temp = getattr(args, 'init_temperature', 1.0)
        self.log_alpha = torch.tensor(np.log(init_temp), dtype=torch.float32, requires_grad=True, device=self.device)
        
    @property
    def alpha(self):
        return self.log_alpha.exp()
    
    def select_action(self, obs, deterministic=False):
        with torch.no_grad():
            obs_tensor = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0).to(self.device)
            action, _ = self.actor.sample(obs_tensor, deterministic=deterministic)
            return action.cpu().numpy()[0].astype(np.float32)
    
    def load(self, filepath):
        checkpoint = torch.load(filepath, map_location=self.device)
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])
        self.target_critic.load_state_dict(checkpoint['target_critic'])
        self.log_alpha = checkpoint['log_alpha']
        print(f"Model loaded from {filepath}")


In [None]:
def get_env(args: VideoGenerationArgs, record_dir: Optional[Path] = None):
    """Create environment for video generation."""
    env = suite.load(
        environment_name=args.environment_name,
        seed=args.seed,
        task_kwargs=dict(
            n_steps_lookahead=args.n_steps_lookahead,
            trim_silence=args.trim_silence,
            gravity_compensation=args.gravity_compensation,
            reduced_action_space=args.reduced_action_space,
            control_timestep=args.control_timestep,
            primitive_fingertip_collisions=args.primitive_fingertip_collisions,
            change_color_on_activation=True,
        ),
    )
    
    if record_dir is not None:
        env = robopianist_wrappers.PianoSoundVideoWrapper(
            environment=env,
            record_dir=record_dir,
            record_every=args.record_every,
            camera_id=args.camera_id,
            height=args.video_height,
            width=args.video_width,
        )
    
    env = wrappers.EpisodeStatisticsWrapper(environment=env, deque_size=1)
    env = robopianist_wrappers.MidiEvaluationWrapper(environment=env, deque_size=1)
    
    if args.action_reward_observation:
        env = wrappers.ObservationActionRewardWrapper(env)
    
    env = wrappers.ConcatObservationWrapper(env)
    env = wrappers.CanonicalSpecWrapper(env, clip=True)
    env = wrappers.SinglePrecisionWrapper(env)
    env = wrappers.DmControlWrapper(env)
    
    return env

# Test environment setup
print("Setting up environment...")
test_env = get_env(video_args)
test_timestep = test_env.reset()

obs_dim = test_timestep.observation.shape[0]
action_dim = test_env.action_spec().shape[0]

print(f"Environment setup complete!")
print(f"Observation dimension: {obs_dim}")
print(f"Action dimension: {action_dim}")

In [None]:
def find_model_files(directory: str, pattern: str = "*.pt") -> List[Path]:
    search_path = Path(directory)
    if not search_path.exists():
        print(f"Directory {directory} does not exist!")
        return []
    
    model_files = []
    for file_path in search_path.rglob(pattern):
        model_files.append(file_path)
    
    # Sort by modification time (newest first)
    model_files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    return model_files

def extract_model_info(model_path: Path) -> Dict[str, any]:
    filename = model_path.stem
    info = {
        'path': model_path,
        'filename': filename,
        'step': None,
        'eval_num': None,
        'is_final': False
    }
    
    if 'final_model' in filename:
        info['is_final'] = True
    elif 'model_step_' in filename:
        parts = filename.split('_')
        for i, part in enumerate(parts):
            if part == 'step' and i + 1 < len(parts):
                try:
                    info['step'] = int(parts[i + 1])
                except ValueError:
                    pass
            elif part == 'eval' and i + 1 < len(parts):
                try:
                    info['eval_num'] = int(parts[i + 1])
                except ValueError:
                    pass
    
    return info

def play_video(filename: str, width: int = 640, height: int = 480):
    if not os.path.exists(filename):
        print(f"Video file not found: {filename}")
        return None
    
    mp4 = open(filename, "rb").read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    
    return HTML(f"""
    <video controls width="{width}" height="{height}">
        <source src="{data_url}" type="video/mp4">
    </video>
    """)

def save_evaluation_results(results: Dict, output_file: Path):
    json_results = {}
    for key, value in results.items():
        if isinstance(value, np.ndarray):
            json_results[key] = value.tolist()
        elif isinstance(value, (np.integer, np.floating)):
            json_results[key] = float(value)
        else:
            json_results[key] = value
    
    with open(output_file, 'w') as f:
        json.dump(json_results, f, indent=2)
    print(f"Results saved to: {output_file}")


In [None]:
def generate_videos_for_model(model_path: Path, args: VideoGenerationArgs, 
                            output_subdir: str = None) -> Dict[str, any]:
    print(f"\n{'='*60}")
    print(f"Processing model: {model_path.name}")
    print(f"{'='*60}")
    
    # Create output directory for this model
    if output_subdir:
        video_output_dir = Path(args.output_dir) / output_subdir
    else:
        model_name = model_path.stem
        video_output_dir = Path(args.output_dir) / model_name
    
    video_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create environment with video recording
    env = get_env(args, record_dir=video_output_dir)
    
    # Initialize agent and load model
    agent = SAC(obs_dim, action_dim, args)
    try:
        agent.load(model_path)
    except Exception as e:
        print(f"Error loading model {model_path}: {e}")
        return {}
    
    print(f"Generating {args.eval_episodes} videos...")
    
    eval_returns = []
    eval_f1s = []
    eval_precisions = []
    eval_recalls = []
    video_files = []
    
    for episode in range(args.eval_episodes):
        print(f"Recording episode {episode + 1}/{args.eval_episodes}...")
        
        timestep = env.reset()
        eval_return = 0.0
        step_count = 0
        
        while not timestep.last():
            action = agent.select_action(timestep.observation, deterministic=args.deterministic)
            timestep = env.step(action)
            eval_return += timestep.reward
            step_count += 1
        
        eval_returns.append(eval_return)
        
        # Get musical metrics
        try:
            musical_metrics = env.get_musical_metrics()
            eval_f1s.append(musical_metrics['f1'])
            eval_precisions.append(musical_metrics['precision'])
            eval_recalls.append(musical_metrics['recall'])
        except (AttributeError, ValueError) as e:
            print(f"Warning: Could not get musical metrics: {e}")
            eval_f1s.append(0.0)
            eval_precisions.append(0.0)
            eval_recalls.append(0.0)
        
        print(f"  Episode {episode + 1}: Return = {eval_return:.2f}, F1 = {eval_f1s[-1]:.4f}, Steps = {step_count}")
        
        # Find the generated video file
        video_pattern = video_output_dir / f"*episode_{episode}*.mp4"
        videos = list(video_output_dir.glob("*.mp4"))
        if videos:
            # Get the most recently created video
            latest_video = max(videos, key=lambda x: x.stat().st_ctime)
            video_files.append(latest_video)
    
    # Calculate summary statistics
    results = {
        'model_path': str(model_path),
        'model_name': model_path.stem,
        'video_output_dir': str(video_output_dir),
        'eval_episodes': args.eval_episodes,
        'returns': eval_returns,
        'f1_scores': eval_f1s,
        'precision_scores': eval_precisions,
        'recall_scores': eval_recalls,
        'mean_return': np.mean(eval_returns),
        'std_return': np.std(eval_returns),
        'mean_f1': np.mean(eval_f1s),
        'std_f1': np.std(eval_f1s),
        'mean_precision': np.mean(eval_precisions),
        'mean_recall': np.mean(eval_recalls),
        'video_files': [str(v) for v in video_files],
        'total_videos': len(video_files)
    }
    
    print(f"\nResults for {model_path.name}:")
    print(f"  Mean Return: {results['mean_return']:.2f} ± {results['std_return']:.2f}")
    print(f"  Mean F1: {results['mean_f1']:.4f} ± {results['std_f1']:.4f}")
    print(f"  Videos generated: {results['total_videos']}")
    print(f"  Videos saved to: {video_output_dir}")
    
    # Save results to JSON
    results_file = video_output_dir / "evaluation_results.json"
    save_evaluation_results(results, results_file)
    
    return results

def batch_generate_videos(models_dir: str, args: VideoGenerationArgs, 
                         max_models: int = None, model_filter: str = None) -> List[Dict]:
    print(f"Searching for models in: {models_dir}")
    
    # Find all model files
    model_files = find_model_files(models_dir)
    
    if not model_files:
        print("No model files found!")
        return []
    
    # Apply filter if specified
    if model_filter:
        model_files = [f for f in model_files if model_filter in f.name]
        print(f"Filtered to {len(model_files)} models matching '{model_filter}'")
    
    # Limit number of models if specified
    if max_models:
        model_files = model_files[:max_models]
        print(f"Processing first {len(model_files)} models")
    
    print(f"Found {len(model_files)} model files to process")
    
    all_results = []
    
    for i, model_path in enumerate(model_files):
        try:
            model_info = extract_model_info(model_path)
            print(f"\nProcessing model {i+1}/{len(model_files)}: {model_path.name}")
            
            # Create organized subdirectory name
            if model_info['is_final']:
                subdir = f"final_model_{model_path.parent.name}"
            elif model_info['step'] is not None:
                subdir = f"step_{model_info['step']:06d}_eval_{model_info['eval_num']}"
            else:
                subdir = f"model_{i+1:03d}_{model_path.stem}"
            
            results = generate_videos_for_model(model_path, args, subdir)
            if results:
                results['model_info'] = model_info
                all_results.append(results)
            
        except Exception as e:
            print(f"Error processing {model_path}: {e}")
            continue
    
    print(f"\n{'='*60}")
    print(f"Batch processing complete!")
    print(f"Successfully processed {len(all_results)}/{len(model_files)} models")
    print(f"{'='*60}")
    
    return all_results


In [None]:
print("Searching for available models...")
available_models = find_model_files(video_args.models_dir)

if available_models:
    print(f"Found {len(available_models)} model files:")
    print("\nAvailable models:")
    print("-" * 80)
    
    for i, model_path in enumerate(available_models[:10]):  # Show first 10
        model_info = extract_model_info(model_path)
        size_mb = model_path.stat().st_size / (1024 * 1024)
        
        if model_info['is_final']:
            model_type = "Final Model"
        elif model_info['step'] is not None:
            model_type = f"Step {model_info['step']}, Eval {model_info['eval_num']}"
        else:
            model_type = "Unknown"
        
        print(f"{i+1:2d}. {model_path.name}")
        print(f"    Path: {model_path}")
        print(f"    Type: {model_type}")
        print(f"    Size: {size_mb:.1f} MB")
        print(f"    Modified: {time.ctime(model_path.stat().st_mtime)}")
        print()
    
    if len(available_models) > 10:
        print(f"... and {len(available_models) - 10} more models")
        
else:
    print(f"No model files found in {video_args.models_dir}")
    print("Please check the models_dir path in the configuration above.")


In [None]:
# Example 1: Generate videos for a single model (modify the path as needed)
if available_models:
    # Use the most recent model by default
    selected_model = available_models[0]
    
    print(f"Generating videos for: {selected_model.name}")
    
    single_results = generate_videos_for_model(
        model_path=selected_model,
        args=video_args,
        output_subdir="single_model_demo"
    )
    
    # Display the results
    if single_results and single_results['video_files']:
        print(f"\nDisplaying first video:")
        first_video = single_results['video_files'][0]
        display(play_video(first_video))
else:
    print("No models available for video generation.")

# Example 2: Batch process multiple models (uncomment to run)
# 
# # Process evaluation models only (models with "eval" in the name)
# batch_results = batch_generate_videos(
#     models_dir=video_args.models_dir,
#     args=video_args,
#     max_models=5,  # Limit to first 5 models
#     model_filter="eval"  # Only process evaluation models
# )
# 
# # Display summary
# if batch_results:
#     print("\nBatch Processing Summary:")
#     print("=" * 60)
#     for result in batch_results:
#         print(f"Model: {result['model_name']}")
#         print(f"  F1 Score: {result['mean_f1']:.4f} ± {result['std_f1']:.4f}")
#         print(f"  Return: {result['mean_return']:.2f} ± {result['std_return']:.2f}")
#         print(f"  Videos: {result['total_videos']}")
#         print()

# Example 3: Custom model selection
# Uncomment and modify this section to process specific models

# # Define specific models to process
# custom_models = [
#     "final_model.pt",
#     "model_step_10000_eval_1.pt", 
#     "model_step_20000_eval_2.pt",
#     "model_step_30000_eval_3.pt"
# ]
# 
# custom_results = []
# models_dir_path = Path(video_args.models_dir)
# 
# for model_name in custom_models:
#     model_files = list(models_dir_path.rglob(model_name))
#     
#     if model_files:
#         model_path = model_files[0]  # Use first match
#         print(f"Processing: {model_name}")
#         
#         results = generate_videos_for_model(
#             model_path=model_path,
#             args=video_args
#         )
#         
#         if results:
#             custom_results.append(results)
#     else:
#         print(f"Model not found: {model_name}")
# 
# print(f"Processed {len(custom_results)} custom models")