In [None]:
from datasets import load_dataset
import torch
ds = load_dataset("pranked03/flowers-blip-captions")

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from transformers import AutoTokenizer, DistilBertModel

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertModel.from_pretrained("distilbert-base-uncased")
model.to(device)
model.eval()

# Max Pooling - Take the max value over time for every dimension.
def max_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
    return torch.max(token_embeddings, 1)[0]

def convert_text_to_feature(sentences, max_length=50):
    inputs = tokenizer.batch_encode_plus(
        sentences, padding='max_length', max_length=max_length, truncation=True, return_tensors='pt'
    )
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    sentence_embeddings = max_pooling(outputs, attention_mask)
    return sentence_embeddings

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class Text2ImageDataset(Dataset):
    def __init__(self, dataset, transform=None, split=0):

        self.transform = transform
        self.dataset = dataset if isinstance(split, str) else dataset['train']  # Default to 'train' split
        self.split = split

        # Assuming text embeddings are precomputed or need to be computed
        self.texts = self.dataset['text']
        self.images = self.dataset['image']
        self.labels = self.dataset['label']
        self.embeddings = self._compute_embeddings(self.texts)

    def _compute_embeddings(self, texts):

        batch_size = 8
        embeddings_list = [] 

        # Process texts in batches
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size] 
            print(f"Batch texts: {batch_texts}")
            print(f"Batch start index: {i}")

            batch_embeddings = convert_text_to_feature(batch_texts).detach().cpu()
            embeddings_list.append(batch_embeddings)
            print("Batch done")

        embeddings = torch.cat(embeddings_list, dim=0)
        return embeddings

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        right_image = self.images[idx]  
        if isinstance(right_image, Image.Image):
            right_image = np.array(right_image)  
        if self.transform is not None:
            right_image = self.transform(right_image)

        txt = self.texts[idx]
        right_embed = self.embeddings[idx] 

        wrong_image = self.find_wrong_image(self.labels[idx])
        if self.transform is not None:
            wrong_image = self.transform(wrong_image)

        sample = {
            'right_images': torch.FloatTensor(right_image), 
            'right_embed': right_embed, 
            'wrong_images': torch.FloatTensor(wrong_image), 
            'txt': str(txt)  
        }

        return sample

    def find_wrong_image(self, category):

        wrong_indices = [i for i, label in enumerate(self.labels) if label != category]

        wrong_idx = np.random.choice(wrong_indices)
        wrong_image = self.images[wrong_idx]

        if isinstance(wrong_image, Image.Image):
            wrong_image = np.array(wrong_image)

        return wrong_image

In [None]:
from torchvision import transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

In [None]:
dataset = Text2ImageDataset(ds, transform=transform)

In [None]:
from torch.utils.data import DataLoader

batch_size = 128
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
        super(Generator, self).__init__()
        self.channels = channels
        self.noise_dim = noise_dim
        self.embed_dim = embed_dim
        self.embed_out_dim = embed_out_dim

        # Text embedding layers
        self.text_embedding = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_out_dim),
            nn.BatchNorm1d(self.embed_out_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Generator architecture
        model = []
        model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
        model += self._create_layer(512, 256, 4, stride=2, padding=1)
        model += self._create_layer(256, 128, 4, stride=2, padding=1)
        model += self._create_layer(128, 64, 4, stride=2, padding=1)
        model += self._create_layer(64, self.channels, 4, stride=2, padding=1, output=True)

        self.model = nn.Sequential(*model)

    def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
        layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
        if output:
            layers.append(nn.Tanh())  # Tanh activation for the output layer
        else:
            layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)]  # Batch normalization and ReLU for other layers
        return layers

    def forward(self, noise, text):
        # Apply text embedding to the input text
        text = self.text_embedding(text)
        text = text.view(text.shape[0], text.shape[1], 1, 1)  # Reshape to match the generator input size
        z = torch.cat([text, noise], 1)  # Concatenate text embedding with noise
        return self.model(z)


# The Embedding model
class Embedding(nn.Module):
    def __init__(self, size_in, size_out):
        super(Embedding, self).__init__()
        self.text_embedding = nn.Sequential(
            nn.Linear(size_in, size_out),
            nn.BatchNorm1d(size_out),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x, text):
        embed_out = self.text_embedding(text)
        embed_out_resize = embed_out.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)  # Resize to match the discriminator input size
        out = torch.cat([x, embed_out_resize], 1)  # Concatenate text embedding with the input feature map
        return out


