In [1]:
from google.colab import drive
drive.mount('/content/mydrive')

Mounted at /content/mydrive


In [2]:
import os

# Set paths
video_dir = "/content/mydrive/MyDrive/msr-vtt/TrainValVideo"
caption_file = "/content/mydrive/MyDrive/msr-vtt/train_val_videodatainfo.json"
output_dir = "/content/mydrive/MyDrive/msr-vtt/frames"

# Create output directory
# os.makedirs(output_dir, exist_ok=True)

# Verify paths
print("Video files:", os.listdir(video_dir))
if os.path.exists(caption_file):
    print("Caption file found!")
else:
    print("Caption file not found.")

Video files: ['video6409.mp4', 'video6421.mp4', 'video6437.mp4', 'video6418.mp4', 'video6423.mp4', 'video6435.mp4', 'video6441.mp4', 'video6413.mp4', 'video6425.mp4', 'video6408.mp4', 'video6426.mp4', 'video6440.mp4', 'video6427.mp4', 'video6434.mp4', 'video6428.mp4', 'video6432.mp4', 'video6430.mp4', 'video6410.mp4', 'video6422.mp4', 'video6439.mp4', 'video641.mp4', 'video6436.mp4', 'video6415.mp4', 'video6420.mp4', 'video642.mp4', 'video6424.mp4', 'video6431.mp4', 'video6412.mp4', 'video6411.mp4', 'video6438.mp4', 'video6429.mp4', 'video6433.mp4', 'video644.mp4', 'video6417.mp4', 'video6407.mp4', 'video643.mp4', 'video6406.mp4', 'video6414.mp4', 'video6416.mp4', 'video6474.mp4', 'video6453.mp4', 'video6467.mp4', 'video6454.mp4', 'video6445.mp4', 'video6479.mp4', 'video6451.mp4', 'video6457.mp4', 'video6462.mp4', 'video6471.mp4', 'video6473.mp4', 'video6461.mp4', 'video645.mp4', 'video6450.mp4', 'video6483.mp4', 'video6460.mp4', 'video6469.mp4', 'video646.mp4', 'video6484.mp4', 'video

In [None]:
# Step 1: Mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

# Step 2: Set paths
video_dir = "/content/mydrive/MyDrive/msr-vtt/TrainValVideo"  # Directory containing video clips
caption_file = "/content/mydrive/MyDrive/msr-vtt/train_val_videodatainfo.json"  # Path to the caption file
output_dir = "/content/mydrive/MyDrive/msr-vtt/frames"  # Directory to save extracted frames
progress_file = "/content/mydrive/MyDrive/msr-vtt/processed_videos.txt"  # File to track processed videos

# Step 3: Import libraries
import os
import cv2
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Step 4: Create output directory if it doesn't exist
# os.makedirs(output_dir, exist_ok=True)

# Step 5: Load captions
with open(caption_file, "r") as f:
    captions_data = json.load(f)

# Extract video metadata and sentences
videos = captions_data["videos"]
sentences = captions_data["sentences"]

# Step 6: Filter videos to start from video3758
# videos = [video for video in videos if int(video["video_id"][5:]) >= 6976]

# Step 7: Create a mapping from video_id to captions
video_to_captions = {}
for sentence in sentences:
    video_id = sentence["video_id"]
    caption = sentence["caption"]
    if video_id not in video_to_captions:
        video_to_captions[video_id] = []
    video_to_captions[video_id].append(caption)

# Step 8: Load processed videos (if the file exists)
if os.path.exists(progress_file):
    with open(progress_file, "r") as f:
        processed_videos = set(f.read().splitlines())
else:
    processed_videos = set()

# Step 9: Extract frames from videos
def extract_frames(video_path, output_dir, frame_rate=1):
    """
    Extracts frames from a video and saves them to the output directory.

    Args:
        video_path (str): Path to the video file.
        output_dir (str): Directory to save the extracted frames.
        frame_rate (int): Number of frames to extract per second.

    Returns:
        List of frame file paths.
    """
    os.makedirs(output_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    frame_interval = max(1, fps // frame_rate)
    frame_count = 0
    saved_frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_interval == 0:
            frame_file = os.path.join(output_dir, f"frame_{frame_count:04d}.jpg")
            cv2.imwrite(frame_file, frame)
            saved_frames.append(frame_file)

        frame_count += 1

    cap.release()
    return saved_frames

# Step 10: Prepare the dataset
class MSRVTTDataset(Dataset):
    def __init__(self, video_dir, videos, video_to_captions, output_dir, transform=None, frame_rate=1):
        self.video_dir = video_dir
        self.videos = videos
        self.video_to_captions = video_to_captions
        self.output_dir = output_dir
        self.transform = transform
        self.frame_rate = frame_rate
        self.data = []

        # Process each video
        for video in videos:
            video_id = video["video_id"]

            # Skip videos without captions
            if video_id not in video_to_captions:
                continue

            # Skip already processed videos
            if video_id in processed_videos:
                print(f"Skipping {video_id} (already processed)")
                continue

            # Define the directory where frames for this video will be saved
            frame_dir = os.path.join(output_dir, video_id)

            # Path to the video file
            video_path = os.path.join(video_dir, f"{video_id}.mp4")

            # Extract frames
            frame_files = extract_frames(video_path, frame_dir, frame_rate)

            # Save video-caption pairs
            captions = video_to_captions[video_id]
            for caption in captions:
                self.data.append({"frames": frame_files, "caption": caption})

            # Mark video as processed
            processed_videos.add(video_id)
            with open(progress_file, "a") as f:
                f.write(f"{video_id}\n")

            print(f"Processed {video_id}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        frame_files = self.data[idx]["frames"]
        caption = self.data[idx]["caption"]

        # Load frames
        frames = []
        for frame_file in frame_files:
            frame = Image.open(frame_file).convert("RGB")
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)

        # Stack frames into a tensor
        frames = torch.stack(frames)  # Shape: (num_frames, 3, H, W)

        return frames, caption

# Step 11: Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Step 12: Create dataset and dataloader
dataset = MSRVTTDataset(
    video_dir=video_dir,
    videos=videos,
    video_to_captions=video_to_captions,
    output_dir=output_dir,
    transform=transform,
    frame_rate=1  # Extract 1 frame per second
)

print(dataset)

# dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# # Step 13: Test the dataset
# for frames, caption in dataloader:
#     print("Frames shape:", frames.shape)  # Should be (batch_size, num_frames, 3, H, W)
#     print("Caption:", caption)
#     break
