In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import pickle

# Load Preprocessed Text Data
with open('D:\\Emogen\\coco2017\\preprocessed_captions.pkl', 'rb') as handle:
    while True:
        try:
            preprocessed_data = pickle.load(handle)
            # Process the loaded chunk of data
        except EOFError:
            break

text_embeddings = preprocessed_data['padded_captions']


In [2]:
# Hyperparameters
num_epochs = 5
batch_size = 32
noise_dim = 100
text_embedding_dim = 300
img_channels = 3
learning_rate = 0.0002
beta1 = 0.5

# Normalization parameters
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Custom Dataset Class with Normalization
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, chunk_size=1000):
        self.root_dir = root_dir
        self.transform = transform
        self.image_list = sorted(os.listdir(root_dir))
        self.chunk_size = chunk_size

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        chunk_start = idx * self.chunk_size
        chunk_end = min((idx + 1) * self.chunk_size, len(self.image_list))

        images = []
        for i in range(chunk_start, chunk_end):
            img_name = os.path.join(self.root_dir, self.image_list[i])
            image = Image.open(img_name).convert('RGB')

            if self.transform:
                image = self.transform(image)

            images.append(image)

        return images


In [3]:
# Generator Class
class Generator(nn.Module):
    def __init__(self, noise_dim, text_embedding_dim, img_channels):
        super(Generator, self).__init__()
        self.img_channels = img_channels
        self.noise_dim = noise_dim
        self.text_embedding_dim = text_embedding_dim
        self.init_size = 256 // 8  # Initial size before upscaling

        self.label_emb = nn.Linear(text_embedding_dim, text_embedding_dim)

        self.model = nn.Sequential(
            *self._block(noise_dim + text_embedding_dim, 128, initial=True),
            *self._block(128, 256),
            *self._block(256, 512),
            *self._block(512, 1024),
            nn.ConvTranspose2d(1024, img_channels, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, initial=False):
        layers = [nn.ConvTranspose2d(in_channels, out_channels, 4, stride=1 if initial else 2, padding=0 if initial else 1)]
        layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        return layers

    def forward(self, noise, texts):
        texts = self.label_emb(texts)
        gen_input = torch.cat((noise, texts), -1)
        gen_input = gen_input.view(gen_input.size(0), -1, 1, 1)
        img = self.model(gen_input)
        return img


In [4]:
# Discriminator Class
class Discriminator(nn.Module):
    def __init__(self, img_channels, text_embedding_dim):
        super(Discriminator, self).__init__()
        self.img_channels = img_channels
        self.text_embedding_dim = text_embedding_dim

        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Final layer
        self.final_layer = nn.Conv2d(512, 1, 4, stride=1, padding=0)

        # Text projection
        self.text_projection = nn.Linear(text_embedding_dim, 512 * 4 * 4)  # Adjust based on your architecture

    def forward(self, img, texts):
        # Process image through model
        feature_rep = self.model(img)

        # Process text embeddings
        text_projection = self.text_projection(texts).view(-1, 512, 4, 4)

        # Integrate text embeddings with image features before the final decision
        combined = torch.cat((feature_rep, text_projection), dim=1)

        validity = self.final_layer(combined)
        return validity.squeeze()


In [5]:
# Instantiate Generator and Discriminator
generator = Generator(noise_dim, text_embedding_dim, img_channels)
discriminator = Discriminator(img_channels, text_embedding_dim)


In [6]:
# Define loss function and optimizers
criterion = nn.BCEWithLogitsLoss()
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))


In [7]:
# DataLoader
dataset = CustomDataset(root_dir='D:\\Emogen\\coco2017\\Train_Resizenew', transform=normalize, chunk_size=batch_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [8]:
if __name__ == '__main__':
    # Training loop
    for epoch in range(num_epochs):
        for batch_idx, image_chunk in enumerate(dataloader):
            real_images = torch.stack(image_chunk).squeeze(0)

            # Adversarial ground truths
            valid = torch.ones(real_images.size(0))
            fake = torch.zeros(real_images.size(0))

            # Configure input
            real_images = real_images  # No need to move real_images to GPU
            text_embeddings_batch = text_embeddings[batch_idx * batch_size: (batch_idx + 1) * batch_size]

            # ---------------------
            # Train Discriminator
            # ---------------------

            optimizer_discriminator.zero_grad()

            # Measure discriminator's ability to classify real images
            validity_real = discriminator(real_images, text_embeddings_batch)
            loss_real = criterion(validity_real, valid)

            # Sample noise and generate a batch of fake images
            noise = torch.randn(real_images.size(0), noise_dim)
            gen_images = generator(noise, text_embeddings_batch)

            # Measure discriminator's ability to classify fake images
            validity_fake = discriminator(gen_images.detach(), text_embeddings_batch)
            loss_fake = criterion(validity_fake, fake)

            # Total discriminator loss
            loss_discriminator = 0.5 * (loss_real + loss_fake)

            loss_discriminator.backward()
            optimizer_discriminator.step()

            # -----------------
            # Train Generator
            # -----------------

            optimizer_generator.zero_grad()

            # Generate a batch of images
            gen_images = generator(noise, text_embeddings_batch)

            # Loss measures generator's ability to fool the discriminator
            validity_gen = discriminator(gen_images, text_embeddings_batch)
            loss_generator = criterion(validity_gen, valid)

            loss_generator.backward()
            optimizer_generator.step()

            # Print progress
            if batch_idx % 5 == 0:
                print(
                    f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(dataloader)}] "
                    f"Loss D: {loss_discriminator.item():.4f} Loss G: {loss_generator.item():.4f}"
                )


In [None]:
# Save generated samples at the end of each epoch
        generator.eval()
        with torch.no_grad():
            eval_noise = torch.randn(num_eval_samples, noise_dim)
            eval_text_embeddings = text_embeddings[:num_eval_samples]
            eval_samples = generator(eval_noise, eval_text_embeddings).cpu()

            save_dir = 'D:\\Emogen\\coco2017\\generated_samples'
            os.makedirs(save_dir, exist_ok=True)
            for i in range(num_eval_samples):
                img = transforms.ToPILImage()(eval_samples[i])
                img.save(os.path.join(save_dir, f'generated_sample_{i + 1}_epoch_{epoch + 1}.png'))

        generator.train()

In [None]:
# Save the trained models
    torch.save(generator.state_dict(), 'D:\\Emogen\\coco2017\\generator.pth')
    torch.save(discriminator.state_dict(), 'D:\\Emogen\\coco2017\\discriminator.pth')
