In [1]:
import torch
import torch.nn as nn
from torchvision.models import vgg16
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model
from decord import VideoReader, cpu
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
import os, gc
from math import floor
from PIL import Image
import numpy as np



# merge import ahh
import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
import os
import random

In [2]:
run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
EPOCHS = 5
LR = 1e-3
WINDOW_SIZE = 16
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
BOTTLENECK_DIM = 768

In [3]:
loss = nn.MSELoss()

In [4]:
class PreprocessingFrameDataset(Dataset):
    def __init__(self, folder_path, window_size=WINDOW_SIZE,
                 resize=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT),
                 cache_dir='preprocessed_frames'):
        self.folder_path = folder_path
        self.window_size = window_size
        self.resize = resize
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

        self.resize_transform = T.Compose([
            T.ToPILImage(),
            T.Resize(resize),
            T.ToTensor()
        ])

        self.frame_files = []
        self.index = []
        self._prepare_frames()

    def _prepare_frames(self):
        video_files = [f for f in os.listdir(self.folder_path) if f.endswith('.mp4')]
        for i, fname in enumerate(video_files):
            base = os.path.splitext(fname)[0]
            cache_path = os.path.join(self.cache_dir, base + '.pt')
            if not os.path.exists(cache_path):
                print(f'Preprocessing {fname} -> {cache_path}')
                vr = VideoReader(os.path.join(self.folder_path, fname), ctx=cpu())
                frames = [self.resize_transform(frame.asnumpy()) for frame in vr]
                torch.save(torch.stack(frames), cache_path)
                del frames, vr
                gc.collect()
            self.frame_files.append(cache_path)
            frame_len = torch.load(cache_path, map_location='cpu').shape[0]
            n_clips = floor(frame_len / self.window_size)
            for j in range(n_clips):
                self.index.append((i, j * self.window_size))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_idx, start = self.index[idx]
        frames = torch.load(self.frame_files[file_idx], mmap=True, map_location='cpu')
        return frames[start:start + self.window_size]

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )
    def forward(self, x):
        return x + self.block(x)

class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=BOTTLENECK_DIM, input_resolution=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT)):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.ReLU(),
            ResidualBlock(64),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            ResidualBlock(512)
        )

        # Infer shape
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, *input_resolution)
            enc_out = self.encoder(dummy)
            self.flattened_size = enc_out.view(1, -1).shape[1]

        self.encoder_fc = nn.Linear(self.flattened_size, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, self.flattened_size)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, enc_out.shape[1:]),
            ResidualBlock(512),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64),
            nn.ConvTranspose2d(64, in_channels, 4, 2, 1),
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        return self.encoder_fc(x)
    
    def decode(self, z):
        z = self.decoder_fc(z)
        return self.decoder(z)


In [6]:
class Trainer:
    def __init__(self, autoenc, transformer, dataloader, epochs=EPOCHS, lr=LR, device=run_device, loss=loss):
        self.autoenc = autoenc
        self.transformer = transformer
        self.dataloader = dataloader
        self.epochs = epochs
        self.device = device
        params = list(autoenc.parameters()) + list(transformer.parameters())
        self.optimizer = torch.optim.Adam(params, lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=epochs)
        self.loss_fn = loss

    def train(self):
        self.autoenc.train()
        self.transformer.train()
        for epoch in range(self.epochs):
            total_loss = 0.0
            for frames in self.dataloader:
                frames = frames.to(self.device)
                B, T, C, H, W = frames.shape
                latents = self.autoenc.encode(frames.view(B * T, C, H, W)).view(B, T, -1)
                inp = latents[:, :-1, :]
                target_frames = frames[:, 1:, :, :, :]
                self.optimizer.zero_grad()
                pred_latents = self.transformer(inputs_embeds=inp).last_hidden_state
                pred_frames = self.autoenc.decode(pred_latents.reshape(-1, BOTTLENECK_DIM)).view(B, T-1, C, H, W)
                loss = self.loss_fn(pred_frames, target_frames)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
                
                del frames
                
            avg_loss = total_loss / len(self.dataloader)
            lr = self.optimizer.param_groups[0]['lr']
            print(f'Epoch {epoch+1}/{self.epochs} - Loss: {avg_loss:.4f} - LR: {lr:.6f}')
            self.scheduler.step()


In [7]:
autoenc = ConvAutoencoder().to(run_device)
transformer = GPT2Model.from_pretrained('decap_gpt2_cm2').to(run_device)

dataset = PreprocessingFrameDataset('video_dataset')
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

trainer = Trainer(autoenc, transformer, dataloader)


In [8]:
trainer.train()

Epoch 1/5 - Loss: 0.0582 - LR: 0.001000
Epoch 2/5 - Loss: 0.0406 - LR: 0.000905
Epoch 3/5 - Loss: 0.0318 - LR: 0.000655
Epoch 4/5 - Loss: 0.0266 - LR: 0.000345
Epoch 5/5 - Loss: 0.0235 - LR: 0.000095


In [1]:
torch.save(autoenc, "./checkpoints/run1/autoenc")

NameError: name 'torch' is not defined

In [2]:
WINDOW_SIZE

NameError: name 'WINDOW_SIZE' is not defined

In [10]:
BATCH_SIZE

2

In [11]:
8192 / 64

128.0