# The Discriminator model
class Discriminator(nn.Module):
    def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
        super(Discriminator, self).__init__()
        self.channels = channels
        self.embed_dim = embed_dim
        self.embed_out_dim = embed_out_dim

        # Discriminator architecture
        self.model = nn.Sequential(
            *self._create_layer(self.channels, 64, 4, 2, 1, normalize=False),
            *self._create_layer(64, 128, 4, 2, 1),
            *self._create_layer(128, 256, 4, 2, 1),
            *self._create_layer(256, 512, 4, 2, 1)
        )
        self.text_embedding = Embedding(self.embed_dim, self.embed_out_dim)  # Text embedding module
        self.output = nn.Sequential(
            nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
        )

    def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
        layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
        if normalize:
            layers.append(nn.BatchNorm2d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, x, text):
        x_out = self.model(x)  # Extract features from the input using the discriminator architecture
        out = self.text_embedding(x_out, text)  # Apply text embedding and concatenate with the input features
        out = self.output(out)  # Final discriminator output
        return out.squeeze(), x_out

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
embed_dim = 768
noise_dim = 100
embed_out_dim = 64
generator = Generator(
    channels=3, embed_dim=embed_dim, noise_dim=noise_dim, embed_out_dim=embed_out_dim
).to(device)
generator.apply(weights_init)

In [None]:
discriminator = Discriminator(
    channels=3, embed_dim=embed_dim, embed_out_dim=embed_out_dim
).to(device)
discriminator.apply(weights_init)

In [None]:
# setting up Adam optimizer for Generator and Discriminator
learning_rate = 0.0002
optimizer_G = torch.optim.Adam(
    generator.parameters(), lr=learning_rate, betas=(0.5, 0.999)
)
optimizer_D = torch.optim.Adam(
    discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)
)

# loss functions
criterion = nn.BCELoss()
l2_loss = nn.MSELoss()
l1_loss = nn.L1Loss()

In [None]:
import time

num_epochs = 200
real_label = 1.
fake_label = 0.
l1_coef = 50
l2_coef = 100

D_losses = []
G_losses = []

for epoch in range(num_epochs):
    epoch_D_loss = []
    epoch_G_loss = []
    batch_time = time.time()

    for idx, batch in enumerate(train_loader):

        images = batch['right_images'].to(device)
        wrong_images = batch['wrong_images'].to(device)
        embeddings = batch['right_embed'].to(device)
        batch_size = images.size(0)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Clear gradients for the discriminator
        optimizer_D.zero_grad()

        # Generate random noise
        noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)

        # Generate fake image batch with the generator
        fake_images = generator(noise, embeddings)

        # Forward pass real batch and calculate loss
        real_out, real_act = discriminator(images, embeddings)
        d_loss_real = criterion(real_out, torch.full_like(real_out, real_label, device=device))

        # Forward pass wrong batch and calculate loss
        wrong_out, wrong_act = discriminator(wrong_images, embeddings)
        d_loss_wrong = criterion(wrong_out, torch.full_like(wrong_out, fake_label, device=device))

        # Forward pass fake batch and calculate loss
        fake_out, fake_act = discriminator(fake_images.detach(), embeddings)
        d_loss_fake = criterion(fake_out, torch.full_like(fake_out, fake_label, device=device))

        # Compute total discriminator loss
        d_loss = d_loss_real + d_loss_wrong + d_loss_fake

        # Backpropagate the gradients
        d_loss.backward()

        # Update the discriminator
        optimizer_D.step()

        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Clear gradients for the generator
        optimizer_G.zero_grad()

        # Generate new random noise
        noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)

        # Generate new fake images using Generator
        fake_images = generator(noise, embeddings)

        # Get discriminator output for the new fake images
        out_fake, act_fake = discriminator(fake_images, embeddings)

        # Get discriminator output for the real images
        out_real, act_real = discriminator(images, embeddings)

        # Calculate losses
        g_bce = criterion(out_fake, torch.full_like(out_fake, real_label, device=device))
        g_l1 = l1_coef * l1_loss(fake_images, images)
        g_l2 = l2_coef * l2_loss(torch.mean(act_fake, 0), torch.mean(act_real, 0).detach())

        # Compute total generator loss
        g_loss = g_bce + g_l1 + g_l2

        # Backpropagate the gradients
        g_loss.backward()

        # Update the generator
        optimizer_G.step()

        epoch_D_loss.append(d_loss.item())
        epoch_G_loss.append(g_loss.item())

    print('Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f} time: {:.2f}'.format(
        epoch+1, idx+1, len(train_loader),
        d_loss.mean().item(),
        g_loss.mean().item(),
        time.time() - batch_time)
    )
    D_losses.append(sum(epoch_D_loss)/len(epoch_D_loss))
    G_losses.append(sum(epoch_G_loss)/len(epoch_G_loss))

In [None]:
import os
model_save_path = '/kaggle/working/'
torch.save(generator.state_dict(), os.path.join(model_save_path, 'generator_bert.pth'))
torch.save(discriminator.state_dict(), os.path.join(model_save_path,'discriminator_bert.pth'))

In [None]:
# generator.eval()

In [None]:
# import matplotlib.pyplot as plt

In [None]:
# sentence = 'a purple passion flower'
# embeddings = convert_text_to_feature([str(sentence)])
# noise = torch.randn(1, noise_dim, 1, 1, device=device)
# pred = generator(noise, embeddings)
# plt.imshow(pred[0].cpu().detach().permute(1, 2, 0))
# plt.show()