<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/WM_DEMO_DEC2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
import torch.distributions as dist
import numpy as np
import torch.nn.functional as F  # <--- CRUCIAL IMPORT FOR THE FIX

# --- 1. CONFIGURATION (Hyperparameters) ---
IMAGE_SHAPE = (3, 64, 64)  # (Channels, Height, Width)
ACTION_SIZE = 5            # Number of possible actions (e.g., Forward, Left, Right, etc.)
STOCHASTIC_SIZE = 32       # Dimension of the stochastic latent state (z_t)
DETERMINISTIC_SIZE = 256   # Dimension of the deterministic memory state (h_t)
HIDDEN_SIZE = 400          # Size of hidden layers

# --- 2. CORE UTILITY CLASSES ---

# Represents the World Model's Internal State at any time 't'
class State(nn.Module):
    def __init__(self, h, z):
        super().__init__()
        self.h = h  # Deterministic Memory (h_t)
        self.z = z  # Stochastic State (z_t)

# --- 3. THE WORLD MODEL ARCHITECTURE (Simplified Networks) ---

# The Recurrent Model: Learns the transition of the deterministic memory (h_t)
class RecurrentModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Input to GRUCell is [z_prev, a_prev]. Size: 32 + 5 = 37.
        input_size = STOCHASTIC_SIZE + ACTION_SIZE

        # The hidden_size (h_t) is DETERMINISTIC_SIZE (256).
        self.gru = nn.GRUCell(input_size, DETERMINISTIC_SIZE)

    def forward(self, h_prev, z_prev, a_prev):
        # Concatenate stochastic state and action to form the input features (x)
        x = torch.cat([z_prev, a_prev], dim=-1)

        # Pass x as input and h_prev as the hidden state (hx)
        h_t = self.gru(x, h_prev) # Update deterministic memory
        return h_t

# The Prior Model: Predicts the distribution of the next stochastic state (z_t)
class PriorModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(DETERMINISTIC_SIZE, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, 2 * STOCHASTIC_SIZE) # Mean and Std
        )

    def forward(self, h_t):
        mean, std = torch.chunk(self.net(h_t), 2, dim=-1)

        # --- FIX APPLIED HERE: Use F.softplus to operate on the tensor ---
        std_positive = F.softplus(std) + 1e-4

        return dist.Normal(mean, std_positive)

# --- 4. THE WORLD MODEL STEP FUNCTION ---

def world_model_step(model, prev_state: State, action_t: torch.Tensor, observation_t: torch.Tensor = None):
    """
    Performs one step of the World Model (either a 'dream' or a 'real' update).
    """

    # 1. Prediction (Transition/Prior Model)
    h_t = model.recurrent_model(prev_state.h, prev_state.z, action_t)
    prior_dist = model.prior_model(h_t)
    z_t_prior = prior_dist.sample()

    # 2. Representation (Posterior Model) - Requires Observation
    if observation_t is not None:
        # We just use the PRIOR z_t as the final state for this conceptual demo
        z_t_final = z_t_prior
    else:
        # When imagining (no observation), the prior becomes the final state.
        z_t_final = z_t_prior

    new_state = State(h_t, z_t_final)
    return new_state, prior_dist

# --- 5. DEMO EXECUTION (Simulating a Single Agent Step) ---

# Initialize the models
class WorldModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.recurrent_model = RecurrentModel()
        self.prior_model = PriorModel()

wm = WorldModel()

# 1. Initialize previous state (s_t-1) and action (a_t-1)
batch_size = 1
h_prev = torch.zeros(batch_size, DETERMINISTIC_SIZE)
z_prev = torch.zeros(batch_size, STOCHASTIC_SIZE)
prev_state = State(h_prev, z_prev)

# Simulate an action taken by the Agent
action_t = torch.tensor([[0.0, 0.0, 1.0, 0.0, 0.0]]) # One-hot encoded action
# Simulate a new, high-dimensional image observation (e.g., from a game screen)
observation_t = torch.randn(batch_size, *IMAGE_SHAPE)

# 2. Run the World Model Step (RSSM Loop)
print("--- World Model Step Simulation ---")

# Step 1: Learn/Update the model based on a real observation
new_real_state, real_prior_dist = world_model_step(wm, prev_state, action_t, observation_t)
print(f"| REPRESENATION (Reality): New h_t shape: {new_real_state.h.shape}")
print(f"| Correction: A loss is calculated here between the prior and the posterior (omitted).")

