In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
import torchvision

In [None]:
latent_dim = 100
batch_size = 64
epochs = 100
lr = 0.0002
beta1 = 0.5
image_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Create output folder 
os.makedirs("gan_images", exist_ok=True)

In [None]:
# ==== Dataset Loader (MNIST) 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([1], [1])
])
dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# from torchvision.datasets import FashionMNIST # Or MNIST, if you're using that
# train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform) #


In [None]:
# Generator 
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, image_size * image_size),
            nn.Tanh()
        )
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, image_size, image_size)
        return img   

In [None]:
img_size = 32 # new image size

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size, 512),  # was 784 before
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:

adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# Training
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)
        batch_size_curr = real_imgs.size(0)

        valid = torch.ones((batch_size_curr, 1), device=device)
        fake = torch.zeros((batch_size_curr, 1), device=device)

        # Train Generator 
        optimizer_G.zero_grad()
        z = torch.randn(batch_size_curr, latent_dim, device=device)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()
        
        # Train Discriminator 
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Show progress every 300 batches
        if i % 300 == 0:
            print(f" Epoch [{epoch+1}/{epochs}], Batch [{i}/{len(dataloader)}] — "
                  f"G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}")

    # Generate and save image after each epoch
    z = torch.randn(25, latent_dim, device=device)
    gen_imgs = generator(z).detach().cpu()
    gen_imgs =(gen_imgs + 1)/2
    grid = np.transpose(torchvision.utils.make_grid(gen_imgs, nrow=5, padding=2, normalize=False), (1, 2, 0))
    plt.imshow(grid.numpy())
    plt.axis("off")
    plt.savefig(f"gan_images/epoch_{epoch+1}.png")
    plt.close()

    # Just one statement after each epoch
    print(f" Epoch {epoch+1} complete — Check the image in 'gan_images/epoch_{epoch+1}.png'")


In [None]:
# After your training loop finishes
torch.save(generator.state_dict(), "generator.pth")
print(" Generator model saved as generator.pth")



In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

# SAME settings from training 
latent_dim = 85
image_size = 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator definition (MUST match training)
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, 256),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(256, 512),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(512, 1024),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(1024, image_size * image_size),
            torch.nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, image_size, image_size)
        return img

# Load trained generator
generator = Generator().to(device)
generator.load_state_dict(torch.load("generator.pth", map_location=device))
generator.eval()

# Generate new fake images
z = torch.randn(25, latent_dim, device=device)
gen_imgs = generator(z).detach().cpu()

# Rescale from [-1, 1] to [0, 1]
gen_imgs =(gen_imgs + 1)/2

# Display in a grid 
grid = np.transpose(torchvision.utils.make_grid(gen_imgs, nrow=5, padding=2, normalize=False), (1, 2, 0))
plt.imshow(grid.numpy(), cmap="gray")
plt.axis("off")
plt.show()


In [None]:
 # Testing 
import torch
import matplotlib.pyplot as plt

def generate_image_by_number(generator, number, latent_dim, device):
    generator.eval()
    torch.manual_seed(number)  # same number → same random z
    z = torch.randn(1, latent_dim, device=device)

    with torch.no_grad():
        img = generator(z).cpu()

    # Scale [-1, 1] → [0, 1]
    img = 0.5*(img + 1)
    img = img.squeeze(0).permute(1, 2, 0).numpy()

    plt.imshow(img)
    plt.axis("off")
    plt.show()

# Ask user for input
num = int(input("Enter a number: "))
generate_image_by_number(generator, num, latent_dim, device)
