# 03 â€” Train MDN-RNN (Memory Model M)

Train the MDN-RNN to predict next-step latent vectors given current z and action.

**Prerequisites:**
1. Train VAE (notebook 02)
2. Run `scripts/encode_rollouts.py` to generate encoded sequences

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

sys.path.insert(0, str(Path.cwd().parent))
from src.config import Config
from src.train_rnn import train_rnn

config = Config()
if torch.backends.mps.is_available():
    config.device = "mps"
print(f"Device: {config.device}")

In [None]:
# Train MDN-RNN
encoded_dir = Path.cwd().parent / "datasets" / "encoded"
model = train_rnn(config.rnn, encoded_dir, device=config.device)

In [None]:
# Visualize predictions vs actual next z
from src.rnn_dataset import RNNDataset

dataset = RNNDataset(encoded_dir, sequence_length=config.rnn.sequence_length)
z_input, actions, z_target = dataset[0]
z_input = z_input.unsqueeze(0).to(config.device)
actions = actions.unsqueeze(0).to(config.device)

model.eval()
with torch.no_grad():
    pi, mu, sigma, _ = model(z_input, actions)

# Use the most probable mixture component's mean as prediction
best_comp = pi[:, :, :, 0].argmax(dim=2)  # (1, T)
pred_z = mu[0, torch.arange(mu.shape[1]), best_comp[0], :].cpu().numpy()
actual_z = z_target.numpy()

fig, axes = plt.subplots(2, 4, figsize=(16, 6))
for i, ax in enumerate(axes.flat):
    ax.plot(actual_z[:, i], label="actual", alpha=0.7)
    ax.plot(pred_z[:, i], label="predicted", alpha=0.7)
    ax.set_title(f"z dim {i}")
    ax.legend(fontsize=8)
plt.suptitle("MDN-RNN: Predicted vs Actual z")
plt.tight_layout()
plt.show()

In [None]:
# Dream: autoregressive generation
# Start from a real z, then let the RNN dream forward
from src.vae import ConvVAE

vae = ConvVAE(latent_dim=config.vae.latent_dim)
vae_ckpt = Path.cwd().parent / config.vae.checkpoint_dir / "vae_final.pt"
vae.load_state_dict(torch.load(vae_ckpt, map_location=config.device, weights_only=True))
vae.to(config.device).eval()

# Seed with real data
z_t = z_input[:, 0:1, :]  # (1, 1, latent_dim)
hidden = model.init_hidden(1, torch.device(config.device))
dream_frames = []

with torch.no_grad():
    for t in range(50):
        # Decode current z to image
        frame = vae.decoder(z_t.squeeze(1))  # (1, 3, 64, 64)
        dream_frames.append(frame.squeeze().cpu().permute(1, 2, 0).numpy())

        # Random action
        a_t = torch.randn(1, 1, 3).to(config.device) * 0.3
        pi, mu, sigma, hidden = model(z_t, a_t, hidden)

        # Sample from most likely component
        best = pi[0, 0, :, 0].argmax()
        z_t = mu[:, :, best, :] + sigma[:, :, best, :] * torch.randn_like(
            sigma[:, :, best, :]
        )

fig, axes = plt.subplots(2, 10, figsize=(20, 4))
for i, ax in enumerate(axes.flat):
    idx = i * 2
    if idx < len(dream_frames):
        ax.imshow(np.clip(dream_frames[idx], 0, 1))
    ax.axis("off")
    ax.set_title(f"t={idx}")
plt.suptitle("MDN-RNN Dream Sequence", fontsize=14)
plt.tight_layout()
plt.show()