In [None]:
# Download COCO 2017 train images
!wget http://images.cocodataset.org/zips/train2017.zip
# Download COCO 2017 val images
!wget http://images.cocodataset.org/zips/val2017.zip
# Download COCO 2017 annotations
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip

--2024-05-18 15:16:27--  http://images.cocodataset.org/zips/train2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 16.182.103.9, 52.217.171.177, 52.216.61.249, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|16.182.103.9|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 19336861798 (18G) [application/zip]
Saving to: ‘train2017.zip’


2024-05-18 15:34:27 (17.1 MB/s) - ‘train2017.zip’ saved [19336861798/19336861798]

--2024-05-18 15:34:27--  http://images.cocodataset.org/zips/val2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 3.5.28.218, 52.217.107.36, 3.5.25.102, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|3.5.28.218|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 815585330 (778M) [application/zip]
Saving to: ‘val2017.zip’


2024-05-18 15:35:17 (16.0 MB/s) - ‘val2017.zip’ saved [815585330/815585330]

--2024-05-18 15:35:17--  http://images.cocodataset.org/anno

In [None]:
import os
import zipfile

In [None]:
# Function to unzip files
def unzip_file(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

In [None]:
# Unzip train images
unzip_file('train2017.zip', '/content/coco')
# Unzip val images
unzip_file('val2017.zip', '/content/coco')
# Unzip annotations
unzip_file('annotations_trainval2017.zip', '/content/coco')

In [None]:
# Verify contents
!ls /content/coco/train2017 | head -n 5
!ls /content/coco/val2017 | head -n 5
!ls /content/coco/annotations

000000000009.jpg
000000000025.jpg
000000000030.jpg
000000000034.jpg
000000000036.jpg
000000000139.jpg
000000000285.jpg
000000000632.jpg
000000000724.jpg
000000000776.jpg
captions_train2017.json  instances_train2017.json  person_keypoints_train2017.json
captions_val2017.json	 instances_val2017.json    person_keypoints_val2017.json


In [None]:
import os

def numfiles(folder_path):
  files = os.listdir(folder_path)
  num_files = len(files)
  return num_files

# Specify the path to the folder in your Google Drive

folder_path_train = '/content/coco/train2017'
folder_path_val = '/content/coco/val2017'



print("Number of files in the folder:", numfiles(folder_path_train))
print("Number of files in the folder:", numfiles(folder_path_val))


Number of files in the folder: 118287
Number of files in the folder: 5000


In [None]:
TF_ENABLE_ONEDNN_OPTS=0

In [None]:
import os
import json
import requests
from PIL import Image
from io import BytesIO
from pycocotools.coco import COCO
import torchvision.transforms as transforms
from torchvision.datasets import CocoCaptions
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [None]:
class COCODataset(Dataset):
    def __init__(self, data_dir, data_type, transform=None):
        self.data_dir = data_dir
        self.data_type = data_type
        self.transform = transform
        self.img_ids, self.captions = self.load_images_and_captions()

    def load_images_and_captions(self):
        # Load captions
        captions_path = os.path.join(self.data_dir, 'annotations', f'captions_{self.data_type}.json')
        with open(captions_path, 'r') as f:
            captions_data = json.load(f)
        captions = [caption['caption'] for caption in captions_data['annotations']]

        # Load image IDs
        img_ids = [img_info['id'] for img_info in captions_data['images']]

        return img_ids, captions

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        image_path = os.path.join(self.data_dir, self.data_type, f'{img_id:012d}.jpg')
        image = Image.open(image_path).convert('RGB')
        caption = self.captions[idx]
        if self.transform:
            image = self.transform(image)
        return image, caption

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
# Load dataset
data_dir = '/content/coco'  # Update this path
data_type = 'train2017'
dataset = COCODataset(data_dir, data_type, transform=transform)
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased'):
        super(TextEncoder, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True).to('cuda')
        outputs = self.bert(**inputs)
        return outputs.last_hidden_state.mean(dim=1)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Linear(768, 256*16*16)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, text_features):
        x = self.fc(text_features)
        x = x.view(-1, 256, 16, 16)
        x = self.deconv(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, stride=1, padding=0)
        )

    def forward(self, x):
        x = self.conv(x)
        return torch.sigmoid(x.view(x.size(0), -1)).mean(1, keepdim=True)  # Ensure the output size is [batch_size, 1]


In [None]:
# Initialize models
text_encoder = TextEncoder().cuda()
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [None]:
# Define loss functions and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# TensorBoard writer
writer = SummaryWriter()

