In [17]:
pip install gymnasium



In [37]:
!pip install -q gymnasium
import gymnasium as gym

env = gym.make("CartPole-v1", render_mode="rgb_array")
obs, info = env.reset()

done = False
while not done:
    frame = env.render()
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated


In [38]:
pip install gymnasium[classic_control]



In [39]:
pip install imageio



In [40]:
%%writefile encoder.py
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_ch=1, code_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_ch, 16, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )
        self.fc_enc = None
        self.code_dim = code_dim

    def forward(self, x):
        B = x.shape[0]
        feat = self.conv(x)
        if self.fc_enc is None:
            flat_dim = feat.numel() // B
            self.fc_enc = nn.Linear(flat_dim, self.code_dim).to(x.device)
            print(f"‚úÖ Initialized fc_enc with input size {flat_dim}")
        flat = feat.reshape(B, -1)
        z = self.fc_enc(flat)
        return z


Overwriting encoder.py


In [41]:
%%writefile decoder.py
import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, code_dim=128, n_prev=5, output_ch=1):
        super().__init__()
        self.n_prev = n_prev
        self.code_dim = code_dim
        total_dim = (n_prev + 1) * code_dim
        self.fc_dec = nn.Linear(total_dim, 64 * 8 * 8)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(16, output_ch, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, z, prev_latents=None):
        if prev_latents is not None and len(prev_latents) > 0:
            prev_cat = torch.cat(prev_latents, dim=1)
            z = torch.cat([z, prev_cat], dim=1)
        else:
            B = z.size(0)
            z = torch.cat([z, torch.zeros(B, self.n_prev * self.code_dim, device=z.device)], dim=1)

        feat = self.fc_dec(z).view(-1, 64, 8, 8)
        rec = self.deconv(feat)
        return rec


Overwriting decoder.py


In [42]:
# ===============================
# Checking encoder.py and decoder.py
# ===============================
import torch
from encoder import Encoder
from decoder import Decoder

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(code_dim=128).to(DEVICE)
decoder = Decoder(code_dim=128, n_prev=5).to(DEVICE)

x = torch.rand(1, 1, 64, 64).to(DEVICE)
z = encoder(x)
rec = decoder(z)

print("Input:", x.shape)
print("Latent:", z.shape)
print("Reconstructed:", rec.shape)


‚úÖ Initialized fc_enc with input size 4096
Input: torch.Size([1, 1, 64, 64])
Latent: torch.Size([1, 128])
Reconstructed: torch.Size([1, 1, 64, 64])


In [43]:
# >>> Cell 1: imports + hyperparams
import os
import numpy as np
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import imageio
import matplotlib.pyplot as plt

# device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# hyperparams
CODE_DIM = 128
N_PREV = 5
IMG_H = IMG_W = 64
BATCH_SEQ = 1   # we process sequentially per frame; for AE training we'll use batch 1 per time-step
AE_EPOCHS = 12
TEACHER_FORCING_EPOCHS = 5    # use GT prev latents for these many epochs (teacher forcing)
SCHEDULED_SAMPLING_DECAY = 0.95  # multiply p_tf each epoch after TF phase
K_CHOICES = [128, 96, 64, 32]   # absolute K's
LAMBDA = 0.01  # penalty coefficient for K in reward (tune)
LEARNING_RATE = 1e-3

# file names (adjust to your paths)
FRAMES_NPY = "cartpole_frames.npy"  # if using cartpole frames saved earlier


Device: cpu


In [44]:
# >>> Cell 2: helper utilities (truncation, psnr)
import math

def apply_truncation(z, K):
    """
    z: (B, code_dim)
    K: int (<= code_dim)
    returns z_limited with zeros after index K
    """
    if K >= z.shape[1]:
        return z
    z_l = z.clone()
    z_l[:, K:] = 0.0
    return z_l

def psnr_from_mse(mse, max_val=1.0):
    # mse is scalar or numpy
    return 10.0 * math.log10(max_val * max_val / (mse + 1e-10))


In [45]:
# >>> Cell 3: import Encoder and Decoder
from encoder import Encoder
from decoder import Decoder

