In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torch torchvision transformers



In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from transformers import BertTokenizer, BertModel
from collections import defaultdict

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
class MSRVTTDataset(Dataset):
    def __init__(self, root_dir, metadata_path, split="train", transform=None, num_frames=20):
        """
        Args:
            root_dir (str): Path to the root directory containing the "frames" folder.
            metadata_path (str): Path to the JSON file containing video and caption metadata.
            split (str): Dataset split ("train", "val", or "test").
            transform (callable, optional): Optional transform to be applied to frames.
            num_frames (int): Number of frames to sample from each video.
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.num_frames = num_frames

        # Load metadata
        with open(metadata_path, "r") as f:
            self.metadata = json.load(f)

        # Filter videos by split
        self.videos = [video for video in self.metadata["videos"] if video["split"] == split]

        # Map video IDs to captions
        self.video_to_captions = defaultdict(list)
        for sentence in self.metadata["sentences"]:
            self.video_to_captions[sentence["video_id"]].append(sentence["caption"])

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        # Get video metadata
        video = self.videos[idx]
        video_id = video["video_id"]
        frame_folder = os.path.join(self.root_dir, "frames", video_id)

        # Load frames
        frame_files = sorted(os.listdir(frame_folder))
        frame_files = [os.path.join(frame_folder, f) for f in frame_files]

        # Sample a fixed number of frames
        if len(frame_files) > self.num_frames:
            frame_files = self._sample_frames(frame_files, self.num_frames)
        frames = [Image.open(f) for f in frame_files]

        # Pad frames if there are fewer than num_frames
        if len(frames) < self.num_frames:
            frames = self._pad_frames(frames, self.num_frames)

        # Apply transforms (if any)
        if self.transform:
            frames = [self.transform(frame) for frame in frames]

        # Convert frames to tensor
        frames = torch.stack(frames)  # Shape: (num_frames, 3, H, W)

        # Get captions
        captions = self.video_to_captions[video_id]

        # Use the first frame as the image input
        image_input = frames[0]  # Shape: (3, H, W)

        return frames, captions, image_input

    def _sample_frames(self, frame_files, num_frames):
        """
        Sample a fixed number of frames from a list of frame files.
        """
        step = len(frame_files) // num_frames
        return frame_files[::step][:num_frames]

    def _pad_frames(self, frames, num_frames):
        """
        Pad frames with the last frame to ensure a fixed number of frames.
        """
        if len(frames) == 0:
            raise ValueError("No frames found in the video folder.")
        last_frame = frames[-1]
        while len(frames) < num_frames:
            frames.append(last_frame)
        return frames


class VideoGenerationModel(nn.Module):
    def __init__(self, text_embedding_size=768, image_embedding_size=2048, hidden_size=512, num_frames=20, frame_size=224):
        super(VideoGenerationModel, self).__init__()
        self.text_embedding_size = text_embedding_size
        self.image_embedding_size = image_embedding_size
        self.hidden_size = hidden_size
        self.num_frames = num_frames
        self.frame_size = frame_size

        # Text encoder (BERT embeddings are already encoded, so this is just a linear layer)
        self.text_fc = nn.Linear(text_embedding_size, hidden_size)

        # Image encoder (ResNet-50)
        self.image_encoder = models.resnet50(pretrained=True)
        self.image_encoder.fc = nn.Linear(self.image_encoder.fc.in_features, hidden_size)

        # Frame generator (GAN)
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(hidden_size * 2, 1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3 * num_frames, kernel_size=4, stride=2, padding=1, output_padding=0),  # Adjusted for 224x224 output
            nn.Tanh()
        )

    def forward(self, text_embedding, image_embedding):
        # Process text and image embeddings
        text_hidden = self.text_fc(text_embedding).unsqueeze(-1).unsqueeze(-1)  # Shape: (batch_size, hidden_size, 1, 1)
        image_hidden = self.image_encoder(image_embedding).unsqueeze(-1).unsqueeze(-1)  # Shape: (batch_size, hidden_size, 1, 1)

        # Debugging: Print shapes
        print("Text hidden shape:", text_hidden.shape)
        print("Image hidden shape:", image_hidden.shape)

        # Concatenate text and image embeddings
        combined_hidden = torch.cat([text_hidden, image_hidden], dim=1)  # Shape: (batch_size, hidden_size * 2, 1, 1)

        # Generate frames
        frames = self.generator(combined_hidden)  # Shape: (batch_size, 3 * num_frames, H, W)

        # Debugging: Print shapes
        print("Generated frames shape before interpolation:", frames.shape)

        # Interpolate frames to the desired size (224x224)
        frames = torch.nn.functional.interpolate(frames, size=(self.frame_size, self.frame_size), mode="bilinear", align_corners=False)

        # Debugging: Print shapes
        print("Generated frames shape after interpolation:", frames.shape)

        # Reshape frames
        frames = frames.view(frames.size(0), self.num_frames, 3, self.frame_size, self.frame_size)  # Shape: (batch_size, num_frames, 3, H, W)

        return frames

Using device: cuda


In [None]:
# Load dataset
dataset = MSRVTTDataset(
    root_dir="/content/drive/MyDrive/msr-vtt",  # Path to the folder containing the "frames" folder
    metadata_path="/content/drive/MyDrive/msr-vtt/train_val_videodatainfo.json",  # Path to metadata file
    split="train",
    transform=transform,
    num_frames=20
)

dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)


# Initialize models
text_encoder = BertModel.from_pretrained("bert-base-uncased").to(device)
video_generator = VideoGenerationModel().to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(video_generator.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    video_generator.train()
    for frames, captions, image_input in dataloader:
        frames = frames.to(device)  # Move frames to GPU
        image_input = image_input.to(device)  # Move image input to GPU

        # Flatten captions into a single list of strings
        flat_captions = [caption for video_captions in captions for caption in video_captions]

        # Tokenize captions
        inputs = tokenizer(flat_captions, return_tensors="pt", padding=True, truncation=True)
        inputs = {key: value.to(device) for key, value in inputs.items()}

        # Get BERT embeddings
        with torch.no_grad():
            outputs = text_encoder(**inputs)
            text_embeddings = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token

        # Reshape text embeddings to match the batch size of image_input
        batch_size = image_input.size(0)  # Get the batch size of image_input

        num_captions_per_video = text_embeddings.shape[0] // batch_size  # Calculate captions per video dynamically

        print("Batch size:", batch_size)
        print("Num captions per video:", num_captions_per_video)

        # Reshape correctly using .reshape()
        text_embeddings = text_embeddings.reshape(batch_size, num_captions_per_video, -1)

        # Average across captions
        text_embeddings = text_embeddings.mean(dim=1)

        print("Final text embedding shape:", text_embeddings.shape)  # Should be (batch_size, 768)

        # Debugging: Print shapes
        print("Text embeddings shape after reshaping:", text_embeddings.shape)
        print("Image input shape:", image_input.shape)

        # Generate frames
        generated_frames = video_generator(text_embeddings, image_input)

        # Compute loss
        loss = criterion(generated_frames, frames)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

Batch size: 3
Num captions per video: 20
Final text embedding shape: torch.Size([3, 768])
Text embeddings shape after reshaping: torch.Size([3, 768])
Image input shape: torch.Size([3, 3, 224, 224])
Text hidden shape: torch.Size([3, 512, 1, 1])
Image hidden shape: torch.Size([3, 512, 1, 1])
Generated frames shape before interpolation: torch.Size([3, 60, 256, 256])
Generated frames shape after interpolation: torch.Size([3, 60, 224, 224])
Epoch [1/10], Loss: 1.9987
Batch size: 3
Num captions per video: 20
Final text embedding shape: torch.Size([3, 768])
Text embeddings shape after reshaping: torch.Size([3, 768])
Image input shape: torch.Size([3, 3, 224, 224])
Text hidden shape: torch.Size([3, 512, 1, 1])
Image hidden shape: torch.Size([3, 512, 1, 1])
Generated frames shape before interpolation: torch.Size([3, 60, 256, 256])
Generated frames shape after interpolation: torch.Size([3, 60, 224, 224])
Epoch [2/10], Loss: 1.9875
Batch size: 3
Num captions per video: 20
Final text embedding shap

In [None]:
torch.save(video_generator.state_dict(), "/content/drive/MyDrive/video_generator.pth")

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model
video_generator = VideoGenerationModel().to(device)
video_generator.load_state_dict(torch.load("/content/drive/MyDrive/video_generator.pth"))
video_generator.eval()

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the text encoder (BERT)
text_encoder = BertModel.from_pretrained("bert-base-uncased").to(device)

# Load an input image
image_path = "/content/drive/MyDrive/msr-vtt/inputCartoon.jpeg"  # Replace with the path to your input image
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to device

# Define a text caption
text = "a cartoon character runs through an ice cave"  # Replace with your desired caption

# Tokenize the text caption
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {key: value.to(device) for key, value in inputs.items()}

# Get BERT embeddings
with torch.no_grad():
    outputs = text_encoder(**inputs)
    text_embedding = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token

# Generate frames
with torch.no_grad():
    generated_frames = video_generator(text_embedding, image)

# Convert frames to numpy array
frames = generated_frames.squeeze().cpu().numpy()  # Shape: (num_frames, 3, H, W)
frames = np.transpose(frames, (0, 2, 3, 1))  # Shape: (num_frames, H, W, 3)
frames = (frames * 255).astype(np.uint8)  # Convert to uint8

# Save frames as a video
output_video_path = "/content/drive/MyDrive/generated_video.mp4"  # Replace with your desired output path
fps = 10  # Frames per second
frame_size = (224, 224)  # Frame size

# Create a video writer
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, frame_size)

# Write frames to the video
for frame in frames:
    out.write(frame)

# Release the video writer
out.release()

print(f"Video saved to {output_video_path}")

Text hidden shape: torch.Size([1, 512, 1, 1])
Image hidden shape: torch.Size([1, 512, 1, 1])
Generated frames shape before interpolation: torch.Size([1, 60, 256, 256])
Generated frames shape after interpolation: torch.Size([1, 60, 224, 224])
Video saved to /content/drive/MyDrive/generated_video.mp4
