In [1]:
!apt install freeglut3-dev
!apt-get -y install xvfb
!pip install pyvirtualdisplay miniworld gymnasium

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  freeglut3 libegl-dev libgl-dev libgl1-mesa-dev libgles-dev libgles1
  libglu1-mesa libglu1-mesa-dev libglvnd-core-dev libglvnd-dev libglx-dev
  libopengl-dev libxt-dev
Suggested packages:
  libxt-doc
The following NEW packages will be installed:
  freeglut3 freeglut3-dev libegl-dev libgl-dev libgl1-mesa-dev libgles-dev
  libgles1 libglu1-mesa libglu1-mesa-dev libglvnd-core-dev libglvnd-dev
  libglx-dev libopengl-dev libxt-dev
0 upgraded, 14 newly installed, 0 to remove and 41 not upgraded.
Need to get 1,192 kB of archives.
After this operation, 6,439 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 freeglut3 amd64 2.8.1-6 [74.0 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglx-dev amd64 1.4.0-1 [14.1 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgl-de

In [2]:
from pyvirtualdisplay import Display

display = Display(visible=0, size=(800, 600))
display.start()

import os
# Make sure we are NOT forcing pyglet headless; clear any leftovers
os.environ.pop("PYGLET_HEADLESS", None)
os.environ.pop("MINIWORLD_HEADLESS", None)

In [3]:
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import defaultdict, deque, OrderedDict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import miniworld
from tqdm import tqdm
import random
from PIL import Image
import sys
import os

# Set device - M4 Mac should use MPS!
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print(f"Using device: MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using device: CUDA")
else:
    device = torch.device('cpu')
    print(f"Using device: CPU (SLOW!)")

print(f"Device: {device}")

Using device: CUDA
Device: cuda


In [None]:

class Encoder(nn.Module):
    """
    3.1 Encoder: CNN + MLP
    Input: o_t ∈ R^(3×64×64)
    Output: e_t ∈ R^(d_e) (e.g., 128)
    """
    def __init__(self, embedding_dim=128):
        super(Encoder, self).__init__()
        self.embedding_dim = embedding_dim
        
        # CNN: 4 conv layers, stride 2
        # Channels: 32 → 64 → 128 → 256
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )
        
        # After 4 stride-2 convolutions: 64 → 32 → 16 → 8 → 4
        # So spatial size is 4x4
        self.flatten_dim = 256 * 4 * 4
        
        # MLP: 1024 → embedding_dim
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.flatten_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, embedding_dim),
        )
    
    def forward(self, obs):
        """
        Args:
            obs: (B, 3, 64, 64) tensor
        Returns:
            e_t: (B, embedding_dim) tensor
        """
        x = self.cnn(obs)
        e_t = self.mlp(x)
        return e_t


class RSSM(nn.Module):
    """
    3.2 RSSM latent dynamics
    Maintains deterministic state h_t and stochastic state z_t
    """
    def __init__(self, action_dim, embedding_dim=128, hidden_dim=200, stochastic_dim=64):
        super(RSSM, self).__init__()
        self.hidden_dim = hidden_dim
        self.stochastic_dim = stochastic_dim
        self.action_dim = action_dim
        self.embedding_dim = embedding_dim
        
        # Prior network (prediction)
        # Input: [z_{t-1}, a_{t-1}] → MLP → GRU → h_t
        self.prior_mlp = nn.Sequential(
            nn.Linear(stochastic_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        self.gru = nn.GRUCell(hidden_dim, hidden_dim)
        
        # Prior: MLP(h_t) → μ_t^prior, log σ_t^prior
        self.prior_mean = nn.Linear(hidden_dim, stochastic_dim)
        self.prior_std = nn.Linear(hidden_dim, stochastic_dim)
        
        # Posterior network (correction)
        # Input: [h_t, e_t] → MLP → μ_t^post, log σ_t^post
        self.posterior_mlp = nn.Sequential(
            nn.Linear(hidden_dim + embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        self.posterior_mean = nn.Linear(hidden_dim, stochastic_dim)
        self.posterior_std = nn.Linear(hidden_dim, stochastic_dim)
    
    def prior(self, h_prev, z_prev, a_prev):
        """
        Prior (prediction): p(z_t | h_{t-1}, z_{t-1}, a_{t-1})
        
        Args:
            h_prev: (B, hidden_dim) previous deterministic state
            z_prev: (B, stochastic_dim) previous stochastic state
            a_prev: (B, action_dim) previous action
        Returns:
            h_t: (B, hidden_dim) new deterministic state
            z_t_prior: (B, stochastic_dim) sampled prior stochastic state
            prior_dist: Normal distribution for KL loss
        """
        # Concat [z_{t-1}, a_{t-1}]
        x = torch.cat([z_prev, a_prev], dim=-1)
        
        # MLP → GRU → h_t
        x = self.prior_mlp(x)
        h_t = self.gru(x, h_prev)
        
        # MLP(h_t) → μ_t^prior, log σ_t^prior
        mean = self.prior_mean(h_t)
        log_std = self.prior_std(h_t)
        log_std = torch.clamp(log_std, min=-10, max=2)  # Clamp for stability
        std = torch.exp(log_std)
        
        # Sample z_t^prior
        prior_dist = Normal(mean, std)
        z_t_prior = prior_dist.rsample()  # Reparameterization trick
        
        return h_t, z_t_prior, prior_dist
    
    def posterior(self, h_t, e_t):
        """
        Posterior (correction): q(z_t | h_t, e_t)
        
        Args:
            h_t: (B, hidden_dim) deterministic state
            e_t: (B, embedding_dim) encoded observation
        Returns:
            z_t_post: (B, stochastic_dim) sampled posterior stochastic state
            post_dist: Normal distribution for KL loss
        """
        # Concat [h_t, e_t]
        x = torch.cat([h_t, e_t], dim=-1)
        
        # MLP → μ_t^post, log σ_t^post
        x = self.posterior_mlp(x)
        mean = self.posterior_mean(x)
        log_std = self.posterior_std(x)
        log_std = torch.clamp(log_std, min=-10, max=2)  # Clamp for stability
        std = torch.exp(log_std)
        
        # Sample z_t^post
        post_dist = Normal(mean, std)
        z_t_post = post_dist.rsample()  # Reparameterization trick
        
        return z_t_post, post_dist


class Decoder(nn.Module):
    """
    3.3 Decoder: Image reconstruction
    Input: [h_t, z_t]
    Output: o_hat_t (same shape as input: 3×64×64)
    """
    def __init__(self, hidden_dim=200, stochastic_dim=64):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.stochastic_dim = stochastic_dim
        
        # MLP: [h_t, z_t] → flattened features
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim + stochastic_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256 * 4 * 4),  # Reshape to 256×4×4
            nn.ReLU(),
        )
        
        # 4 deconv layers
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),  # Output in [0, 1]
        )
    
    def forward(self, h_t, z_t):
        """
        Args:
            h_t: (B, hidden_dim) deterministic state
            z_t: (B, stochastic_dim) stochastic state
        Returns:
            o_hat_t: (B, 3, 64, 64) reconstructed observation
        """
        # Concat [h_t, z_t]
        x = torch.cat([h_t, z_t], dim=-1)
        
        # MLP → reshape
        x = self.mlp(x)
        x = x.view(-1, 256, 4, 4)
        
        # Deconv layers
        o_hat_t = self.deconv(x)
        
        return o_hat_t