encoder = Encoder(input_ch=1, code_dim=CODE_DIM).to(DEVICE)
decoder = Decoder(output_ch=1, code_dim=CODE_DIM, n_prev=N_PREV).to(DEVICE)
print(encoder)
print(decoder)


Encoder(
  (conv): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU(inplace=True)
  )
)
Decoder(
  (fc_dec): Linear(in_features=768, out_features=4096, bias=True)
  (deconv): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(16, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): Sigmoid()
  )
)


In [47]:
# Debug + robust CartPole frames generator (works with various gym versions)
import os
import gym
import numpy as np
from PIL import Image
import imageio

FRAMES_NPY = "cartpole_frames.npy"
N_FRAMES = 2000
SAVE_GIF = True  # —Å–æ—Ö—Ä–∞–Ω—è—Ç—å –ø—Ä–µ–≤—å—é .gif —Å –ø–µ—Ä–≤—ã–º–∏ 10 –∫–∞–¥—Ä–∞–º–∏

if not os.path.exists(FRAMES_NPY):
    print("üé• Creating CartPole video dataset (robust mode)...")
    # —Å–æ–∑–¥–∞—ë–º env ‚Äî —Å—Ç–∞—Ä—ã–π gym –æ–±—ã—á–Ω–æ –ø—Ä–∏–Ω–∏–º–∞–µ—Ç render_mode=None –∏ –≤–æ–∑–≤—Ä–∞—â–∞–µ—Ç array –∏–∑ env.render()
    # –µ—Å–ª–∏ —Å–≤—è–∑–∫–∞ gym/gymnasium —Å–º–µ—à–∞–ª–∞—Å—å, —ç—Ç–æ—Ç –∫–æ–¥ –≤—Å—ë —Ä–∞–≤–Ω–æ –ø–æ—Å—Ç–∞—Ä–∞–µ—Ç—Å—è –æ–±—Ä–∞–±–æ—Ç–∞—Ç—å frame
    try:
        env = gym.make("CartPole-v1", render_mode="rgb_array")
    except TypeError:
        # —Å—Ç–∞—Ä—ã–µ gym –º–æ–≥—É—Ç –Ω–µ –ø—Ä–∏–Ω—è—Ç—å render_mode –≤ –∫–æ–Ω—Å—Ç—Ä—É–∫—Ç–æ—Ä–µ
        env = gym.make("CartPole-v1")
    frames = []

    # old gym: env.reset() -> obs OR (obs, info) for newer landscape. –ø–æ–ø—Ä–æ–±—É–µ–º –æ–±–æ–∏—Ö:
    reset_res = env.reset()
    # –£–¥–æ–±–Ω–æ –ø–æ–ª—É—á–∏—Ç—å –ø–µ—Ä–≤–æ–Ω–∞—á–∞–ª—å–Ω–æ–µ obs (–≤ –ª—é–±–æ–º —Ñ–æ—Ä–º–∞—Ç–µ)
    if isinstance(reset_res, tuple):
        obs = reset_res[0]
    else:
        obs = reset_res

    for i in range(N_FRAMES):
        frame = env.render()  # –º–æ–∂–µ—Ç –≤–µ—Ä–Ω—É—Ç—å np.ndarray –∏–ª–∏ list –∏–ª–∏ PIL image-like

        # ---- Debug prints for first few iterations ----
        if i < 3:
            print(f"DEBUG: step {i} | type(frame) = {type(frame)}")
            # –µ—Å–ª–∏ —ç—Ç–æ numpy array ‚Äî –≤—ã–≤–µ–¥–µ–º —Ñ–æ—Ä–º—É; –µ—Å–ª–∏ list ‚Äî —Ä–∞–∑–º–µ—Ä, –ø—Ä–∏–º–µ—Ä
            try:
                arr_try = np.array(frame)
                print("DEBUG: -> np.array(frame).shape =", arr_try.shape, "dtype=", arr_try.dtype)
            except Exception as e:
                print("DEBUG: cannot convert frame to np.array directly:", e)
            # show a tiny repr snippet:
            rep = repr(frame)
            print("DEBUG: repr(frame)[:200] =", rep[:200])

        # ---- Convert robustly to numpy array ----
        # –ï—Å–ª–∏ frame ‚Äî PIL Image, —Å–¥–µ–ª–∞—Ç—å np.array(frame)
        if hasattr(frame, "convert"):  # PIL Image-like
            frame_arr = np.array(frame)
        else:
            # try to coerce to np.array
            frame_arr = np.asarray(frame)

        # now ensure frame_arr is numeric np array
        if frame_arr.ndim == 2:
            # already grayscale HxW
            gray = frame_arr.astype(np.uint8)
            gray = np.expand_dims(gray, -1)  # H W 1
        elif frame_arr.ndim == 3:
            # Could be HxWx3 RGB, or HxWx4 RGBA; take first 3
            if frame_arr.shape[2] >= 3:
                rgb = frame_arr[..., :3].astype(np.float32)
                # convert to grayscale
                gray = (0.2989 * rgb[...,0] + 0.5870 * rgb[...,1] + 0.1140 * rgb[...,2]).astype(np.uint8)
                gray = np.expand_dims(gray, -1)
            else:
                # unexpected channels, collapse to single channel
                gray = frame_arr[..., 0:1].astype(np.uint8)
        else:
            # fallback: try flatten/pad or raise clearer error
            raise RuntimeError(f"Cannot interpret frame array with ndim={frame_arr.ndim}, repr first chars: {repr(frame)[:200]}")

        frames.append(gray)

        # step environment ‚Äî –∏—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å old API: env.step(action) -> (obs, reward, done, info)
        action = env.action_space.sample()
        step_res = env.step(action)
        # handle both old and new API
        if isinstance(step_res, tuple):
            if len(step_res) == 4:
                obs, reward, done, info = step_res
            elif len(step_res) == 5:  # new gymnasium: (obs, reward, terminated, truncated, info)
                obs, reward, terminated, truncated, info = step_res
                done = terminated or truncated
            else:
                # fallback
                obs = step_res[0]
                done = False
        else:
            # unexpected type
            done = False

        if done:
            # reset, keep going
            reset_res = env.reset()
            obs = reset_res[0] if isinstance(reset_res, tuple) else reset_res

    env.close()
    if isinstance(frame, list) and len(frame) == 1:
      frame_arr = np.asarray(frame[0])
    else:
      frame_arr = np.asarray(frame)

    np.save(FRAMES_NPY, frames)
    print(f"‚úÖ Saved {FRAMES_NPY}, shape {frames.shape}")
