In [1]:
import os
import random
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from IPython.display import clear_output,display,HTML
import base64

In [2]:
os.makedirs('tranining_dataset', exist_ok=True)
num_videos = 10000
frames_per_video =10
img_size =(64,64)
shape_size = 10

In [3]:
prompts_and_movements = [
    ("circle moving down", "circle", "down"),  # Move circle downward
    ("circle moving left", "circle", "left"),  # Move circle leftward
    ("circle moving right", "circle", "right"),  # Move circle rightward
    ("circle moving diagonally up-right", "circle", "diagonal_up_right"),  # Move circle diagonally up-right
    ("circle moving diagonally down-left", "circle", "diagonal_down_left"),  # Move circle diagonally down-left
    ("circle moving diagonally up-left", "circle", "diagonal_up_left"),  # Move circle diagonally up-left
    ("circle moving diagonally down-right", "circle", "diagonal_down_right"),  # Move circle diagonally down-right
    ("circle rotating clockwise", "circle", "rotate_clockwise"),  # Rotate circle clockwise
    ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"),  # Rotate circle counter-clockwise
    ("circle shrinking", "circle", "shrink"),  # Shrink circle
    ("circle expanding", "circle", "expand"),  # Expand circle
    ("circle bouncing vertically", "circle", "bounce_vertical"),  # Bounce circle vertically
    ("circle bouncing horizontally", "circle", "bounce_horizontal"),  # Bounce circle horizontally
    ("circle zigzagging vertically", "circle", "zigzag_vertical"),  # Zigzag circle vertically
    ("circle zigzagging horizontally", "circle", "zigzag_horizontal"),  # Zigzag circle horizontally
    ("circle moving up-left", "circle", "up_left"),  # Move circle up-left
    ("circle moving down-right", "circle", "down_right"),  # Move circle down-right
    ("circle moving down-left", "circle", "down_left"),  # Move circle down-left
]

In [4]:
def create_image_with_moving_shape(size, frame_num, shape, direction):
    img = Image.new('RGB',size, color=(255,255,255))
    draw = ImageDraw.Draw(img)
    center_x, center_y = size[0]//2, size[1]//2
    postion = (center_x, center_y)
    direction_map = {  
        # Adjust position downwards based on frame number
        "down": (0, frame_num * 5 % size[1]),  
        # Adjust position to the left based on frame number
        "left": (-frame_num * 5 % size[0], 0),  
        # Adjust position to the right based on frame number
        "right": (frame_num * 5 % size[0], 0),  
        # Adjust position diagonally up and to the right
        "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position diagonally down and to the left
        "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Adjust position diagonally up and to the left
        "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position diagonally down and to the right
        "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Rotate the image clockwise based on frame number
        "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),  
        # Rotate the image counter-clockwise based on frame number
        "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),  
        # Adjust position for a bouncing effect vertically
        "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),  
        # Adjust position for a bouncing effect horizontally
        "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),  
        # Adjust position for a zigzag effect vertically
        "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),  
        # Adjust position for a zigzag effect horizontally
        "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),  
        # Adjust position upwards and to the right based on frame number
        "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position upwards and to the left based on frame number
        "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position downwards and to the right based on frame number
        "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Adjust position downwards and to the left based on frame number
        "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])  
    }

    if direction in direction_map:
        if isinstance(direction_map[direction],tuple):
            postion = tuple(np.add(postion,direction_map[direction]))
        else:
            img = direction_map[direction]


    return np.array(img)

In [5]:
for i in range(num_videos):
    prompt, shape, direction = random.choice(prompts_and_movements)
    video_dir = f'training_dataset/video_{i}'
    os.makedirs(video_dir,exist_ok=True)
    with open(f'{video_dir}/prompt.txt', 'w') as f:
        f.write(prompt)

    for frame_num in range(frames_per_video):
        img =create_image_with_moving_shape(img_size,frame_num,shape,direction)
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png',img)

In [None]:
class Text2VideoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_dirs = [os.path.join(root_dir,d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir,d))]
        self.frame_paths=[]
        self.prompts=[]

        for video_dir in self.video_dirs:
            frames =[os.path.join(video_dir,f) for f in os.listdir(video_dir) if f.endswith('.png')]
            self.frame_paths.extend(frames)
            with open(os.path.join(video_dir,'prompt.txt'),'r')as f:
                prompt = f.read.strip()

            self.prompts.extend([prompt] * len(frames))


    def __len__(self):
        return len(self.frame_paths)
    
    