<a href="https://colab.research.google.com/github/mirahhamid/GAN-medical-image-augmentation/blob/main/GAN_PYTORCH.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Simple Implementation of GAN on MRI images using Pytorch**

Total of 70 MRI image samples used in this implementation. The data is acquired from Alzheimer's Disease Neuroimaging Initiative (ADNI) database --> (https://adni.loni.usc.edu/)

Input data dimension size: 256 x 256 x 3

Two main components of GAN: Discriminator (D) and Generator (G)

We build the simple convolutional neural networks for both components.
Generator: Generate 'fake' image based on random noise (vector numbers)
Discriminator: Do binary classification --> 'fake' image: 0, 'real':1

**Notes: Set the GPU --> change runtime type**

## **Mount the drive**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

## **Import** **the libraries and packages**

In [None]:

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim


## **Initialization of the parameters and configuration**

In [None]:

SEED = 40
NOISE_DIM = 100
CHANNELS  = 3
WIDTH = HEIGHT = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" #use GPU

# training parameters
EPOCHS = 10
BATCH_SIZE = 10
STEPS_PER_EPOCH = 2000

# MRI data folder
MAIN_DIR = "/content/drive/MyDrive/Colab Notebooks/AD-NEW"  # change to your path

torch.manual_seed(SEED)
np.random.seed(SEED)


def set_requires_grad(model, flag: bool):
    for p in model.parameters():
        p.requires_grad = flag

## **Visualize the generated image**

In [None]:

@torch.no_grad()
def sample_images_torch(generator, noise_np, nrow=2, ncol=5, figsize=(22, 8), save=False, save_prefix="gen"):
    """
    noise_np: (N, NOISE_DIM) numpy float32
    Generates N images, displays first nrow*ncol.
    """
    generator.eval()
    z = torch.from_numpy(noise_np).float().to(DEVICE)        # (N, NOISE_DIM)
    fake = generator(z)                                      # (N, 3, 256, 256) in [-1,1]
    fake = (fake + 1) / 2                                    # -> [0,1]
    fake = fake.clamp(0, 1).cpu()

    plt.figure(figsize=figsize)
    for i in range(nrow * ncol):
        img = fake[i].permute(1, 2, 0).numpy()               # HWC
        plt.subplot(nrow, ncol, i + 1)
        plt.imshow(img)
        plt.axis("off")
        if save:
            plt.savefig(f"{save_prefix}_{i}.png", dpi=200)
    plt.tight_layout()
    plt.show()
    generator.train()

## **Load the data**

In [None]:
def load_images(folder):
    imgs = []
    for fname in os.listdir(folder):
        img_path = os.path.join(folder, fname)

        # Force grayscale
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue

        img = cv2.resize(img, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)

        # Repeat grayscale to 3 channels
        img = np.stack([img, img, img], axis=-1)  # (H, W, 3)

        imgs.append(img)

    return np.array(imgs, dtype=np.uint8)

data = load_images(MAIN_DIR)
print("Loaded data:", data.shape)

if data.shape[0] < BATCH_SIZE:
    raise ValueError(f"Not enough images ({data.shape[0]}) for batch size {BATCH_SIZE}.")

X_train = data

# Normalize to [-1, 1]
X_train = (X_train.astype(np.float32) - 127.5) / 127.5  # float32

print("X_train:", X_train.shape, X_train.dtype, "min/max:", X_train.min(), X_train.max())

# Visualize some real images
plt.figure(figsize=(20, 8))
for i in range(min(10, X_train.shape[0])):
    plt.subplot(2, 5, i + 1)
    img = (X_train[i] + 1) / 2
    plt.imshow(np.clip(img, 0, 1))
    plt.axis("off")
plt.tight_layout()
plt.show()

# Move to torch tensor (N,3,256,256)
X_train_torch = torch.from_numpy(X_train).float().permute(0, 3, 1, 2)  # NHWC -> NCHW

In [None]:

'''def load_images(folder):
    imgs = []
    for fname in os.listdir(folder):
        img_path = os.path.join(folder, fname)
        try:
            img = cv2.imread(img_path)
            if img is None:
                continue

            # BGR -> RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # resize to 256x256
            img = cv2.resize(img, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)

            imgs.append(img)
        except Exception as e:
            # print(img_path, e)
            continue

    imgs = np.array(imgs, dtype=np.uint8)  # (N,256,256,3)
    return imgs'''



## **Build the Generator and Discriminator Model**

In [None]:

class Generator(nn.Module):
    """
    noise (NOISE_DIM,) -> 256x256x3
    Start from 32x32 then upsample:
    32 -> 64 -> 128 -> 256
    """
    def __init__(self, noise_dim=NOISE_DIM, out_channels=CHANNELS):
        super().__init__()

        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 32 * 32 * 256),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 32->64
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1),  # 64->128
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(128,  64, kernel_size=4, stride=2, padding=1),  # 128->256
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z)                       # (B, 32*32*256)
        x = x.view(-1, 256, 32, 32)          # (B, 256, 32, 32)
        return self.up(x)                    # (B, 3, 256, 256)

class Discriminator(nn.Module):
    """
    256x256x3 -> 1 (real/fake)
    """
    def __init__(self, in_channels=CHANNELS):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),   # 256->256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),           # 256->128
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),          # 128->64
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),          # 64->32
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),          # 32->16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 16 * 16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)

# ================= BUILD =================
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

print(generator)
print(discriminator)

## **Set the loss function and optimizer**

In [None]:
criterion = nn.BCELoss()
D_OPT   = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
GAN_OPT = optim.Adam(generator.parameters(),     lr=2e-4, betas=(0.5, 0.999))

## **Start Training**

In [None]:

for epoch in range(EPOCHS):
    last_d_loss, last_g_loss = None, None

    for _ in tqdm(range(STEPS_PER_EPOCH), desc=f"Epoch {epoch+1}/{EPOCHS}"):
        # ----- Sample real images
        idx = torch.randint(0, X_train_torch.size(0), (BATCH_SIZE,))
        real_X = X_train_torch[idx].to(DEVICE)  # (B,3,256,256)

        # ----- Generate fake images
        z = torch.randn(BATCH_SIZE, NOISE_DIM, device=DEVICE)
        fake_X = generator(z)  # (B,3,256,256)

        # ----- Train Discriminator (real=1, fake=0)
        discriminator.train()
        generator.train()
        set_requires_grad(discriminator, True)

        D_OPT.zero_grad()

        real_y = torch.ones(BATCH_SIZE, 1, device=DEVICE)
        fake_y = torch.zeros(BATCH_SIZE, 1, device=DEVICE)

        pred_real = discriminator(real_X)
        loss_real = criterion(pred_real, real_y)

        pred_fake = discriminator(fake_X.detach())
        loss_fake = criterion(pred_fake, fake_y)

        d_loss = 0.5 * (loss_real + loss_fake)
        d_loss.backward()
        D_OPT.step()

        # ----- Train Generator (freeze D): want D(G(z)) = 1
        set_requires_grad(discriminator, False)
        GAN_OPT.zero_grad()

        z2 = torch.randn(BATCH_SIZE, NOISE_DIM, device=DEVICE)
        gen_X = generator(z2)
        pred = discriminator(gen_X)

        g_loss = criterion(pred, real_y)
        g_loss.backward()
        GAN_OPT.step()

        last_d_loss = d_loss.item()
        last_g_loss = g_loss.item()

    print(f"EPOCH: {epoch + 1} | Generator Loss: {last_g_loss:.4f} | Discriminator Loss: {last_d_loss:.4f}")

    noise_np = np.random.normal(0, 1, size=(10, NOISE_DIM)).astype(np.float32)
    sample_images_torch(generator, noise_np, nrow=2, ncol=5, figsize=(22, 8), save=False)

# Final samples
noise_np = np.random.normal(0, 1, size=(100, NOISE_DIM)).astype(np.float32)
sample_images_torch(generator, noise_np, nrow=10, ncol=10, figsize=(24, 20), save=False)
