In [9]:
import os
import random
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from IPython.display import clear_output,display,HTML
import base64
import torchvision.transforms as transforms

In [3]:
os.makedirs('tranining_dataset', exist_ok=True)
num_videos = 10000
frames_per_video =10
img_size =(64,64)
shape_size = 10

In [4]:
prompts_and_movements = [
    ("circle moving down", "circle", "down"),  # Move circle downward
    ("circle moving left", "circle", "left"),  # Move circle leftward
    ("circle moving right", "circle", "right"),  # Move circle rightward
    ("circle moving diagonally up-right", "circle", "diagonal_up_right"),  # Move circle diagonally up-right
    ("circle moving diagonally down-left", "circle", "diagonal_down_left"),  # Move circle diagonally down-left
    ("circle moving diagonally up-left", "circle", "diagonal_up_left"),  # Move circle diagonally up-left
    ("circle moving diagonally down-right", "circle", "diagonal_down_right"),  # Move circle diagonally down-right
    ("circle rotating clockwise", "circle", "rotate_clockwise"),  # Rotate circle clockwise
    ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"),  # Rotate circle counter-clockwise
    ("circle shrinking", "circle", "shrink"),  # Shrink circle
    ("circle expanding", "circle", "expand"),  # Expand circle
    ("circle bouncing vertically", "circle", "bounce_vertical"),  # Bounce circle vertically
    ("circle bouncing horizontally", "circle", "bounce_horizontal"),  # Bounce circle horizontally
    ("circle zigzagging vertically", "circle", "zigzag_vertical"),  # Zigzag circle vertically
    ("circle zigzagging horizontally", "circle", "zigzag_horizontal"),  # Zigzag circle horizontally
    ("circle moving up-left", "circle", "up_left"),  # Move circle up-left
    ("circle moving down-right", "circle", "down_right"),  # Move circle down-right
    ("circle moving down-left", "circle", "down_left"),  # Move circle down-left
]

In [5]:
def create_image_with_moving_shape(size, frame_num, shape, direction):
    img = Image.new('RGB',size, color=(255,255,255))
    draw = ImageDraw.Draw(img)
    center_x, center_y = size[0]//2, size[1]//2
    postion = (center_x, center_y)
    direction_map = {  
        # Adjust position downwards based on frame number
        "down": (0, frame_num * 5 % size[1]),  
        # Adjust position to the left based on frame number
        "left": (-frame_num * 5 % size[0], 0),  
        # Adjust position to the right based on frame number
        "right": (frame_num * 5 % size[0], 0),  
        # Adjust position diagonally up and to the right
        "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position diagonally down and to the left
        "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Adjust position diagonally up and to the left
        "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position diagonally down and to the right
        "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Rotate the image clockwise based on frame number
        "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),  
        # Rotate the image counter-clockwise based on frame number
        "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),  
        # Adjust position for a bouncing effect vertically
        "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),  
        # Adjust position for a bouncing effect horizontally
        "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),  
        # Adjust position for a zigzag effect vertically
        "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),  
        # Adjust position for a zigzag effect horizontally
        "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),  
        # Adjust position upwards and to the right based on frame number
        "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position upwards and to the left based on frame number
        "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position downwards and to the right based on frame number
        "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Adjust position downwards and to the left based on frame number
        "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])  
    }

    if direction in direction_map:
        if isinstance(direction_map[direction],tuple):
            postion = tuple(np.add(postion,direction_map[direction]))
        else:
            img = direction_map[direction]


    return np.array(img)

In [6]:
for i in range(num_videos):
    prompt, shape, direction = random.choice(prompts_and_movements)
    video_dir = f'training_dataset/video_{i}'
    os.makedirs(video_dir,exist_ok=True)
    with open(f'{video_dir}/prompt.txt', 'w') as f:
        f.write(prompt)

    for frame_num in range(frames_per_video):
        img =create_image_with_moving_shape(img_size,frame_num,shape,direction)
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png',img)

