In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils

# ==== SETTINGS ====
image_dir = "/Users/dhanalakshmijothi/Desktop/python/med_cnn_classifier/data/images"
batch_size = 128
image_size = 64  # Resize images to 64x64
nz = 100         # Size of latent vector
num_epochs = 50
lr = 0.0002
beta1 = 0.5
save_dir = "./generated_images"
os.makedirs(save_dir, exist_ok=True)

# ==== DEVICE ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== TRANSFORM + DATA ====
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(3)], [0.5 for _ in range(3)])  # Normalize to [-1, 1]
])

dataset = datasets.ImageFolder(root=image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# ==== MODELS ====
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=3):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


# ==== INIT MODELS ====
netG = Generator(nz=nz).to(device)
netD = Discriminator().to(device)

# ==== LOSS + OPTIM ====
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Fixed noise to track training progress
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1.
fake_label = 0.

# ==== TRAINING ====
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        ##########################
        # (1) Train Discriminator
        ##########################
        netD.zero_grad()
        real_images = real_images.to(device)
        b_size = real_images.size(0)
        labels = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        output = netD(real_images).view(-1)
        lossD_real = criterion(output, labels)
        lossD_real.backward()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake_images = netG(noise)
        labels.fill_(fake_label)

        output = netD(fake_images.detach()).view(-1)
        lossD_fake = criterion(output, labels)
        lossD_fake.backward()
        optimizerD.step()
        lossD = lossD_real + lossD_fake

        ##########################
        # (2) Train Generator
        ##########################
        netG.zero_grad()
        labels.fill_(real_label)  # flip labels!
        output = netD(fake_images).view(-1)
        lossG = criterion(output, labels)
        lossG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f"[{epoch+1}/{num_epochs}] Step {i}/{len(dataloader)} | Loss_D: {lossD:.4f} | Loss_G: {lossG:.4f}")

    # Save generated images
    fake_grid = vutils.make_grid(netG(fixed_noise).detach(), padding=2, normalize=True)
    vutils.save_image(fake_grid, os.path.join(save_dir, f"fake_epoch_{epoch+1}.png"))

# ==== SAVE MODELS ====
torch.save(netG.state_dict(), "generator.pth")
torch.save(netD.state_dict(), "discriminator.pth")
print("Models saved.")
