In [1]:
# Core libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# For loading and transforming data
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Metrics
from sklearn.metrics import accuracy_score, classification_report, f1_score, recall_score, precision_score

# Additional utilities
from torch.optim import Adam
from torch.nn import Conv2d, ConvTranspose2d, LeakyReLU, BatchNorm2d
from torchvision.utils import save_image

In [2]:
import os
import shutil

# Base directory where the original folders are located
base_dir = '/kaggle/input/cv-project-detector/Classifier Data'

# New directory where the combined images will be located
new_base_dir = '/kaggle/working/images/'

# Create new directories if they don't exist
os.makedirs(os.path.join(new_base_dir, 'Diseased'), exist_ok=True)
os.makedirs(os.path.join(new_base_dir, 'No_Disease'), exist_ok=True)

# Categories and diseases
categories = ['train', 'val']
diseases = ['Disease_Present', 'No_Disease']

# Copy the files
for cat in categories:
    for disease in diseases:
        # Directory where the current images are located
        old_dir = os.path.join(base_dir, cat, disease)
        
        # Directory where the images are going to be moved to
        new_dir_name = 'Diseased' if disease == 'Disease_Present' else 'No_Disease'
        new_dir = os.path.join(new_base_dir, new_dir_name)

        # Copy each file
        for filename in os.listdir(old_dir):
            old_file = os.path.join(old_dir, filename)
            new_file = os.path.join(new_dir, filename)
            
            # Check if the file already exists, if so, skip or rename
            if not os.path.exists(new_file):
                shutil.copy(old_file, new_file)  # Copy the file
            else:
                # If a file with the same name exists, append an extra identifier before the extension
                base, extension = os.path.splitext(new_file)
                new_filename = base + '_duplicate' + extension
                shutil.copy(old_file, new_filename)


### CGans - 256 by 256

