In [None]:
pip install gymnasium[classic_control]



In [None]:
pip install imageio



In [1]:
# >>> Cell 1: imports + hyperparams
import os
import numpy as np
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import imageio
import matplotlib.pyplot as plt

# device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# hyperparams
CODE_DIM = 128
N_PREV = 5
IMG_H = IMG_W = 64
BATCH_SEQ = 1   # we process sequentially per frame; for AE training we'll use batch 1 per time-step
AE_EPOCHS = 12
TEACHER_FORCING_EPOCHS = 5    # use GT prev latents for these many epochs (teacher forcing)
SCHEDULED_SAMPLING_DECAY = 0.95  # multiply p_tf each epoch after TF phase
K_CHOICES = [128, 96, 64, 32]   # absolute K's
LAMBDA = 0.01  # penalty coefficient for K in reward (tune)
LEARNING_RATE = 1e-3

# file names (adjust to your paths)
FRAMES_NPY = "cartpole_frames.npy"  # if using cartpole frames saved earlier


Device: cpu


In [2]:
# >>> Cell 2: helper utilities (truncation, psnr)
import math

def apply_truncation(z, K):
    """
    z: (B, code_dim)
    K: int (<= code_dim)
    returns z_limited with zeros after index K
    """
    if K >= z.shape[1]:
        return z
    z_l = z.clone()
    z_l[:, K:] = 0.0
    return z_l

def psnr_from_mse(mse, max_val=1.0):
    # mse is scalar or numpy
    return 10.0 * math.log10(max_val * max_val / (mse + 1e-10))


In [None]:
# >>> Cell 3: import Encoder and Decoder
from encoder import Encoder
from decoder import Decoder

encoder = Encoder(input_ch=1, code_dim=CODE_DIM).to(DEVICE)
decoder = Decoder(output_ch=1, code_dim=CODE_DIM, n_prev=N_PREV).to(DEVICE)
print(encoder)
print(decoder)


ModuleNotFoundError: No module named 'encoder'

In [None]:
# >>> Cell 4: load dataset (npy) or Kaggle dataset snippet

# If you have a local .npy of frames: load it
if os.path.exists(FRAMES_NPY):
    frames = np.load(FRAMES_NPY)   # shape expected (T, H, W, C) where C=1 or 3
    print("Loaded frames shape:", frames.shape)
else:
    # Example: how to download Kaggle dataset via 'kaggle' library â€” uncomment and adapt if needed
    # !pip install kaggle  # in Colab, set up API token beforehand
    # from kaggle.api.kaggle_api_extended import KaggleApi
    # api = KaggleApi(); api.authenticate()
    # api.dataset_download_files('tanvirnwu/cat-dog-image-and-video', path='dataset_kaggle', unzip=True)
    raise FileNotFoundError(f"{FRAMES_NPY} not found. Upload or create frames by running CartPole recorder.")

# Normalize to [0,1] float32; convert to shape (T, C, H, W)
frames = frames.astype(np.float32) / 255.0
# if frames shape is (T,H,W) or (T,H,W,1)
if frames.ndim == 3:
    frames = frames[..., None]
if frames.shape[-1] == 3:
    # convert to grayscale
    frames_gray = np.dot(frames[...,:3], [0.2989, 0.5870, 0.1140])
    frames = frames_gray[..., None]
frames = np.transpose(frames, (0, 3, 1, 2))  # -> (T, C, H, W)
T_total = frames.shape[0]
print("Processed frames:", frames.shape)


Loaded frames shape: (2000, 64, 64, 1)
Processed frames: (2000, 1, 64, 64)


In [None]:
# >>> Cell 5: sequential dataset helper (generator)
class SequentialFrames:
    def __init__(self, frames_array):
        self.frames = frames_array  # numpy (T, C, H, W)
        self.T = frames_array.shape[0]

    def iterate_epochs(self, epoch_count=1):
        # yield sequences frame-by-frame (no shuffling)
        for e in range(epoch_count):
            for t in range(self.T):
                # yield single frame as torch tensor (1, C, H, W)
                img = torch.tensor(self.frames[t:t+1], dtype=torch.float32, device=DEVICE)
                yield t, img

seq_dataset = SequentialFrames(frames)


In [None]:
# >>> Cell 6: AE training loop (sequential, with buffer and scheduled sampling)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=LEARNING_RATE)
criterion = nn.MSELoss()