class RewardHead(nn.Module):
    """
    3.4 Reward head
    Input: [h_t, z_t]
    Output: scalar r_hat_t
    """
    def __init__(self, hidden_dim=200, stochastic_dim=64):
        super(RewardHead, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim + stochastic_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )
    
    def forward(self, h_t, z_t):
        """
        Args:
            h_t: (B, hidden_dim) deterministic state
            z_t: (B, stochastic_dim) stochastic state
        Returns:
            r_hat_t: (B, 1) predicted reward
        """
        x = torch.cat([h_t, z_t], dim=-1)
        r_hat_t = self.mlp(x)
        return r_hat_t.squeeze(-1)  # (B,)


class ValueHead(nn.Module):
    """
    3.5 Value head
    Input: [h_t, z_t]
    Output: scalar v_hat_t
    """
    def __init__(self, hidden_dim=200, stochastic_dim=64):
        super(ValueHead, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim + stochastic_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )
    
    def forward(self, h_t, z_t):
        """
        Args:
            h_t: (B, hidden_dim) deterministic state
            z_t: (B, stochastic_dim) stochastic state
        Returns:
            v_hat_t: (B, 1) predicted value
        """
        x = torch.cat([h_t, z_t], dim=-1)
        v_hat_t = self.mlp(x)
        return v_hat_t.squeeze(-1)  # (B,)


class PolicyPriorHead(nn.Module):
    """
    3.6 Policy prior head (for MCTS)
    Input: [h_t, z_t]
    Output: policy prior π_θ(a|s) - logits over actions
    """
    def __init__(self, hidden_dim=200, stochastic_dim=64, action_dim=3):
        super(PolicyPriorHead, self).__init__()
        self.action_dim = action_dim
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim + stochastic_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
        )
    
    def forward(self, h_t, z_t):
        """
        Args:
            h_t: (B, hidden_dim) deterministic state
            z_t: (B, stochastic_dim) stochastic state
        Returns:
            logits: (B, action_dim) action logits
            probs: (B, action_dim) action probabilities (softmax)
        """
        x = torch.cat([h_t, z_t], dim=-1)
        logits = self.mlp(x)
        probs = F.softmax(logits, dim=-1)
        return logits, probs