In [4]:
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, img_shape, n_classes):
        super(Discriminator, self).__init__()
        nc = img_shape[0]  # Number of channels in the images

        self.label_embedding = nn.Embedding(n_classes, n_classes)
        self.model = nn.Sequential(
            # input is (nc) x 256 x 256
            nn.Conv2d(nc + n_classes, 64, 4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. (64) x 128 x 128
            nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. (128) x 64 x 64
            nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. (256) x 32 x 32
            nn.Conv2d(256, 512, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. (512) x 16 x 16
            nn.Conv2d(512, 1024, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. (1024) x 8 x 8
            nn.Conv2d(1024, 2048, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            # State size. (2048) x 4 x 4
            nn.Conv2d(2048, 1, 4, stride=1, padding=0, bias=False),
            nn.Flatten(),
            nn.Sigmoid()
            # Output size. 1
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        label_embedding = self.label_embedding(labels)
        label_embedding = label_embedding.view(label_embedding.size(0), label_embedding.size(1), 1, 1)
        label_embedding = label_embedding.repeat(1, 1, img.shape[2], img.shape[3])
        img = torch.cat((img, label_embedding), 1)

        return self.model(img)


In [6]:
# Hyperparameters
import torch.optim as optim
z_dim = 100
img_size = 256
img_channels = 3  # or 3 for RGB images
n_classes = 2  # Diseased or not
lr = 3e-4
batch_size = 64
epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Initialize generator and discriminator
img_shape = (img_channels, img_size, img_size)
generator = Generator(z_dim=z_dim, img_shape=img_shape, n_classes=n_classes).to(device)
generator.label_embedding = generator.label_embedding.to(device)
discriminator = Discriminator(img_shape=img_shape, n_classes=n_classes).to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))


In [7]:
# Image transformations
from torchvision import datasets, transforms
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Data loaders for your dataset
dataloader = DataLoader(
    datasets.ImageFolder('/kaggle/working/images/', transform=transform),
    batch_size=batch_size,
    shuffle=True,
)

# Loss function
adversarial_loss = torch.nn.BCELoss()

In [8]:
# Accessing class names
class_names = dataloader.dataset.classes
print("Class Names:", class_names)

Class Names: ['Diseased', 'No_Disease']


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from PIL import Image
import os

# Function to display and save generated images
def show_and_save_generated(imgs, labels, epoch, generated_images_dir, num_images=10):
    imgs = (imgs + 1) / 2  # Rescale images from [-1,1] to [0,1]
    grid = make_grid(imgs[:num_images], nrow=5).detach().cpu().numpy()
    grid = np.transpose(grid, (1, 2, 0))  # Convert from (C, H, W) to (H, W, C)

    # Display the grid of images
    plt.figure(figsize=(10, 5))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

    # Save individual images
    for i, img in enumerate(imgs):
        class_label = 'Diseased' if labels[i].item() == 0 else 'No_Disease'
        class_dir = os.path.join(generated_images_dir, class_label)
        os.makedirs(class_dir, exist_ok=True)  # Create the class directory if it doesn't exist
        image_path = os.path.join(class_dir, f'epoch_{epoch}_image_{i}.png')
        save_image(img, image_path)

# Define a directory to save model checkpoints
checkpoint_dir = '/kaggle/working/checkpoints_cgans/'
os.makedirs(checkpoint_dir, exist_ok=True)

# Define a directory to save generated images
generated_images_dir = '/kaggle/working/generated_images_cgans/'
os.makedirs(generated_images_dir, exist_ok=True)



epoch_losses = {}
# Training Loop
for epoch in range(epochs):
    
    epoch_g_loss = 0.0  # Initialize epoch generator loss
    epoch_d_loss = 0.0  # Initialize epoch discriminator loss
    for i, (imgs, labels) in enumerate(dataloader):
        
        valid = torch.ones(imgs.size(0), 1, requires_grad=False).to(device)
        fake = torch.zeros(imgs.size(0), 1, requires_grad=False).to(device)
        
        # Configure input
        real_imgs = imgs.to(device)
        labels = labels.to(device)
        
        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = torch.randn(imgs.size(0), z_dim).to(device)
        gen_labels = torch.randint(0, n_classes, (imgs.size(0),)).to(device)

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        
        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
    
    # Calculate average losses for this epoch
    avg_g_loss = epoch_g_loss / len(dataloader)
    avg_d_loss = epoch_d_loss / len(dataloader)
    
    epoch_losses[epoch] = {'generator_loss': avg_g_loss, 'discriminator_loss': avg_d_loss}
        
    print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

    # Save model checkpoints every 10 epochs and display grid of images
    if epoch % 10 == 0:
        torch.save(generator.state_dict(), os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth'))
        torch.save(discriminator.state_dict(), os.path.join(checkpoint_dir, f'discriminator_epoch_{epoch}.pth'))

        # Generate and save example images
        with torch.no_grad():
            z_example = torch.randn(10, z_dim).to(device)  # Generate 10 random noise vectors
            gen_labels_example = torch.randint(0, n_classes, (10,)).to(device)  # Generate random labels
            gen_imgs_example = generator(z_example, gen_labels_example)
            show_and_save_generated(gen_imgs_example, gen_labels_example, epoch, generated_images_dir, num_images=10)


In [10]:
import plotly.graph_objects as go

# Extract generator and discriminator losses
generator_losses = [epoch_losses[epoch]['generator_loss'] for epoch in range(epochs)]
discriminator_losses = [epoch_losses[epoch]['discriminator_loss'] for epoch in range(epochs)]

# Create Plotly figure
fig = go.Figure()

# Add generator loss trace
fig.add_trace(go.Scatter(x=list(range(epochs)), y=generator_losses, mode='lines', name='Generator Loss'))

# Add discriminator loss trace
fig.add_trace(go.Scatter(x=list(range(epochs)), y=discriminator_losses, mode='lines', name='Discriminator Loss'))

# Update layout
fig.update_layout(title='Generator and Discriminator Losses vs Epoch',
                  xaxis_title='Epoch',
                  yaxis_title='Loss')

# Show plot
fig.show()
