In [1]:
import torch
import torch.nn as nn
from torchvision.models import vgg16
from transformers import GPT2LMHeadModel, GPT2Model
from decord import VideoReader, cpu
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.transforms.functional as TF
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from math import floor
import os
import random
import gc
import math

In [2]:
# pytorch settings
run_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# gpt2 settings
decap_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2").transformer.to(run_device)
#decap_gpt2 = GPT2Model.from_pretrained("decap_gpt2_ae3").to(run_device)

# training settings
BATCH_SIZE = 24
EPOCHS = 512
LR = 0.001
WINDOW_SIZE = 64

# video settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3

# model settings
BOTTLENECK_DIM = 768

In [3]:
class ImageProcessor:
    def tensor_to_pil(self, image_tensor: torch.Tensor) -> Image.Image:
        """
        Convert a tensor to a PIL Image.
        
        Args:
            image_tensor (torch.Tensor): A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        
        Returns:
            Image.Image: A PIL Image object.
        """
        # Clamp to [0, 1], convert to [0, 255] and uint8
        image_np = (image_tensor.clamp(0, 1).mul(255).byte().cpu().permute(1, 2, 0).numpy())
        return Image.fromarray(image_np)
    
    def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
        """
        Convert a PIL image to a PyTorch tensor of shape (C, H, W) with values in [0, 1].
        
        Args:
            image (Image.Image): A PIL Image object.
        
        Returns:
            torch.Tensor: A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        """
        return transforms.ToTensor()(image)  # Already returns (C, H, W)

In [4]:
class ImageCompressor:
    def __init__(self, resolution=(RESOLUTION_HEIGHT, RESOLUTION_WIDTH), channels=CHANNELS, latent_dim=BOTTLENECK_DIM):
        self.H, self.W = resolution
        self.C = channels
        self.latent_dim = latent_dim
        self.per_channel_dim = latent_dim // self.C

        # Compute shared downscaled size for all channels
        scale_factor = math.sqrt(self.per_channel_dim / (self.H * self.W))
        self.h_down = max(1, int(self.H * scale_factor))
        self.w_down = max(1, int(self.W * scale_factor))

        # Store meta info for decoding
        self.meta = {
            'H': self.H,
            'W': self.W,
            'h_down': self.h_down,
            'w_down': self.w_down
        }

    def encode(self, image):
        # image: (3, H, W) or (B, 3, H, W)
        is_batched = image.dim() == 4
        if not is_batched:
            image = image.unsqueeze(0)  # Make it (1, 3, H, W)

        B = image.shape[0]
        latent_parts = []

        for c in range(self.C):
            ch = image[:, c:c+1, :, :]  # (B, 1, H, W)
            down = F.interpolate(ch, size=(self.h_down, self.w_down), mode='bilinear', align_corners=False)
            flat = down.view(B, -1)

            if flat.shape[1] < self.per_channel_dim:
                pad = torch.zeros((B, self.per_channel_dim - flat.shape[1]), device=flat.device, dtype=flat.dtype)
                flat = torch.cat([flat, pad], dim=1)
            else:
                flat = flat[:, :self.per_channel_dim]

            latent_parts.append(flat)

        latent = torch.cat(latent_parts, dim=1)  # (B, latent_dim)

        if not is_batched:
            latent = latent.squeeze(0)

        return latent

    def decode(self, latent):
        # latent: (latent_dim,) or (B, latent_dim)
        is_batched = latent.dim() == 2
        if not is_batched:
            latent = latent.unsqueeze(0)

        B = latent.shape[0]
        per_channel_dim = self.per_channel_dim
        h_down, w_down = self.h_down, self.w_down
        H, W = self.H, self.W

        channels = []
        for i in range(self.C):
            start = i * per_channel_dim
            end = start + per_channel_dim
            flat = latent[:, start:end]
            ch = flat[:, :h_down * w_down].reshape(B, 1, h_down, w_down)
            up = F.interpolate(ch, size=(H, W), mode='bilinear', align_corners=False)
            channels.append(up)

        recon = torch.cat(channels, dim=1)  # (B, C, H, W)

        if not is_batched:
            recon = recon.squeeze(0)

        return recon


