# Setup

In [None]:
!pip install --upgrade ale-py

In [None]:
!pip install minari
!pip install --upgrade gymnasium 


In [None]:
import gymnasium as gym
import ale_py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from collections import deque

import minari
import random

from IPython.display import Video

import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
!mkdir videos

In [None]:
env = gym.make('BreakoutNoFrameskip-v4', obs_type='rgb', render_mode='rgb_array')
env = gym.wrappers.AtariPreprocessing(env, grayscale_obs=True, frame_skip=4, terminal_on_life_loss=True, scale_obs=True)
env = gym.wrappers.RecordVideo(env, video_folder="videos/", name_prefix="breakout_test")
env = gym.wrappers.FrameStackObservation(env, 4)  # stack 4 consecutive frames

In [None]:
obs, info = env.reset()
frames = [obs]

total_reward = 0
for step in range(10):
    action = env.action_space.sample()  # random action
    obs, reward, done, truncated, info = env.step(action)

    fig, axes = plt.subplots(1, 4, figsize=(12, 4))

    for i in range(4):
        axes[i].imshow(obs[i, :, :], cmap='gray')
        axes[i].set_title(f'Step {step+1}, Frame {i+1}')
        axes[i].axis('off')
    plt.show()

    frames.append(obs)
    total_reward += reward
    if done or truncated:
        break

env.close()


#SL

In [None]:
class MinariDataset(Dataset):
    def __init__(self, dataset, stack_size=4):
        # obs
        obs = np.concatenate([ep.observations for ep in dataset], axis=0)
        obs = np.array([cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) for frame in obs], dtype=np.float32)
        obs = np.array([cv2.resize(frame, (84, 84)) for frame in obs], dtype=np.float32)
        obs = obs / 255.0
        obs = (obs > 0.1).astype(np.float32)
        obs = (obs - 0.5) * 2

        obs = obs[:,8:,:]
        obs = obs[:,:,4:80]

        obs[:,:8,:] = 0.5

        self.obs = obs

        # diffs
        diffs = (obs[1:] - obs[:-1])/2
        self.diffs = diffs

        # ball position

        h,w = obs.shape[1:]
        ds = diffs.copy()
        ds[:, 67:70, :] = 0.0
        ball_pos = np.zeros((obs.shape[0], 2), dtype=np.float32)

        for i in range(obs.shape[0]):
            cond = (ds[max(0,i-1),:,:] == 1) | (ds[min(obs.shape[0]-2,i),:,:] == -1)
            idx = np.argwhere(cond)
            if len(idx) > 0:
                ball_pos[i, 0] = idx[:, 0].mean() / h
                ball_pos[i, 1] = idx[:, 1].mean() / w

        self.ball_pos = ball_pos

        self.stack_size = stack_size
        self.max_idx = len(self.obs) - stack_size

    def __len__(self):
        return self.max_idx

    def __getitem__(self, idx):

          # --- Target ---
          stack = self.obs[idx : idx + self.stack_size].copy()  # (stack_size, H, W)
          target = torch.tensor(stack, dtype=torch.float32)

          # --- Input ---
          drop_idx = random.randint(0, self.stack_size - 1)
          masked_stack = stack.copy()
          masked_stack[drop_idx] = 0.0
          x = torch.tensor(masked_stack, dtype=torch.float32)

          # --- Diff ---
          # differences between consecutive frames: (stack_size-1, H, W)
          diff_stack = self.diffs[idx : idx + self.stack_size - 1].copy()
          diff_stack = torch.tensor(diff_stack, dtype=torch.float32)

          # --- ball position ---
          # ball position
          ball_pos = self.ball_pos[idx : idx + self.stack_size].copy()
          ball_pos = torch.tensor(ball_pos, dtype=torch.float32)


          return x, target, diff_stack, ball_pos

In [None]:
minari_dataset = minari.load_dataset('atari/breakout/expert-v0', download=True)
minari_dataset = MinariDataset(dataset = minari_dataset)
minari_dataloader = DataLoader(minari_dataset, batch_size=512, shuffle=True)

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, input_channels=4, latent_dim=128):
        super().__init__()
        # --- Encoder ---
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=8, stride=4)  # -> (32, 18, 18)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)             # -> (64, 8, 8)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)             # -> (64, 6, 6)
        self.fc1 = nn.Linear(6*6*64, latent_dim)

        # --- Decoder ---
        self.fc2 = nn.Linear(latent_dim, 6*6*64)
        self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1)  # -> (64, 8, 8)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2)  # -> (32, 18, 18)
        self.deconv3 = nn.ConvTranspose2d(32, input_channels, kernel_size=8, stride=4)  # -> (4, 84, 84)

        # --- Difference Decoder ---
        # Predicts (stack_size-1) difference frames per input stack
        self.diff_fc2 = nn.Linear(latent_dim, 64*6*6)
        self.diff_deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1)
        self.diff_deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2)
        self.diff_deconv3 = nn.ConvTranspose2d(32, input_channels-1, kernel_size=8, stride=4)

        # --- Ball Position Decoder (2×4) ---
        self.ball_fc = nn.Linear(latent_dim, 8)  # → 8 numbers: 2 rows × 4 columns


    def forward(self, x):

        # --- Encoder ---
        x_enc = F.leaky_relu(self.conv1(x))
        x_enc = F.leaky_relu(self.conv2(x_enc))
        x_enc = F.leaky_relu(self.conv3(x_enc))
        x_enc_flat = x_enc.view(x_enc.size(0), -1)
        latent = F.leaky_relu(self.fc1(x_enc_flat))

        # --- Full-frame Decoder ---
        x_dec = F.leaky_relu(self.fc2(latent))
        x_dec = x_dec.view(x_dec.size(0), 64, 6, 6)
        x_dec = F.leaky_relu(self.deconv1(x_dec))
        x_dec = F.leaky_relu(self.deconv2(x_dec))
        recon_full = torch.tanh(self.deconv3(x_dec))  # (B, 4, H, W)

        # --- Difference Decoder ---
        d_dec = F.leaky_relu(self.diff_fc2(latent))
        d_dec = d_dec.view(d_dec.size(0), 64, 6, 6)
        d_dec = F.leaky_relu(self.diff_deconv1(d_dec))
        d_dec = F.leaky_relu(self.diff_deconv2(d_dec))
        recon_diff = torch.tanh(self.diff_deconv3(d_dec))  # (B, 3, H, W)  (stack_size-1)

        # --- Ball Position Head ---
        ball_raw = torch.sigmoid(self.ball_fc(latent))        # (B, 8)
        ball_pos = ball_raw.view(-1, 4, 2)           # (B, 2, 4)


        return recon_full, recon_diff, ball_pos, latent

