In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model
import imageio, os

run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# model settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
BOTTLENECK_DIM = 768

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )
    def forward(self, x):
        return x + self.block(x)

class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=BOTTLENECK_DIM):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.ReLU(),
            ResidualBlock(64),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            ResidualBlock(512)
        )
        self.encoder_fc = nn.Linear(512 * 4 * 4, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, 512 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (512, 4, 4)),
            ResidualBlock(512),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64),
            nn.ConvTranspose2d(64, in_channels, 4, 2, 1),
            nn.Tanh()
        )
    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        return self.encoder_fc(x)
    def decode(self, z):
        z = self.decoder_fc(z)
        return self.decoder(z)


In [None]:
autoenc = torch.load('checkpoints/run1/autoenc', map_location=run_device)
autoenc = autoenc.to(run_device).eval()
transformer = torch.load('checkpoints/run1/transformer', map_location=run_device)
transformer = transformer.to(run_device).eval()


In [None]:
def generate_frames(num_frames=600):
    seq = torch.zeros(1, 1, BOTTLENECK_DIM, device=run_device)
    frames = []
    with torch.no_grad():
        for _ in range(num_frames):
            out = transformer(inputs_embeds=seq).last_hidden_state
            next_latent = out[:, -1:, :]
            frame = autoenc.decode(next_latent.squeeze(1)).clamp(-1, 1)
            frames.append(frame.cpu())
            seq = torch.cat([seq, next_latent], dim=1)
    return frames

def save_video(frames, output='output.mp4', fps=60):
    writer = imageio.get_writer(output, fps=fps)
    for frame in frames:
        img = frame.squeeze(0).permute(1,2,0).numpy()
        img = ((img + 1)/2 * 255).astype('uint8')
        writer.append_data(img)
    writer.close()
    print(f'Saved {output}')


In [None]:
frames = generate_frames()
save_video(frames)
