In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import sigmoid_focal_loss
from einops import rearrange
from safetensors.torch import load_file
from vector_quantize_pytorch import VectorQuantize
from tqdm import tqdm

encoder_state = load_file("encoder.safetensors")
vq_state = load_file("vq.safetensors")
decoder_state = load_file("decoder.safetensors")

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
        )
        self.act = nn.SiLU()

    def forward(self, x):
        return x + self.act(self.net(x))


class Encoder(nn.Module):
    """
    (B, 3, 84, 84) → (B, 16, 512)

    - 4 conv layers
    - 2 residual blocks per layer
    - 21x21 feature map partitioned into 4x4 non-overlapping 5x5 windows -> 16 tokens
    """

    def __init__(self, in_channels=3, out_channels=512):
        super().__init__()
        self.patch = 5

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),
            ResidualBlock(64),
            ResidualBlock(64),
        )  # 84 → 42

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),
            ResidualBlock(128),
            ResidualBlock(128),
        )  # 42 → 21

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.SiLU(),
            ResidualBlock(256),
            ResidualBlock(256),
        )  # 21 → 21

        self.conv4 = nn.Sequential(
            nn.Conv2d(256, out_channels, kernel_size=3, stride=1, padding=1),
            nn.SiLU(),
            ResidualBlock(out_channels),
            ResidualBlock(out_channels),
        )  # 21 → 21

        self.projection = nn.Linear(
            self.patch * self.patch * out_channels, out_channels
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # (B, 512, 21, 21)
        x = x[:, :, 0:20, 0:20]
        x = rearrange(
            x,
            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
            p1=self.patch,
            p2=self.patch,
        )  # (B, 16, 5*5*512)

        x = self.projection(x)  # (B, 16, 512)

        return x  # (B, 16, 512)


class Decoder(nn.Module):
    """
    (B, 16, 512) → (B, 3, 84, 84)

    - reverse of Encoder
    - 16 tokens -> 4x4 grid of 5x5 patches -> (B, 512, 21, 21)
    - 4 deconv layers
    """

    def __init__(self, in_channels=512, out_channels=3):
        super().__init__()
        self.patch = 5

        self.projection = nn.Linear(
            in_channels, self.patch * self.patch * in_channels
        )

        self.deconv4 = nn.Sequential(
            ResidualBlock(in_channels),
            ResidualBlock(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1),
        )  # 21 → 21

        self.deconv3 = nn.Sequential(
            ResidualBlock(256),
            ResidualBlock(256),
            nn.SiLU(),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
        )  # 21 → 21

        self.deconv2 = nn.Sequential(
            ResidualBlock(128),
            ResidualBlock(128),
            nn.SiLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
        )  # 21 → 42

        self.deconv1 = nn.Sequential(
            ResidualBlock(64),
            ResidualBlock(64),
            nn.SiLU(),
            nn.ConvTranspose2d(
                64, out_channels, kernel_size=4, stride=2, padding=1
            ),
        )  # 42 → 84

    def forward(self, x):
        # x: (B, 16, 512)
        x = self.projection(x)  # (B, 16, 5×5×512)
        x = x.view(
            x.size(0), 4, 4, 512, self.patch, self.patch
        )  # (B, 4, 4, 512, 5, 5)
        x = x.permute(0, 3, 1, 4, 2, 5)  # (B, 512, 4, 5, 4, 5)
        x = x.reshape(x.size(0), 512, 4 * 5, 4 * 5)  # (B, 512, 20, 20)

        # 填充到 21×21（补 1 行 1 列）
        x = F.pad(x, (0, 1, 0, 1))  # (B, 512, 21, 21)

        x = self.deconv4(x)  # (B, 256, 21, 21)
        x = self.deconv3(x)  # (B, 128, 21, 21)
        x = self.deconv2(x)  # (B, 64, 42, 42)
        x = self.deconv1(x)  # (B, 3, 84, 84)

        return x

In [3]:
class MiniWorldModel(nn.Module):
    def __init__(self, num_actions=18):
        super().__init__()
        self.obs_embed = nn.Embedding(512, 256)  # 512 tokens, 256 dim
        self.action_embed = nn.Embedding(
            num_actions, 256
        )  # num_actions tokens, 256 dim

        self.pos = nn.Parameter(
            torch.randn(1, 1024, 256)
        )  # positional encoding for 64 tokens

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=256,
            nhead=4,
            dim_feedforward=1024,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)

        self.obs_head = nn.Sequential(
            nn.LayerNorm(256),
            nn.Linear(256, 512),  # output obs embedding
        )

        self.reward_head = nn.Sequential(
            nn.LayerNorm(256),
            nn.Linear(256, 1),
        )
        nn.init.zeros_(self.reward_head[-1].weight)
        nn.init.zeros_(self.reward_head[-1].bias)

    def rollout(self, obs_tokens, action_tokens, actor_critic):
        """
        obs_tokens: token ids from VQ-VAE -> (B, 16)
        action_tokens: token ids for actions -> (B,)
        actor_critic: ActorCritic model to compute next state and reward
        returns: next_obs_logits, reward
        """
        pass

    def forward(self, obs_tokens, action_tokens):
        """
        obs_tokens: (B, T, K)  ← each obs is K tokens
        action_tokens: (B, T)  ← each timestep has 1 discrete action
        Returns:
            pred_obs_logits: (B, T, K, vocab_size)
            pred_rewards:    (B, T)
        """
        B, T, K = obs_tokens.shape

        # flatten for embedding
        z = obs_tokens.reshape(B, T * K)  # (B, T*K)
        a = action_tokens  # (B, T)

        # embed tokens
        z_embed = self.obs_embed(z)  # (B, T*K, 256)
        a_embed = self.action_embed(a).unsqueeze(2)  # (B, T, 1, 256)

        # interleave
        tokens = torch.cat(
            [z_embed.view(B, T, K, -1), a_embed], dim=2
        )  # (B, T, K+1, 256)
        tokens = tokens.reshape(B, T * (K + 1), 256)  # (B, T*(K+1), 256)

        # add positional encoding
        pos_embed = self.pos[:, : tokens.size(1), :]  # (1, T*(K+1), 256)
        x = tokens + pos_embed  # (B, T*(K+1), 256)

        # causal mask (L, L)
        mask = nn.Transformer.generate_square_subsequent_mask(x.size(1)).to(
            x.device
        )  # (T*(K+1), T*(K+1))

        # decode
        out = self.transformer(x, mask, is_causal=True)  # (B, T*(K+1), 256)

        out = out.reshape(B, T, K + 1, 256)  # (B, T, K+1, 256)

        obs_logits = self.obs_head(out[:, :, :-1, :])  # (B, T, K, 512)
        rewards = self.reward_head(out[:, :, -1, :]).squeeze(-1)  # (B, T)

        return obs_logits, rewards

