# *Importing the necessary libraries*

In [1]:
# Core libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# For loading and transforming data
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Metrics
from sklearn.metrics import accuracy_score, classification_report, f1_score, recall_score, precision_score

# Additional utilities
from torch.optim import Adam
from torch.nn import Conv2d, ConvTranspose2d, LeakyReLU, BatchNorm2d
from torchvision.utils import save_image

# *Creating the dataset for the training of GANS*

In [2]:
import os
import shutil

# Base directory where the original folders are located
base_dir = '/kaggle/input/dataset/Classifier Data'

# New directory where the combined images will be located
new_base_dir = '/kaggle/working/images/'

# Create new directories if they don't exist
os.makedirs(os.path.join(new_base_dir, 'Diseased'), exist_ok=True)
os.makedirs(os.path.join(new_base_dir, 'No_Disease'), exist_ok=True)

# Categories and diseases
categories = ['train', 'val']
diseases = ['Disease_Present', 'No_Disease']

# Copy the files
for cat in categories:
    for disease in diseases:
        # Directory where the current images are located
        old_dir = os.path.join(base_dir, cat, disease)
        
        # Directory where the images are going to be moved to
        new_dir_name = 'Diseased' if disease == 'Disease_Present' else 'No_Disease'
        new_dir = os.path.join(new_base_dir, new_dir_name)

        # Copy each file
        for filename in os.listdir(old_dir):
            old_file = os.path.join(old_dir, filename)
            new_file = os.path.join(new_dir, filename)
            
            # Check if the file already exists, if so, skip or rename
            if not os.path.exists(new_file):
                shutil.copy(old_file, new_file)  # Copy the file
            else:
                # If a file with the same name exists, append an extra identifier before the extension
                base, extension = os.path.splitext(new_file)
                new_filename = base + '_duplicate' + extension
                shutil.copy(old_file, new_filename)


# **Generator Network**

In [3]:
# Creating the generator architecture
class Generator(nn.Module):
    def __init__(self, z_dim, img_shape, n_classes):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.label_embedding = nn.Embedding(n_classes, n_classes)

        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim + n_classes, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, self.img_shape[0], 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # Embed labels and concatenate with the noise vector
        label_emb = self.label_embedding(labels)  # Transform labels into embeddings
        label_emb = label_emb.unsqueeze(2).unsqueeze(3)  # Reshape to match the batch and noise dimensions
        z = z.unsqueeze(2).unsqueeze(3)  # Reshape z to match the batch and label dimensions

        # Concatenate noise vector z and label embeddings along the channel dimension
        input_gen = torch.cat([z, label_emb], dim=1)

        # Generate an image from the noise vector and labels
        output = self.model(input_gen)
        output = output.view(-1, *self.img_shape)  # Reshape to the output image size (C, H, W)

        return output


# **Discrminator Network**

In [4]:
# Creating the discriminator architecture
class Discriminator(nn.Module):
    def __init__(self, img_shape, n_classes):
        super(Discriminator, self).__init__()
        nc = img_shape[0]  # Number of channels in the images

        self.label_embedding = nn.Embedding(n_classes, n_classes)
        self.model = nn.Sequential(
            nn.Conv2d(nc + n_classes, 64, 4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 2048, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(2048, 1, 4, stride=1, padding=0, bias=False),
            nn.Flatten()
        )

    def forward(self, img, labels):
        label_embedding = self.label_embedding(labels)
        label_embedding = label_embedding.view(-1, label_embedding.size(1), 1, 1)
        label_embedding = label_embedding.repeat(1, 1, img.shape[2], img.shape[3])
        img = torch.cat((img, label_embedding), 1)

        return self.model(img)


# **Setting of hyperparameters and initialising the networks**

In [7]:
import torch.optim as optim

# Hyperparameters
z_dim = 100
img_size = 256
img_channels = 3  # RGB images
n_classes = 2  # Diseased or not
lr_d = 2e-6  # Lower learning rate for better stability
lr_g = 2e-4
batch_size = 32
epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize generator and discriminator
img_shape = (img_channels, img_size, img_size)
generator = Generator(z_dim=z_dim, img_shape=img_shape, n_classes=n_classes).to(device)
discriminator = Discriminator(img_shape=img_shape, n_classes=n_classes).to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.0, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.0, 0.9))

adversarial_loss = nn.BCELoss()

# **Image Transformations**

In [8]:
# Image transformations
from torchvision import datasets, transforms
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Data loaders for your dataset
dataloader = DataLoader(
    datasets.ImageFolder('/kaggle/working/images/', transform=transform),
    batch_size=batch_size,
    shuffle=True,
)


# **Training Loop**

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from PIL import Image
import os

