In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2LMHeadModel
import imageio, os
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm

run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# model settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
CONVERTED_FRAMERATE = 16

WINDOW_SIZE = 24
INPUT_WINDOW_SIZE = WINDOW_SIZE - 1
ENCODED_DIM = 1200
NUM_TRANSFORMER_BLOCKS = 6
MLP_HIDDEN_DIM = 2000
NUM_HEADS = 12

DROPOUT = 0.0

In [3]:
# --- TRANSFORMER ---

In [4]:
class CausalSelfAttention(nn.Module):
    def __init__(self, num_heads, dropout=0.1):
        super().__init__()
        assert ENCODED_DIM % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = ENCODED_DIM // num_heads

        self.qkv_proj = nn.Linear(ENCODED_DIM, 3 * ENCODED_DIM)
        self.out_proj = nn.Linear(ENCODED_DIM, ENCODED_DIM)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        B, T, C = x.shape
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, T, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if attn_mask is not None:
            attn_scores = attn_scores.masked_fill(attn_mask == 0, float('-inf'))

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        out = attn_weights @ v
        out = out.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(out)

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads, mlp_hidden_dim, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(ENCODED_DIM)
        self.attn = CausalSelfAttention(num_heads, dropout)
        self.ln2 = nn.LayerNorm(ENCODED_DIM)
        self.mlp = nn.Sequential(
            nn.Linear(ENCODED_DIM, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, ENCODED_DIM),
            nn.Dropout(dropout)
        )

    def forward(self, x, attn_mask=None):
        x = x + self.attn(self.ln1(x), attn_mask)
        x = x + self.mlp(self.ln2(x))
        return x

In [6]:
class Transformer(nn.Module):
    def __init__(self, seq_len=INPUT_WINDOW_SIZE, num_heads=NUM_HEADS, mlp_hidden_dim=MLP_HIDDEN_DIM, dropout=DROPOUT):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(num_heads, mlp_hidden_dim, dropout)
            for _ in range(NUM_TRANSFORMER_BLOCKS)
        ])
        self.ln_f = nn.LayerNorm(ENCODED_DIM)
        self.register_buffer("causal_mask", torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0))

    def forward(self, x):
        B, T, _ = x.shape
        attn_mask = self.causal_mask[:, :, :T, :T]

        for block in self.blocks:
            x = block(x, attn_mask)

        return self.ln_f(x)

In [7]:
# --- AUTOENCODER ---

In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
    def forward(self, x):
        return x + self.block(x)

In [9]:
class Autoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=ENCODED_DIM, input_resolution=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT)):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 4, 2, 1),  # 64x64
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),           # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.Conv2d(64, 128, 4, 2, 1),          # 16x16
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),         # 8x8
            nn.ReLU()
        )

        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, *input_resolution)
            enc_out = self.encoder(dummy)
            self.flattened_size = enc_out.view(1, -1).shape[1]

        self.encoder_fc = nn.Linear(self.flattened_size, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, self.flattened_size)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, enc_out.shape[1:]),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        return self.encoder_fc(x)

    def decode(self, z):
        z = self.decoder_fc(z)
        return self.decoder(z)

In [10]:
# --- MISC ---

In [11]:
class ImageProcessor:
    def tensor_to_pil(self, image_tensor: torch.Tensor) -> Image.Image:
        """
        Convert a tensor to a PIL Image.
        
        Args:
            image_tensor (torch.Tensor): A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        
        Returns:
            Image.Image: A PIL Image object.
        """
        # Clamp to [0, 1], convert to [0, 255] and uint8
        image_np = (image_tensor.clamp(0, 1).mul(255).byte().cpu().permute(1, 2, 0).numpy())
        return Image.fromarray(image_np)
    
    def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
        """
        Convert a PIL image to a PyTorch tensor of shape (C, H, W) with values in [0, 1].
        
        Args:
            image (Image.Image): A PIL Image object.
        
        Returns:
            torch.Tensor: A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        """
        return transforms.ToTensor()(image)  # Already returns (C, H, W)

In [12]:
# Load the autoencoder
autoencoder = Autoencoder()
autoenced_state_dict = torch.load("checkpoints/run4/autoencoder.pth", map_location=run_device)
autoencoder.load_state_dict(autoenced_state_dict)
autoencoder = autoencoder.to(run_device).eval()

