<H1 style="display:center">Generative Adversarial Networks</H1>

In [100]:
# Operating System module for interacting with the operating system
import os

# Module for generating random numbers
import random

# Module for numerical operations
import numpy as np

# OpenCV library for image processing
import cv2

# Python Imaging Library for image processing
from PIL import Image, ImageDraw, ImageFont

# PyTorch library for deep learning
import torch

# Dataset class for creating custom datasets in PyTorch
from torch.utils.data import Dataset

# Module for image transformations
import torchvision.transforms as transforms

# Neural network module in PyTorch
import torch.nn as nn

# Optimization algorithms in PyTorch
import torch.optim as optim

# Function for padding sequences in PyTorch
from torch.nn.utils.rnn import pad_sequence

# Function for saving images in PyTorch
from torchvision.utils import save_image

# Module for plotting graphs and images
import matplotlib.pyplot as plt

# Module for displaying rich content in IPython environments
from IPython.display import clear_output, display, HTML

# Module for encoding and decoding binary data to text
import base64

## Making Training Data

In [101]:
os.makedirs("training_data", exist_ok=True) # Initialising the training data directory
num_vids = 10000 #Number of videos to generate
frames_per_video = 10 # fps
img_size = (64,64) #size of image
shape_size = 10 #Defining the size of shape (Circle)

In [102]:
#Prompts vs Mvt...
prompts_and_movements = [
    ("circle moving down", "circle", "down"),  # Move circle downward
    ("circle moving left", "circle", "left"),  # Move circle leftward
    ("circle moving right", "circle", "right"),  # Move circle rightward
    ("circle moving diagonally up-right", "circle", "diagonal_up_right"),  # Move circle diagonally up-right
    ("circle moving diagonally down-left", "circle", "diagonal_down_left"),  # Move circle diagonally down-left
    ("circle moving diagonally up-left", "circle", "diagonal_up_left"),  # Move circle diagonally up-left
    ("circle moving diagonally down-right", "circle", "diagonal_down_right"),  # Move circle diagonally down-right
    ("circle rotating clockwise", "circle", "rotate_clockwise"),  # Rotate circle clockwise
    ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"),  # Rotate circle counter-clockwise
    ("circle shrinking", "circle", "shrink"),  # Shrink circle
    ("circle expanding", "circle", "expand"),  # Expand circle
    ("circle bouncing vertically", "circle", "bounce_vertical"),  # Bounce circle vertically
    ("circle bouncing horizontally", "circle", "bounce_horizontal"),  # Bounce circle horizontally
    ("circle zigzagging vertically", "circle", "zigzag_vertical"),  # Zigzag circle vertically
    ("circle zigzagging horizontally", "circle", "zigzag_horizontal"),  # Zigzag circle horizontally
    ("circle moving up-left", "circle", "up_left"),  # Move circle up-left
    ("circle moving down-right", "circle", "down_right"),  # Move circle down-right
    ("circle moving down-left", "circle", "down_left"),  # Move circle down-left
]

In [103]:
#Defining the circle with parameters
def create_moving_shape(size,frame_num,shape,direction):
    img = Image.new('RGB',size=size,color=(255,255,255))
    draw = ImageDraw.Draw(img)
    centerX, centerY = size[0] // 2, size[1] // 2
    position = (centerX, centerY)
    direction_map = {
        'down' : (0,frame_num*5 % size[1]),
        'left' : (-frame_num*5 % size[0],0),
        'right' : (frame_num*5 % size[0],0),
        "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(centerX, centerY), fillcolor=(255, 255, 255)),  
        "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(centerX, centerY), fillcolor=(255, 255, 255)),  
        "bounce_vertical": (0, centerY - abs(frame_num * 5 % size[1] - centerY)),  
        "bounce_horizontal": (centerX - abs(frame_num * 5 % size[0] - centerX), 0),  
        "zigzag_vertical": (0, centerY - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, centerY + frame_num * 5 % size[1]),  
        "zigzag_horizontal": (centerX - frame_num * 5 % size[0], centerY) if frame_num % 2 == 0 else (centerX + frame_num * 5 % size[0], centerY),  
        "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])  
    }
    #Check if the direction is in the map
    if direction in direction_map:
        if isinstance(direction_map[direction], tuple):
            position = tuple(np.add(position, direction_map[direction]))
        else:
            img = direction_map[direction]
    return np.array(img)

In [104]:
#Looping over the number of videos to generate the videos
for i in range(num_vids):
    prompt, shape, direction = random.choice(prompts_and_movements)
    video_dir = f'training_dataset/video{i}'
    os.makedirs(video_dir, exist_ok=True)
    with open(f'{video_dir}/prompt.txt','w') as f:
        f.write(prompt)

    for frame_num in range(frames_per_video):
        img = create_moving_shape(size=img_size, frame_num=frame_num, shape=shape,direction=direction)
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