# Function to display and save generated images
def show_and_save_generated(imgs, labels, epoch, generated_images_dir, num_images=10):
    imgs = (imgs + 1) / 2  # Rescale images from [-1,1] to [0,1]
    grid = make_grid(imgs[:num_images], nrow=5).detach().cpu().numpy()
    grid = np.transpose(grid, (1, 2, 0))  # Convert from (C, H, W) to (H, W, C)

    # Display the grid of images
    plt.figure(figsize=(10, 5))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

    # Save individual images
    for i, img in enumerate(imgs):
        class_label = 'No_Disease' if labels[i].item() == 0 else 'Diseased'
        class_dir = os.path.join(generated_images_dir, class_label)
        os.makedirs(class_dir, exist_ok=True)  # Create the class directory if it doesn't exist
        image_path = os.path.join(class_dir, f'epoch_{epoch}_image_{i}.png')
        save_image(img, image_path)

def compute_gradient_penalty(D, real_samples, fake_samples, labels):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
    alpha = alpha.expand_as(real_samples)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates, labels)
    
    # Create a tensor for 'grad_outputs' filled with ones, which is required for the gradient computation
    grad_outputs = torch.ones_like(d_interpolates, device=real_samples.device)
    
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


# Define a directory to save model checkpoints
checkpoint_dir = '/kaggle/working/checkpoints_wgans/'
os.makedirs(checkpoint_dir, exist_ok=True)

# Define a directory to save generated images
generated_images_dir = '/kaggle/working/generated_images_wgans/'
os.makedirs(generated_images_dir, exist_ok=True)


# Define the number of critic iterations per generator iteration
critic_iterations = 3
lambda_gp = 8  # Gradient penalty lambda hyperparameter

# Define a dictionary to store epoch losses
epoch_losses = {}

# Training loop
for epoch in range(epochs):
    epoch_g_loss = 0.0  # Initialize epoch generator loss
    epoch_d_loss = 0.0
    
    for i, (imgs, labels) in enumerate(dataloader):
        
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = torch.randn(imgs.size(0), z_dim).to(device)
        gen_labels = torch.randint(0, n_classes, (imgs.size(0),)).to(device)

        # Generate a batch of images
        fake_imgs = generator(z, gen_labels)

        # Real images
        real_validity = discriminator(real_imgs, labels)
        # Fake images
        fake_validity = discriminator(fake_imgs.detach(), gen_labels)
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, labels)

        # Discriminator loss
        d_loss = fake_validity.mean() - real_validity.mean() + lambda_gp * gradient_penalty

        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()
        if i % critic_iterations == 0:
            # Train the generator every critic_iterations steps
            # -----------------
            #  Train Generator
            # -----------------
            # Generate a batch of images
            gen_imgs = generator(z, gen_labels)
            # Loss measures generator's ability to fool the discriminator
            g_loss = -discriminator(gen_imgs, gen_labels).mean()

            g_loss.backward()
            optimizer_G.step()
        
    # Calculate average losses for this epoch
        avg_g_loss = epoch_g_loss / len(dataloader)
        avg_d_loss = epoch_d_loss / len(dataloader)
        
        # Store the epoch number and losses in the dictionary
        epoch_losses[epoch] = {'generator_loss': avg_g_loss, 'discriminator_loss': avg_d_loss}

        print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
    # Check pointing for every epoch
    if epoch % 10 == 0:
        torch.save(generator.state_dict(), os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth'))
        torch.save(discriminator.state_dict(), os.path.join(checkpoint_dir, f'discriminator_epoch_{epoch}.pth'))
        # Generate and save example images
        with torch.no_grad():
            z_example = torch.randn(10, z_dim).to(device)  # Generate 10 random noise vectors
            gen_labels_example = torch.randint(0, n_classes, (10,)).to(device)  # Generate random labels
            gen_imgs_example = generator(z_example, gen_labels_example)
            show_and_save_generated(gen_imgs_example, gen_labels_example, epoch, generated_images_dir, num_images=100)

In [10]:
print(valid.shape)
print(validity.shape)

torch.Size([32, 1])
torch.Size([32, 1])


In [None]:
import plotly.graph_objects as go

# Convert epoch_losses dictionary to lists
epochs_list = list(epoch_losses.keys())
generator_losses_list = [entry['generator_loss'] for entry in epoch_losses.values()]
discriminator_losses_list = [entry['discriminator_loss'] for entry in epoch_losses.values()]

# Create Plotly figure
fig = go.Figure()

# Add generator loss trace
fig.add_trace(go.Scatter(x=epochs_list, y=generator_losses_list, mode='lines', name='Generator Loss'))

# Add discriminator loss trace
fig.add_trace(go.Scatter(x=epochs_list, y=discriminator_losses_list, mode='lines', name='Discriminator Loss'))

# Update layout
fig.update_layout(
    title='Generator and Discriminator Losses Over Epochs',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    legend=dict(x=0, y=1),
)

# Show plot
fig.show()