In [5]:
class PreprocessingLatentDataset(Dataset):
    def __init__(self, folder_path, compressor, window_size=WINDOW_SIZE, resize=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT),
                 cache_dir="preprocessed_latents", run_device="cuda"):
        self.folder_path = folder_path
        self.compressor = compressor
        self.window_size = window_size
        self.resize = resize
        self.cache_dir = cache_dir
        self.run_device = run_device

        os.makedirs(self.cache_dir, exist_ok=True)

        self.resize_transform = T.Compose([
            T.ToPILImage(),
            T.Resize(resize, interpolation=Image.BILINEAR),
            T.ToTensor()
        ])

        # Step 1: Build index and preprocess as needed
        self.latent_files = []
        self.index = []  # (file_idx, start_frame)
        self._prepare_latents()

    def _prepare_latents(self):
        video_files = [f for f in os.listdir(self.folder_path) if f.endswith(".mp4")]

        for i, fname in enumerate(video_files):
            base = os.path.splitext(fname)[0]
            cache_path = os.path.join(self.cache_dir, base + ".pt")

            # Preprocess if missing
            if not os.path.exists(cache_path):
                print(f"Preprocessing {fname} -> {cache_path}")
                video_path = os.path.join(self.folder_path, fname)
                vr = VideoReader(video_path, ctx=cpu())

                latents = []
                for frame in vr:
                    latent = self.compressor.encode(self.resize_transform(frame.asnumpy()))
                    latents.append(latent)
                    del latent
                    torch.cuda.empty_cache()

                torch.save(torch.stack(latents), cache_path)
                del latents, vr
                gc.collect()
                torch.cuda.empty_cache()

            self.latent_files.append(cache_path)

            # Build (file_idx, start_frame) index
            latent_len = torch.load(cache_path, map_location='cpu').shape[0]
            n_clips = floor(latent_len / self.window_size)
            for j in range(n_clips):
                self.index.append((i, j * self.window_size))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_idx, start = self.index[idx]
        latent_seq = torch.load(self.latent_files[file_idx], map_location='cpu')
        return latent_seq[start: start + self.window_size]

In [6]:
class Trainer:
    def __init__(self, model, compressor, dataloader, epochs=EPOCHS, lr=LR, device=run_device, log_dir="runs/transformer/gpt2_cm2"):
        self.model = model
        self.compressor = compressor
        self.dataloader = dataloader
        self.epochs = epochs
        self.device = device

        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=epochs)
        self.loss_fn = nn.MSELoss()

        self.writer = SummaryWriter(log_dir)

    def train(self):
        self.model.train()

        for epoch in range(self.epochs):
            epoch_loss = 0.0

            for batch_idx, batch in enumerate(self.dataloader):
                batch = batch.to(self.device)  # [B, T, D]
                input_seq = batch[:, :-1, :]   # input: all except last time step
                target_seq = batch[:, 1:, :]   # target: all except first time step

                self.optimizer.zero_grad()
                output = self.model(inputs_embeds=input_seq).last_hidden_state  # [B, T-1, D]
                loss = self.loss_fn(output, target_seq)

                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            avg_loss = epoch_loss / len(self.dataloader)
            current_lr = self.optimizer.param_groups[0]["lr"]

            print(f"Epoch {epoch + 1}/{self.epochs} - Loss: {avg_loss:.4f} - LR: {current_lr:.6f}")
            self.writer.add_scalar("Loss/train", avg_loss, epoch)
            self.writer.add_scalar("LR", current_lr, epoch)

            self.scheduler.step()
            self.log_images(epoch, batch)

    def log_images(self, epoch, last_batch_latents):
        self.model.eval()

        with torch.no_grad():
            input_seq = last_batch_latents[:, :-1, :].to(self.device)
            gpt2_outputs = self.model(inputs_embeds=input_seq).last_hidden_state  # [B, T-1, D]

            decoded_images = self.compressor.decode(gpt2_outputs[0, :8, :])
            decoded_images = decoded_images.clamp(0, 1)

            self.writer.add_images("GPT2_Reconstructions", decoded_images, epoch)

        self.model.train()

In [7]:
compressor = ImageCompressor()
dataset = PreprocessingLatentDataset("video_dataset", compressor=compressor)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
trainer = Trainer(decap_gpt2, compressor, dataloader, epochs=EPOCHS, lr=LR, device=run_device)

In [8]:
trainer.train()