KeyboardInterrupt: 

In [105]:
class TextToVideoDataset:
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_dirs = [os.path.join(root_dir,d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.frame_paths = []
        self.prompts = []

        for video_dir in self.video_dirs:
            frames = [os.path.join(video_dir,f) for f in os.listdir(video_dir) if f.endswith('.png')]
            self.frame_paths.extend(frames)
            with open(os.path.join(video_dir,'prompt.txt'),'r') as f:
                prompt = f.read().strip()

            self.prompts.extend([prompt]*len(frames))

    def __len__(self):
        return len(self.frame_paths)
    
    def __getitem__(self,idx):
        frame_path = self.frame_paths[idx]
        image = Image.open(frame_path)
        prompt = self.prompts[idx]

        if self.transform:
            image = self.transform(image)
            return image, prompt
                
                
        print(f"Image shape: {image.size()}, Prompt: {prompt}")


In [106]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
dataset = TextToVideoDataset(root_dir='training_dataset',transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [107]:
class TextEmbedding(nn.Module):
    def __init__(self,vocab_size, embed_size):
        super(TextEmbedding,self).__init__()
        self.embedding = nn.Embedding(vocab_size,embed_size)

    def forward(self,x):
        return self.embedding(x)

In [108]:
class TextEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TextEmbedding,self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self,x):
        return self.embedding(x)

In [109]:
class Generator(nn.Module):
    def __init__(self, text_embed_size):
        super(Generator,self).__init__()
        self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8) # This is the flatten layer
        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) # Conv Layer 1
        self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1) # Conv Layer 2
        self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Conv Layer 3
        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()

    def forward(self, noise, text_embed):
        text_embed = text_embed.repeat(noise.size(0), 1)
        x = torch.cat((noise,text_embed),dim=1)
        x = self.fc1(x).view(-1, 256, 8, 8)
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        x = self.tanh(self.deconv3(x))

        return x

In [110]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.fc1 = nn.Linear(256 * 8 * 8, 1)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        x = self.leaky_relu(self.conv1(input))
        x = self.leaky_relu(self.conv2(x))
        x = self.leaky_relu(self.conv3(x))

        x = x.view(-1, 256 * 8 * 8)

        x = self.sigmoid(self.fc1(x))

        return x

In [113]:
#Checking for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#Creating a simple vocabulary for text prompts
all_prompts = [prompt for prompt, _, _ in prompts_and_movements] #Extracts All
vocab = {word : idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}
vocab_size = len(vocab)
embed_size = 10

def encode_text(prompt):
    return torch.tensor([vocab[word] for word in prompt.split()])

text_embedding = TextEmbedding(vocab_size=vocab_size, embed_size=embed_size).to(device=device)
netG = Generator(embed_size).to(device=device)
netD = Discriminator().to(device=device)
criterion = nn.BCELoss().to(device=device)
optimiserD = optim.Adam(netD.parameters(),lr=0.0002, betas=(0.5, 0.999))
optimiserG = optim.Adam(netG.parameters(),lr=0.0002, betas=(0.5, 0.999))

cuda


In [115]:
num_epochs = 13
for epoch in range(num_epochs):
    for i, (data, prompts) in enumerate(dataloader):
        real_data = data.to(device)
        prompts = [prompts for prompt in prompts]

        netD.zero_grad()
        batch_size = real_data.size(0)
        labels = torch.ones(batch_size, 1 ).to(device=device)
        output = netD(real_data)
        lossD_real = criterion(output, labels)
        lossD_real.backward()
        noise = torch.randn(batch_size,100).to(device=device)
        text_embeds = torch.stack([text_embedding(encode_text(prompt=prompt).to(device=device)).mean(dim=0)])
        fake_data = netG(noise, text_embeds)
        labels = torch.zeros(batch_size,1).to(device=device)
        output = netD(fake_data.detach())
        lossD_fake = criterion(output, labels)
        lossD_fake.backward()
        optimiserD.step()


        netG.zero_grad()
        labels = torch.ones(batch_size,1).to(device)
        output = netD(fake_data)
        lossG = criterion(output,labels).to(device)
        lossG.backward()
        optimiserG.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")
    

Epoch [1/13] Loss D: 1.3862930536270142, Loss G: 0.6933430433273315
Epoch [2/13] Loss D: 100.0, Loss G: 100.0
Epoch [3/13] Loss D: 100.0, Loss G: 100.0
Epoch [4/13] Loss D: 100.0, Loss G: 100.0
Epoch [5/13] Loss D: 100.0, Loss G: 100.0
Epoch [6/13] Loss D: 100.0, Loss G: 100.0