In [4]:
class ActorCritic(nn.Module):
    def __init__(self, num_actions=18):
        super().__init__()
        self.embed = nn.Embedding(512, 512)
        self.pos = nn.Parameter(torch.randn(1, 16, 512))

        self.blocks = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=512,
                nhead=8,
                activation="gelu",
                dim_feedforward=1024,
                batch_first=True,
                norm_first=True,
            ),
            num_layers=6,
        )

        self.actor_head = nn.Sequential(
            nn.LayerNorm(512),
            nn.Linear(512, num_actions),  # num_actions
        )

        self.critic_head = nn.Sequential(
            nn.LayerNorm(512),
            nn.Linear(512, 1),
        )

    def forward(self, x):
        """
        x: token ids from VQ-VAE -> (B, 16)
        returns:
        - action logits: (B, 18)
        - value: (B, 1)
        """
        x = self.embed(x) + self.pos[:, : x.size(1), :]
        x = self.blocks(x)
        x = x.mean(dim=1)  # (B, 512)

        action_logits = self.actor_head(x)
        value = self.critic_head(x)

        return action_logits, value

In [5]:
from torch.utils.data import Dataset, DataLoader
import zarr


class TrajectoryDataset(Dataset):
    def __init__(self, game, horizon=8):
        self.root = zarr.open_group("dataset200k.zarr", mode="r")
        self.frames = self.root[game]["frames"][:]  # (N, 4, 3, 84, 84)
        self.actions = self.root[game]["actions"][:]  # (N,)
        self.rewards = self.root[game]["rewards"][:]  # (N,)
        self.dones = self.root[game]["dones"][:]  # (N,)
        self.horizon = horizon

    def __len__(self):
        return len(self.frames) - self.horizon

    def __getitem__(self, idx):
        # check if any done in the horizon (to avoid crossing episode boundary)
        if self.dones[idx : idx + self.horizon].any():
            # skip invalid episode (you can also loop until valid one)
            return self.__getitem__((idx + self.horizon) % len(self))

        frame_seq = self.frames[idx : idx + self.horizon, -1]  # (H, 3, 84, 84)
        action_seq = self.actions[idx : idx + self.horizon]  # (H,)
        reward_seq = self.rewards[idx : idx + self.horizon]  # (H,)
        reward_seq = (reward_seq - reward_seq.mean()) / (
            reward_seq.std() + 1e-6
        )  # normalize rewards

        return (
            torch.from_numpy(frame_seq).float().div_(255),  # (H, 3, 84, 84)
            torch.from_numpy(action_seq),  # (H,)
            torch.from_numpy(reward_seq).float(),  # (H,)
        )

