## Importing Needed Libraries

<p align="justify">insert spiel</p>

In [None]:
# Part 1: Data Preprocessing

import torch

from torch.utils.data import DataLoader, Subset

from torchvision import datasets, transforms

# Part 2: Creating the DCGAN Generator and Discriminator Classes

import torch.nn as nn

# Part 3: Training a DCGAN for Each Underrepresented Class (Cordana, Healthy, Pestalotiopsis)

import shutil

from pathlib import Path

import torch.optim as optim

from torchvision.utils import save_image

import os

# https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://pyimagesearch.com/2021/10/25/training-a-dcgan-in-pytorch/
# https://medium.com/@manoharmanok/implementing-dcgan-in-pytorch-using-the-celeba-dataset-a-comprehensive-guide-660e6e8e29d2


## Part 1: Data Preprocessing

<p align="justify">insert spiel</p>

In [13]:
# constants

RAW_DATA_DIR = "../training_data"
GAN_SIZE = (128, 128)
CNN_SIZE = (224, 224)
BANANA_CLASSES  = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]

<p align="justify">insert spiel</p>

In [None]:
transform_gan_b = transforms.Compose([
    transforms.Resize(GAN_SIZE), # Resize for DCGAN
    transforms.ToTensor(),       # To tensor
    transforms.Normalize(
        [0.5, 0.5, 0.5], 
        [0.5, 0.5, 0.5],
    )  # Normalize to [-1, 1] for DCGAN
])

transform_gan_p = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(
        brightness = 0.2, 
        contrast   = 0.2, 
        saturation = 0.2, 
        hue        = 0.2,
    ),
    transforms.RandomAffine(10),
    transforms.RandomResizedCrop(GAN_SIZE, scale = (0.8, 1.0)),
])

transform_cnn = transforms.Compose([
    transforms.Resize(CNN_SIZE), # Resize for CNN
    transforms.ToTensor(),       # To tensor
])

<p align="justify">insert spiel</p>

In [None]:
def load_gan_data(batch_size = 32, workers = 4, target_class = None, num_variants = 10):
    # Load full dataset with base GAN transformations
    dataset_gan = datasets.ImageFolder(root = RAW_DATA_DIR, transform = transform_gan_b)

    if target_class:
        # Get class index from the class name
        class_index = dataset_gan.class_to_idx[target_class]

        # Filter indices where target matches
        indices = [i for i, (_, label) in enumerate(dataset_gan.samples) if label == class_index]

        # Wrap in a Subset
        dataset_gan = Subset(dataset_gan, indices)

    # Create a list to store augmented images
    augmented_images = []

    # Apply augmentations to each image in the loaded dataset
    for i in range(len(dataset_gan)):
        image, label = dataset_gan[i]

        # Generate num_variants augmented versions of image
        for _ in range(num_variants):
            augmented_image = transform_gan_p(image)

            augmented_images.append((augmented_image, label))

    # Create new dataset with augmented images
    augmented_dataset = torch.utils.data.TensorDataset(
        torch.stack([image[0] for image in augmented_images]),  # Stack all augmented images
        torch.tensor([image[1] for image in augmented_images])  # Stack all labels
    )

    # Create DataLoader for the GAN data
    dataloader_gan = DataLoader(augmented_dataset, batch_size = batch_size, shuffle = True, num_workers = workers)

    return dataloader_gan

def load_cnn_data(batch_size = 32, workers = 4):
    # Load dataset with CNN transformations
    dataset_cnn = datasets.ImageFolder(root=RAW_DATA_DIR, transform = transform_cnn)
    
    # Create DataLoader for the CNN data
    dataloader_cnn = DataLoader(dataset_cnn, batch_size=batch_size, shuffle = True, num_workers = workers)

    return dataloader_cnn

<p align="justify">insert spiel</p>

In [16]:
def initialize_weights(model):
    classname = model.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

## Part 2: Creating the DCGAN Generator and Discriminator Classes

<p align="justify">insert spiel</p>

In [None]:
# constants

BATCH_SIZE = 128        # Number of images per training batch

INPUT_DIMENSION = 100   # Dimensionality of the generator input

NC = 3                  # Number of channels in the training images

NGF = 64                # Base number of feature maps in the generator

NDF = 64                # Base number of feature maps in Discriminator

EPOCHS = 200            # Number of training epochs

LEARNING_RATE = 0.0002  # Learning rate for both optimizers

BETA1 = 0.5             # Beta1 value for the Adam optimizer to help stabilize DCGAN training

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

