In [1]:
from dataloader import CelebALoad
import os
from torch.utils.data import DataLoader

root_dir = '../dataset/images'
image_list = os.listdir(root_dir)

dataset = CelebALoad(root_dir, image_list, resize=(64,64))
dataloader = DataLoader(dataset, batch_size=32, pin_memory=False, shuffle=True)

In [2]:
import torch
from gan import Generator, Discriminator
from torch import nn

Z_DIM = 256
G = Generator(z_dim=Z_DIM, img_channels=3, feature_g=256)
D = Discriminator(img_channels=3, feature_d=256)
G = G.to("cuda")
D = D.to("cuda")
beta1 = 0.5

optimizer_G = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=4e-4, betas=(beta1, 0.999))

In [3]:
from torchsummary import summary
with torch.no_grad():
    dummy_input = torch.randn(1, Z_DIM, 1, 1).to("cuda")
    output = G(dummy_input)

print(output.shape)
summary(G, (Z_DIM, 1, 1),device="cuda")

torch.Size([1, 3, 64, 64])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1           [-1, 2048, 4, 4]       8,390,656
              ReLU-2           [-1, 2048, 4, 4]               0
   ConvTranspose2d-3           [-1, 1024, 8, 8]      33,555,456
              ReLU-4           [-1, 1024, 8, 8]               0
   ConvTranspose2d-5          [-1, 512, 16, 16]       8,389,120
              ReLU-6          [-1, 512, 16, 16]               0
   ConvTranspose2d-7          [-1, 256, 32, 32]       2,097,408
              ReLU-8          [-1, 256, 32, 32]               0
   ConvTranspose2d-9            [-1, 3, 64, 64]          12,291
             Tanh-10            [-1, 3, 64, 64]               0
Total params: 52,444,931
Trainable params: 52,444,931
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 7.69
P

In [4]:
with torch.no_grad():
    dummy_input = torch.randn(1, 3, 64, 64).to("cuda")
    output = D(dummy_input)
print(output.shape)
summary(D, (3, 64, 64),device="cuda")

torch.Size([1])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 32, 32]          12,544
         LeakyReLU-2          [-1, 256, 32, 32]               0
            Conv2d-3          [-1, 512, 16, 16]       2,097,664
         LeakyReLU-4          [-1, 512, 16, 16]               0
            Conv2d-5           [-1, 1024, 8, 8]       8,389,632
         LeakyReLU-6           [-1, 1024, 8, 8]               0
            Conv2d-7           [-1, 2048, 4, 4]      33,556,480
         LeakyReLU-8           [-1, 2048, 4, 4]               0
            Conv2d-9              [-1, 1, 1, 1]          32,769
Total params: 44,089,089
Trainable params: 44,089,089
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 7.50
Params size (MB): 168.19
Estimated Total Size (MB): 175.73
-----------------

In [5]:
from PIL import Image
import numpy as np
import torchvision.utils as vutils

# Convert and save

def show_generated_images(fake_images, epoch, batch):
    fake_images = fake_images[:32].detach().cpu().copy()
    grid = vutils.make_grid(fake_images, normalize=True, nrow=8)
    img = grid.permute(1, 2, 0).numpy()
    img = ((img)* 255).astype('uint8') 
    img = Image.fromarray(img)
    img.save("generated_images/output_image.png")
    if batch%50 == 0:
        img.save(f"generated_images/e{epoch}_b{batch}.png")
    del fake_images, grid, img, batch, epoch