In [6]:
encoder = Encoder().to("cuda")
quantizer = VectorQuantize(
    dim=512,
    codebook_size=512,  # each table smaller
    decay=0.8,
    commitment_weight=0.1,
).to("cuda")

encoder.load_state_dict(encoder_state)
quantizer.load_state_dict(vq_state)
world_model = MiniWorldModel(num_actions=6).to("cuda")

world_model.train()
encoder.eval()
quantizer.eval()

optimizer = torch.optim.Adam(world_model.parameters(), lr=1e-4)



In [7]:
import numpy as np
from torch.utils.data import (
    Subset,
    DataLoader,
    RandomSampler,
    SequentialSampler,
)

dataset = TrajectoryDataset("SpaceInvaders", horizon=8)
N = len(dataset)  # 起点的总个数，≈ 帧数 - horizon
h = dataset.horizon
dones = dataset.dones  # (N_total,)

# 1) 找到每个episode结束的位置
episode_ends = np.where(dones)[0]  # 包含终止帧本身
episode_starts = np.insert(episode_ends[:-1] + 1, 0, 0)

# 2) 为每条 episode 生成“合法起点”区间 [start, end - h]
episode_start_idx = []
for s, e in zip(episode_starts, episode_ends):
    valid = np.arange(s, max(s, e - h + 1))  # 可能为空，但不会跨边界
    if len(valid):
        episode_start_idx.append(valid)

# 3) 按 episode 随机打乱后 9:1 切分
rng = np.random.default_rng(seed=42)
perm = rng.permutation(len(episode_start_idx))
split = int(len(perm) * 0.9)
train_ep_ids = perm[:split]
val_ep_ids = perm[split:]

train_indices = np.concatenate([episode_start_idx[i] for i in train_ep_ids])
val_indices = np.concatenate([episode_start_idx[i] for i in val_ep_ids])

train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, val_indices)

train_loader = DataLoader(
    train_subset,
    batch_size=256,
    sampler=RandomSampler(train_subset, replacement=True),  # 打乱
    num_workers=8,
    prefetch_factor=2,
)

val_loader = DataLoader(
    val_subset,
    batch_size=256,
    sampler=SequentialSampler(val_subset),  # 不打乱，便于复现
    num_workers=8,
    prefetch_factor=2,
)

In [None]:
import wandb

run = wandb.init(
    project="pretrain-world-model",
    name="SpaceInvaders-200k",
    config={
        "batch_size": 256,
        "epochs": 10,
        "learning_rate": 1e-4,
        "horizon": 8,
    },
)