# Step 2: Imagine a future step (The 'Dream' - no observation)
action_imagine = torch.tensor([[1.0, 0.0, 0.0, 0.0, 0.0]]) # Imagine a different action
imagined_state, imagined_prior_dist = world_model_step(wm, new_real_state, action_imagine, observation_t=None)
print(f"| PREDICTION (Dream): Imagined h_t shape: {imagined_state.h.shape}")
print(f"| Prediction: The agent can now calculate reward/value for this imagined state.")

--- World Model Step Simulation ---
| REPRESENATION (Reality): New h_t shape: torch.Size([1, 256])
| Correction: A loss is calculated here between the prior and the posterior (omitted).
| PREDICTION (Dream): Imagined h_t shape: torch.Size([1, 256])
| Prediction: The agent can now calculate reward/value for this imagined state.


## ðŸš€ Full Conceptual Code: World Model + Controller

In [6]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F
import numpy as np

# --- 1. CONFIGURATION (Hyperparameters) ---
IMAGE_SHAPE = (3, 64, 64)
ACTION_SIZE = 5
STOCHASTIC_SIZE = 32
DETERMINISTIC_SIZE = 256
HIDDEN_SIZE = 400
# Controller Hyperparameters
PLANNING_HORIZON = 15      # How many steps the agent "looks ahead" in its dream
GAMMA = 0.99               # Discount factor for future rewards

# --- 2. CORE UTILITY CLASSES ---

class State(nn.Module):
    """Holds the deterministic and stochastic parts of the latent state."""
    def __init__(self, h, z):
        super().__init__()
        self.h = h
        self.z = z
        # Combined latent feature (used as input to Actor, Critic, and Reward Models)
        self.feature = torch.cat([h, z], dim=-1)

# --- 3. THE WORLD MODEL ARCHITECTURE ---

class RecurrentModel(nn.Module):
    def __init__(self):
        super().__init__()
        input_size = STOCHASTIC_SIZE + ACTION_SIZE
        self.gru = nn.GRUCell(input_size, DETERMINISTIC_SIZE)
    def forward(self, h_prev, z_prev, a_prev):
        x = torch.cat([z_prev, a_prev], dim=-1)
        h_t = self.gru(x, h_prev)
        return h_t

class PriorModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(DETERMINISTIC_SIZE, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, 2 * STOCHASTIC_SIZE)
        )
    def forward(self, h_t):
        mean, std = torch.chunk(self.net(h_t), 2, dim=-1)
        std_positive = F.softplus(std) + 1e-4
        return dist.Normal(mean, std_positive)

# --- 4. THE CONTROLLER (ACTOR & CRITIC) ---

class Actor(nn.Module):
    """The Policy: Chooses the action (a_t) based on the current latent state (s_t)."""
    def __init__(self):
        super().__init__()
        input_size = DETERMINISTIC_SIZE + STOCHASTIC_SIZE
        self.net = nn.Sequential(
            nn.Linear(input_size, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, ACTION_SIZE)
        )
    # FIX: Input is expected to be the tensor feature, not the State object.
    def forward(self, feature: torch.Tensor) -> dist.Distribution:
        logits = self.net(feature)
        return dist.OneHotCategorical(logits=logits)

class Critic(nn.Module):
    """The Value Function: Estimates the total future return (V(s_t)) from the state."""
    def __init__(self):
        super().__init__()
        input_size = DETERMINISTIC_SIZE + STOCHASTIC_SIZE
        self.net = nn.Sequential(
            nn.Linear(input_size, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, 1)
        )
    # FIX: Input is expected to be the tensor feature, not the State object.
    def forward(self, feature: torch.Tensor) -> torch.Tensor:
        return self.net(feature)

# --- 5. THE WORLD MODEL STEP FUNCTION ---

def world_model_step(wm, prev_state: State, action_t: torch.Tensor, observation_t: torch.Tensor = None):
    # Transition Model: h_t = f(h_t-1, z_t-1, a_t-1)
    h_t = wm.recurrent_model(prev_state.h, prev_state.z, action_t)

    # Prior Model: z_t ~ p(z_t | h_t)
    prior_dist = wm.prior_model(h_t)
    z_t_final = prior_dist.sample()

    # Create the new state object
    new_state = State(h_t, z_t_final)

    # FIX: Pass the new state's .feature tensor to the RewardModel
    reward_t = wm.reward_model(new_state.feature)

    return new_state, reward_t

