In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2LMHeadModel
import imageio, os
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm

run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# video settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
CONVERTED_FRAMERATE = 24

# model settings
WINDOW_SIZE = 48
ENCODED_DIM = 768

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
    def forward(self, x):
        return x + self.block(x)

In [4]:
class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=ENCODED_DIM, input_resolution=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT)):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 4, 2, 1),  # 64x64
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),           # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.Conv2d(64, 128, 4, 2, 1),          # 16x16
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),         # 8x8
            nn.ReLU()
        )

        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, *input_resolution)
            enc_out = self.encoder(dummy)
            self.flattened_size = enc_out.view(1, -1).shape[1]

        self.encoder_fc = nn.Linear(self.flattened_size, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, self.flattened_size)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, enc_out.shape[1:]),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = F.gelu(self.encoder_fc(x))
        
        return x

    def decode(self, z):
        z = F.gelu(self.decoder_fc(z))
        z = torch.sigmoid(self.decoder(z))
        
        return z

In [5]:
class ImageProcessor:
    def tensor_to_pil(self, image_tensor: torch.Tensor) -> Image.Image:
        """
        Convert a tensor to a PIL Image.
        
        Args:
            image_tensor (torch.Tensor): A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        
        Returns:
            Image.Image: A PIL Image object.
        """
        # Clamp to [0, 1], convert to [0, 255] and uint8
        image_np = (image_tensor.clamp(0, 1).mul(255).byte().cpu().permute(1, 2, 0).numpy())
        return Image.fromarray(image_np)
    
    def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
        """
        Convert a PIL image to a PyTorch tensor of shape (C, H, W) with values in [0, 1].
        
        Args:
            image (Image.Image): A PIL Image object.
        
        Returns:
            torch.Tensor: A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        """
        return transforms.ToTensor()(image)  # Already returns (C, H, W)

In [6]:
# Load the autoencoder
autoencoder = ConvAutoencoder()
autoenced_state_dict = torch.load("checkpoints/run19/autoenc.pth", map_location=run_device)
autoencoder.load_state_dict(autoenced_state_dict)
autoencoder = autoencoder.to(run_device).eval()

# load the transformer
transformer = GPT2LMHeadModel.from_pretrained("checkpoints/run19/gpt2_decap").transformer
transformer = transformer.to(run_device).eval()

# Load image processor
proc = ImageProcessor()

In [7]:
def generate_frames(num_frames, context_length=WINDOW_SIZE, autoencoder=autoencoder, transformer=transformer):
    autoencoder.eval()
    transformer.eval()
    
    total_seq = torch.zeros(num_frames + context_length, CHANNELS, RESOLUTION_HEIGHT, RESOLUTION_WIDTH, device=run_device)

    with torch.no_grad():
        for i in tqdm(range(num_frames)):
            current_slice = total_seq[i:i + context_length]
            
            slice_latents = autoencoder.encode(current_slice)
            
            slice_latents += torch.rand(slice_latents.shape, device=run_device) * 999
            
            prediction_latents = transformer(inputs_embeds=slice_latents.unsqueeze(0)).last_hidden_state
            
            prediction_frame = autoencoder.decode(prediction_latents.squeeze(0))  # shape: (context_length, C, H, W)
            
            total_seq[context_length + i] = prediction_frame[-1]
    
    return total_seq

In [8]:
def save_video(frames, output='output.mp4', fps=60):
    writer = imageio.get_writer(output, fps=fps)
    for frame in tqdm(frames):
        img = frame.permute(1,2,0).numpy()
        img = ((img + 1)/2 * 255).astype('uint8')
        writer.append_data(img)
    writer.close()
    print(f'Saved {output}')

In [9]:
init_img = proc.pil_to_tensor(Image.open("test.png").convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))).to(run_device)

frames = generate_frames(512).detach().cpu()
torch.cuda.empty_cache()
save_video(frames)

100%|██████████| 512/512 [00:10<00:00, 50.62it/s]
100%|██████████| 560/560 [00:00<00:00, 2864.44it/s]

Saved output.mp4





In [10]:
frames.shape

torch.Size([560, 3, 128, 128])

In [11]:
img = Image.open("test.png").convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))

In [12]:
enc_img = proc.pil_to_tensor(img).to(run_device).unsqueeze(0)

In [13]:
latent = autoenc.encode(enc_img)

NameError: name 'autoenc' is not defined

In [None]:
prediction = transformer(inputs_embeds=latent).last_hidden_state

In [None]:
decoded = autoenc.decode(prediction)

In [None]:
proc.tensor_to_pil(decoded.squeeze(0))

In [None]:
init_img

tensor([[[0.4941, 0.4980, 0.4980,  ..., 0.4157, 0.4157, 0.4157],
         [0.4980, 0.5020, 0.4980,  ..., 0.4196, 0.4196, 0.4196],
         [0.4980, 0.5020, 0.5059,  ..., 0.4196, 0.4157, 0.4196],
         ...,
         [0.4314, 0.3804, 0.4667,  ..., 0.2784, 0.3804, 0.3686],
         [0.6471, 0.6235, 0.5843,  ..., 0.2627, 0.2706, 0.2196],
         [0.4588, 0.4196, 0.2902,  ..., 0.3020, 0.2784, 0.2784]],

        [[0.6745, 0.6745, 0.6784,  ..., 0.6196, 0.6196, 0.6235],
         [0.6784, 0.6824, 0.6784,  ..., 0.6275, 0.6275, 0.6275],
         [0.6784, 0.6824, 0.6863,  ..., 0.6275, 0.6235, 0.6275],
         ...,
         [0.4510, 0.4039, 0.4980,  ..., 0.3294, 0.4275, 0.4039],
         [0.6863, 0.6706, 0.6353,  ..., 0.3216, 0.3255, 0.2706],
         [0.5137, 0.4667, 0.3176,  ..., 0.3765, 0.3490, 0.3529]],

        [[0.8392, 0.8392, 0.8392,  ..., 0.8392, 0.8353, 0.8314],
         [0.8353, 0.8353, 0.8392,  ..., 0.8431, 0.8392, 0.8353],
         [0.8353, 0.8392, 0.8431,  ..., 0.8392, 0.8392, 0.