# wandb.watch(world_model, log="gradients", log_freq=100)

global_step = 0
for epoch in range(10):
    bar = tqdm(train_loader, leave=True, desc=f"Epoch {epoch + 1:02d}")
    world_model.train()
    for frames, actions, rewards in bar:
        global_step += 1

        B, H, C, Ht, Wt = frames.shape  # (B, H, 3, 84, 84)

        frames = frames.view(-1, 3, 84, 84).to(
            "cuda", dtype=torch.float32, non_blocking=True
        )  # (B*H, 3, 84, 84)
        actions = actions.to(
            "cuda", dtype=torch.long, non_blocking=True
        )  # (B, H)
        rewards = rewards.to(
            "cuda", dtype=torch.float32, non_blocking=True
        )  # (B, H)

        with torch.no_grad():
            z_e = encoder(frames)
            _, indices, _ = quantizer(z_e)
            obs_tokens = indices.view(B, H, 16)

        pred_obs_logits, pred_rewards = world_model(
            obs_tokens, actions
        )  # (B, H, 16, 512), (B, H)

        # get logits ready for loss computation
        pred_obs_logits = pred_obs_logits.permute(
            0, 1, 3, 2
        )  # (B, H, 512, 16)

        obs_loss = F.cross_entropy(
            pred_obs_logits.reshape(-1, 512),
            obs_tokens.reshape(-1),
            reduction="mean",
        )

        reward_target = (rewards.abs() > 1e-6).float()
        reward_loss = sigmoid_focal_loss(
            pred_rewards,
            reward_target,
            reduction="mean",
            alpha=0.8,
            gamma=4.5,
        )

        # total loss
        loss = obs_loss + reward_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log to wandb
        wandb.log(
            {
                "train/obs_loss": obs_loss.item(),
                "train/reward_loss": reward_loss.item(),
                "train/total_loss": loss.item(),
            },
            step=global_step,
        )

        bar.set_postfix(
            obs_loss=obs_loss.item(),
            reward_loss=reward_loss.item(),
            loss=loss.item(),
        )

    # Validation loop
    val_bar = tqdm(val_loader, leave=True, desc=f"Validation {epoch + 1:02d}")
    world_model.eval()
    obs_loss_sum = 0.0  # 加权和（按 token 计）
    reward_loss_sum = 0.0
    token_cnt = 0  # obs token 总数 (= B*T*K)

    tp = pos = fp = 0  # 召回 / 精度用

    with torch.no_grad():
        for frames, actions, rewards in val_bar:
            B, H, C, Ht, Wt = frames.shape

            # ---- 同你原来的前向部分 -----------------------------------------
            frames = frames.view(-1, 3, 84, 84).to(
                "cuda", dtype=torch.float32, non_blocking=True
            )
            actions = actions.to("cuda", dtype=torch.long, non_blocking=True)
            rewards = rewards.to(
                "cuda", dtype=torch.float32, non_blocking=True
            )

            z_e = encoder(frames)
            _, indices, _ = quantizer(z_e)
            obs_tokens = indices.view(B, H, 16)

            pred_obs_logits, pred_rewards = world_model(obs_tokens, actions)
            pred_obs_logits = pred_obs_logits.permute(0, 1, 3, 2)

            # ---- 损失 --------------------------------------------------------
            obs_loss = F.cross_entropy(
                pred_obs_logits.reshape(-1, 512),
                obs_tokens.reshape(-1),
                reduction="mean",
            )
            reward_target = (rewards.abs() > 1e-6).float()
            reward_loss = sigmoid_focal_loss(
                pred_rewards,
                reward_target,
                reduction="mean",
                alpha=0.8,
                gamma=4.5,
            )

            # ---- 加权累加（注意权重是样本/token 数） --------------------------
            batch_tokens = obs_tokens.numel()  # B*H*16
            obs_loss_sum += obs_loss.item() * batch_tokens
            reward_loss_sum += reward_loss.item() * batch_tokens
            token_cnt += batch_tokens

            # ---- 累加召回 / 精度 --------------------------------------------
            prob = torch.sigmoid(pred_rewards)
            pred = prob > 0.5

            tp += (pred & reward_target.bool()).sum().item()
            pos += reward_target.sum().item()
            fp += (pred & (~reward_target.bool())).sum().item()

    # ─── 2. 循环结束后统一计算 epoch 级指标 ─────────────────────────────────────
    obs_loss_epoch = obs_loss_sum / token_cnt
    reward_loss_epoch = reward_loss_sum / token_cnt
    total_loss_epoch = obs_loss_epoch + reward_loss_epoch

    recall = tp / pos if pos else float("nan")
    precision = tp / (tp + fp) if (tp + fp) else float("nan")
    f1 = (
        2 * precision * recall / (precision + recall)
        if precision and recall
        else float("nan")
    )

    val_bar.set_postfix(
        obs_loss=obs_loss.item(),
        reward_loss=reward_loss.item(),
        total_loss=total_loss_epoch,
        recall=recall,
        precision=precision,
        f1=f1,
    )

    # ─── 3. 只 log 一次 ───────────────────────────────────────────────────────
    wandb.log(
        {
            "val/obs_loss": obs_loss_epoch,
            "val/reward_loss": reward_loss_epoch,
            "val/total_loss": total_loss_epoch,
            "val/recall": recall,
            "val/precision": precision,
            "val/f1": f1,
        },
        step=global_step,  # 用同一个 global_step 标记这一轮验证
    )