# --- 6. THE IMAGINATION LOOP (Where Policy Training Happens) ---

def imagine_trajectories(wm, policy, start_state: State, horizon: int):
    """
    Generates imagined sequences of states and rewards.
    """

    current_state = start_state
    imagined_states = [current_state]
    imagined_rewards = []

    print(f"\n--- IMAGINATION: Starting trajectory of length {horizon} ---")

    for t in range(horizon):

        # 1. Controller (Actor) chooses the best action from the current imagined state
        # FIX: Pass the current state's .feature tensor to the policy (Actor)
        action_dist = policy(current_state.feature)
        action_t = action_dist.sample()

        # 2. World Model (RSSM) predicts the next state and reward based on the action
        next_state, reward_t = world_model_step(wm, current_state, action_t, observation_t=None)

        # 3. Store the results and update the current state
        imagined_states.append(next_state)
        imagined_rewards.append(reward_t)
        current_state = next_state

        print(f"| Step {t+1}: Action chosen (size {action_t.shape[-1]}), Reward predicted (value {reward_t.item():.4f})")

    print(f"--- IMAGINATION COMPLETE. Policy ready for update. ---")

    return imagined_states, imagined_rewards

# --- 7. DEMO EXECUTION ---

# Initialize the full suite of models
class AgentModels(nn.Module):
    def __init__(self):
        super().__init__()
        self.recurrent_model = RecurrentModel()
        self.prior_model = PriorModel()
        self.actor = Actor()
        self.critic = Critic()
        # Dummy reward model needs input size of DETERMINISTIC_SIZE + STOCHASTIC_SIZE
        input_size = DETERMINISTIC_SIZE + STOCHASTIC_SIZE
        self.reward_model = nn.Sequential(
            nn.Linear(input_size, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, 1)
        )
wm = AgentModels()

# 1. Initialize previous state and action
batch_size = 1
h_prev = torch.zeros(batch_size, DETERMINISTIC_SIZE)
z_prev = torch.zeros(batch_size, STOCHASTIC_SIZE)
prev_state = State(h_prev, z_prev)
action_t = torch.tensor([[0.0, 0.0, 1.0, 0.0, 0.0]])
observation_t = torch.randn(batch_size, *IMAGE_SHAPE)

# 2. Step into reality to get a starting state (Representation)
# The output is discarded with _, but the state is updated
new_real_state, _ = world_model_step(wm, prev_state, action_t, observation_t)
print("--- Real World Initialization Complete. Starting Imagination ---")

# 3. Run the Imagination Loop (Prediction and Policy Training)
imagine_trajectories(wm, wm.actor, new_real_state, PLANNING_HORIZON)

--- Real World Initialization Complete. Starting Imagination ---

--- IMAGINATION: Starting trajectory of length 15 ---
| Step 1: Action chosen (size 5), Reward predicted (value -0.0521)
| Step 2: Action chosen (size 5), Reward predicted (value 0.0098)
| Step 3: Action chosen (size 5), Reward predicted (value 0.0109)
| Step 4: Action chosen (size 5), Reward predicted (value -0.0217)
| Step 5: Action chosen (size 5), Reward predicted (value -0.0651)
| Step 6: Action chosen (size 5), Reward predicted (value -0.0492)
| Step 7: Action chosen (size 5), Reward predicted (value 0.0355)
| Step 8: Action chosen (size 5), Reward predicted (value -0.0672)
| Step 9: Action chosen (size 5), Reward predicted (value -0.0698)
| Step 10: Action chosen (size 5), Reward predicted (value -0.0425)
| Step 11: Action chosen (size 5), Reward predicted (value -0.0331)
| Step 12: Action chosen (size 5), Reward predicted (value -0.0131)
| Step 13: Action chosen (size 5), Reward predicted (value -0.0241)
| Step 1

([State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State(),
  State()],
 [tensor([[-0.0521]], grad_fn=<AddmmBackward0>),
  tensor([[0.0098]], grad_fn=<AddmmBackward0>),
  tensor([[0.0109]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0217]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0651]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0492]], grad_fn=<AddmmBackward0>),
  tensor([[0.0355]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0672]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0698]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0425]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0331]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0131]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0241]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0618]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0509]], grad_fn=<AddmmBackward0>)])