NGPU = 1  # Number of GPUs to use (0 means CPU only)


In [18]:
# Generator

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()

        self.ngpu = ngpu

        # Generator network composed of a stack of transposed conv blocks
        self.main = nn.Sequential(
            self._block(INPUT_DIMENSION, NGF * 16, 4, 1, 0, bias = False),  # First layer: latent vector -> feature map
            self._block(NGF * 16, NGF * 8, 4, 2, 1, bias = False),          # Upsample to 8 x 8
            self._block(NGF * 8, NGF * 4, 4, 2, 1, bias = False),           # Upsample to 16 x 16
            self._block(NGF * 4, NGF * 2, 4, 2, 1, bias = False),           # Upsample to 32 x 32
            self._block(NGF * 2, NGF, 4, 2, 1, bias = False),               # Upsample to 64 x 64

            nn.ConvTranspose2d(NGF, NC, 4, 2, 1, bias = False),      # Final upsample to 128x128 with RGB output
            nn.Tanh()                                                # Output pixel values in [-1, 1]
        )

    # Helper function to define a generator block:

    # ConvTranspose2d -> InstanceNorm2d -> ReLU -> Dropout

    def _block(self, i_channels, o_channels, kernel_size, stride, padding, bias):
        return nn.Sequential(
            nn.ConvTranspose2d(
                i_channels, 
                o_channels, 
                kernel_size, 
                stride, 
                padding, 
                bias = bias),
            nn.InstanceNorm2d(o_channels),
            nn.ReLU(True),
            nn.Dropout2d(0.3) # Dropout to help regularize on small data
        )

    def forward(self, input):
        return self.main(input)


In [19]:
# Discriminator

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()

        self.ngpu = ngpu

        # Discriminator network composed of downsampling conv blocks
        self.main = nn.Sequential(
            self._block(NC, NDF, 4, 2, 1, bias = False, use_batchNorm2D = False), # First block: no BatchNorm
            self._block(NDF,     NDF *  2, 4, 2, 1, bias = False),                # Downsample to 32 x 32
            self._block(NDF * 2, NDF *  4, 4, 2, 1, bias = False),                # Downsample to 16 x 16
            self._block(NDF * 4, NDF *  8, 4, 2, 1, bias = False),                # Downsample to 8 x 8
            self._block(NDF * 8, NDF * 16, 4, 2, 1, bias = False),                # Downsample to 4 x 4

            nn.Conv2d(NDF * 16, 1, 4, 1, 0, bias = False),                        # Final layer: reduce to 1 x 1
        )

    # Helper function to define a discriminator block:

    # Conv2d -> (optional) InstanceNorm2d -> LeakyReLU

    def _block(self, i_channels, o_channels, kernel_size, stride, padding, bias, use_batchNorm2D = True):
        layers = [nn.Conv2d(
            i_channels, 
            o_channels, 
            kernel_size, 
            stride, 
            padding, 
            bias = bias)]
        
        if use_batchNorm2D:
            layers.append(nn.InstanceNorm2d(o_channels))
        
        layers.append(nn.LeakyReLU(0.2, inplace = True))
        layers.append(nn.Dropout2d(0.3)) # Dropout to help regularize on small data

        return nn.Sequential(*layers)

    def forward(self, input):
        return self.main(input)


## Part 3: Training a DCGAN for Each Underrepresented Class (Cordana, Healthy, Pestalotiopsis)

<p align="justify">insert spiel</p>

In [20]:
GAN_OUTPUT_DIRECTORY_TEST = "../model2/gan_test" # for debugging while training

def prepare_output_directory():

    for cls in BANANA_CLASSES:
        if cls != "sigatoka":
            full_path = Path(GAN_OUTPUT_DIRECTORY_TEST) / cls

            # If the directory exists, remove and recreate it
            if full_path.exists():
                shutil.rmtree(full_path)

            full_path.mkdir(parents = True, exist_ok = True)

prepare_output_directory()

