<a href="https://colab.research.google.com/github/florianaewing/CSB430SWIWinter2026/blob/main/videogeneratormodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
import os
print(os.getcwd())
print(os.listdir())

/content
['.config', 'sample_data', '.ipynb_checkpoints']


In [38]:
import cv2
import numpy as np

video_path = "sample_data/daytime_earth.mp4"
cap = cv2.VideoCapture(video_path)

frames = []
frame_count = 0
while True:
    ret, frame = cap.read()
    if not ret:
        break
    # Resize to 32x18 (16:9) for pixel-art
    frame = cv2.resize(frame, (32, 18))
    # Ensure RGB
    if frame.shape[2] == 3:
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    elif frame.shape[2] == 4:  # remove alpha if present
        frame = frame[:, :, :3]
    # Normalize to [0,1]
    frame = frame.astype(np.float32) / 255.0
    frames.append(frame)
    frame_count += 1

cap.release()
frames = np.array(frames)
print(f"Loaded {frame_count} frames, shape: {frames.shape}")


Loaded 2795 frames, shape: (2795, 18, 32, 3)


In [39]:
seq_len = 16
X, Y = [], []

num_sequences = len(frames) - seq_len
if num_sequences <= 0:
    raise ValueError(f"Not enough frames ({len(frames)}) for sequence length {seq_len}")

for i in range(num_sequences):
    seq_x = frames[i:i+seq_len]
    seq_y = frames[i+1:i+seq_len+1]
    # Ensure sequences have correct shape
    if seq_x.shape != (seq_len, 18, 32, 3) or seq_y.shape != (seq_len, 18, 32, 3):
        continue
    X.append(seq_x)
    Y.append(seq_y)

X = np.array(X)
Y = np.array(Y)
print(f"Created {X.shape[0]} sequences, X shape: {X.shape}, Y shape: {Y.shape}")

max_sequences = 1000
X = X[:max_sequences]
Y = Y[:max_sequences]
print(f"Using {X.shape[0]} sequences for training")

# Transpose to PyTorch format: (B, T, C, H, W)
X = np.transpose(X, (0, 1, 4, 2, 3))
Y = np.transpose(Y, (0, 1, 4, 2, 3))
print("After transpose, X shape:", X.shape, "Y shape:", Y.shape)


Created 2779 sequences, X shape: (2779, 16, 18, 32, 3), Y shape: (2779, 16, 18, 32, 3)
Using 1000 sequences for training
After transpose, X shape: (1000, 16, 3, 18, 32) Y shape: (1000, 16, 3, 18, 32)


In [40]:
import torch
import torch.nn as nn

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(input_channels + hidden_channels,
                              4 * hidden_channels,
                              kernel_size,
                              padding=padding)
        self.hidden_channels = hidden_channels

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        conv_out = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_out, 4, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

class ConvLSTM(nn.Module):
    def __init__(self, input_channels=3, hidden_channels=64, kernel_size=3):
        super().__init__()
        self.cell = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
        self.decoder = nn.Conv2d(hidden_channels, input_channels, kernel_size=3, padding=1)

    def forward(self, input_seq):
        B, T, C, H, W = input_seq.shape
        h = torch.zeros(B, 64, H, W).to(input_seq.device)
        c = torch.zeros(B, 64, H, W).to(input_seq.device)
        outputs = []
        for t in range(T):
            h, c = self.cell(input_seq[:, t], h, c)
            out = torch.sigmoid(self.decoder(h))
            outputs.append(out)
        return torch.stack(outputs, dim=1)


In [41]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ConvLSTM().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
Y_tensor = torch.tensor(Y, dtype=torch.float32).to(device)

batch_size = 4
num_epochs = 50

for epoch in range(num_epochs):
    for i in range(0, len(X_tensor), batch_size):
        x_batch = X_tensor[i:i+batch_size]
        y_batch = Y_tensor[i:i+batch_size]
        optimizer.zero_grad()
        pred = model(x_batch)
        loss = criterion(pred, y_batch)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")


Epoch 1/50, Loss: 0.006846
Epoch 2/50, Loss: 0.005299
Epoch 3/50, Loss: 0.005066
Epoch 4/50, Loss: 0.004828
Epoch 5/50, Loss: 0.004708
Epoch 6/50, Loss: 0.004612
Epoch 7/50, Loss: 0.004570
Epoch 8/50, Loss: 0.004542
Epoch 9/50, Loss: 0.004506
Epoch 10/50, Loss: 0.004454
Epoch 11/50, Loss: 0.004415
Epoch 12/50, Loss: 0.004384
Epoch 13/50, Loss: 0.004355
Epoch 14/50, Loss: 0.004324
Epoch 15/50, Loss: 0.004298
Epoch 16/50, Loss: 0.004257
Epoch 17/50, Loss: 0.004218
Epoch 18/50, Loss: 0.004171
Epoch 19/50, Loss: 0.004130
Epoch 20/50, Loss: 0.004085
Epoch 21/50, Loss: 0.004038
Epoch 22/50, Loss: 0.004002
Epoch 23/50, Loss: 0.003970
Epoch 24/50, Loss: 0.003930
Epoch 25/50, Loss: 0.003892
Epoch 26/50, Loss: 0.003886
Epoch 27/50, Loss: 0.003877
Epoch 28/50, Loss: 0.003855
Epoch 29/50, Loss: 0.003838
Epoch 30/50, Loss: 0.003813
Epoch 31/50, Loss: 0.003766
Epoch 32/50, Loss: 0.003739
Epoch 33/50, Loss: 0.003731
Epoch 34/50, Loss: 0.003737
Epoch 35/50, Loss: 0.003709
Epoch 36/50, Loss: 0.003719
E

# Real Time Display

In [None]:
model.eval()
seed = X_tensor[0:1]  # first sequence
h = torch.zeros(1, 64, 18, 32).to(device)
c = torch.zeros(1, 64, 18, 32).to(device)

frame = seed[:, 0]

scale = 20  # scale tiny frames for display

import cv2
from google.colab.patches import cv2_imshow # Import cv2_imshow

with torch.no_grad():
    while True:  # keep generating until window closed
        h, c = model.cell(frame, h, c)
        frame = torch.sigmoid(model.decoder(h))
        # Pixel-art quantization
        frame_vis = (frame * 15).round() / 15
        # Convert to HWC for OpenCV
        img = frame_vis.cpu().numpy()[0].transpose(1, 2, 0)
        img_up = cv2.resize(img, (32*scale, 18*scale), interpolation=cv2.INTER_NEAREST)
        img_up = (img_up * 255).astype('uint8')[:, :, ::-1]  # RGB -> BGR

        cv2_imshow(img_up) # Use cv2_imshow instead of cv2.imshow
        if cv2.waitKey(30) & 0xFF == ord('q'):
            break  # press 'q' to quit

cv2.destroyAllWindows()