run.finish()

[34m[1mwandb[0m: Currently logged in as: [33mdejayvu[0m ([33mdejayvu-university-of-oxford[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 01: 100%|██████████| 692/692 [12:56<00:00,  1.12s/it, loss=0.442, obs_loss=0.433, reward_loss=0.00847]
Validation 01: 100%|██████████| 78/78 [01:20<00:00,  1.03s/it]
Epoch 02: 100%|██████████| 692/692 [12:56<00:00,  1.12s/it, loss=0.272, obs_loss=0.265, reward_loss=0.00749]
Validation 02: 100%|██████████| 78/78 [01:20<00:00,  1.03s/it]
Epoch 03: 100%|██████████| 692/692 [12:56<00:00,  1.12s/it, loss=0.189, obs_loss=0.182, reward_loss=0.00735]
Validation 03: 100%|██████████| 78/78 [01:20<00:00,  1.03s/it]
Epoch 04: 100%|██████████| 692/692 [12:57<00:00,  1.12s/it, loss=0.146, obs_loss=0.138, reward_loss=0.00743]
Validation 04: 100%|██████████| 78/78 [01:20<00:00,  1.03s/it]
Epoch 05: 100%|██████████| 692/692 [12:57<00:00,  1.12s/it, loss=0.145, obs_loss=0.138, reward_loss=0.0068] 
Validation 05: 100%|██████████| 78/78 [01:20<00:00,  1.03s/it]
Epoch 06: 100%|██████████| 692/692 [12:57<00:00,  1.12s/it, loss=0.126, obs_loss=0.121, reward_loss=0.00506]
Validation 06: 100%|██████████|

0,1
train/obs_loss,█▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/reward_loss,▆▅▇▅▅█▄▅▇▅▅█▃▅▇▄▇▆▂▆▅▅▃▅▅▄▅▂▅▅▃▃▂▁▅▁▃▅▄▃
train/total_loss,█▅▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/f1,▁▁▄▅▆▆▇███
val/obs_loss,█▃▂▂▁▁▁▁▁▁
val/precision,▃█▅▆▄▃▄▃▅▁
val/recall,▂▁▃▄▅▆▆▇▇█
val/reward_loss,█▇▅▄▂▂▂▂▁▂
val/total_loss,█▃▂▂▁▁▁▁▁▁

0,1
train/obs_loss,0.13655
train/reward_loss,0.00641
train/total_loss,0.14296
val/f1,0.40732
val/obs_loss,0.17235
val/precision,0.28091
val/recall,0.74063
val/reward_loss,0.00745
val/total_loss,0.17979