Epoch 1/512 - Loss: 1.0024 - LR: 0.001000
Epoch 2/512 - Loss: 0.2519 - LR: 0.001000
Epoch 3/512 - Loss: 0.2103 - LR: 0.001000
Epoch 4/512 - Loss: 0.1802 - LR: 0.001000
Epoch 5/512 - Loss: 0.1585 - LR: 0.001000
Epoch 6/512 - Loss: 0.1383 - LR: 0.001000
Epoch 7/512 - Loss: 0.1238 - LR: 0.001000
Epoch 8/512 - Loss: 0.1126 - LR: 0.001000
Epoch 9/512 - Loss: 0.1034 - LR: 0.000999
Epoch 10/512 - Loss: 0.0967 - LR: 0.000999
Epoch 11/512 - Loss: 0.0914 - LR: 0.000999
Epoch 12/512 - Loss: 0.0875 - LR: 0.000999
Epoch 13/512 - Loss: 0.0841 - LR: 0.000999
Epoch 14/512 - Loss: 0.0818 - LR: 0.000998
Epoch 15/512 - Loss: 0.0797 - LR: 0.000998
Epoch 16/512 - Loss: 0.0783 - LR: 0.000998
Epoch 17/512 - Loss: 0.0777 - LR: 0.000998
Epoch 18/512 - Loss: 0.0753 - LR: 0.000997
Epoch 19/512 - Loss: 0.0750 - LR: 0.000997
Epoch 20/512 - Loss: 0.0736 - LR: 0.000997
Epoch 21/512 - Loss: 0.0723 - LR: 0.000996
Epoch 22/512 - Loss: 0.0716 - LR: 0.000996
Epoch 23/512 - Loss: 0.0707 - LR: 0.000995
Epoch 24/512 - Loss:

In [9]:
proc.tensor_to_pil(compressor.latent_to_image(dataset[100][0]))

NameError: name 'proc' is not defined

In [None]:
proc = ImageProcessor()

In [None]:
x = 0

In [None]:
x += 1
proc.tensor_to_pil(autoenc.decode(dataset[0][x].to(run_device).unsqueeze(0)).cpu().squeeze(0))

NameError: name 'autoenc' is not defined

In [None]:
import torch
from torchvision.utils import save_image
import os
from tqdm import tqdm
import subprocess
import imageio

def generate_latent_sequence(model, autoencoder, sequence_length=1000, latent_dim=768, device="cuda"):
    model.eval()
    autoencoder.eval()

    with torch.no_grad():
        generated = []

        current = torch.zeros(1, 1, latent_dim).to(device)  # [B=1, T=1, D]

        for _ in tqdm(range(sequence_length), desc="Generating"):
            out = model(inputs_embeds=current).last_hidden_state  # [1, T, D]
            next_latent = out[:, -1:, :]  # [1, 1, D]
            current = torch.cat([current, next_latent], dim=1)

        return current.squeeze(0)  # [T, D]

def decode_latents_to_images(latents, autoencoder, output_folder="generated_video"):
    os.makedirs(output_folder, exist_ok=True)
    with torch.no_grad():
        for i, latent in enumerate(tqdm(latents, desc="Decoding")):
            img = autoencoder.decode(latent.unsqueeze(0)).clamp(0, 1)
            save_image(img, os.path.join(output_folder, f"frame_{i:04d}.png"))

def make_video_imageio(frame_folder="generated_video", output_file="output.mp4", fps=60):
    frames = sorted([
        os.path.join(frame_folder, f)
        for f in os.listdir(frame_folder)
        if f.endswith(".png")
    ])

    writer = imageio.get_writer(output_file, fps=fps)
    for frame_path in frames:
        image = imageio.imread(frame_path)
        writer.append_data(image)
    writer.close()
    print(f"✅ Video saved using imageio: {output_file}")

# === CONFIG ===
LATENT_DIM = 768
FRAMES = 1000  # 1 second @ 60 FPS
DEVICE = "cuda"

# === RUN ===
latents = generate_latent_sequence(model=decap_gpt2, autoencoder=autoenc, sequence_length=FRAMES, latent_dim=LATENT_DIM, device=DEVICE)
decode_latents_to_images(latents, autoencoder=autoenc, output_folder="generated_video")
make_video_imageio(frame_folder="generated_video", output_file="generated_output.mp4", fps=60)

In [None]:
#decap_gpt2.save_pretrained("decap_gpt2_ae3")