In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2LMHeadModel
import imageio, os
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm

run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# video settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
CONVERTED_FRAMERATE = 16

# model settings
WINDOW_SIZE = 46
ENCODED_DIM = 768

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
    def forward(self, x):
        return x + self.block(x)

class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=ENCODED_DIM, input_resolution=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT)):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 4, 2, 1),  # 64x64
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),           # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.Conv2d(64, 128, 4, 2, 1),          # 16x16
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),         # 8x8
            nn.ReLU()
        )

        # dynamically compute shape of whatever comes out of the conv block
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, *input_resolution)
            enc_out = self.encoder(dummy)
            self.flattened_size = enc_out.view(1, -1).shape[1]

        self.encoder_fc = nn.Linear(self.flattened_size, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, self.flattened_size)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, enc_out.shape[1:]),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = torch.tanh(self.encoder_fc(x))
        
        return x

    def decode(self, z):
        z = self.decoder_fc(z)
        z = torch.sigmoid(self.decoder(z))
        
        return z

In [4]:
class ImageProcessor:
    def __init__(self):
        # Input: PIL -> Tensor in [0, 1]
        self.pil_to_tensor_transform = transforms.ToTensor()

        # Output: Tensor in [0, 1] -> PIL
        self.tensor_to_pil_transform = transforms.Compose([
            transforms.Lambda(lambda x: x.clamp(0, 1)),  # Ensure [0, 1]
            transforms.Lambda(lambda x: (x * 255).byte()),
            transforms.Lambda(lambda x: x.permute(1, 2, 0).cpu().numpy()),
            transforms.Lambda(lambda x: Image.fromarray(x))
        ])

    def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
        """
        Convert a PIL Image to a PyTorch tensor in [0, 1], shape (C, H, W).
        """
        return self.pil_to_tensor_transform(image)

    def tensor_to_pil(self, image_tensor: torch.Tensor) -> Image.Image:
        """
        Convert a tensor (C, H, W) in [0, 1] to a PIL Image.
        """
        return self.tensor_to_pil_transform(image_tensor)

In [5]:
# Load the autoencoder
autoencoder = ConvAutoencoder()
autoenced_state_dict = torch.load("checkpoints/exp12/autoencoder.pth", map_location=run_device)
autoencoder.load_state_dict(autoenced_state_dict)
autoencoder = autoencoder.to(run_device).eval()

# load the transformer
transformer = GPT2LMHeadModel.from_pretrained("checkpoints/exp12/gpt2_transformer").transformer
transformer = transformer.to(run_device).eval()

# Load image processor
proc = ImageProcessor()

In [6]:
def generate_frames(num_frames, initial_frame=None, initial_frame_length=WINDOW_SIZE, context_length=WINDOW_SIZE, autoencoder=autoencoder, transformer=transformer):
    autoencoder.eval()
    transformer.eval()
    
    total_seq = torch.zeros(num_frames + context_length, CHANNELS, RESOLUTION_HEIGHT, RESOLUTION_WIDTH, device=run_device)

    if initial_frame is not None:
        initial_frame = initial_frame.to(run_device)
        initial_frame = initial_frame.repeat(initial_frame_length, 1, 1, 1)  # shape: (initial_frame_length, C, H, W)
        total_seq[:initial_frame_length] = initial_frame

    with torch.no_grad():
        for i in tqdm(range(num_frames)):
            current_slice = total_seq[i:i + context_length]
            
            slice_latents = autoencoder.encode(current_slice)
            
            slice_latents += (torch.rand(slice_latents.shape, device=run_device) * 2 - 1) * 0.0
            
            prediction_latents = transformer(inputs_embeds=slice_latents.unsqueeze(0)).last_hidden_state
            
            prediction_frame = autoencoder.decode(prediction_latents.squeeze(0))  # shape: (context_length, C, H, W)
            
            # normalize prediction_frame to [0, 1], anything to make it passable 😭🙏
            prediction_frame -= prediction_frame.min()
            prediction_frame /= prediction_frame.max()
            
            total_seq[context_length + i] = prediction_frame[-1]
    
    return total_seq

In [7]:
def save_video(frames, output='test.mp4', fps=CONVERTED_FRAMERATE):
    writer = imageio.get_writer(output, fps=fps)
    for frame in tqdm(frames):
        img = frame.permute(1,2,0).numpy()
        img = (img * 255).astype('uint8')
        writer.append_data(img)
    writer.close()
    print(f'Saved {output}')

In [8]:
init_img = proc.pil_to_tensor(Image.open("sample_images/20087.jpg").convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))).to(run_device)

frames = generate_frames(512, initial_frame=init_img).detach().cpu()
torch.cuda.empty_cache()
save_video(frames)

100%|██████████| 512/512 [00:13<00:00, 37.69it/s]
100%|██████████| 558/558 [00:00<00:00, 2789.99it/s]

Saved test.mp4