In [None]:
def train_dcgan_per_class(target_class = None, resume = False, checkpoint_path = None):
    if target_class:
        dataloader = load_gan_data(batch_size = BATCH_SIZE, target_class = target_class)
    
    else:
        dataloader = load_gan_data(batch_size = BATCH_SIZE)

    # Initialize Generator and Discriminator
    netG = Generator(ngpu = NGPU).to(DEVICE)
    netD = Discriminator(ngpu = NGPU).to(DEVICE)

    # Initialize weights
    netG.apply(initialize_weights)
    netD.apply(initialize_weights)

    # Handle multi-GPU setup if applicable
    if (DEVICE.type == "cuda") and (NGPU > 1):
        netG = nn.DataParallel(netG, list(range(NGPU)))
        netD = nn.DataParallel(netD, list(range(NGPU)))

    # Loss function
    criterion = nn.BCEWithLogitsLoss()

    # Fixed noise for generating sample outputs and tracking progress during training
    fixed_noise = torch.randn(64, INPUT_DIMENSION, 1, 1, device = DEVICE)

    # Optimizers for Generator and Discriminato
    optimizerD = optim.Adam(netD.parameters(), lr = LEARNING_RATE, betas = (BETA1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr = LEARNING_RATE, betas = (BETA1, 0.999))

    # Labels for real and fake images
    real_label = 0.9 # Slightly less than 1
    fake_label = 0.1 # Slightly more than 0

    # Default starting epoch
    start_epoch = 0

    # Allows training models from a checkpoint para hindi paulit-ulit

    if resume and checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location = DEVICE)
        netG.load_state_dict(checkpoint["netG"])
        netD.load_state_dict(checkpoint["netD"])
        optimizerG.load_state_dict(checkpoint["optimizerG"])
        optimizerD.load_state_dict(checkpoint["optimizerD"])

        start_epoch = checkpoint["epoch"] + 1

    SAVE_EVERY = 10  # Save checkpoint every 10 epochs

    # Actual training
    for epoch in range(EPOCHS):
        for i, (real_images, _) in enumerate(dataloader): # Iterate through batches in the dataset
            # 1. Update Discriminator: 
            #    maximize log(D(x)) + log(1 - D(G(z)))

            # 1.A. Train Discriminator on real images
            netD.zero_grad()

            # Format real batch
            real_images = real_images.to(DEVICE)

            # Well, train the Discriminator on noisy real images
            noise = torch.randn_like(real_images) * 0.1 # 0.1 controls the magnitude of noise

            noisy_real_images = real_images + noise

            size = real_images.size(0)

            label = torch.full((size,), real_label, dtype = torch.float, device = DEVICE)

            # Forward pass noisy real images through Discriminator
            output = netD(noisy_real_images).view(-1)

            # Calculate Discriminator loss for noisy real images
            errD_real = criterion(output, label)

            # Backpropagate error for noisy real images
            errD_real.backward()

            # Mean output for noisy real images
            D_x = output.mean().item()

            # 1.B. Train Discriminator on batch of all fake images

            # Generate batch of latent vectors
            noise = torch.randn(size, INPUT_DIMENSION, 1, 1, device = DEVICE)

            # Generate fake images with Generator
            fake = netG(noise)

            # Classify fake images with Discriminator
            label.fill_(fake_label)

            # Forward pass fake images through Discriminator
            output = netD(fake.detach()).view(-1)

            # Calculate Discriminator loss for fake images
            errD_fake = criterion(output, label)

            # Backpropagate error for fake images
            errD_fake.backward()

            # Clip Discriminator gradients for stability
            torch.nn.utils.clip_grad_norm_(netD.parameters(), max_norm = 1.0)

            # Mean output for fake images
            D_G_z1 = output.mean().item()

            # Compute total Discriminator error = real error + fake error
            errD = errD_real + errD_fake

            # Finally update Discriminator
            optimizerD.step()

            # 2. Update Generator: 
            #    maximize log(D(G(z)))

            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for Generator cost

            # Pass fake images through Discriminator
            output = netD(fake).view(-1)

            # Calculate Generator loss based on Discriminator's output
            errG = criterion(output, label)

            # Backpropagate error for Generator
            errG.backward()

            # Mean output for fake images after Generator update
            D_G_z2 = output.mean().item()

            # Finally update Generator
            optimizerG.step()

            # Debugging: Print losses and monitor training progress

            if i % 50 == 0:
                print(
                  f"Epoch [{epoch}/{EPOCHS}] Batch {i}/{len(dataloader)} \
                    Loss D: {errD.item():.4f}, loss G: {errG.item():.4f} \
                    D(x): {D_x:.4f}, \
                    D(G(z))_real: {D_G_z1:.4f}, D(G(z))_fake: {D_G_z2:.4f}"
                )

        if epoch % SAVE_EVERY == 0:
            fake_images = netG(fixed_noise).detach()

            if target_class:
                path = f"{GAN_OUTPUT_DIRECTORY_TEST}/{target_class}"
            
            else:
                path = GAN_OUTPUT_DIRECTORY_TEST

            # Save images Generator could produce during checkpoints
            save_image(
                fake_images,
                os.path.join(path, f"sample_epoch_{epoch}.png"),
                normalize = True
            )

            # Save model version
            save_dict = {
                "epoch": epoch,
                "netG": netG.state_dict(),
                "netD": netD.state_dict(),
                "optimizerG": optimizerG.state_dict(),
                "optimizerD": optimizerD.state_dict(),
            }

            torch.save(save_dict, os.path.join(path, f"checkpoint_epoch_{epoch}.pth"))

    return netG, netD