In [6]:
import sys
import gc
from collections import defaultdict
import psutil
def print_largest_variables(n=5):
    """Prints top `n` largest variables in RAM with their names (if possible)."""
    all_objects = gc.get_objects()
    size_by_name = defaultdict(list)
    
    # Track variable names by object identity
    for obj in all_objects:
        try:
            size = sys.getsizeof(obj)
            referrers = gc.get_referrers(obj)
            for ref in referrers:
                # Check if the object is in globals()/locals()
                if isinstance(ref, dict):
                    for name, val in ref.items():
                        if val is obj:
                            size_by_name[(name, type(obj))].append(size)
                            break
        except:
            continue
    
    # Aggregate sizes per (name, type)
    aggregated = []
    for (name, typ), sizes in size_by_name.items():
        aggregated.append((sum(sizes), name, typ))
    
    # Sort by total size (descending)
    aggregated.sort(reverse=True, key=lambda x: x[0])
    
    print(f"Top {n} largest variables in RAM:")
    for i, (size, name, typ) in enumerate(aggregated[:n], 1):
        print(f"{i}. {name} ({typ}): {size / (1024 ** 2):.2f} MB") 
         
def cleanup_memory():
    gc.collect()  # Force Python garbage collection
    torch.cuda.empty_cache()  # Clear GPU memory (if using CUDA)

In [15]:
from torch.utils.tensorboard import SummaryWriter
from IPython.display import clear_output
from torch.autograd import grad
import torch
from time import time, strftime, localtime
import cv2
import logging as log

