# World Model Training Script

This notebook implements training for the world model on MiniWorld OneRoom environment.

## Training Pipeline:
1. Data collection with simple exploration policy
2. Replay buffer storage
3. World model training with:
   - Reconstruction loss
   - KL divergence loss
   - Reward prediction loss
   - Value prediction loss (n-step returns)


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym
import miniworld
from collections import deque, defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
from PIL import Image

# Import world model
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath('')))
from fun import WorldModel

# Set device - M4 Mac should use MPS!
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print(f"Using device: MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using device: CUDA")
else:
    device = torch.device('cpu')
    print(f"Using device: CPU (SLOW!)")

print(f"Device: {device}")


Using device: MPS (Apple Silicon GPU)
Device: mps


In [None]:
# Configuration
config = {
    'env_name': 'MiniWorld-OneRoom-v0',
    'obs_size': (64, 64),
    'action_dim': 3,  # FIXED: OneRoom has 3 actions (was 4)
    'embedding_dim': 128,
    'hidden_dim': 200,
    'stochastic_dim': 64,  # INCREASED: 32→64 for more capacity (harder to collapse)
    
    # Training hyperparameters
    'batch_size': 16,  # Reduced for MPS memory efficiency
    'seq_length': 10,  # Reduced from 50 for faster training (5x speedup!)
    'learning_rate': 3e-4,  # INCREASED: 1e-4→3e-4 for faster learning
    'num_collection_episodes': 100,  # Episodes to collect before training
    'num_training_steps': 10000,
    'collect_every_n_steps': 50,  # Collect less frequently (was 10)
    
    # Loss weights - VERY AGGRESSIVE anti-collapse settings
    'lambda_rec': 100.0,  # MASSIVELY increased (was 10.0) - reconstruction is KING
    'lambda_kl_start': 0.0,  # Start with zero KL weight
    'lambda_kl_end': 0.01,  # VERY weak KL (was 0.5) - barely regularize
    'kl_anneal_steps': 8000,  # VERY slow annealing (was 2000) - let reconstruction learn first
    'lambda_reward': 1.0,
    'lambda_value': 1.0,
    'free_nats': 8.0,  # MUCH higher free bits (was 3.0) - strong protection
    
    # N-step returns
    'n_step': 5,
    'gamma': 0.99,
    
    # Exploration
    'epsilon': 0.3,  # For epsilon-greedy heuristic policy
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")


Configuration:
  env_name: MiniWorld-OneRoom-v0
  obs_size: (64, 64)
  action_dim: 4
  embedding_dim: 128
  hidden_dim: 200
  stochastic_dim: 32
  batch_size: 16
  seq_length: 10
  learning_rate: 0.0001
  num_collection_episodes: 100
  num_training_steps: 10000
  collect_every_n_steps: 50
  lambda_rec: 10.0
  lambda_kl_start: 0.0
  lambda_kl_end: 0.5
  kl_anneal_steps: 2000
  lambda_reward: 1.0
  lambda_value: 1.0
  free_nats: 3.0
  n_step: 5
  gamma: 0.99
  epsilon: 0.3


In [3]:
class ReplayBuffer:
    """
    Replay buffer for storing trajectories.
    Stores: (o_t, a_t, r_t, o_{t+1}, done_t)
    """
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def add(self, obs, action, reward, next_obs, done):
        """Add a single transition"""
        self.buffer.append({
            'obs': obs,
            'action': action,
            'reward': reward,
            'next_obs': next_obs,
            'done': done,
        })
    
    def add_trajectory(self, trajectory):
        """Add a full trajectory"""
        for transition in trajectory:
            self.add(**transition)
    
    def sample_sequences(self, batch_size, seq_length):
        """
        Sample sequences of length seq_length from the buffer.
        Returns sequences of (obs, action, reward, done)
        """
        if len(self.buffer) < seq_length:
            return None
        
        # Sample random starting indices
        max_start = len(self.buffer) - seq_length
        starts = np.random.randint(0, max_start, size=batch_size)
        
        obs_seq = []
        action_seq = []
        reward_seq = []
        done_seq = []
        
        for start in starts:
            # Extract sequence
            seq = [self.buffer[start + i] for i in range(seq_length)]
            
            obs_seq.append([s['obs'] for s in seq])
            action_seq.append([s['action'] for s in seq])
            reward_seq.append([s['reward'] for s in seq])
            done_seq.append([s['done'] for s in seq])
        
        # Convert to tensors
        # obs: (batch, seq, 3, 64, 64)
        obs_tensor = torch.stack([torch.stack([torch.tensor(o, dtype=torch.float32) for o in obs]) for obs in obs_seq])
        # action: (batch, seq, action_dim) - one-hot
        action_tensor = torch.stack([torch.stack([F.one_hot(torch.tensor(a, dtype=torch.long), config['action_dim']).float() for a in action]) for action in action_seq])
        # reward: (batch, seq)
        reward_tensor = torch.stack([torch.tensor(reward, dtype=torch.float32) for reward in reward_seq])
        # done: (batch, seq)
        done_tensor = torch.stack([torch.tensor(done, dtype=torch.float32) for done in done_seq])
        
        return obs_tensor, action_tensor, reward_tensor, done_tensor
    
    def __len__(self):
        return len(self.buffer)


In [4]:
def preprocess_obs(obs):
    """
    Preprocess observation to (3, 64, 64) tensor.
    MiniWorld returns (H, W, 3) numpy array.
    """
    if isinstance(obs, np.ndarray):
        # Convert to PIL Image if needed
        if obs.dtype != np.uint8:
            obs = (obs * 255).astype(np.uint8)
        img = Image.fromarray(obs)
    else:
        img = obs
    
    # Resize to 64x64
    img = img.resize((64, 64), Image.LANCZOS)
    
    # Convert to numpy and normalize to [0, 1]
    img_array = np.array(img).astype(np.float32) / 255.0
    
    # Convert HWC to CHW: (64, 64, 3) -> (3, 64, 64)
    img_array = np.transpose(img_array, (2, 0, 1))
    
    return img_array


def heuristic_policy(env, epsilon=0.3):
    """
    Simple heuristic policy: prefers moving forward (action 0), 
    occasionally takes random actions.
    """
    if np.random.random() < epsilon:
        return env.action_space.sample()
    else:
        return 0  # move_forward


In [5]:
def collect_trajectory(env, policy_fn, max_steps=500):
    """
    Collect a single trajectory using the exploration policy.
    
    Returns:
        trajectory: list of (obs, action, reward, next_obs, done)
    """
    obs, info = env.reset()
    obs = preprocess_obs(obs)
    
    trajectory = []
    total_reward = 0
    
    for step in range(max_steps):
        # Get action from policy
        action = policy_fn(env, config['epsilon'])
        
        # Take step
        next_obs, reward, terminated, truncated, info = env.step(action)
        next_obs = preprocess_obs(next_obs)
        done = terminated or truncated
        
        # Store transition
        trajectory.append({
            'obs': obs.copy(),
            'action': action,
            'reward': float(reward),
            'next_obs': next_obs.copy(),
            'done': float(done),
        })
        
        total_reward += reward
        obs = next_obs
        
        if done:
            break
    
    return trajectory, total_reward, len(trajectory)


In [6]:
def compute_n_step_returns(rewards, dones, values, gamma=0.99, n_step=5):
    """
    Optimized vectorized n-step returns computation.
    G_t = r_t + γ*r_{t+1} + ... + γ^{n-1}*r_{t+n-1} + γ^n * V_{t+n}
    
    Args:
        rewards: (batch, seq) tensor of rewards
        dones: (batch, seq) tensor of done flags
        values: (batch, seq) tensor of predicted values (for bootstrapping) - should be detached
        gamma: discount factor
        n_step: number of steps for n-step return
    
    Returns:
        returns: (batch, seq) tensor of n-step returns
    """
    batch_size, seq_length = rewards.shape
    device = rewards.device
    values = values.detach()
    
    # Pre-compute discount factors
    discounts = gamma ** torch.arange(n_step + 1, device=device, dtype=torch.float32)
    
    returns = torch.zeros_like(rewards)
    
    # Vectorized computation per timestep
    for t in range(seq_length):
        # Get rewards for next n steps
        end_idx = min(t + n_step, seq_length)
        n_rewards = rewards[:, t:end_idx]  # (B, n_actual)
        n_dones = dones[:, t:end_idx]  # (B, n_actual)
        
        n_actual = n_rewards.shape[1]
        disc = discounts[:n_actual].unsqueeze(0)  # (1, n_actual)
        
        # Mask out rewards after done (cumulative product to stop after first done)
        done_mask = torch.cumprod(1.0 - n_dones, dim=1)  # (B, n_actual)
        masked_rewards = n_rewards * done_mask
        
        # Sum discounted rewards
        reward_sum = (masked_rewards * disc).sum(dim=1)  # (B,)
        
        # Bootstrap value
        if t + n_step < seq_length:
            # Use value at t+n_step if not done
            bootstrap = discounts[n_step] * values[:, t + n_step] * (1.0 - dones[:, t + n_step])
        else:
            # Use last value if sequence ended
            bootstrap = discounts[n_actual] * values[:, -1] * (1.0 - dones[:, -1])
        
        returns[:, t] = reward_sum + bootstrap
    
    return returns


In [7]:
# Initialize environment
env = gym.make(config['env_name'], render_mode='rgb_array')
print(f"Environment: {config['env_name']}")
print(f"Action space: {env.action_space}")
print(f"Observation space: {env.observation_space}")

# Initialize world model
model = WorldModel(
    action_dim=config['action_dim'],
    embedding_dim=config['embedding_dim'],
    hidden_dim=config['hidden_dim'],
    stochastic_dim=config['stochastic_dim'],
    action_space_size=config['action_dim'],
).to(device)

print(f"\nWorld Model initialized:")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# Initialize replay buffer
replay_buffer = ReplayBuffer(capacity=10000)

print("\nInitialization complete!")


Falling back to num_samples=4
Falling back to non-multisampled frame buffer
Falling back to num_samples=4
Falling back to non-multisampled frame buffer
Environment: MiniWorld-OneRoom-v0
Action space: Discrete(3)
Observation space: Box(0, 255, (60, 80, 3), uint8)

World Model initialized:
  Total parameters: 10,941,945

Initialization complete!


## 4.1 Data Collection Phase


In [8]:
# Initial data collection
print("Collecting initial trajectories...")
episode_rewards = []
episode_lengths = []

for episode in tqdm(range(config['num_collection_episodes']), desc="Collecting"):
    trajectory, total_reward, traj_length = collect_trajectory(
        env, heuristic_policy, max_steps=500
    )
    replay_buffer.add_trajectory(trajectory)
    episode_rewards.append(total_reward)
    episode_lengths.append(traj_length)

print(f"\nCollected {len(replay_buffer)} transitions")
print(f"Average reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
print(f"Average length: {np.mean(episode_lengths):.2f} ± {np.std(episode_lengths):.2f}")


Collecting initial trajectories...


Collecting: 100%|██████████| 100/100 [00:08<00:00, 11.18it/s]


Collected 10000 transitions
Average reward: 0.07 ± 0.24
Average length: 171.25 ± 34.45





## 4.2 Training Loop


In [9]:
# Training loop with KL annealing and free bits
model.train()
losses_history = {
    'total': [],
    'recon': [],
    'reward': [],
    'kl': [],
    'kl_raw': [],  # Track raw KL before free bits
    'value': [],
    'kl_weight': [],  # Track annealing schedule
    'posterior_std': [],  # Track posterior collapse
    'prior_std': [],
}

print("Starting training with KL annealing and free bits...")
print(f"KL weight: {config['lambda_kl_start']:.3f} -> {config['lambda_kl_end']:.3f} over {config['kl_anneal_steps']} steps")
print(f"Free nats: {config['free_nats']:.1f}")

for step in tqdm(range(config['num_training_steps']), desc="Training"):
    # Collect new data periodically
    if step % config['collect_every_n_steps'] == 0 and step > 0:
        trajectory, _, _ = collect_trajectory(env, heuristic_policy, max_steps=500)
        replay_buffer.add_trajectory(trajectory)
    
    # KL annealing schedule
    if step < config['kl_anneal_steps']:
        kl_weight = config['lambda_kl_start'] + (config['lambda_kl_end'] - config['lambda_kl_start']) * (step / config['kl_anneal_steps'])
    else:
        kl_weight = config['lambda_kl_end']
    
    # Sample batch of sequences
    batch = replay_buffer.sample_sequences(
        batch_size=config['batch_size'],
        seq_length=config['seq_length']
    )
    
    if batch is None:
        continue
    
    obs_seq, action_seq, reward_seq, done_seq = batch
    obs_seq = obs_seq.to(device)  # (B, T, 3, 64, 64)
    action_seq = action_seq.to(device)  # (B, T, action_dim)
    reward_seq = reward_seq.to(device)  # (B, T)
    done_seq = done_seq.to(device)  # (B, T)
    
    # Reshape for processing: (B*T, ...)
    B, T = obs_seq.shape[:2]
    obs_flat = obs_seq.view(B * T, 3, 64, 64)
    action_flat = action_seq.view(B * T, config['action_dim'])
    reward_flat = reward_seq.view(B * T)
    
    # Initialize states
    h_prev = None
    z_prev = None
    
    # Forward pass through sequence
    all_outputs = []
    for t in range(T):
        obs_t = obs_seq[:, t]  # (B, 3, 64, 64)
        action_t = action_seq[:, t]  # (B, action_dim)
        
        # Use previous action for RSSM (or zero for first step)
        if t == 0:
            action_prev = torch.zeros(B, config['action_dim'], device=device)
        else:
            action_prev = action_seq[:, t-1]
        
        # Forward pass
        outputs = model(obs_t, action_prev, h_prev, z_prev, use_posterior=True)
        all_outputs.append(outputs)
        
        # Update states for next step
        h_prev = outputs['h_t']
        z_prev = outputs['z_t']
    
    # Stack outputs: (B, T, ...)
    o_hat_seq = torch.stack([o['o_hat_t'] for o in all_outputs], dim=1)  # (B, T, 3, 64, 64)
    r_hat_seq = torch.stack([o['r_hat_t'] for o in all_outputs], dim=1)  # (B, T)
    v_hat_seq = torch.stack([o['v_hat_t'] for o in all_outputs], dim=1)  # (B, T)
    
    # Compute n-step returns for value targets (detach values for bootstrapping)
    with torch.no_grad():
        value_targets = compute_n_step_returns(
            reward_seq, done_seq, v_hat_seq.detach(),
            gamma=config['gamma'],
            n_step=config['n_step']
        )
    
    # Compute losses
    recon_loss = F.mse_loss(o_hat_seq, obs_seq)
    reward_loss = F.mse_loss(r_hat_seq, reward_seq)
    value_loss = F.mse_loss(v_hat_seq, value_targets)
    
    # KL loss with FREE BITS constraint (sum over sequence, mean over batch)
    kl_loss_raw = torch.tensor(0.0, device=device)
    kl_loss = torch.tensor(0.0, device=device)
    posterior_stds = []
    prior_stds = []
    
    for t in range(T):
        if all_outputs[t]['prior_dist'] is not None and all_outputs[t]['post_dist'] is not None:
            # Per-dimension KL divergence
            kl_per_dim = torch.distributions.kl.kl_divergence(
                all_outputs[t]['post_dist'],
                all_outputs[t]['prior_dist']
            )  # (B, stochastic_dim)
            
            # Raw KL (for logging)
            kl_t_raw = kl_per_dim.mean()
            kl_loss_raw += kl_t_raw
            
            # Free bits: max(kl_per_dim, free_nats)
            # This prevents the KL from going below free_nats per dimension
            kl_per_dim_clamped = torch.maximum(kl_per_dim, torch.tensor(config['free_nats'], device=device))
            kl_t = kl_per_dim_clamped.mean()
            kl_loss += kl_t
            
            # Track std for diagnosing posterior collapse
            posterior_stds.append(all_outputs[t]['post_dist'].stddev.mean().item())
            prior_stds.append(all_outputs[t]['prior_dist'].stddev.mean().item())
    
    kl_loss = kl_loss / T
    kl_loss_raw = kl_loss_raw / T
    
    # Total loss
    total_loss = (
        config['lambda_rec'] * recon_loss +
        kl_weight * kl_loss +
        config['lambda_reward'] * reward_loss +
        config['lambda_value'] * value_loss
    )
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    # Log losses
    losses_history['total'].append(total_loss.item())
    losses_history['recon'].append(recon_loss.item())
    losses_history['reward'].append(reward_loss.item())
    losses_history['kl'].append(kl_loss.item())
    losses_history['kl_raw'].append(kl_loss_raw.item())
    losses_history['value'].append(value_loss.item())
    losses_history['kl_weight'].append(kl_weight)
    losses_history['posterior_std'].append(np.mean(posterior_stds) if posterior_stds else 0)
    losses_history['prior_std'].append(np.mean(prior_stds) if prior_stds else 0)
    
    # Print progress (less frequently to reduce overhead)
    if (step + 1) % 100 == 0:
        print(f"\nStep {step + 1}/{config['num_training_steps']}")
        print(f"  Total loss: {total_loss.item():.4f}")
        print(f"  Recon: {recon_loss.item():.4f}, Reward: {reward_loss.item():.4f}")
        print(f"  KL (raw/clamped): {kl_loss_raw.item():.4f}/{kl_loss.item():.4f}, KL weight: {kl_weight:.4f}")
        print(f"  Value: {value_loss.item():.4f}")
        print(f"  Post/Prior std: {np.mean(posterior_stds):.4f}/{np.mean(prior_stds):.4f}")

print("\nTraining complete!")


Starting training with KL annealing and free bits...
KL weight: 0.000 -> 0.500 over 2000 steps
Free nats: 3.0


Training:   1%|          | 100/10000 [00:17<27:20,  6.04it/s]


Step 100/10000
  Total loss: 0.1574
  Recon: 0.0083, Reward: 0.0001
  KL (raw/clamped): 0.0242/3.0000, KL weight: 0.0248
  Value: 0.0001
  Post/Prior std: 0.9567/1.0013


Training:   2%|▏         | 200/10000 [00:34<27:22,  5.97it/s]


Step 200/10000
  Total loss: 0.2202
  Recon: 0.0071, Reward: 0.0000
  KL (raw/clamped): 0.4367/3.0000, KL weight: 0.0498
  Value: 0.0001
  Post/Prior std: 0.7474/1.0054


Training:   3%|▎         | 300/10000 [00:52<28:49,  5.61it/s]


Step 300/10000
  Total loss: 0.2882
  Recon: 0.0064, Reward: 0.0000
  KL (raw/clamped): 1.2786/3.0000, KL weight: 0.0747
  Value: 0.0002
  Post/Prior std: 0.3154/1.0255


Training:   4%|▍         | 400/10000 [01:10<28:26,  5.62it/s]


Step 400/10000
  Total loss: 0.3671
  Recon: 0.0068, Reward: 0.0000
  KL (raw/clamped): 1.6696/3.0001, KL weight: 0.0998
  Value: 0.0001
  Post/Prior std: 0.1639/1.0651


Training:   5%|▌         | 500/10000 [01:28<29:01,  5.45it/s]


Step 500/10000
  Total loss: 0.4409
  Recon: 0.0067, Reward: 0.0000
  KL (raw/clamped): 1.8894/3.0000, KL weight: 0.1247
  Value: 0.0001
  Post/Prior std: 0.1264/1.0856


Training:   6%|▌         | 600/10000 [01:47<30:56,  5.06it/s]


Step 600/10000
  Total loss: 0.5037
  Recon: 0.0054, Reward: 0.0000
  KL (raw/clamped): 2.1045/3.0000, KL weight: 0.1497
  Value: 0.0000
  Post/Prior std: 0.1022/1.0815


Training:   7%|▋         | 674/10000 [02:03<28:29,  5.46it/s]


KeyboardInterrupt: 

In [None]:
# Plot training losses with enhanced diagnostics
fig, axes = plt.subplots(3, 3, figsize=(18, 12))

# Row 1: Main losses
axes[0, 0].plot(losses_history['total'])
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True)

axes[0, 1].plot(losses_history['recon'])
axes[0, 1].set_title('Reconstruction Loss')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True)