class WorldModel(nn.Module):
    """
    Complete World Model integrating all components:
    - Encoder
    - RSSM (prior & posterior)
    - Decoder
    - Reward head
    - Value head
    - Policy prior head
    """
    def __init__(
        self,
        action_dim=3,
        embedding_dim=128,
        hidden_dim=200,
        stochastic_dim=64,
        action_space_size=3,
    ):
        super(WorldModel, self).__init__()
        
        self.action_dim = action_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.stochastic_dim = stochastic_dim
        
        # Components
        self.encoder = Encoder(embedding_dim=embedding_dim)
        self.rssm = RSSM(
            action_dim=action_dim,
            embedding_dim=embedding_dim,
            hidden_dim=hidden_dim,
            stochastic_dim=stochastic_dim,
        )
        self.decoder = Decoder(
            hidden_dim=hidden_dim,
            stochastic_dim=stochastic_dim,
        )
        self.reward_head = RewardHead(
            hidden_dim=hidden_dim,
            stochastic_dim=stochastic_dim,
        )
        self.value_head = ValueHead(
            hidden_dim=hidden_dim,
            stochastic_dim=stochastic_dim,
        )
        self.policy_prior_head = PolicyPriorHead(
            hidden_dim=hidden_dim,
            stochastic_dim=stochastic_dim,
            action_dim=action_space_size,
        )
    
    def forward(self, obs, action, h_prev=None, z_prev=None, use_posterior=True):
        """
        Forward pass through the world model.
        
        Args:
            obs: (B, 3, 64, 64) current observation
            action: (B, action_dim) previous action (one-hot or embedding)
            h_prev: (B, hidden_dim) previous deterministic state (None for first step)
            z_prev: (B, stochastic_dim) previous stochastic state (None for first step)
            use_posterior: bool, if True use posterior (training), else use prior (imagination)
        
        Returns:
            dict with all outputs
        """
        batch_size = obs.shape[0]
        device = obs.device
        
        # Initialize states if None
        if h_prev is None:
            h_prev = torch.zeros(batch_size, self.hidden_dim, device=device)
        if z_prev is None:
            z_prev = torch.zeros(batch_size, self.stochastic_dim, device=device)
        
        # Encode observation
        e_t = self.encoder(obs)
        
        # RSSM: Prior prediction
        h_t, z_t_prior, prior_dist = self.rssm.prior(h_prev, z_prev, action)
        
        # RSSM: Posterior correction (if training)
        if use_posterior:
            z_t, post_dist = self.rssm.posterior(h_t, e_t)
        else:
            z_t = z_t_prior
            post_dist = None
        
        # Decoder: Image reconstruction
        o_hat_t = self.decoder(h_t, z_t)
        
        # Reward head
        r_hat_t = self.reward_head(h_t, z_t)
        
        # Value head
        v_hat_t = self.value_head(h_t, z_t)
        
        # Policy prior head
        policy_logits, policy_probs = self.policy_prior_head(h_t, z_t)
        
        return {
            'e_t': e_t,
            'h_t': h_t,
            'z_t': z_t,
            'z_t_prior': z_t_prior,
            'prior_dist': prior_dist,
            'post_dist': post_dist,
            'o_hat_t': o_hat_t,
            'r_hat_t': r_hat_t,
            'v_hat_t': v_hat_t,
            'policy_logits': policy_logits,
            'policy_probs': policy_probs,
        }
    
    def compute_loss(self, obs, action, reward, value_targets=None, h_prev=None, z_prev=None, 
                     recon_loss_weight=1.0, reward_loss_weight=1.0, kl_loss_weight=0.1, value_loss_weight=1.0):
        """
        Compute training losses:
        - Reconstruction loss (MSE)
        - Reward prediction loss (MSE)
        - KL divergence loss (prior vs posterior)
        - Value prediction loss (MSE) - optional
        
        Args:
            obs: (B, 3, 64, 64) observation
            action: (B, action_dim) action
            reward: (B,) true reward
            value_targets: (B,) n-step return targets (optional)
            h_prev, z_prev: previous states
            recon_loss_weight: weight for reconstruction loss
            reward_loss_weight: weight for reward loss
            kl_loss_weight: weight for KL loss
            value_loss_weight: weight for value loss
        
        Returns:
            dict with losses
        """
        # Forward pass with posterior (training)
        outputs = self.forward(obs, action, h_prev, z_prev, use_posterior=True)
        
        # Reconstruction loss (MSE)
        recon_loss = F.mse_loss(outputs['o_hat_t'], obs)
        
        # Reward prediction loss (MSE)
        reward_loss = F.mse_loss(outputs['r_hat_t'], reward)
        
        # KL divergence loss (KL(q(z_t|h_t,e_t) || p(z_t|h_{t-1},z_{t-1},a_{t-1})))
        kl_loss = 0.0
        if outputs['prior_dist'] is not None and outputs['post_dist'] is not None:
            kl_loss = torch.distributions.kl.kl_divergence(
                outputs['post_dist'], outputs['prior_dist']
            ).mean()
        
        # Value prediction loss (MSE) - if targets provided
        value_loss = torch.tensor(0.0, device=obs.device)
        if value_targets is not None:
            value_loss = F.mse_loss(outputs['v_hat_t'], value_targets)
        
        # Total loss
        total_loss = (
            recon_loss_weight * recon_loss +
            reward_loss_weight * reward_loss +
            kl_loss_weight * kl_loss +
            value_loss_weight * value_loss
        )
        
        return {
            'total_loss': total_loss,
            'recon_loss': recon_loss,
            'reward_loss': reward_loss,
            'kl_loss': kl_loss,
            'value_loss': value_loss,
        }