In [None]:
autoencoder = Autoencoder()
autoencoder = autoencoder.to(device)

In [None]:
state_dict = torch.load("/kaggle/input/asdfghj/autoencoder_6.pth", map_location=device)
autoencoder.load_state_dict(state_dict, strict=False,)

In [None]:
optimizer = optim.Adam(autoencoder.parameters(), lr=1e-4)

In [None]:
num_epochs = 1001

for epoch in tqdm(range(num_epochs)):

    autoencoder.train()
    total_loss = 0.0

    for input, target, diff_target, pos_target  in tqdm(minari_dataloader,leave=False):

        input, target, diff_target, pos_target = input.to(device), target.to(device), diff_target.to(device), pos_target.to(device)

        # --- Forward pass ---
        recon, recon_diff, recon_pos, _ = autoencoder(input)

        # --- Loss computation ---
        # Full-stack reconstruction
        loss = F.mse_loss(recon, target)

        # Difference prediction with weighted loss

        weight_diff = (diff_target != 0).float()

        ball_mask = torch.zeros_like(weight_diff)
        ball_mask[:, :, 29:67, :] = 1.0
        ball_mask[:, :, 9:15, :] = 1.0

        weight_diff = weight_diff + 10 * weight_diff * ball_mask
        weight_diff = weight_diff + 0.01

        weight_diff = weight_diff / weight_diff.sum()


        loss_diff = ((recon_diff - diff_target)**2 * weight_diff).sum()

        # Ball position prediction
        loss_pos = F.mse_loss(recon_pos, pos_target)

        # Total loss
        loss = loss + loss_diff + loss_pos

        # --- Backprop ---
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * input.size(0)

    avg_loss = total_loss / len(minari_dataloader.dataset)
    print(f"Epoch {epoch:03d} | Loss: {avg_loss:.6f}")

    # --- Visualization / checkpointing ---
    if epoch % 50 == 0 or (epoch in [5, 10]):

        torch.save(autoencoder.state_dict(), f"autoencoder_{epoch:03d}.pth")

        shards = {"Input": input,
          "Target": target,
          "Recon": recon,
          "Diff_Target": diff_target,
          "Weight-Diff": weight_diff,
          "Recon_Diff": recon_diff
        }

        for key, value in shards.items():

          shard = value[0].cpu().detach().numpy()

          fig, axes = plt.subplots(1, shard.shape[0], figsize=(9, 3))
          for i in range(shard.shape[0]):
              axes[i].imshow(shard[i], cmap='gray')
              axes[i].set_title(f'{key} - Frame {i+1}')
              axes[i].axis('off')
          plt.show()

        print(f"Pos Target: {pos_target[0]}")
        print(f"Recon Pos: {recon_pos[0]}")



In [None]:
autoencoder.eval()
num_samples = 5  # number of samples to visualize
sample_indices = random.sample(range(len(minari_dataset)), num_samples)

for idx in sample_indices:
    input, target, diff_target, pos_target = minari_dataset[idx]

    # Add batch dimension and move to device
    input = input.unsqueeze(0).to(device)

    # Forward pass
    with torch.no_grad():
        recon, recon_diff,recon_pos, _ = autoencoder(input)

    # Move tensors back to CPU for plotting
    input_np = input[0].cpu().numpy()
    target_np = target.cpu().numpy()
    recon_np = recon[0].cpu().numpy()
    diff_target_np = diff_target.cpu().numpy()
    recon_diff_np = recon_diff[0].cpu().numpy()

    # --- Plot ---
    shards = {
        "Input": input_np,
        "Target": target_np,
        "Recon": recon_np,
        "Diff_Target": diff_target_np,
        "Recon_Diff": recon_diff_np
    }

    for key, value in shards.items():
        fig, axes = plt.subplots(1, value.shape[0], figsize=(9, 3))
        for i in range(value.shape[0]):
            axes[i].imshow(value[i], cmap='gray')
            axes[i].set_title(f'{key} - Frame {i+1}')
            axes[i].axis('off')
        plt.suptitle(f'Sample {idx} - {key}')
        plt.show()


    print(f"Pos Target: {pos_target}")
    print(f"Pos Recon: {recon_pos}")