p_tf = 1.0  # teacher forcing probability
for epoch in range(1, AE_EPOCHS + 1):
    total_loss = 0.0
    # scheduled sampling adjustment after TF epochs
    use_teacher = epoch <= TEACHER_FORCING_EPOCHS
    if not use_teacher:
        p_tf = max(0.0, p_tf * SCHEDULED_SAMPLING_DECAY)

    # process sequence: maintain buffer of previous z_limited (or z if teacher forcing)
    buffer = deque(maxlen=N_PREV)
    # Warmup first N_PREV frames: encode/decode without context (or with zeros)
    for t in range(T_total):
        x_t = torch.tensor(frames[t:t+1], dtype=torch.float32, device=DEVICE)  # (1,C,H,W)
        z_t = encoder.encode(x_t)  # (1, code_dim)

        # Warmup: if buffer not full, we will decode without prev or with zeros
        if len(buffer) < N_PREV:
            z_limited = z_t  # full latent during AE training warmup
            rec = decoder.decode(z_limited, prev_latents=None)
            loss = criterion(rec, x_t)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # store limited latent (we store full z for teacher-forcing phase)
            buffer.append(z_limited.detach())
            continue

        # now buffer is full (or at least has some items)
        # scheduled sampling: decide whether to use GT latents (teacher) or model's past latents:
        if use_teacher or (torch.rand(1).item() < p_tf):
            prev_stack = torch.cat(list(buffer), dim=0).unsqueeze(0)  # buffer: list of (1,code_dim) -> (n_prev, code_dim)
            # reshape to (1, n_prev, code_dim)
            prev_stack = prev_stack.view(1, N_PREV, CODE_DIM)
        else:
            # closed-loop: use previous latents that were produced (we stored them in buffer already)
            prev_stack = torch.cat(list(buffer), dim=0).unsqueeze(0).view(1, N_PREV, CODE_DIM)

        # Optionally: random K during AE training to make decoder robust to truncation
        # We'll randomly pick K among K_CHOICES some fraction of times
        if torch.rand(1).item() < 0.3:
            K = int(np.random.choice(K_CHOICES))
        else:
            K = CODE_DIM

        z_limited = apply_truncation(z_t, K)

        rec = decoder.decode(z_limited, prev_latents=prev_stack)
        loss = criterion(rec, x_t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # push limited latent used for context (detached)
        buffer.append(z_limited.detach())

    avg_loss = total_loss / T_total
    print(f"AE Epoch {epoch}/{AE_EPOCHS} | avg loss: {avg_loss:.6f} | use_teacher={use_teacher} p_tf={p_tf:.3f}")

# save models
torch.save(encoder.state_dict(), "encoder.pth")
torch.save(decoder.state_dict(), "decoder.pth")
print("Saved encoder/decoder.")


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
# >>> Cell 7: quick validation function - greedy K chosen externally or max K
def reconstruct_sequence_greedy(encoder, decoder, frames_np, K_choice=128, use_prev_recon=True):
    encoder.eval(); decoder.eval()
    T = frames_np.shape[0]
    recon_frames = []
    buffer = deque(maxlen=N_PREV)
    with torch.no_grad():
        for t in range(T):
            x = torch.tensor(frames_np[t:t+1], dtype=torch.float32, device=DEVICE)
            z = encoder.encode(x)  # (1, code_dim)
            z_l = apply_truncation(z, K_choice)
            if len(buffer) < N_PREV:
                prev = None
            else:
                prev = torch.cat(list(buffer), dim=0).unsqueeze(0).view(1, N_PREV, CODE_DIM)
            rec = decoder.decode(z_l, prev_latents=prev)  # (1,1,H,W)
            recon_frames.append(rec.cpu().numpy()[0,0])  # (H,W)
            buffer.append(z_l)  # use z_l as prev latent (closed-loop)
    return np.stack(recon_frames, axis=0)

# example: reconstruct with K=128
recon = reconstruct_sequence_greedy(encoder, decoder, frames, K_choice=128)
print("recon shape", recon.shape)
# save a short preview as mp4 (normalize to 0..255)
sample_vid = (recon[:200] * 255).astype(np.uint8)
imageio.mimsave("recon_preview_k128.mp4", sample_vid, fps=30)
print("Saved recon_preview_k128.mp4")


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.