else:
    print("Found existing", FRAMES_NPY)
    frames = np.load(FRAMES_NPY)
    print("Loaded shape:", frames.shape)

# Optional: save first 10 frames as gif for quick visual check
if SAVE_GIF:
    preview = (frames[:10, ..., 0]).astype(np.uint8)  # (10, H, W)
    pil_frames = [Image.fromarray(f) for f in preview]
    pil_frames[0].save("preview_cartpole.gif", save_all=True, append_images=pil_frames[1:], duration=100, loop=0)
    print("Saved preview_cartpole.gif (first 10 frames).")


üé• Creating CartPole video dataset (robust mode)...
DEBUG: step 0 | type(frame) = <class 'list'>
DEBUG: -> np.array(frame).shape = (1, 400, 600, 3) dtype= uint8
DEBUG: repr(frame)[:200] = [array([[[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [


RuntimeError: Cannot interpret frame array with ndim=4, repr first chars: [array([[[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [

In [None]:
# >>> Cell 4: load dataset (npy) or Kaggle dataset snippet

# If you have a local .npy of frames: load it
if os.path.exists(FRAMES_NPY):
    frames = np.load(FRAMES_NPY)   # shape expected (T, H, W, C) where C=1 or 3
    print("Loaded frames shape:", frames.shape)
else:
    # Example: how to download Kaggle dataset via 'kaggle' library ‚Äî uncomment and adapt if needed
    # !pip install kaggle  # in Colab, set up API token beforehand
    # from kaggle.api.kaggle_api_extended import KaggleApi
    # api = KaggleApi(); api.authenticate()
    # api.dataset_download_files('tanvirnwu/cat-dog-image-and-video', path='dataset_kaggle', unzip=True)
    raise FileNotFoundError(f"{FRAMES_NPY} not found. Upload or create frames by running CartPole recorder.")

# Normalize to [0,1] float32; convert to shape (T, C, H, W)
frames = frames.astype(np.float32) / 255.0
# if frames shape is (T,H,W) or (T,H,W,1)
if frames.ndim == 3:
    frames = frames[..., None]
if frames.shape[-1] == 3:
    # convert to grayscale
    frames_gray = np.dot(frames[...,:3], [0.2989, 0.5870, 0.1140])
    frames = frames_gray[..., None]
frames = np.transpose(frames, (0, 3, 1, 2))  # -> (T, C, H, W)
T_total = frames.shape[0]
print("Processed frames:", frames.shape)


In [48]:
# >>> Cell 5: sequential dataset helper (generator)
class SequentialFrames:
    def __init__(self, frames_array):
        self.frames = frames_array  # numpy (T, C, H, W)
        self.T = frames_array.shape[0]

    def iterate_epochs(self, epoch_count=1):
        # yield sequences frame-by-frame (no shuffling)
        for e in range(epoch_count):
            for t in range(self.T):
                # yield single frame as torch tensor (1, C, H, W)
                img = torch.tensor(self.frames[t:t+1], dtype=torch.float32, device=DEVICE)
                yield t, img

seq_dataset = SequentialFrames(frames)


AttributeError: 'list' object has no attribute 'shape'

In [49]:
# >>> Cell 6: AE training loop (sequential, with buffer and scheduled sampling)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=LEARNING_RATE)
criterion = nn.MSELoss()

p_tf = 1.0  # teacher forcing probability
for epoch in range(1, AE_EPOCHS + 1):
    total_loss = 0.0
    # scheduled sampling adjustment after TF epochs
    use_teacher = epoch <= TEACHER_FORCING_EPOCHS
    if not use_teacher:
        p_tf = max(0.0, p_tf * SCHEDULED_SAMPLING_DECAY)

    # process sequence: maintain buffer of previous z_limited (or z if teacher forcing)
    buffer = deque(maxlen=N_PREV)
    # Warmup first N_PREV frames: encode/decode without context (or with zeros)
    for t in range(T_total):
        x_t = torch.tensor(frames[t:t+1], dtype=torch.float32, device=DEVICE)  # (1,C,H,W)
        z_t = encoder(x_t)  # (1, code_dim)

        # Warmup: if buffer not full, we will decode without prev or with zeros
        if len(buffer) < N_PREV:
            z_limited = z_t  # full latent during AE training warmup
            rec = decoder.decode(z_limited, prev_latents=None)
            loss = criterion(rec, x_t)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # store limited latent (we store full z for teacher-forcing phase)
            buffer.append(z_limited.detach())
            continue

        # now buffer is full (or at least has some items)
        # scheduled sampling: decide whether to use GT latents (teacher) or model's past latents:
        if use_teacher or (torch.rand(1).item() < p_tf):
            prev_stack = torch.cat(list(buffer), dim=0).unsqueeze(0)  # buffer: list of (1,code_dim) -> (n_prev, code_dim)
            # reshape to (1, n_prev, code_dim)
            prev_stack = prev_stack.view(1, N_PREV, CODE_DIM)
        else:
            # closed-loop: use previous latents that were produced (we stored them in buffer already)
            prev_stack = torch.cat(list(buffer), dim=0).unsqueeze(0).view(1, N_PREV, CODE_DIM)

        # Optionally: random K during AE training to make decoder robust to truncation
        # We'll randomly pick K among K_CHOICES some fraction of times
        if torch.rand(1).item() < 0.3:
            K = int(np.random.choice(K_CHOICES))
        else:
            K = CODE_DIM

        z_limited = apply_truncation(z_t, K)

        rec = decoder.decode(z_limited, prev_latents=prev_stack)
        loss = criterion(rec, x_t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # push limited latent used for context (detached)
        buffer.append(z_limited.detach())

    avg_loss = total_loss / T_total
    print(f"AE Epoch {epoch}/{AE_EPOCHS} | avg loss: {avg_loss:.6f} | use_teacher={use_teacher} p_tf={p_tf:.3f}")

# save models
torch.save(encoder.state_dict(), "encoder.pth")
torch.save(decoder.state_dict(), "decoder.pth")
print("Saved encoder/decoder.")


NameError: name 'T_total' is not defined

  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
# >>> Cell 7: quick validation function - greedy K chosen externally or max K
def reconstruct_sequence_greedy(encoder, decoder, frames_np, K_choice=128, use_prev_recon=True):
    encoder.eval(); decoder.eval()
    T = frames_np.shape[0]
    recon_frames = []
    buffer = deque(maxlen=N_PREV)
    with torch.no_grad():
        for t in range(T):
            x = torch.tensor(frames_np[t:t+1], dtype=torch.float32, device=DEVICE)
            z = encoder(x)  # (1, code_dim)
            z_l = apply_truncation(z, K_choice)
            if len(buffer) < N_PREV:
                prev = None
            else:
                prev = torch.cat(list(buffer), dim=0).unsqueeze(0).view(1, N_PREV, CODE_DIM)
            rec = decoder.decode(z_l, prev_latents=prev)  # (1,1,H,W)
            recon_frames.append(rec.cpu().numpy()[0,0])  # (H,W)
            buffer.append(z_l)  # use z_l as prev latent (closed-loop)
    return np.stack(recon_frames, axis=0)

# example: reconstruct with K=128
recon = reconstruct_sequence_greedy(encoder, decoder, frames, K_choice=128)
print("recon shape", recon.shape)
# save a short preview as mp4 (normalize to 0..255)
sample_vid = (recon[:200] * 255).astype(np.uint8)
imageio.mimsave("recon_preview_k128.mp4", sample_vid, fps=30)
print("Saved recon_preview_k128.mp4")


In [None]:
# >>> New Cell: visualize original vs reconstructed
import matplotlib.pyplot as plt

# –ø–æ–∫–∞–∂–µ–º –ø–µ—Ä–≤—ã–µ 5 –∫–∞–¥—Ä–æ–≤ –æ—Ä–∏–≥–∏–Ω–∞–ª—å–Ω—ã—Ö –∏ —Ä–µ–∫–æ–Ω—Å—Ç—Ä—É–∏—Ä–æ–≤–∞–Ω–Ω—ã—Ö
n_show = 5
fig, axes = plt.subplots(2, n_show, figsize=(12, 4))
for i in range(n_show):
    axes[0, i].imshow(frames[i, 0], cmap='gray')
    axes[0, i].set_title(f"Original {i}")
    axes[0, i].axis('off')
    axes[1, i].imshow(recon[i], cmap='gray')
    axes[1, i].set_title(f"Reconstructed {i}")
    axes[1, i].axis('off')
plt.tight_layout()
plt.show()


In [None]:
# >>> New Cell: make GIF comparison of original vs reconstructed
import imageio
import numpy as np

n_frames = 10  # —Å–∫–æ–ª—å–∫–æ –∫–∞–¥—Ä–æ–≤ –ø–æ–∫–∞–∑–∞—Ç—å
frames_orig = (frames[:n_frames, 0] * 255).astype(np.uint8)
frames_rec = (recon[:n_frames] * 255).astype(np.uint8)

# –°–æ–∑–¥–∞–¥–∏–º side-by-side –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è
combined = []
for i in range(n_frames):
    top = frames_orig[i]
    bottom = frames_rec[i]
    concat = np.concatenate([top, bottom], axis=1)  # —Å–ª–µ–≤–∞ –æ—Ä–∏–≥–∏–Ω–∞–ª, —Å–ø—Ä–∞–≤–∞ —Ä–µ–∫–æ–Ω—Å—Ç—Ä—É–∫—Ü–∏—è
    combined.append(concat)

# –°–æ—Ö—Ä–∞–Ω—è–µ–º GIF
imageio.mimsave("comparison.gif", combined, fps=2)
print("‚úÖ Saved comparison.gif (original vs reconstructed)")

from IPython.display import Image
Image(filename="comparison.gif")