axes[0, 2].plot(losses_history['reward'])
axes[0, 2].set_title('Reward Prediction Loss')
axes[0, 2].set_xlabel('Step')
axes[0, 2].set_ylabel('Loss')
axes[0, 2].grid(True)

# Row 2: KL diagnostics
axes[1, 0].plot(losses_history['kl_raw'], label='Raw KL', alpha=0.7)
axes[1, 0].plot(losses_history['kl'], label='Clamped KL (free bits)', alpha=0.7)
axes[1, 0].set_title('KL Divergence Loss')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

axes[1, 1].plot(losses_history['kl_weight'])
axes[1, 1].set_title('KL Weight (Annealing Schedule)')
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Weight')
axes[1, 1].grid(True)

axes[1, 2].plot(losses_history['value'])
axes[1, 2].set_title('Value Prediction Loss')
axes[1, 2].set_xlabel('Step')
axes[1, 2].set_ylabel('Loss')
axes[1, 2].grid(True)

# Row 3: Posterior collapse diagnostics
axes[2, 0].plot(losses_history['posterior_std'], label='Posterior', alpha=0.7)
axes[2, 0].plot(losses_history['prior_std'], label='Prior', alpha=0.7)
axes[2, 0].set_title('Latent Distribution Std Devs')
axes[2, 0].set_xlabel('Step')
axes[2, 0].set_ylabel('Std Dev')
axes[2, 0].legend()
axes[2, 0].grid(True)
axes[2, 0].axhline(y=0.1, color='r', linestyle='--', alpha=0.3, label='Collapse threshold')