In [5]:
# Configuration
config = {
    'env_name': 'MiniWorld-OneRoom-v0',
    'obs_size': (64, 64),
    'action_dim': 3,  # OneRoom has 3 actions
    'embedding_dim': 128,
    'hidden_dim': 200,
    'stochastic_dim': 64,
    
    # Training hyperparameters
    'batch_size': 16,
    'seq_length': 10,
    'learning_rate': 3e-4,
    'num_collection_episodes': 100,
    'num_training_steps': 10000,
    'collect_every_n_steps': 50,
    
    # Loss weights / regularization balance
    'lambda_rec': 10.0,   # Reduce recon dominance so KL can matter
    'lambda_kl_start': 0.0,
    'lambda_kl_end': 0.10,  # Stronger KL at the end
    'kl_anneal_steps': 4000,  # Faster ramp-up
    'lambda_reward': 1.0,
    'lambda_value': 1.0,
    'free_nats': 1.0,  # Small per-dim allowance before KL is penalized
    
    # N-step returns
    'n_step': 5,
    'gamma': 0.99,
    
    # Exploration
    'epsilon': 0.3,  # For epsilon-greedy heuristic policy
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")


Configuration:
  env_name: MiniWorld-OneRoom-v0
  obs_size: (64, 64)
  action_dim: 3
  embedding_dim: 128
  hidden_dim: 200
  stochastic_dim: 64
  batch_size: 16
  seq_length: 10
  learning_rate: 0.0003
  num_collection_episodes: 100
  num_training_steps: 10000
  collect_every_n_steps: 50
  lambda_rec: 10.0
  lambda_kl_start: 0.0
  lambda_kl_end: 0.1
  kl_anneal_steps: 4000
  lambda_reward: 1.0
  lambda_value: 1.0
  free_nats: 1.0
  n_step: 5
  gamma: 0.99
  epsilon: 0.3


In [6]:
class ReplayBuffer:
    """
    Replay buffer for storing trajectories.
    Stores: (o_t, a_t, r_t, o_{t+1}, done_t)
    """
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def add(self, obs, action, reward, next_obs, done):
        """Add a single transition"""
        self.buffer.append({
            'obs': obs,
            'action': action,
            'reward': reward,
            'next_obs': next_obs,
            'done': done,
        })
    
    def add_trajectory(self, trajectory):
        """Add a full trajectory"""
        for transition in trajectory:
            self.add(**transition)
    
    def sample_sequences(self, batch_size, seq_length):
        """
        Sample sequences of length seq_length from the buffer.
        Returns sequences of (obs, action, reward, done)
        """
        if len(self.buffer) < seq_length:
            return None
        
        # Sample random starting indices
        max_start = len(self.buffer) - seq_length
        starts = np.random.randint(0, max_start, size=batch_size)
        
        obs_seq = []
        action_seq = []
        reward_seq = []
        done_seq = []
        
        for start in starts:
            # Extract sequence
            seq = [self.buffer[start + i] for i in range(seq_length)]
            
            obs_seq.append([s['obs'] for s in seq])
            action_seq.append([s['action'] for s in seq])
            reward_seq.append([s['reward'] for s in seq])
            done_seq.append([s['done'] for s in seq])
        
        # Convert to tensors
        # obs: (batch, seq, 3, 64, 64)
        obs_tensor = torch.stack([torch.stack([torch.tensor(o, dtype=torch.float32) for o in obs]) for obs in obs_seq])
        # action: (batch, seq, action_dim) - one-hot
        action_tensor = torch.stack([torch.stack([F.one_hot(torch.tensor(a, dtype=torch.long), config['action_dim']).float() for a in action]) for action in action_seq])
        # reward: (batch, seq)
        reward_tensor = torch.stack([torch.tensor(reward, dtype=torch.float32) for reward in reward_seq])
        # done: (batch, seq)
        done_tensor = torch.stack([torch.tensor(done, dtype=torch.float32) for done in done_seq])
        
        return obs_tensor, action_tensor, reward_tensor, done_tensor
    
    def __len__(self):
        return len(self.buffer)


In [7]:
def preprocess_obs(obs):
    """
    Preprocess observation to (3, 64, 64) tensor.
    MiniWorld returns (H, W, 3) numpy array.
    """
    if isinstance(obs, np.ndarray):
        # Convert to PIL Image if needed
        if obs.dtype != np.uint8:
            obs = (obs * 255).astype(np.uint8)
        img = Image.fromarray(obs)
    else:
        img = obs
    
    # Resize to 64x64
    img = img.resize((64, 64), Image.LANCZOS)
    
    # Convert to numpy and normalize to [0, 1]
    img_array = np.array(img).astype(np.float32) / 255.0
    
    # Convert HWC to CHW: (64, 64, 3) -> (3, 64, 64)
    img_array = np.transpose(img_array, (2, 0, 1))
    
    return img_array


def heuristic_policy(env, epsilon=0.3):
    """
    Simple heuristic policy: prefers moving forward (action 0), 
    occasionally takes random actions.
    """
    if np.random.random() < epsilon:
        return env.action_space.sample()
    else:
        return 0  # move_forward


In [8]:
def collect_trajectory(env, policy_fn, max_steps=500):
    """
    Collect a single trajectory using the exploration policy.
    
    Returns:
        trajectory: list of (obs, action, reward, next_obs, done)
    """
    obs, info = env.reset()
    obs = preprocess_obs(obs)
    
    trajectory = []
    total_reward = 0
    
    for step in range(max_steps):
        # Get action from policy
        action = policy_fn(env, config['epsilon'])
        
        # Take step
        next_obs, reward, terminated, truncated, info = env.step(action)
        next_obs = preprocess_obs(next_obs)
        done = terminated or truncated
        
        # Store transition
        trajectory.append({
            'obs': obs.copy(),
            'action': action,
            'reward': float(reward),
            'next_obs': next_obs.copy(),
            'done': float(done),
        })
        
        total_reward += reward
        obs = next_obs
        
        if done:
            break
    
    return trajectory, total_reward, len(trajectory)


In [9]:
def compute_n_step_returns(rewards, dones, values, gamma=0.99, n_step=5):
    """
    Optimized vectorized n-step returns computation.
    G_t = r_t + γ*r_{t+1} + ... + γ^{n-1}*r_{t+n-1} + γ^n * V_{t+n}
    
    Args:
        rewards: (batch, seq) tensor of rewards
        dones: (batch, seq) tensor of done flags
        values: (batch, seq) tensor of predicted values (for bootstrapping) - should be detached
        gamma: discount factor
        n_step: number of steps for n-step return
    
    Returns:
        returns: (batch, seq) tensor of n-step returns
    """
    batch_size, seq_length = rewards.shape
    device = rewards.device
    values = values.detach()
    
    # Pre-compute discount factors
    discounts = gamma ** torch.arange(n_step + 1, device=device, dtype=torch.float32)
    
    returns = torch.zeros_like(rewards)
    
    # Vectorized computation per timestep
    for t in range(seq_length):
        # Get rewards for next n steps
        end_idx = min(t + n_step, seq_length)
        n_rewards = rewards[:, t:end_idx]  # (B, n_actual)
        n_dones = dones[:, t:end_idx]  # (B, n_actual)
        
        n_actual = n_rewards.shape[1]
        disc = discounts[:n_actual].unsqueeze(0)  # (1, n_actual)
        
        # Mask out rewards after done (cumulative product to stop after first done)
        done_mask = torch.cumprod(1.0 - n_dones, dim=1)  # (B, n_actual)
        masked_rewards = n_rewards * done_mask
        
        # Sum discounted rewards
        reward_sum = (masked_rewards * disc).sum(dim=1)  # (B,)
        
        # Bootstrap value
        if t + n_step < seq_length:
            # Use value at t+n_step if not done
            bootstrap = discounts[n_step] * values[:, t + n_step] * (1.0 - dones[:, t + n_step])
        else:
            # Use last value if sequence ended
            bootstrap = discounts[n_actual] * values[:, -1] * (1.0 - dones[:, -1])
        
        returns[:, t] = reward_sum + bootstrap
    
    return returns


In [10]:
# Initialize environment
env = gym.make(config['env_name'], render_mode='rgb_array')
print(f"Environment: {config['env_name']}")
print(f"Action space: {env.action_space}")
print(f"Observation space: {env.observation_space}")

# Initialize world model
model = WorldModel(
    action_dim=config['action_dim'],
    embedding_dim=config['embedding_dim'],
    hidden_dim=config['hidden_dim'],
    stochastic_dim=config['stochastic_dim'],
    action_space_size=config['action_dim'],
).to(device)

print(f"\nWorld Model initialized:")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# Initialize replay buffer
replay_buffer = ReplayBuffer(capacity=10000)

print("\nInitialization complete!")


Falling back to num_samples=4
Falling back to num_samples=4
Environment: MiniWorld-OneRoom-v0
Action space: Discrete(3)
Observation space: Box(0, 255, (60, 80, 3), uint8)

World Model initialized:
  Total parameters: 11,030,960

Initialization complete!


## 4.1 Data Collection Phase


In [11]:
# Initial data collection
print("Collecting initial trajectories...")
episode_rewards = []
episode_lengths = []

for episode in tqdm(range(config['num_collection_episodes']), desc="Collecting"):
    trajectory, total_reward, traj_length = collect_trajectory(
        env, heuristic_policy, max_steps=500
    )
    replay_buffer.add_trajectory(trajectory)
    episode_rewards.append(total_reward)
    episode_lengths.append(traj_length)

print(f"\nCollected {len(replay_buffer)} transitions")
print(f"Average reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
print(f"Average length: {np.mean(episode_lengths):.2f} ± {np.std(episode_lengths):.2f}")


Collecting initial trajectories...


Collecting: 100%|██████████| 100/100 [00:26<00:00,  3.75it/s]


Collected 10000 transitions
Average reward: 0.02 ± 0.14
Average length: 176.42 ± 25.06





## 4.2 Training Loop


In [12]:
# Training loop with KL annealing and free bits
model.train()
losses_history = {
    'total': [],
    'recon': [],
    'reward': [],
    'kl': [],
    'kl_raw': [],  # Track raw KL before free bits
    'value': [],
    'kl_weight': [],  # Track annealing schedule
    'posterior_std': [],  # Track posterior collapse
    'prior_std': [],
}

print("Starting training with KL annealing and free bits...")
print(f"KL weight: {config['lambda_kl_start']:.3f} -> {config['lambda_kl_end']:.3f} over {config['kl_anneal_steps']} steps")
print(f"Free nats: {config['free_nats']:.1f}")

for step in tqdm(range(config['num_training_steps']), desc="Training"):
    # Collect new data periodically
    if step % config['collect_every_n_steps'] == 0 and step > 0:
        trajectory, _, _ = collect_trajectory(env, heuristic_policy, max_steps=500)
        replay_buffer.add_trajectory(trajectory)
    
    # KL annealing schedule
    if step < config['kl_anneal_steps']:
        kl_weight = config['lambda_kl_start'] + (config['lambda_kl_end'] - config['lambda_kl_start']) * (step / config['kl_anneal_steps'])
    else:
        kl_weight = config['lambda_kl_end']
    
    # Sample batch of sequences
    batch = replay_buffer.sample_sequences(
        batch_size=config['batch_size'],
        seq_length=config['seq_length']
    )
    
    if batch is None:
        continue
    
    obs_seq, action_seq, reward_seq, done_seq = batch
    obs_seq = obs_seq.to(device)  # (B, T, 3, 64, 64)
    action_seq = action_seq.to(device)  # (B, T, action_dim)
    reward_seq = reward_seq.to(device)  # (B, T)
    done_seq = done_seq.to(device)  # (B, T)
    
    # Reshape for processing: (B*T, ...)
    B, T = obs_seq.shape[:2]
    obs_flat = obs_seq.view(B * T, 3, 64, 64)
    action_flat = action_seq.view(B * T, config['action_dim'])
    reward_flat = reward_seq.view(B * T)
    
    # Initialize states
    h_prev = None
    z_prev = None
    
    # Forward pass through sequence
    all_outputs = []
    for t in range(T):
        obs_t = obs_seq[:, t]  # (B, 3, 64, 64)
        action_t = action_seq[:, t]  # (B, action_dim)
        
        # Use previous action for RSSM (or zero for first step)
        if t == 0:
            action_prev = torch.zeros(B, config['action_dim'], device=device)
        else:
            action_prev = action_seq[:, t-1]
        
        # Forward pass
        outputs = model(obs_t, action_prev, h_prev, z_prev, use_posterior=True)
        all_outputs.append(outputs)
        
        # Update states for next step
        h_prev = outputs['h_t']
        z_prev = outputs['z_t']
    
    # Stack outputs: (B, T, ...)
    o_hat_seq = torch.stack([o['o_hat_t'] for o in all_outputs], dim=1)  # (B, T, 3, 64, 64)
    r_hat_seq = torch.stack([o['r_hat_t'] for o in all_outputs], dim=1)  # (B, T)
    v_hat_seq = torch.stack([o['v_hat_t'] for o in all_outputs], dim=1)  # (B, T)
    
    # Compute n-step returns for value targets (detach values for bootstrapping)
    with torch.no_grad():
        value_targets = compute_n_step_returns(
            reward_seq, done_seq, v_hat_seq.detach(),
            gamma=config['gamma'],
            n_step=config['n_step']
        )
    
    # Compute losses
    recon_loss = F.mse_loss(o_hat_seq, obs_seq)
    reward_loss = F.mse_loss(r_hat_seq, reward_seq)
    value_loss = F.mse_loss(v_hat_seq, value_targets)
    
    # KL loss with FREE BITS constraint (sum over sequence, mean over batch)
    kl_loss_raw = torch.tensor(0.0, device=device)
    kl_loss = torch.tensor(0.0, device=device)
    posterior_stds = []
    prior_stds = []
    
    for t in range(T):
        if all_outputs[t]['prior_dist'] is not None and all_outputs[t]['post_dist'] is not None:
            # Per-dimension KL divergence
            kl_per_dim = torch.distributions.kl.kl_divergence(
                all_outputs[t]['post_dist'],
                all_outputs[t]['prior_dist']
            )  # (B, stochastic_dim)
            
            # Raw KL (for logging)
            kl_t_raw = kl_per_dim.mean()
            kl_loss_raw += kl_t_raw
            
            # Free bits: don't penalize the first `free_nats` nats per dim
            # Standard form: max(kl_per_dim - free_nats, 0)
            free_nats = config['free_nats']
            kl_per_dim_clamped = torch.clamp(kl_per_dim - free_nats, min=0.0)
            kl_t = kl_per_dim_clamped.mean()
            kl_loss += kl_t
            
            # Track std for diagnosing posterior collapse
            posterior_stds.append(all_outputs[t]['post_dist'].stddev.mean().item())
            prior_stds.append(all_outputs[t]['prior_dist'].stddev.mean().item())
    
    kl_loss = kl_loss / T
    kl_loss_raw = kl_loss_raw / T
    
    # Total loss
    total_loss = (
        config['lambda_rec'] * recon_loss +
        kl_weight * kl_loss +
        config['lambda_reward'] * reward_loss +
        config['lambda_value'] * value_loss
    )
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    # Log losses
    losses_history['total'].append(total_loss.item())
    losses_history['recon'].append(recon_loss.item())
    losses_history['reward'].append(reward_loss.item())
    losses_history['kl'].append(kl_loss.item())
    losses_history['kl_raw'].append(kl_loss_raw.item())
    losses_history['value'].append(value_loss.item())
    losses_history['kl_weight'].append(kl_weight)
    losses_history['posterior_std'].append(np.mean(posterior_stds) if posterior_stds else 0)
    losses_history['prior_std'].append(np.mean(prior_stds) if prior_stds else 0)
    
    # Print progress (less frequently to reduce overhead)
    if (step + 1) % 1000 == 0:
        print(f"\nStep {step + 1}/{config['num_training_steps']}")
        print(f"  Total loss: {total_loss.item():.4f}")
        print(f"  Recon: {recon_loss.item():.4f}, Reward: {reward_loss.item():.4f}")
        print(f"  KL (raw/clamped): {kl_loss_raw.item():.4f}/{kl_loss.item():.4f}, KL weight: {kl_weight:.4f}")
        print(f"  Value: {value_loss.item():.4f}")
        print(f"  Post/Prior std: {np.mean(posterior_stds):.4f}/{np.mean(prior_stds):.4f}")

print("\nTraining complete!")


Starting training with KL annealing and free bits...
KL weight: 0.000 -> 0.100 over 4000 steps
Free nats: 1.0


Training:  10%|▉         | 999/10000 [01:44<14:00, 10.71it/s]


Step 1000/10000
  Total loss: 0.0578
  Recon: 0.0058, Reward: 0.0000
  KL (raw/clamped): 0.8172/0.0115, KL weight: 0.0250
  Value: 0.0000
  Post/Prior std: 0.2532/0.8229


Training:  20%|██        | 2000/10000 [03:28<15:27,  8.62it/s]


Step 2000/10000
  Total loss: 0.0454
  Recon: 0.0045, Reward: 0.0000
  KL (raw/clamped): 0.7919/0.0038, KL weight: 0.0500
  Value: 0.0000
  Post/Prior std: 0.2871/0.8988


Training:  30%|███       | 3000/10000 [05:13<11:16, 10.35it/s]


Step 3000/10000
  Total loss: 0.0388
  Recon: 0.0039, Reward: 0.0001
  KL (raw/clamped): 0.8078/0.0008, KL weight: 0.0750
  Value: 0.0000
  Post/Prior std: 0.2861/0.9210


Training:  40%|████      | 4000/10000 [06:58<09:29, 10.54it/s]


Step 4000/10000
  Total loss: 0.0538
  Recon: 0.0054, Reward: 0.0000
  KL (raw/clamped): 0.7484/0.0015, KL weight: 0.1000
  Value: 0.0000
  Post/Prior std: 0.2963/0.8818


Training:  50%|█████     | 5000/10000 [08:43<07:45, 10.74it/s]


Step 5000/10000
  Total loss: 0.0286
  Recon: 0.0029, Reward: 0.0000
  KL (raw/clamped): 0.7714/0.0004, KL weight: 0.1000
  Value: 0.0000
  Post/Prior std: 0.3209/0.9876


Training:  70%|███████   | 7000/10000 [12:14<04:45, 10.49it/s]


Step 7000/10000
  Total loss: 0.0467
  Recon: 0.0047, Reward: 0.0000
  KL (raw/clamped): 0.8111/0.0017, KL weight: 0.1000
  Value: 0.0000
  Post/Prior std: 0.3070/0.9819


Training:  80%|████████  | 8000/10000 [13:57<03:44,  8.91it/s]


Step 8000/10000
  Total loss: 0.0364
  Recon: 0.0036, Reward: 0.0000
  KL (raw/clamped): 0.7622/0.0014, KL weight: 0.1000
  Value: 0.0000
  Post/Prior std: 0.3883/1.1406


Training:  90%|█████████ | 9000/10000 [15:41<01:34, 10.57it/s]


Step 9000/10000
  Total loss: 0.0327
  Recon: 0.0033, Reward: 0.0000
  KL (raw/clamped): 0.7944/0.0018, KL weight: 0.1000
  Value: 0.0000
  Post/Prior std: 0.4092/1.2455


Training: 100%|██████████| 10000/10000 [17:26<00:00,  9.56it/s]


Step 10000/10000
  Total loss: 0.0247
  Recon: 0.0025, Reward: 0.0000
  KL (raw/clamped): 0.7709/0.0020, KL weight: 0.1000
  Value: 0.0000
  Post/Prior std: 0.4420/1.2964

Training complete!





In [None]:
# Plot training losses with enhanced diagnostics
fig, axes = plt.subplots(3, 3, figsize=(18, 12))

# Row 1: Main losses
axes[0, 0].plot(losses_history['total'])
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True)

axes[0, 1].plot(losses_history['recon'])
axes[0, 1].set_title('Reconstruction Loss')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True)

