## Importing Needed Libraries

<p align="justify">insert spiel</p>

In [47]:
# Part 1: Preprocessing

import torch

from torch.utils.data import DataLoader

from torchvision import datasets, transforms

# Part 2: Creating DCGAN Generator and Discriminator

import torch.nn as nn

import torch.optim as optim

# https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://pyimagesearch.com/2021/10/25/training-a-dcgan-in-pytorch/
# https://medium.com/@manoharmanok/implementing-dcgan-in-pytorch-using-the-celeba-dataset-a-comprehensive-guide-660e6e8e29d2


## Part 1: Data Preprocessing

<p align="justify">insert spiel</p>

In [48]:
# constants

RAW_DATA_DIR = "../training_data"
GAN_SIZE = (128, 128)
CNN_SIZE = (224, 224)
BANANA_CLASSES  = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]

<p align="justify">insert spiel</p>

In [49]:
transform_gan = transforms.Compose([
    transforms.Resize(GAN_SIZE),  # Resize for DCGAN
    transforms.ToTensor(),        # Convert to tensor
    transforms.Normalize(
        [0.5, 0.5, 0.5], 
        [0.5, 0.5, 0.5],
    )  # Normalize to [-1, 1] for DCGAN
])

transform_cnn = transforms.Compose([
    transforms.Resize(CNN_SIZE),  # Resize for CNN
    transforms.ToTensor(),        # Convert to tensor
])

<p align="justify">insert spiel</p>

In [50]:
def load_gan_data(batch_size = 32, workers = 2):
    # Load dataset with GAN transformations
    dataset_gan = datasets.ImageFolder(root=RAW_DATA_DIR, transform = transform_gan)
    
    # Create DataLoader for GAN data
    dataloader_gan = DataLoader(dataset_gan, batch_size=batch_size, shuffle = True, num_workers = workers)

    return dataloader_gan

def load_cnn_data(batch_size = 32, workers = 2):
    # Load dataset with CNN transformations
    dataset_cnn = datasets.ImageFolder(root=RAW_DATA_DIR, transform = transform_cnn)
    
    # Create DataLoader for CNN data
    dataloader_cnn = DataLoader(dataset_cnn, batch_size=batch_size, shuffle = True, num_workers = workers)

    return dataloader_cnn

<p align="justify">insert spiel</p>

In [51]:
def initialize_weights(model):
    classname = model.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

## Part 2: Creating DCGAN Generator and Discriminator

<p align="justify">insert spiel</p>

In [None]:
# constants

BATCH_SIZE = 128        # Number of images per training batch

INPUT_DIMENSION = 100   # Dimensionality of the generator input

NC = 3                  # Number of channels in the training images

NGF = 64                # Base number of feature maps in the generator

NDF = 64                # Base number of feature maps in the discriminator

EPOCHS = 100            # Number of training epochs

LEARNING_RATE = 0.0002  # Learning rate for both optimizers

BETA1 = 0.5             # Beta1 value for the Adam optimizer to help stabilize DCGAN training

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

NGPU = 1  # Number of GPUs to use (0 means CPU only)


In [53]:
# Generator

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()

        self.ngpu = ngpu

        # Generator network composed of a stack of transposed conv blocks
        self.main = nn.Sequential(
            self._block(INPUT_DIMENSION, NGF * 16, 4, 1, 0, False),  # First layer: latent vector -> feature map
            self._block(NGF * 16, NGF * 8, 4, 2, 1, False),          # Upsample to 8 x 8
            self._block(NGF * 8, NGF * 4, 4, 2, 1, False),           # Upsample to 16 x 16
            self._block(NGF * 4, NGF * 2, 4, 2, 1, False),           # Upsample to 32 x 32
            self._block(NGF * 2, NGF, 4, 2, 1, False),               # Upsample to 64 x 64

            nn.ConvTranspose2d(NGF, NC, 4, 2, 1, bias = False),      # Final upsample to 128x128 with RGB output
            nn.Tanh()                                                # Output pixel values in [-1, 1]
        )

    # Helper function to define a generator block:

    # ConvTranspose2d -> BatchNorm2d -> ReLU

    def _block(self, i_channels, o_channels, kernel_size, stride, padding, bias):
        return nn.Sequential(
            nn.ConvTranspose2d(
                i_channels, 
                o_channels, 
                kernel_size, 
                stride, 
                padding, 
                bias),
            nn.BatchNorm2d(o_channels),
            nn.ReLU(True)
        )

    def forward(self, input):
        return self.main(input)