# WGAN-GP Gradient Penalty
def gradient_penalty(D, real_images, fake_images, device="cuda"):
        
    batch_size, channels, height, width = real_images.shape
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)  # Random interpolation
    interpolated_images = alpha * real_images+ (1 - alpha) * fake_images
    interpolated_images.requires_grad = True

    # Calculate the gradient
    d_interpolated = D(interpolated_images)
    gradients = grad(
        outputs=d_interpolated, 
        inputs=interpolated_images, 
        grad_outputs=torch.ones(d_interpolated.size()).to(device),
        create_graph=True, 
        retain_graph=True, 
        only_inputs=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()  # Penalty term
    del alpha, interpolated_images, d_interpolated
    del gradients
    return grad_penalty

def save_best_model(g_loss, best_G_loss):
    if g_loss < best_G_loss:

        torch.save(D.state_dict(), 'models/best_model_D.pth')
        best_G_loss = g_loss
        torch.save(G.state_dict(), 'models/best_model_G.pth')
        print(f"✅ Saved best model with g_loss: {g_loss:.4f}")
    return best_G_loss

def denormalize(tensor, mean, std):
    mean = torch.tensor(mean, device=tensor.device).view(3, 1, 1)
    std = torch.tensor(std, device=tensor.device).view(3, 1, 1)
    return tensor * std + mean

def save_sample(fake_imgs, epoch, batch):
    grid = vutils.make_grid(fake_imgs.detach(), nrow=8)  # No need for normalize=False
    grid = denormalize(grid, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Denormalize the [-1, 1] range
    img = grid.clamp(0, 1).mul(255).permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
    img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(f"generated_images/e{epoch}_b{batch}.png", img_bgr)

In [16]:
def train_batch(real_imgs, epochs, epoch, batch, best_G_loss, d_loss_record, g_loss_record, writer, train_d,):
    log.debug(f"RAM Used 1: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")
    cleanup_memory()
    G.train()
    D.train()
    batch_size = real_imgs.size(0)
    ### Train Discriminator ###
    if train_d :    
        with torch.no_grad():
            z = torch.randn(batch_size, Z_DIM, 1, 1, device="cuda")
            fake_imgs = G(z).detach().clone()
            del z
        log.debug(f"RAM Used 2: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")

        # WGAN loss for Discriminator
        d_loss_real = torch.mean(D(real_imgs))   # Expect real images to have high score
        d_loss_fake = torch.mean(D(fake_imgs))   # Expect fake images to have low score
        d_loss = d_loss_fake - d_loss_real
        gp = gradient_penalty(D, real_imgs, fake_imgs, device="cuda")
        lambda_gp = 20  # Typically set to 10, increased to 20 after trial and error (the disc learns fast but not adapting)
        d_loss += lambda_gp * gp
        log.debug(f"RAM Used 4: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")

        # Forward pass
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()
        d_loss_record= d_loss.detach().item()
        log.debug(f"RAM Used 5: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")
        del gp
        del real_imgs, fake_imgs
    ### END Train Discriminator ###
    
    ### Train Generator ###
    if batch%5 == 0:

        z = torch.randn(batch_size, Z_DIM, 1, 1, device="cuda")
        fake_imgs = G(z)
        pred = D(fake_imgs)  

        log.debug(f"RAM Used 6: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")

        g_loss = -torch.mean(pred)  # fool discriminator

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()
        g_loss_record = g_loss.detach().item()
        best_G_loss = save_best_model(g_loss.item(), best_G_loss)
        log.debug(f"RAM Used 7: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")

        del pred, z, g_loss
        log.debug(f"RAM Used 8: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")
    
    ### END Train Generator ###
    
    ### Tensorboard and results
    if batch % 10 == 0:
        writer.add_scalars( "gan/losses", {
                            "D_loss": float(d_loss_record),
                            "G_loss": float(g_loss_record)
                        }, batch + epoch * len(dataloader))
        writer.flush()  # Ensure data is written to disk
        with torch.no_grad():
            save_sample(fake_imgs, epoch, batch)

    if batch == len(dataloader):
        torch.save(D.state_dict(), f'models/model_{epoch}_D.pth')
        torch.save(G.state_dict(), f'models/model_{epoch}_G.pth')
    
    log.debug(f"RAM Used 9: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")

    
    clear_output(wait = True)

    batch += 1

    cleanup_memory()
    log.debug(f"RAM Used 10: {psutil.virtual_memory().used / (1024 ** 3):.2f} GB")
    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
        % (epoch, epochs, batch, len(dataloader), d_loss_record, g_loss_record)
    )
    return batch, best_G_loss, d_loss_record, g_loss_record 


In [17]:
def train(epochs, epoch, batch, best_G_loss, d_loss_record, g_loss_record, writer, train_d = True):
    # To let the stack delete on each iteration
    for real_imgs in dataloader:
        real_imgs = real_imgs.to("cuda")
        batch, best_G_loss, d_loss_record, g_loss_record = train_batch(real_imgs, epochs, epoch, batch, best_G_loss, d_loss_record, g_loss_record, writer, train_d=True)
    return best_G_loss, d_loss_record, g_loss_record

In [18]:
# -500 cutoff is picked after watching logs
log.basicConfig(level=log.INFO)

try:
    now = localtime(time())
    now = strftime("%m-%d_%H_%M", now)
    writer = SummaryWriter(f"Logs/{now}/")
    best_G_loss = float("inf")
    d_loss_record = 0
    g_loss_record = 0
    epochs = 5
    for epoch in range(1, epochs+1):
        batch = 1
        print("Training...")
        critic_n = 5
        best_G_loss, d_loss_record, g_loss_record = train(epochs, epoch, batch, best_G_loss, d_loss_record, g_loss_record, writer)
        cleanup_memory()
        
except KeyboardInterrupt:
    print("Training stopped by user.") 


[Epoch 5/5] [Batch 6333/6332] [D loss: -15.845795] [G loss: -6.875937]


In [20]:
# -500 cutoff is picked after watching logs
log.basicConfig(level=log.INFO)

try:
    epochs = 10
    for epoch in range(1, epochs+1):
        batch = 1
        print("Training...")
        critic_n = 5
        best_G_loss, d_loss_record, g_loss_record = train(epochs, epoch, batch, best_G_loss, d_loss_record, g_loss_record, writer)
        cleanup_memory()
        
except KeyboardInterrupt:
    print("Training stopped by user.") 


[Epoch 10/10] [Batch 6333/6332] [D loss: -12.857266] [G loss: -1.324729]


In [21]:
torch.save(D.state_dict(), 'models/final_D.pth')
torch.save(G.state_dict(), 'models/final_G.pth')