axes[0, 2].plot(losses_history['reward'])
axes[0, 2].set_title('Reward Prediction Loss')
axes[0, 2].set_xlabel('Step')
axes[0, 2].set_ylabel('Loss')
axes[0, 2].grid(True)

# Row 2: KL diagnostics
axes[1, 0].plot(losses_history['kl_raw'], label='Raw KL', alpha=0.7)
axes[1, 0].plot(losses_history['kl'], label='Clamped KL (free bits)', alpha=0.7)
axes[1, 0].set_title('KL Divergence Loss')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

axes[1, 1].plot(losses_history['kl_weight'])
axes[1, 1].set_title('KL Weight (Annealing Schedule)')
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Weight')
axes[1, 1].grid(True)

axes[1, 2].plot(losses_history['value'])
axes[1, 2].set_title('Value Prediction Loss')
axes[1, 2].set_xlabel('Step')
axes[1, 2].set_ylabel('Loss')
axes[1, 2].grid(True)

# Row 3: Posterior collapse diagnostics
axes[2, 0].plot(losses_history['posterior_std'], label='Posterior', alpha=0.7)
axes[2, 0].plot(losses_history['prior_std'], label='Prior', alpha=0.7)
axes[2, 0].set_title('Latent Distribution Std Devs')
axes[2, 0].set_xlabel('Step')
axes[2, 0].set_ylabel('Std Dev')
axes[2, 0].legend()
axes[2, 0].grid(True)
axes[2, 0].axhline(y=0.1, color='r', linestyle='--', alpha=0.3, label='Collapse threshold')