In [None]:
# Train DCGAN for "cordana" class
trained_generator_cordana, _ = train_dcgan_per_class(target_class = "cordana")

# train_dcgan_per_class(target_class = "cordana", resume = True, checkpoint_path = "<insert>/checkpoint_epoch_<insert>.pth")


Epoch [0/100] Batch 0/12                     Loss D: 2.2262, loss G: 5.4314                     D(x): 0.8105,                     D(G(z))_real: 0.8789, D(G(z))_fake: -6.0181
Epoch [1/100] Batch 0/12                     Loss D: 1.6429, loss G: 2.9868                     D(x): 1.9622,                     D(G(z))_real: -1.0652, D(G(z))_fake: -3.0778
Epoch [2/100] Batch 0/12                     Loss D: 1.4284, loss G: 4.7325                     D(x): 2.3362,                     D(G(z))_real: -0.6259, D(G(z))_fake: -5.2165
Epoch [3/100] Batch 0/12                     Loss D: 1.6087, loss G: 5.6744                     D(x): 3.0469,                     D(G(z))_real: -0.5782, D(G(z))_fake: -6.2853
Epoch [4/100] Batch 0/12                     Loss D: 1.3591, loss G: 8.9547                     D(x): 4.0753,                     D(G(z))_real: -1.2167, D(G(z))_fake: -9.9492
Epoch [5/100] Batch 0/12                     Loss D: 1.1862, loss G: 0.6901                     D(x): 3.7823,                 

In [None]:
# Train DCGAN for "healthy" class
# trained_generator_healthy, _ = train_dcgan_per_class(target_class = "healthy")

In [None]:
# Train DCGAN for "pestalotiopsis" class
# trained_generator_pestalotiopsis, _ = train_dcgan_per_class(target_class = "pestalotiopsis")

## Part 4: Generating Images for Each Underrepresented Class (Cordana, Healthy, Pestalotiopsis)

<p align="justify">insert spiel</p>

In [24]:
def generate_synthetic_images(dcgan_generator, amount_to_generate, class_label, output_directory):

    # Set the generator to evaluation mode to disable Dropout and InstanceNorm2d updates
    dcgan_generator.eval()

    # Construct the path to the class-specific output directory
    class_output_directory = os.path.join(output_directory, class_label)

    # Create the output directory if it does not exist just in case
    os.makedirs(class_output_directory, exist_ok = True)

    # Disable gradient computation for efficiency during inference
    with torch.no_grad():
        for i in range(0, amount_to_generate, 16): # Batches of 16
            batch_size = min(16, amount_to_generate - i) # Adjusts batch size if near the end of generation

            # Sample random noise vectors as generator input
            noise = torch.randn(batch_size, INPUT_DIMENSION, 1, 1, device = DEVICE)

            # Generate a batch of fake images from the noise
            fake = dcgan_generator(noise)

            # Save each generated image to the output directory
            for j in range(batch_size):
                save_image(
                    fake[j], # Single image tensor
                    os.path.join(class_output_directory, f"gen_{i + j}.png"),
                    normalize = True
                )

In [None]:
GAN_OUTPUT_DIRECTORY_BALANCED = "../model2/balanced"

# Target count based on the dominant Sigatoka class
TARGET_COUNT = 424

# Dictionary of class names and their real image counts
real_image_counts = {
    "cordana"        : 145,
    "healthy"        : 115,
    "pestalotiopsis" : 155,
}

# Dictionary mapping class labels to their corresponding trained generators
trained_generators = {
    "cordana"        : trained_generator_cordana,
    "healthy"        : trained_generator_healthy,
    "pestalotiopsis" : trained_generator_pestalotiopsis,
}

# Generate synthetic images for each underrepresented class
for label, real_count in real_image_counts.items():
    amount_to_generate = TARGET_COUNT - real_count

    generator = trained_generators[label]

    generate_synthetic_images(
        dcgan_generator = generator, amount_to_generate = amount_to_generate, class_label = label, output_directory = GAN_OUTPUT_DIRECTORY_BALANCED
    )