In [12]:
class Text2VideoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_dirs = [os.path.join(root_dir,d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir,d))]
        self.frame_paths=[]
        self.prompts=[]

        for video_dir in self.video_dirs:
            frames =[os.path.join(video_dir,f) for f in os.listdir(video_dir) if f.endswith('.png')]
            self.frame_paths.extend(frames)
            with open(os.path.join(video_dir,'prompt.txt'),'r')as f:
                prompt = f.read().strip()

            self.prompts.extend([prompt] * len(frames))


    def __len__(self):
        return len(self.frame_paths)


    def __getitem__(self,idx):
        frame_path = self.frame_paths[idx]
        image = Image.open(frame_path)
        prompt = self.prompts[idx]
        if self.transform:
            image = self.transform(image)


        return image, prompt


    
    

In [13]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,))
    ]
)
dataset = Text2VideoDataset(root_dir='training_dataset', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle =True)

In [15]:
class TextEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TextEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)


    def forward(self, x):
        return self.embedding(x)

In [20]:
class Generator(nn.Module):
    def __init__(self, text_embed_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100 + text_embed_size, 256*8*8)
        self.deconv1 = nn.ConvTranspose2d(256,128,4,2,1)
        self.deconv2 = nn.ConvTranspose2d(128,64,4,2,1)
        self.deconv3 = nn.ConvTranspose2d(64,3,4,2,1)

        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()
    
    def forward(self, noise, text_embed):
        x= torch.cat((noise, text_embed), dim=1)
        x = self.fc1(x).view(-1,256,8,8)
        x= self.relu(self.deconv1(x))
        x= self.relu(self.deconv2(x))
        x= self.tanh(self.deconv3(x))


        return x


In [21]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3,64,4,2,1)
        self.conv2 = nn.Conv2d(64,128,4,2,1)
        self.conv3 = nn.Conv2d(128,256,4,2,1)

        self.fc1 = nn.Linear(256*8*8, 1)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        x = self.leaky_relu(self.conv1(input))
        x= self.leaky_relu(self.conv2(x))
        x= self.leaky_relu(self.conv3(x))

        x= x.view(-1,256*8*8)
        x= self.sigmoid(self.fc1(x))

        return x

    


        

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
all_prompts = [prompt for prompt, _, _ in prompts_and_movements] 
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))} 
vocab_size = len(vocab)  
embed_size = 10  

def encode_text(prompt):
    return torch.tensor([vocab[word] for word in prompt.split()])

text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  
netG = Generator(embed_size).to(device)  
netD = Discriminator().to(device)  
criterion = nn.BCELoss().to(device)  
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Discriminator
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Genera

In [None]:
# Number of epochs
num_epochs = 13
for epoch in range(num_epochs):
    for i, (data, prompts) in enumerate(dataloader):
        real_data = data.to(device)
        prompts = [prompt for prompt in prompts]
        netD.zero_grad()  
        batch_size = real_data.size(0)  
        labels = torch.ones(batch_size, 1).to(device)  
        output = netD(real_data)  
        lossD_real = criterion(output, labels)  
        lossD_real.backward()  
       
        # Generate fake data
        noise = torch.randn(batch_size, 100).to(device)  
        text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  
        fake_data = netG(noise, text_embeds)  
        labels = torch.zeros(batch_size, 1).to(device) 
        output = netD(fake_data.detach())  
        lossD_fake = criterion(output, labels)  
        lossD_fake.backward()  
        optimizerD.step()  

        # Update Generator
        netG.zero_grad()  
        labels = torch.ones(batch_size, 1).to(device) 
        output = netD(fake_data) 
        lossG = criterion(output, labels)  
        lossG.backward()  
        optimizerG.step()  
    
    # Print epoch information
    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")

In [None]:
# Save the Generator model's state dictionary to a file named 'generator.pth'
torch.save(netG.state_dict(), 'generator.pth')

# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'
torch.save(netD.state_dict(), 'discriminator.pth')

In [None]:
def generate_video(text_prompt, num_frames=10):
    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)
    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)
    for frame_num in range(num_frames):
        noise = torch.randn(1, 100).to(device)
        with torch.no_grad():
            fake_frame = netG(noise, text_embed)
        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')
generate_video('circle moving up-right')

In [None]:

folder_path = 'generated_video_circle_moving_up-right'
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]
image_files.sort()
frames = []
for image_file in image_files:
  image_path = os.path.join(folder_path, image_file)
  frame = cv2.imread(image_path)
  frames.append(frame)


frames = np.array(frames)

# Define the frame rate (frames per second)
fps = 10
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))

# Write each frame to the video
for frame in frames:
  out.write(frame)

# Release the video writer
out.release()