In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim, img_size):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim + text_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, img_size * img_size * 3),
            nn.Tanh()
        )

    def forward(self, noise, text):
        x = torch.cat((noise, text), dim=1)
        return self.net(x).view(-1, 3, img_size, img_size)

class Discriminator(nn.Module):
    def __init__(self, img_size, text_dim):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(img_size * img_size * 3 + text_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, text):
        x = torch.cat((img.view(img.size(0), -1), text), dim=1)
        return self.net(x)

noise_dim, text_dim, img_size = 100, 128, 64
batch_size, lr, epochs = 32, 0.0002, 100

G = Generator(noise_dim, text_dim, img_size).to(device)
D = Discriminator(img_size, text_dim).to(device)
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)

def generate_noise(batch_size, dim):
    return torch.randn(batch_size, dim).to(device)

def generate_text_embeddings(batch_size, dim):
    return torch.randn(batch_size, dim).to(device)

for epoch in range(epochs):
    real_imgs = torch.randn(batch_size, 3, img_size, img_size).to(device)
    real_texts = generate_text_embeddings(batch_size, text_dim)
    real_labels, fake_labels = torch.ones(batch_size, 1).to(device), torch.zeros(batch_size, 1).to(device)

    fake_imgs = G(generate_noise(batch_size, noise_dim), generate_text_embeddings(batch_size, text_dim))

    d_loss_real = criterion(D(real_imgs, real_texts), real_labels)
    d_loss_fake = criterion(D(fake_imgs.detach(), real_texts), fake_labels)
    d_loss = d_loss_real + d_loss_fake

    optimizer_D.zero_grad()
    d_loss.backward()
    optimizer_D.step()

    g_loss = criterion(D(fake_imgs, real_texts), real_labels)
    optimizer_G.zero_grad()
    g_loss.backward()
    optimizer_G.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}] | d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")
        save_image(fake_imgs.data[:25], f"generated_{epoch + 1}.png", nrow=5, normalize=True)

torch.save(G.state_dict(), "generator.pth")
torch.save(D.state_dict(), "discriminator.pth")

Epoch [10/100] | d_loss: 1.0295, g_loss: 0.6851
Epoch [20/100] | d_loss: 0.7939, g_loss: 0.7324
Epoch [30/100] | d_loss: 0.6648, g_loss: 0.8607
Epoch [40/100] | d_loss: 0.6500, g_loss: 0.8795
Epoch [50/100] | d_loss: 0.6706, g_loss: 0.8191
Epoch [60/100] | d_loss: 0.7020, g_loss: 0.8238
Epoch [70/100] | d_loss: 0.6308, g_loss: 0.8818
Epoch [80/100] | d_loss: 0.5489, g_loss: 0.9909
Epoch [90/100] | d_loss: 0.4374, g_loss: 1.2419
Epoch [100/100] | d_loss: 0.3093, g_loss: 1.9461