# Histogram of final losses
axes[2, 1].hist(losses_history['recon'][-1000:], bins=30, alpha=0.7)
axes[2, 1].set_title('Recent Reconstruction Loss Distribution')
axes[2, 1].set_xlabel('Loss')
axes[2, 1].set_ylabel('Frequency')
axes[2, 1].grid(True)

# Histogram of KL
axes[2, 2].hist(losses_history['kl_raw'][-1000:], bins=30, alpha=0.7, label='Raw')
axes[2, 2].hist(losses_history['kl'][-1000:], bins=30, alpha=0.5, label='Clamped')
axes[2, 2].set_title('Recent KL Loss Distribution')
axes[2, 2].set_xlabel('Loss')
axes[2, 2].set_ylabel('Frequency')
axes[2, 2].legend()
axes[2, 2].grid(True)

plt.tight_layout()
plt.show()

# Print summary statistics
print("\n=== Training Summary ===")
print(f"Final reconstruction loss: {losses_history['recon'][-1]:.6f}")
print(f"Final KL (raw/clamped): {losses_history['kl_raw'][-1]:.6f} / {losses_history['kl'][-1]:.6f}")
print(f"Final posterior std: {losses_history['posterior_std'][-1]:.6f}")
print(f"Final prior std: {losses_history['prior_std'][-1]:.6f}")
print(f"\nPosterior collapse check:")
if losses_history['posterior_std'][-1] < 0.1:
    print("  ⚠️ WARNING: Posterior may have collapsed (std < 0.1)")
