In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, Dataset,Subset
from torchvision.utils import save_image
import os


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

batch_size = 64
train_dataset = FashionMNIST(root='data', train=True, download=True, transform=transform)
subset_size = 8000
# Randomly select a subset of indices
subset_indices = np.random.choice(len(train_dataset), subset_size, replace=False)

# Create the subset dataset
train_subset = Subset(train_dataset, subset_indices)
train_loader = DataLoader(dataset=train_subset, batch_size=batch_size, shuffle=True)
len(train_subset)

8000

In [4]:
# defining generator
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_shape):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.init_size = img_shape[1] // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim + num_classes, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, img_shape[0], 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


In [5]:
# defining descriminator
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(num_classes, img_shape[2] * img_shape[1])

        self.model = nn.Sequential(
            nn.Conv2d(img_shape[0] + 1, 16, 3, 2, 1),  # Output: [batch_size, 16, 14, 14]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(16, 32, 3, 2, 1),  # Output: [batch_size, 32, 7, 7]
            nn.BatchNorm2d(32, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(32, 64, 3, 2, 1),  # Output: [batch_size, 64, 4, 4]
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(64, 128, 3, 2, 1),  # Output: [batch_size, 128, 2, 2]
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
        )

        ds_size = img_shape[1] // 16 # Calculate downsampled size correctly
        self.adv_layer = nn.Sequential(nn.Linear(512 * ds_size * ds_size, 1), nn.Sigmoid())

    def forward(self, img, labels):
        # Embed labels and expand to match image dimensions
        labels = self.label_embedding(labels)
        labels = labels.view(labels.size(0), 1, img.size(2), img.size(3))

        # Concatenate image and label embeddings
        d_in = torch.cat((img, labels), 1)

       # print("Input shape:", d_in.shape)  # Debug print

        out = self.model(d_in)

        #print("Output shape after conv layers:", out.shape)  # Debug print

        out = out.view(out.shape[0], -1)

        #print("Output shape after view:", out.shape)  # Debug print

        validity = self.adv_layer(out)

        #print("Output shape after linear layer:", validity.shape)  # Debug print

        return validity




In [6]:
# Define training procedure
latent_dim = 100
img_shape = (1, 28, 28)
num_classes = 10

# Initialize models
generator = Generator(latent_dim, num_classes, img_shape)
discriminator = Discriminator(num_classes, img_shape)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Tensor = torch.FloatTensor


In [7]:
def train(generator, discriminator, optimizer_G, optimizer_D, dataloader, epochs):
    for epoch in range(epochs):
        for i, (imgs, labels) in enumerate(dataloader):

            batch_size = imgs.size(0)

            valid = Tensor(batch_size, 1).fill_(1.0)
            fake = Tensor(batch_size, 1).fill_(0.0)

            real_imgs = imgs.type(Tensor)
            labels = labels.type(torch.LongTensor)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            z = Tensor(np.random.normal(0, 1, (batch_size, latent_dim)))
            gen_labels = torch.randint(0, num_classes, (batch_size,))

            gen_imgs = generator(z, gen_labels)

            g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

            if epoch % 10 == 0:
                os.makedirs("images", exist_ok=True)
                save_image(gen_imgs.data[:25], f"images/{epoch}.png", nrow=5, normalize=True)

train(generator, discriminator, optimizer_G, optimizer_D, train_loader, epochs=100)


[Epoch 0/100] [Batch 0/125] [D loss: 0.6924914121627808] [G loss: 0.6869462728500366]
[Epoch 0/100] [Batch 1/125] [D loss: 0.6918516159057617] [G loss: 0.6864961981773376]
[Epoch 0/100] [Batch 2/125] [D loss: 0.6912848353385925] [G loss: 0.6854057312011719]
[Epoch 0/100] [Batch 3/125] [D loss: 0.6910380125045776] [G loss: 0.6860911846160889]
[Epoch 0/100] [Batch 4/125] [D loss: 0.6899421215057373] [G loss: 0.6855213642120361]
[Epoch 0/100] [Batch 5/125] [D loss: 0.6901485919952393] [G loss: 0.6855944991111755]
[Epoch 0/100] [Batch 6/125] [D loss: 0.6891896724700928] [G loss: 0.6833280324935913]
[Epoch 0/100] [Batch 7/125] [D loss: 0.6892015933990479] [G loss: 0.6830828189849854]
[Epoch 0/100] [Batch 8/125] [D loss: 0.6873176097869873] [G loss: 0.6826400756835938]
[Epoch 0/100] [Batch 9/125] [D loss: 0.6861150860786438] [G loss: 0.6808255314826965]
[Epoch 0/100] [Batch 10/125] [D loss: 0.6858386993408203] [G loss: 0.6810575127601624]
[Epoch 0/100] [Batch 11/125] [D loss: 0.6844223141670

In [8]:
len(train_dataset)

60000