In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import numpy as np

# Set the device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the Generator network
class Generator(nn.Module):
    def __init__(self, noise_dim, text_embed_dim, img_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(noise_dim + text_embed_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, img_size * img_size * 3)  # 3 for RGB channels
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, noise, text_embedding):
        x = torch.cat((noise, text_embedding), dim=1)  # Concatenate noise and text embedding
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.tanh(self.fc3(x))  # Output between -1 and 1 for images
        x = x.view(-1, 3, img_size, img_size)  # Reshape to image dimensions
        return x

# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self, img_size, text_embed_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(img_size * img_size * 3 + text_embed_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, img, text_embedding):
        img = img.view(img.size(0), -1)  # Flatten image
        x = torch.cat((img, text_embedding), dim=1)  # Concatenate image and text embedding
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.sigmoid(self.fc3(x))  # Output between 0 and 1

# Hyperparameters
noise_dim = 100
text_embed_dim = 128  # Assume we have a text embedding of 128 dimensions
img_size = 64  # Generate images of size 64x64
batch_size = 32
lr = 0.0002
num_epochs = 100

# Create generator and discriminator
G = Generator(noise_dim, text_embed_dim, img_size).to(device)
D = Discriminator(img_size, text_embed_dim).to(device)

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)

# Function to generate random noise and fake text embeddings
def generate_noise(batch_size, noise_dim):
    return torch.randn(batch_size, noise_dim).to(device)

def generate_text_embeddings(batch_size, text_embed_dim):
    return torch.randn(batch_size, text_embed_dim).to(device)  # Fake text embeddings

# Training the GAN
for epoch in range(num_epochs):
    for _ in range(batch_size):
        # Train Discriminator
        real_imgs = torch.randn(batch_size, 3, img_size, img_size).to(device)  # Fake real images (normally you'd load these)
        real_text_embeddings = generate_text_embeddings(batch_size, text_embed_dim)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Generate fake images
        noise = generate_noise(batch_size, noise_dim)
        fake_text_embeddings = generate_text_embeddings(batch_size, text_embed_dim)
        fake_imgs = G(noise, fake_text_embeddings)

        # Discriminator loss on real images
        outputs_real = D(real_imgs, real_text_embeddings)
        d_loss_real = criterion(outputs_real, real_labels)

        # Discriminator loss on fake images
        outputs_fake = D(fake_imgs.detach(), fake_text_embeddings)
        d_loss_fake = criterion(outputs_fake, fake_labels)

        # Total Discriminator loss
        d_loss = d_loss_real + d_loss_fake
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        outputs_fake = D(fake_imgs, fake_text_embeddings)
        g_loss = criterion(outputs_fake, real_labels)
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    # Print progress
    print(f'Epoch [{epoch + 1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

    # Save generated images every few epochs
    if (epoch + 1) % 10 == 0:
        save_image(fake_imgs.data[:25], f'generated_images_{epoch + 1}.png', nrow=5, normalize=True)

# Save model
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')


Epoch [1/100], d_loss: 0.6171801090240479, g_loss: 0.9658297300338745
Epoch [2/100], d_loss: 0.684040904045105, g_loss: 0.8820551633834839
Epoch [3/100], d_loss: 0.3821849226951599, g_loss: 1.2799551486968994
Epoch [4/100], d_loss: 1.1392452716827393, g_loss: 1.321393609046936
Epoch [5/100], d_loss: 2.8805367946624756, g_loss: 8.115911483764648
Epoch [6/100], d_loss: 1.7252289056777954, g_loss: 2.943157911300659
Epoch [7/100], d_loss: 0.7031729817390442, g_loss: 1.7058354616165161
Epoch [8/100], d_loss: 1.0193713903427124, g_loss: 2.8527371883392334
Epoch [9/100], d_loss: 1.259008526802063, g_loss: 2.8984405994415283
Epoch [10/100], d_loss: 0.38769012689590454, g_loss: 2.5549330711364746


KeyboardInterrupt: 