# Histogram of final losses
axes[2, 1].hist(losses_history['recon'][-1000:], bins=30, alpha=0.7)
axes[2, 1].set_title('Recent Reconstruction Loss Distribution')
axes[2, 1].set_xlabel('Loss')
axes[2, 1].set_ylabel('Frequency')
axes[2, 1].grid(True)

# Histogram of KL
axes[2, 2].hist(losses_history['kl_raw'][-1000:], bins=30, alpha=0.7, label='Raw')
axes[2, 2].hist(losses_history['kl'][-1000:], bins=30, alpha=0.5, label='Clamped')
axes[2, 2].set_title('Recent KL Loss Distribution')
axes[2, 2].set_xlabel('Loss')
axes[2, 2].set_ylabel('Frequency')
axes[2, 2].legend()
axes[2, 2].grid(True)

plt.tight_layout()
plt.show()

# Print summary statistics
print("\n=== Training Summary ===")
print(f"Final reconstruction loss: {losses_history['recon'][-1]:.6f}")
print(f"Final KL (raw/clamped): {losses_history['kl_raw'][-1]:.6f} / {losses_history['kl'][-1]:.6f}")
print(f"Final posterior std: {losses_history['posterior_std'][-1]:.6f}")
print(f"Final prior std: {losses_history['prior_std'][-1]:.6f}")
print(f"\nPosterior collapse check:")
if losses_history['posterior_std'][-1] < 0.1:
    print("  ⚠️ WARNING: Posterior may have collapsed (std < 0.1)")
