In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import AutoModel
from torchvision.utils import save_image
import imageio, os

run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# video settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
BOTTLENECK_DIM = 768

In [3]:
class ImageCompressor:
    def __init__(self, resolution=(RESOLUTION_HEIGHT, RESOLUTION_WIDTH), channels=CHANNELS, latent_dim=BOTTLENECK_DIM):
        self.H, self.W = resolution
        self.C = channels
        self.latent_dim = latent_dim
        self.per_channel_dim = latent_dim // self.C
        scale_factor = math.sqrt(self.per_channel_dim / (self.H * self.W))
        self.h_down = max(1, int(self.H * scale_factor))
        self.w_down = max(1, int(self.W * scale_factor))

    def encode(self, image):
        is_batched = image.dim() == 4
        if not is_batched:
            image = image.unsqueeze(0)
        B = image.shape[0]
        latent_parts = []
        for c in range(self.C):
            ch = image[:, c:c+1, :, :]
            down = F.interpolate(ch, size=(self.h_down, self.w_down), mode='bilinear', align_corners=False)
            flat = down.view(B, -1)
            if flat.shape[1] < self.per_channel_dim:
                pad = torch.zeros((B, self.per_channel_dim - flat.shape[1]), device=flat.device, dtype=flat.dtype)
                flat = torch.cat([flat, pad], dim=1)
            else:
                flat = flat[:, :self.per_channel_dim]
            latent_parts.append(flat)
        latent = torch.cat(latent_parts, dim=1)
        if not is_batched:
            latent = latent.squeeze(0)
        return latent

    def decode(self, latent):
        is_batched = latent.dim() == 2
        if not is_batched:
            latent = latent.unsqueeze(0)
        B = latent.shape[0]
        per_channel_dim = self.per_channel_dim
        h_down, w_down = self.h_down, self.w_down
        H, W = self.H, self.W
        channels = []
        for i in range(self.C):
            start = i * per_channel_dim
            end = start + per_channel_dim
            flat = latent[:, start:end]
            ch = flat[:, :h_down * w_down].reshape(B, 1, h_down, w_down)
            up = F.interpolate(ch, size=(H, W), mode='bilinear', align_corners=False)
            channels.append(up)
        recon = torch.cat(channels, dim=1)
        if not is_batched:
            recon = recon.squeeze(0)
        return recon

In [4]:
# load model
decap_gpt2 = AutoModel.from_pretrained('decap_gpt2_cm2').to(run_device)
compressor = ImageCompressor()

In [5]:
def generate_latent_sequence(model, sequence_length=600, latent_dim=BOTTLENECK_DIM, device=run_device):
    model.eval()
    with torch.no_grad():
        current = torch.zeros(1, 1, latent_dim, device=device)
        for _ in range(sequence_length):
            out = model(inputs_embeds=current[:, -60:]).last_hidden_state
            next_latent = out[:, -1:, :]
            current = torch.cat([current, next_latent], dim=1)
        return current.squeeze(0)

In [6]:
def decode_latents_to_images(latents, output_folder='generated_video'):
    os.makedirs(output_folder, exist_ok=True)
    with torch.no_grad():
        for i, latent in enumerate(latents):
            img = compressor.decode(latent.unsqueeze(0)).clamp(0,1)
            save_image(img, os.path.join(output_folder, f'frame_{i:04d}.png'))

In [7]:
def make_video_imageio(frame_folder='generated_video', output_file='output2.mp4', fps=60):
    frames = sorted([os.path.join(frame_folder, f) for f in os.listdir(frame_folder) if f.endswith('.png')])
    writer = imageio.get_writer(output_file, fps=fps)
    for fp in frames:
        image = imageio.imread(fp)
        writer.append_data(image)
    writer.close()
    print(f'✅ Video saved using imageio: {output_file}')

In [8]:
# === RUN ===
latents = generate_latent_sequence(decap_gpt2, sequence_length=6000)
decode_latents_to_images(latents)
make_video_imageio()

  image = imageio.imread(fp)


✅ Video saved using imageio: output2.mp4