elif losses_history['posterior_std'][-1] > 0.5:
    print("  ✅ GOOD: Posterior is active (std > 0.5)")
else:
    print("  ⚠️ MARGINAL: Posterior std is low but not collapsed")


In [None]:
# Visualize reconstructions with detailed statistics
model.eval()
with torch.no_grad():
    # Sample a sequence
    batch = replay_buffer.sample_sequences(batch_size=1, seq_length=10)
    if batch is not None:
        obs_seq, action_seq, reward_seq, done_seq = batch
        obs_seq = obs_seq.to(device)
        action_seq = action_seq.to(device)
        
        # Reconstruct
        h_prev = None
        z_prev = None
        reconstructions = []
        
        for t in range(min(10, obs_seq.shape[1])):
            obs_t = obs_seq[:, t]
            if t == 0:
                action_prev = torch.zeros(1, config['action_dim'], device=device)
            else:
                action_prev = action_seq[:, t-1]
            
            outputs = model(obs_t, action_prev, h_prev, z_prev, use_posterior=True)
            reconstructions.append(outputs['o_hat_t'])
            h_prev = outputs['h_t']
            z_prev = outputs['z_t']
        
        # Plot original vs reconstructed
        fig, axes = plt.subplots(2, 10, figsize=(20, 4))
        for t in range(min(10, len(reconstructions))):
            # Original
            orig = obs_seq[0, t].cpu().numpy().transpose(1, 2, 0)
            axes[0, t].imshow(np.clip(orig, 0, 1))
            axes[0, t].set_title(f'Original t={t}')
            axes[0, t].axis('off')
            
            # Reconstructed
            recon = reconstructions[t][0].cpu().numpy().transpose(1, 2, 0)
            axes[1, t].imshow(np.clip(recon, 0, 1))
            axes[1, t].set_title(f'Reconstructed t={t}')
            axes[1, t].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Print detailed statistics for the first frame
        print("\n=== Reconstruction Statistics (t=0) ===")
        orig_0 = obs_seq[0, 0].cpu().numpy()
        recon_0 = reconstructions[0][0].cpu().numpy()
        
        print(f"Original image:")
        print(f"  Mean: {orig_0.mean():.4f}, Std: {orig_0.std():.4f}")
        print(f"  Min: {orig_0.min():.4f}, Max: {orig_0.max():.4f}")
        
        print(f"\nReconstructed image:")
        print(f"  Mean: {recon_0.mean():.4f}, Std: {recon_0.std():.4f}")
        print(f"  Min: {recon_0.min():.4f}, Max: {recon_0.max():.4f}")
        
        # Per-channel statistics
        print(f"\nPer-channel (RGB) statistics:")
        for c, color in enumerate(['Red', 'Green', 'Blue']):
            print(f"  {color} - Orig: {orig_0[c].mean():.4f}, Recon: {recon_0[c].mean():.4f}")
        
        # MSE
        mse = np.mean((orig_0 - recon_0) ** 2)
        print(f"\nMSE: {mse:.6f}")
        print(f"RMSE: {np.sqrt(mse):.6f}")
        
        # Check if reconstruction is constant
        if recon_0.std() < 0.01:
            print("\n⚠️ WARNING: Reconstruction has very low variance - likely outputting constant values!")
            print(f"This suggests the decoder is not learning. Check:")
            print(f"  1. Is the reconstruction loss actually backpropagating?")
            print(f"  2. Is the KL loss too strong initially?")
            print(f"  3. Are the latent codes carrying information?")
        else:
            print(f"\n✅ Reconstruction has variance: {recon_0.std():.4f}")


In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config,
    'losses_history': losses_history,
}, 'worldmodel_checkpoint.pth')
print("Model saved to worldmodel_checkpoint.pth")