In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# RMS Prop
import torch.optim as optim
import os
from torch.utils.data import DataLoader


In [43]:
def get_device():
    return torch.device("mps")

device = get_device()
print(device)

mps


In [44]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512), 
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, z):
        return self.gen(z)
    
class Critic(nn.Module):
    def __init__(self, img_dim):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
        )
    def forward(self, x):
        return self.critic(x)
    

In [45]:
lr = 0.0002
batch_size = 64
epochs = 200
img_dim = 784
img_size = 28
latent_dim = 100

generator = Generator(latent_dim).to(device)
critic = Critic(img_dim).to(device)

optim_g = optim.RMSprop(generator.parameters(), lr=lr)
optim_c = optim.RMSprop(critic.parameters(), lr=lr)
weight_clip = 0.01
critic_iterations = 5


In [46]:
def get_data_loader(img_size=32, batch_size=32):
    transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
    dataset = datasets.MNIST("../../data/mnist", train=True, download=True, transform=transform)
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )

data_loader = get_data_loader(img_size=img_size, batch_size=batch_size)

In [47]:
def save_generated_images(generator, batch_size, fixed_noise, epoch, batch_idx, save_dir="images"):
    """
    Save generated images from the generator.
    
    Args:
        generator: The generator model
        batch_size: Batch size
        fixed_noise: Fixed noise vector for consistent image generation
        epoch: Current epoch number
        batch_idx: Current batch index
        save_dir: Directory to save images
    """
    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Generate images
    generator.eval()  # Set to eval mode
    with torch.no_grad():
        fake = generator(fixed_noise).reshape(-1, 1, 28, 28)
        # Un-normalize from [-1,1] to [0,1] range
        fake = fake * 0.5 + 0.5
        fake = fake.clamp(0, 1)
        
        # Create grid and save
        img_grid_fake = torchvision.utils.make_grid(
            fake, 
            normalize=False, 
            padding=2,
            nrow=int(batch_size ** 0.5)  # Calculate grid size based on batch size
        )
        
        # Create filename with timestamp
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{save_dir}/fake_images_epoch{epoch}_batch{batch_idx}_{timestamp}.png"
        
        torchvision.utils.save_image(img_grid_fake, filename, normalize=False)
    
    generator.train()  # Set back to training mode

# Then in your training loop:
if batch_idx % 100 == 0:
    print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx}/{len(data_loader)}] "
          f"Loss_C: {loss_critic:.4f} Loss_G: {loss_gen:.4f}")
    save_generated_images(generator, fixed_noise, epoch, batch_idx)

In [48]:
len(data_loader)

938

In [49]:
# delete all files in images folder
if os.path.exists("images"):
    for file in os.listdir("images"):
        os.remove(os.path.join("images", file))

fixed_noise = torch.randn(batch_size, latent_dim).to(device)
for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(data_loader):
        real = real.view(-1, img_dim).to(device)
        batch_size = real.shape[0]
        for _ in range(critic_iterations):
            noise = torch.randn(batch_size, latent_dim).to(device)
            fake = generator(noise)
            critic_real = critic(real).view(-1)
            critic_fake = critic(fake).view(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            loss_critic.backward()
            optim_c.step()

            for p in critic.parameters():
                p.data.clamp_(-weight_clip, weight_clip)

        noise = torch.randn(batch_size, latent_dim).to(device)
        fake = generator(noise)
        output = critic(fake).view(-1)
        loss_gen = -torch.mean(output)
        generator.zero_grad()
        loss_gen.backward()
        optim_g.step()


        if batch_idx % 300 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx}/{len(data_loader)}] Loss_C: {loss_critic:.4f} Loss_G: {loss_gen:.4f}")
            # saving images to folder
            with torch.no_grad():
                fake = generator(fixed_noise).reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                torchvision.utils.save_image(img_grid_fake, f"images/fake_images_{epoch}_{batch_idx}.png")





Epoch [0/200] Batch [0/938] Loss_C: -0.2873 Loss_G: 0.2625