# load the transformer
transformer = Transformer()
transformer_state_dict = torch.load("checkpoints/run4/transformer.pth", map_location=run_device)
transformer.load_state_dict(transformer_state_dict)
transformer = transformer.to(run_device).eval()

# Load image processor
proc = ImageProcessor()

In [None]:
1: custom transformer doesnt seem to improve results by much
2: noticed that numbers coming from decoder were astronomical, look into activ funcs
3: vgg might not be so good after all, consider MSE only training

In [13]:
def generate_frames(num_frames, context_length=INPUT_WINDOW_SIZE, autoencoder=autoencoder, transformer=transformer):
    autoencoder.eval()
    transformer.eval()
    
    total_seq = torch.zeros(num_frames + context_length, CHANNELS, RESOLUTION_HEIGHT, RESOLUTION_WIDTH, device=run_device)

    with torch.no_grad():
        for i in tqdm(range(num_frames)):
            current_slice = total_seq[i:i + context_length]
            
            slice_latents = autoencoder.encode(current_slice)
            
            slice_latents += torch.rand(slice_latents.shape, device=run_device) * 99999
            
            prediction_latents = transformer(slice_latents.unsqueeze(0))
            
            prediction_frame = autoencoder.decode(prediction_latents.squeeze(0))  # shape: (context_length, C, H, W)
            
            total_seq[context_length + i] = prediction_frame[-1]
    
    return total_seq

In [14]:
def save_video(frames, output='output.mp4', fps=60):
    writer = imageio.get_writer(output, fps=fps)
    for frame in tqdm(frames):
        img = frame.permute(1,2,0).numpy()
        img = ((img + 1)/2 * 255).astype('uint8')
        writer.append_data(img)
    writer.close()
    print(f'Saved {output}')

In [15]:
init_img = proc.pil_to_tensor(Image.open("test.png").convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))).to(run_device)

frames = generate_frames(512).detach().cpu()
torch.cuda.empty_cache()
save_video(frames)

100%|██████████| 512/512 [00:06<00:00, 78.59it/s]
100%|██████████| 535/535 [00:00<00:00, 2609.30it/s]

Saved output.mp4





In [16]:
frames.shape

torch.Size([535, 3, 128, 128])

In [17]:
img = Image.open("test.png").convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))

In [18]:
enc_img = proc.pil_to_tensor(img).to(run_device).unsqueeze(0)

In [19]:
latent = autoenc.encode(enc_img)

NameError: name 'autoenc' is not defined

In [None]:
prediction = transformer(inputs_embeds=latent).last_hidden_state

In [None]:
decoded = autoenc.decode(prediction)

In [None]:
proc.tensor_to_pil(decoded.squeeze(0))

In [None]:
init_img

tensor([[[0.4941, 0.4980, 0.4980,  ..., 0.4157, 0.4157, 0.4157],
         [0.4980, 0.5020, 0.4980,  ..., 0.4196, 0.4196, 0.4196],
         [0.4980, 0.5020, 0.5059,  ..., 0.4196, 0.4157, 0.4196],
         ...,
         [0.4314, 0.3804, 0.4667,  ..., 0.2784, 0.3804, 0.3686],
         [0.6471, 0.6235, 0.5843,  ..., 0.2627, 0.2706, 0.2196],
         [0.4588, 0.4196, 0.2902,  ..., 0.3020, 0.2784, 0.2784]],

        [[0.6745, 0.6745, 0.6784,  ..., 0.6196, 0.6196, 0.6235],
         [0.6784, 0.6824, 0.6784,  ..., 0.6275, 0.6275, 0.6275],
         [0.6784, 0.6824, 0.6863,  ..., 0.6275, 0.6235, 0.6275],
         ...,
         [0.4510, 0.4039, 0.4980,  ..., 0.3294, 0.4275, 0.4039],
         [0.6863, 0.6706, 0.6353,  ..., 0.3216, 0.3255, 0.2706],
         [0.5137, 0.4667, 0.3176,  ..., 0.3765, 0.3490, 0.3529]],

        [[0.8392, 0.8392, 0.8392,  ..., 0.8392, 0.8353, 0.8314],
         [0.8353, 0.8353, 0.8392,  ..., 0.8431, 0.8392, 0.8353],
         [0.8353, 0.8392, 0.8431,  ..., 0.8392, 0.8392, 0.