In [None]:
eval_texts = [
    "A cat sitting on a bench",
    "A beautiful landscape with mountains",
    "A group of people playing football",
    "A close-up of a colorful bird",
    "A city skyline at night"
]

In [None]:
def inference_pipeline(text_prompts,epoch):
    output_dir = "/content/coco/generated_outs"
    os.makedirs(output_dir, exist_ok=True)
    text_encoder.eval()
    generator.eval()

    with torch.no_grad():

        for i,text in enumerate(text_prompts):
            text_features = text_encoder([text]).cuda()
            generated_image = generator(text_features)
            generated_image = generated_image.squeeze().cpu().numpy().transpose(1, 2, 0)
            generated_image = (generated_image + 1) / 2  # Denormalize

            plt.imshow(generated_image)
            plt.axis('off')
            plt.savefig(os.path.join(output_dir, f'epoch_{epoch}_image_{i}.png'))
            plt.close()

    return generated_image

In [None]:
def train_model(epochs, dataloader):
    text_encoder.train()
    generator.train()
    discriminator.train()

    train_losses_G = []
    train_losses_D = []

    for epoch in range(epochs):
        epoch_loss_G = 0
        epoch_loss_D = 0

        for i, (images, captions) in enumerate(dataloader):
            batch_size = images.size(0)
            images = images.cuda()
            valid = torch.ones(batch_size, 1).cuda()
            fake = torch.zeros(batch_size, 1).cuda()

            # ---------------------
            #  Train Generator
            # ---------------------
            optimizer_G.zero_grad()

            text_features = text_encoder(captions).cuda()
            generated_images = generator(text_features)

            g_loss = adversarial_loss(discriminator(generated_images), valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            real_loss = adversarial_loss(discriminator(images), valid)
            fake_loss = adversarial_loss(discriminator(generated_images.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            epoch_loss_G += g_loss.item()
            epoch_loss_D += d_loss.item()

            if i % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}')

        epoch_loss_G /= len(dataloader)
        epoch_loss_D /= len(dataloader)
        train_losses_G.append(epoch_loss_G)
        train_losses_D.append(epoch_loss_D)

        # Logging the epoch losses
        writer.add_scalar('Loss/Generator', epoch_loss_G, epoch)
        writer.add_scalar('Loss/Discriminator', epoch_loss_D, epoch)

        # Save model checkpoints
        torch.save({
            'epoch': epoch + 1,
            'text_encoder_state_dict': text_encoder.state_dict(),
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'loss_G': epoch_loss_G,
            'loss_D': epoch_loss_D,
        }, f'checkpoint_epoch_{epoch+1}.pth')

        inference_pipeline(eval_texts,epoch)

    return train_losses_G, train_losses_D


# Train the model
train_losses_G, train_losses_D = train_model(epochs=50, dataloader=dataloader)

# Close the writer
writer.close()


Epoch [1/50], Step [1/232], G Loss: 0.7063, D Loss: 0.6684
Epoch [1/50], Step [101/232], G Loss: 4.0336, D Loss: 0.0205
Epoch [1/50], Step [201/232], G Loss: 3.6305, D Loss: 0.4282
Epoch [2/50], Step [1/232], G Loss: 3.4248, D Loss: 0.4274


In [None]:
def generate_and_display(text):
    text_encoder.eval()
    generator.eval()

    with torch.no_grad():
        text_features = text_encoder([text]).cuda()
        generated_image = generator(text_features)

    generated_image = generated_image.squeeze().cpu().numpy().transpose(1, 2, 0)
    generated_image = (generated_image + 1) / 2  # Denormalize
    plt.imshow(generated_image)
    plt.axis('off')
    plt.show()

def plot_losses(train_losses_G, train_losses_D):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_G, label='Generator Loss')
    plt.plot(train_losses_D, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Epochs')
    plt.legend()
    plt.show()

# Example usage
generate_and_display("A cat sitting on a bench")
plot_losses(train_losses_G, train_losses_D)


In [None]:
# Example usage
text_prompt = "a cat under a tree"
generated_image = inference_pipeline(text_prompt)
generated_image = generated_image.squeeze().cpu().numpy().transpose(1, 2, 0)
generated_image = (generated_image + 1) / 2  # Denormalize
plt.imshow(generated_image)
plt.axis('off')
plt.show()

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params_gen = count_parameters(generator)
total_params_dis = count_parameters(discriminator)
total_params = total_params_gen + total_params_dis
print(f"Total number of parameters in the model: {total_params}")
