In [1]:
# Uncomment and run the appropriate command for your operating system, if required
# No installation is reqiured on Google Colab / Kaggle notebooks

# Linux / Binder / Windows (No GPU)
# !pip install numpy matplotlib torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# Linux / Windows (GPU)
# pip install numpy matplotlib torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
 
# MacOS (NO GPU)
# !pip install numpy matplotlib torch torchvision torchaudio

# Installing necessary packages

In [2]:
# Install necessary libraries if not already installed
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

required_packages = ['numpy', 'matplotlib', 'torch', 'torchvision', 'torchaudio', 'transformers', 'pillow']

for package in required_packages:
    try:
        __import__(package)
    except ImportError:
        install(package)



In [3]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from PIL import Image
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt

# Data Preprocessing

In [4]:
import re

image_dir = 'dataset/images'
files = os.listdir(image_dir) # list of all files in the directory

counter = 1

# pattern to match files already in the `artwork003.ext` format
pattern = re.compile(r'artwork\d{3}\.\w+')

for file_name in files:
    file_ext = os.path.splitext(file_name)[1]
    
    if not pattern.match(file_name):
        # Create the new file name
        new_file_name = f'artwork{counter:03d}{file_ext}'
        
        # Get the full file paths
        old_file_path = os.path.join(image_dir, file_name)
        new_file_path = os.path.join(image_dir, new_file_name)
        
        # Rename the file
        os.rename(old_file_path, new_file_path)
        
        # Increment the counter
        counter += 1

print("Renaming completed.")


Renaming completed.


In [5]:
# custom dataset
class TextImageDataset(Dataset):
    def __init__(self, text_dir, image_dir, transform=None):
        self.text_dir = text_dir
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))
        self.text_files = sorted(os.listdir(text_dir))
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        txt_name = os.path.join(self.text_dir, self.text_files[idx])
        with open(txt_name, 'r') as f:
            description = f.read()
        
        tokens = self.tokenizer(description, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        return image, tokens['input_ids'].squeeze(), tokens['attention_mask'].squeeze()

# Data transforms
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize expects 3 channels (RGB)
])

# Creating dataset and dataloader
dataset = TextImageDataset(text_dir='dataset/text', image_dir='dataset/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# Feature Engineering

In [6]:
# defining the Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim, text_dim, img_channels):
        super(Generator, self).__init__()
        self.text_embedding = nn.Linear(text_dim, latent_dim)
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, noise, text):
        text_embedding = self.text_embedding(text)
        x = torch.cat([noise, text_embedding], dim=1)
        x = x.unsqueeze(2).unsqueeze(3)
        return self.gen(x)

# Define the Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, img_channels, text_dim):
        super(Discriminator, self).__init__()
        self.img_dis = nn.Sequential(
            nn.Conv2d(img_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )
        self.text_embedding = nn.Linear(text_dim, 512)

    def forward(self, img, text):
        img_out = self.img_dis(img).view(img.size(0), -1)
        text_embedding = self.text_embedding(text)
        x = torch.cat([img_out, text_embedding], dim=1)
        return torch.sigmoid(x)

# Initialize models
text_dim = 768  # BERT base model output dimension
latent_dim = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(latent_dim, text_dim, 3).to(device)
discriminator = Discriminator(3, text_dim).to(device)

# Training the Model
The generator and discriminator are trained here using a loop.
## Optimization
The optimization is embedded within the training loop. The code uses Adam optimizer to update the weights of both the generator and the discriminator.
## Evaluation
Evaluation happens within the training loop where loss values for both the generator and the discriminator are printed every 100 steps. Additionally, generated sample images are saved at the end of each epoch to visualize the generatorâ€™s progress.

In [7]:
# training Function
def train_gan(generator, discriminator, dataloader, epochs, lr, device):
    # optimization
    criterion = nn.BCELoss()
    optim_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optim_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for i, (images, text_ids, text_mask) in enumerate(dataloader):
            batch_size = images.size(0)
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # moving data to the appropriate device
            images, text_ids, text_mask = images.to(device), text_ids.to(device), text_mask.to(device)
            
            # Discriminator training
            optim_d.zero_grad()
            outputs = discriminator(images, text_ids)
            real_loss = criterion(outputs, real_labels)
            real_loss.backward()

            noise = torch.randn(batch_size, latent_dim).to(device)
            fake_images = generator(noise, text_ids)
            outputs = discriminator(fake_images.detach(), text_ids)
            fake_loss = criterion(outputs, fake_labels)
            fake_loss.backward()
            optim_d.step()
            
            # Generator training
            optim_g.zero_grad()
            outputs = discriminator(fake_images, text_ids)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optim_g.step()
            

            if (i+1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], '
                      f'D Loss: {real_loss.item()+fake_loss.item():.4f}, G Loss: {g_loss.item():.4f}')
        
        # saving some sample images at each epoch
        with torch.no_grad():
            sample_noise = torch.randn(batch_size, latent_dim).to(device)
            sample_images = generator(sample_noise, text_ids)
            os.makedirs('samples', exist_ok=True)
            utils.save_image(sample_images, f'samples/sample_epoch_{epoch+1}.png', nrow=8, normalize=True)
    
    torch.save(generator.state_dict(), 'generator.pth')
    torch.save(discriminator.state_dict(), 'discriminator.pth')

# setting parameters
image_size = 64
batch_size = 128
epochs = 25
lr = 0.0002

# training the models
train_gan(generator, discriminator, dataloader, epochs, lr, device)

RuntimeError: output with shape [1, 64, 64] doesn't match the broadcast shape [3, 64, 64]