In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image

# Set the device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the Generator network
class Generator(nn.Module):
    def __init__(self, noise_dim, text_embed_dim, img_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(noise_dim + text_embed_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, img_size * img_size * 3)  # 3 for RGB channels
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, noise, text_embedding):
        x = torch.cat((noise, text_embedding), dim=1)  # Concatenate noise and text embedding
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.tanh(self.fc3(x))  # Output between -1 and 1 for images
        x = x.view(-1, 3, img_size, img_size)  # Reshape to image dimensions
        return x


# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self, img_size, text_embed_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(img_size * img_size * 3 + text_embed_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, img, text_embedding):
        img = img.view(img.size(0), -1)  # Flatten image
        x = torch.cat((img, text_embedding), dim=1)  # Concatenate image and text embedding
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))  # Output between 0 and 1
        return x


# Hyperparameters
noise_dim = 100
text_embed_dim = 128  # Text embedding dimension
img_size = 64  # Image size (64x64)
batch_size = 32
lr = 0.0002
num_epochs = 100

# Create generator and discriminator
G = Generator(noise_dim, text_embed_dim, img_size).to(device)
D = Discriminator(img_size, text_embed_dim).to(device)

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)

# Function to generate random noise and text embeddings
def generate_noise(batch_size, noise_dim):
    return torch.randn(batch_size, noise_dim).to(device)

def generate_text_embeddings(batch_size, text_embed_dim):
    return torch.randn(batch_size, text_embed_dim).to(device)  # Fake text embeddings


# Training the GAN
for epoch in range(num_epochs):
    for _ in range(batch_size):  # Adjusted to mimic mini-batches
        # Generate fake and real data
        real_imgs = torch.randn(batch_size, 3, img_size, img_size).to(device)  # Simulated real images
        real_text_embeddings = generate_text_embeddings(batch_size, text_embed_dim)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Generate fake images
        noise = generate_noise(batch_size, noise_dim)
        fake_text_embeddings = generate_text_embeddings(batch_size, text_embed_dim)
        fake_imgs = G(noise, fake_text_embeddings)

        # Train Discriminator
        outputs_real = D(real_imgs, real_text_embeddings)
        d_loss_real = criterion(outputs_real, real_labels)

        outputs_fake = D(fake_imgs.detach(), fake_text_embeddings)
        d_loss_fake = criterion(outputs_fake, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        outputs_fake = D(fake_imgs, fake_text_embeddings)
        g_loss = criterion(outputs_fake, real_labels)  # Generator tries to fool the discriminator
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    # Print progress
    print(f'Epoch [{epoch + 1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

    # Save generated images every 10 epochs
    if (epoch + 1) % 10 == 0:
        save_image(fake_imgs.data[:25], f'generated_images_{epoch + 1}.png', nrow=5, normalize=True)

# Save the trained model
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')


Epoch [1/100], d_loss: 0.6463, g_loss: 0.8963
Epoch [2/100], d_loss: 0.6301, g_loss: 0.8505
Epoch [3/100], d_loss: 0.3155, g_loss: 1.5057
Epoch [4/100], d_loss: 1.3964, g_loss: 1.0284
Epoch [5/100], d_loss: 2.3029, g_loss: 11.3853
Epoch [6/100], d_loss: 2.3182, g_loss: 2.4406
Epoch [7/100], d_loss: 1.5308, g_loss: 0.8348
Epoch [8/100], d_loss: 1.0777, g_loss: 1.5070
Epoch [9/100], d_loss: 0.8694, g_loss: 1.7199
Epoch [10/100], d_loss: 1.8924, g_loss: 0.5584
Epoch [11/100], d_loss: 1.9785, g_loss: 0.9841
Epoch [12/100], d_loss: 0.2918, g_loss: 2.3639
Epoch [13/100], d_loss: 1.4044, g_loss: 0.5039
Epoch [14/100], d_loss: 0.6987, g_loss: 1.6821
Epoch [15/100], d_loss: 1.4220, g_loss: 0.5217
Epoch [16/100], d_loss: 1.5823, g_loss: 1.0236
Epoch [17/100], d_loss: 1.3948, g_loss: 6.3302
Epoch [18/100], d_loss: 1.9811, g_loss: 3.6895
Epoch [19/100], d_loss: 0.8004, g_loss: 2.9090
Epoch [20/100], d_loss: 0.6083, g_loss: 3.2473
Epoch [21/100], d_loss: 0.8178, g_loss: 4.0133
Epoch [22/100], d_los