In [54]:
# Discriminator

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()

        self.ngpu = ngpu

        # Discriminator network composed of downsampling conv blocks
        self.main = nn.Sequential(
            self._block(NC, NDF, 4, 2, 1, False, use_batchNorm2D=False), # First block: no BatchNorm
            self._block(NDF,     NDF * 2, 4, 2, 1, False),               # Downsample to 32x32
            self._block(NDF * 2, NDF * 4, 4, 2, 1, False),               # Downsample to 16x16
            self._block(NDF * 4, NDF * 8, 4, 2, 1, False),               # Downsample to 8x8
            self._block(NDF * 8, NDF * 16, 4, 2, 1, False),              # Downsample to 4x4

            nn.Conv2d(NDF * 16, 1, 4, 1, 0, bias = False),               # Final layer: reduce to 1x1
            nn.Sigmoid()                                                # Output probability [0, 1] of real vs fake
        )

    # Helper function to define a discriminator block:

    # Conv2d -> (optional) BatchNorm2d -> LeakyReLU

    def _block(self, i_channels, o_channels, kernel_size, stride, padding, bias, use_batchNorm2D = True):
        layers = [nn.Conv2d(
            i_channels, 
            o_channels, 
            kernel_size, 
            stride, 
            padding, 
            bias)]
        
        if use_batchNorm2D:
            layers.append(nn.BatchNorm2d(o_channels))
        
        layers.append(nn.LeakyReLU(0.2, inplace = True))

        return nn.Sequential(*layers)

    def forward(self, input):
        return self.main(input)


## Part 3: Train a DCGAN for Each Underrepresented Class (Cordana, Healthy, Pestalotiopsis)

<p align="justify">insert spiel</p>

In [None]:
def filter_data_by_class(dataset, class_name):

    class_index = dataset.class_to_idx[class_name]

    return [(image, label) for image, label in dataset if label == class_index]

In [None]:
# to push soon

def train_dcgan_per_class(dataloader_gan, class_name, epochs = EPOCHS):
    # Filter the dataset for the specific class
    filtered_data = filter_data_by_class(dataloader_gan.dataset, class_name)
    dataset_class = torch.utils.data.TensorDataset(*zip(*filtered_data))  # Custom dataset for a single class
    dataloader_class = DataLoader(dataset_class, batch_size=BATCH_SIZE, shuffle=True)

    # Initialize the generator and discriminator for the class
    gen = Generator(ngpu = NGPU).to(DEVICE)
    dis = Discriminator(ngpu = NGPU).to(DEVICE)

    # Optimizers and loss function
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(gen.parameters(), lr = LEARNING_RATE, betas = (BETA1, 0.999))
    optimizer_d = optim.Adam(dis.parameters(), lr = LEARNING_RATE, betas = (BETA1, 0.999))

    # Training loop for the class-specific DCGAN
    for epoch in range(epochs):
        for i, (real_images, _) in enumerate(dataloader_class, 0):
            real_images = real_images.to(DEVICE)

            # Labels for real and fake images
            real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
            fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

            # ========================
            # Train the Discriminator
            # ========================
            dis.zero_grad()

            # Forward pass with real images
            output_real = dis(real_images)
            d_loss_real = criterion(output_real, real_labels)
            d_loss_real.backward()

            # Generate fake images from the generator
            noise = torch.randn(BATCH_SIZE, INPUT_DIMENSION, 1, 1, device = DEVICE)
            fake_images = gen(noise)

            # Forward pass with fake images
            output_fake = dis(fake_images.detach())
            d_loss_fake = criterion(output_fake, fake_labels)
            d_loss_fake.backward()

            # Update discriminator
            d_loss = d_loss_real + d_loss_fake
            optimizer_d.step()

            # ========================
            # Train the Generator
            # ========================
            gen.zero_grad()

            # Forward pass with fake images (again)
            output_fake = dis(fake_images)
            g_loss = criterion(output_fake, real_labels)  # We want the generator to fool the discriminator
            g_loss.backward()

            # Update generator
            optimizer_g.step()

            # Print losses every few iterations
            if i % 50 == 0:
                print(f"Class: {class_name}, Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader_class)}], "
                      f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

        # Save generated images and models periodically (optional)
        if epoch % 10 == 0:
            save_generated_images(epoch, fake_images, class_name)
            save_models(epoch, generator, discriminator, class_name)