elif losses_history['posterior_std'][-1] > 0.5:
    print("  ✅ GOOD: Posterior is active (std > 0.5)")
else:
    print("  ⚠️ MARGINAL: Posterior std is low but not collapsed")


In [None]:
# Visualize reconstructions with detailed statistics
model.eval()
with torch.no_grad():
    # Sample a sequence
    batch = replay_buffer.sample_sequences(batch_size=1, seq_length=10)
    if batch is not None:
        obs_seq, action_seq, reward_seq, done_seq = batch
        obs_seq = obs_seq.to(device)
        action_seq = action_seq.to(device)
        
        # Reconstruct
        h_prev = None
        z_prev = None
        reconstructions = []
        
        for t in range(min(10, obs_seq.shape[1])):
            obs_t = obs_seq[:, t]
            if t == 0:
                action_prev = torch.zeros(1, config['action_dim'], device=device)
            else:
                action_prev = action_seq[:, t-1]
            
            outputs = model(obs_t, action_prev, h_prev, z_prev, use_posterior=True)
            reconstructions.append(outputs['o_hat_t'])
            h_prev = outputs['h_t']
            z_prev = outputs['z_t']
        
        # Plot original vs reconstructed
        fig, axes = plt.subplots(2, 10, figsize=(20, 4))
        for t in range(min(10, len(reconstructions))):
            # Original
            orig = obs_seq[0, t].cpu().numpy().transpose(1, 2, 0)
            axes[0, t].imshow(np.clip(orig, 0, 1))
            axes[0, t].set_title(f'Original t={t}')
            axes[0, t].axis('off')
            
            # Reconstructed
            recon = reconstructions[t][0].cpu().numpy().transpose(1, 2, 0)
            axes[1, t].imshow(np.clip(recon, 0, 1))
            axes[1, t].set_title(f'Reconstructed t={t}')
            axes[1, t].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Print detailed statistics for the first frame
        print("\n=== Reconstruction Statistics (t=0) ===")
        orig_0 = obs_seq[0, 0].cpu().numpy()
        recon_0 = reconstructions[0][0].cpu().numpy()
        
        print(f"Original image:")
        print(f"  Mean: {orig_0.mean():.4f}, Std: {orig_0.std():.4f}")
        print(f"  Min: {orig_0.min():.4f}, Max: {orig_0.max():.4f}")
        
        print(f"\nReconstructed image:")
        print(f"  Mean: {recon_0.mean():.4f}, Std: {recon_0.std():.4f}")
        print(f"  Min: {recon_0.min():.4f}, Max: {recon_0.max():.4f}")
        
        # Per-channel statistics
        print(f"\nPer-channel (RGB) statistics:")
        for c, color in enumerate(['Red', 'Green', 'Blue']):
            print(f"  {color} - Orig: {orig_0[c].mean():.4f}, Recon: {recon_0[c].mean():.4f}")
        
        # MSE
        mse = np.mean((orig_0 - recon_0) ** 2)
        print(f"\nMSE: {mse:.6f}")
        print(f"RMSE: {np.sqrt(mse):.6f}")
        
        # Check if reconstruction is constant
        if recon_0.std() < 0.01:
            print("\n⚠️ WARNING: Reconstruction has very low variance - likely outputting constant values!")
            print(f"This suggests the decoder is not learning. Check:")
            print(f"  1. Is the reconstruction loss actually backpropagating?")
            print(f"  2. Is the KL loss too strong initially?")
            print(f"  3. Are the latent codes carrying information?")
        else:
            print(f"\n✅ Reconstruction has variance: {recon_0.std():.4f}")


In [None]:
# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config,
    'losses_history': losses_history,
}, 'worldmodel_checkpoint.pth')
print("Model saved to worldmodel_checkpoint.pth")


MCTS         