In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model
import imageio, os
from PIL import Image

run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# model settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
BOTTLENECK_DIM = 768

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )
    def forward(self, x):
        return x + self.block(x)

class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=BOTTLENECK_DIM, input_resolution=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT)):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.ReLU(),
            ResidualBlock(64),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            ResidualBlock(512)
        )

        # Infer shape
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, *input_resolution)
            enc_out = self.encoder(dummy)
            self.flattened_size = enc_out.view(1, -1).shape[1]

        self.encoder_fc = nn.Linear(self.flattened_size, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, self.flattened_size)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, enc_out.shape[1:]),
            ResidualBlock(512),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64),
            nn.ConvTranspose2d(64, in_channels, 4, 2, 1),
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        return self.encoder_fc(x)
    
    def decode(self, z):
        z = self.decoder_fc(z)
        return self.decoder(z)

In [4]:
autoenc = ConvAutoencoder()
state_dict = torch.load("checkpoints/run2/autoenc.pth", map_location=run_device)  # optional: map to device
autoenc.load_state_dict(state_dict)
autoenc = autoenc.to(run_device).eval()
transformer = GPT2Model.from_pretrained("checkpoints/run2/gpt2_decap")
transformer = transformer.to(run_device).eval()

In [None]:
def generate_frames(num_frames=6000, seed_count=4, seed_image=None, seed_folder='video_dataset'):
    """Generate video frames. If `seed_image` is provided, it is used
    as the first `seed_count` frames. Otherwise, frames from `seed_folder`
    are used if available.
    """
    frames = []
    latents = []
    if seed_image is not None:
        img = Image.open(seed_image).convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))
        t = torch.from_numpy(np.array(img)).permute(2,0,1).float() / 127.5 - 1
        for _ in range(seed_count):
            frames.append(t.unsqueeze(0))
            latents.append(autoenc.encode(t.unsqueeze(0).to(run_device)))
    elif os.path.isdir(seed_folder):
        videos = [f for f in os.listdir(seed_folder) if f.endswith('.mp4')]
        if videos:
            reader = imageio.get_reader(os.path.join(seed_folder, videos[0]))
            for i, img in enumerate(reader):
                if i >= seed_count:
                    break
                t = torch.from_numpy(img).permute(2,0,1).float() / 127.5 - 1
                frames.append(t.unsqueeze(0))
                latents.append(autoenc.encode(t.unsqueeze(0).to(run_device)))
            reader.close()
    if not latents:
        for _ in range(seed_count):
            z = torch.randn(1, BOTTLENECK_DIM, device=run_device)
            latents.append(z)
            frames.append(autoenc.decode(z).clamp(-1,1).cpu())
    seq = torch.stack(latents, dim=1)
    with torch.no_grad():
        for _ in range(num_frames - seed_count):
            out = transformer(inputs_embeds=seq).last_hidden_state
            next_latent = out[:, -1:, :]
            frame = autoenc.decode(next_latent.squeeze(1)).clamp(-1, 1)
            frames.append(frame.cpu())
            seq = torch.cat([seq, next_latent], dim=1)[:, -16:, :]
    return frames


In [8]:
frames = generate_frames(seed_image='test.png', seed_count=4)
save_video(frames)

Saved output.mp4
