## CNN-Based Classification of Neutrophils in Microscopy Images (ResNet34)

In [None]:
# Import libraries
import torch
import os
import random
import numpy as np
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
from torch import nn, optim
from torchvision.models import ResNet34_Weights
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
from sklearn.utils import shuffle
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths to directories
healthy_dir = '/Users/irdynaumaira/Downloads/CompSci_MSc-project/Images/Healthy'
nets_dir = '/Users/irdynaumaira/Downloads/CompSci_MSc-project/Images/NETs'
class_names = ['Healthy', 'NETs']
image_size = (224, 224)

# CLAHE (Contrast Limited Adaptive Histogram Equalization)
def apply_clahe(image):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return clahe.apply(image)

# Load and preprocess data with CLAHE
def load_data():
    images, labels = [], []
    for folder, label in zip([healthy_dir, nets_dir], [0, 1]):  # 0: Healthy, 1: NETs
        for file in tqdm(os.listdir(folder)):
            img_path = os.path.join(folder, file)
            image = cv2.imread(img_path)
            if image is None:
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, image_size)
            gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            enhanced_image = apply_clahe(gray_image)
            images.append(enhanced_image)
            labels.append(label)
    images = np.array(images, dtype='float32') / 255.0  # Normalize
    labels = np.array(labels, dtype='int64')
    return images, labels

# Load the data
images, labels = load_data()
print("Data is successfully loaded.")

# Dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images, self.labels = images, labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Convert grayscale to RGB
        image = cv2.cvtColor(self.images[idx], cv2.COLOR_GRAY2RGB)
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)
        return image, label

# Data augmentation and transformation
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(108),
    transforms.RandomResizedCrop(image_size, scale=(0.7, 1.3)),
    transforms.ColorJitter(contrast=(0.7, 1.3)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # Standard normalization for ImageNet
])

# Function to set random seeds for reproducibility
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Function to load pre-trained ResNet34 model and modify for binary classification
def load_resnet34():
    weights = ResNet34_Weights.IMAGENET1K_V1
    model = models.resnet34(weights=weights)
    

    # Modify the final fully connected layer for binary classification
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Freeze all layers except the final fully connected layer
    for name, param in model.named_parameters():
        if "fc" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    return model.to(device)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for images_batch, labels_batch in train_loader:
            images_batch, labels_batch = images_batch.to(device), labels_batch.to(device)
            optimizer.zero_grad()
            outputs = model(images_batch)
            loss = criterion(outputs, labels_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels_batch.size(0)
            correct += (predicted == labels_batch).sum().item()

        train_losses.append(running_loss / len(train_loader))
        train_accuracies.append(100 * correct / total)

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for images_batch, labels_batch in val_loader:
                images_batch, labels_batch = images_batch.to(device), labels_batch.to(device)
                outputs = model(images_batch)
                loss = criterion(outputs, labels_batch)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels_batch.size(0)
                val_correct += (predicted == labels_batch).sum().item()

        val_losses.append(val_loss / len(val_loader))
        val_accuracies.append(100 * val_correct / val_total)

        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_losses[-1]:.4f}, "
              f"Val Loss: {val_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.2f}%, "
              f"Val Acc: {val_accuracies[-1]:.2f}%")

    return train_losses, val_losses, train_accuracies, val_accuracies

# Seeds for experiments
seeds = [42, 256, 282, 382, 404]

# Initialize metrics lists
all_train_accuracies = []
all_val_accuracies = []
all_test_accuracies = []
all_test_precisions = []
all_test_recalls = []
all_test_f1s = []

for seed in seeds:
    print(f"\nRunning with seed: {seed}")
    set_seed(seed)

    # Stratified splitting
    train_images, test_images, train_labels, test_labels = train_test_split(
        images, labels, test_size=0.2, random_state=seed, stratify=labels)
    val_images, test_images, val_labels, test_labels = train_test_split(
        test_images, test_labels, test_size=0.5, random_state=seed, stratify=test_labels)

    # Create datasets and dataloaders
    train_dataset = CustomDataset(train_images, train_labels, transform)
    val_dataset = CustomDataset(val_images, val_labels, transform)
    test_dataset = CustomDataset(test_images, test_labels, transform)

    train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Load ResNet34 model
    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)

    # Freeze all layers except the final fully connected layer
    for param in model.parameters():
        param.requires_grad = False
    for param in model.fc.parameters():
        param.requires_grad = True

    # Send model to device
    model = model.to(device)

    # Compute class weights
    class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    # Criterion and optimizer
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=0.0015, weight_decay=0.0001)

    # Train the model
    train_losses, val_losses, train_accuracies, val_accuracies = train_model(
        model, train_loader, val_loader, criterion, optimizer, num_epochs=100)

    # Test evaluation
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for images_batch, labels_batch in test_loader:
            images_batch, labels_batch = images_batch.to(device), labels_batch.to(device)
            outputs = model(images_batch)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(labels_batch.cpu().numpy())

    # Accuracy and metrics
    test_accuracy = 100 * np.sum(np.array(all_preds) == np.array(all_targets)) / len(all_targets)
    all_test_accuracies.append(test_accuracy)

    report = classification_report(all_targets, all_preds, target_names=class_names, output_dict=True)
    all_test_precisions.append(report['weighted avg']['precision'])
    all_test_recalls.append(report['weighted avg']['recall'])
    all_test_f1s.append(report['weighted avg']['f1-score'])

    print(f"Test Accuracy for seed {seed}: {test_accuracy:.2f}%")

    # Confusion Matrix
    cm = confusion_matrix(all_targets, all_preds)
    plt.figure(figsize=(6, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix for Seed {seed}')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()
    
    # Plot training and validation accuracy
    epochs_range = range(1, 101)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
    plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
    plt.title(f'Accuracy for Seed {seed}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    # Plot training and validation loss
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses, label='Validation Loss')
    plt.title(f'Loss for Seed {seed}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

# Calculate and print average metrics
def calculate_mean_and_std(values):
    mean = np.mean(values)
    std = np.std(values)
    return mean, std

metrics = {
    "Test Accuracy (%)": all_test_accuracies,
    "Precision": all_test_precisions,
    "Recall": all_test_recalls,
    "F1-Score": all_test_f1s,
}

for metric_name, values in metrics.items():
    mean, std = calculate_mean_and_std(values)
    print(f'Average {metric_name}: {mean:.2f} ± {std:.2f}')

# Plot box plot for Test Accuracy across different seeds 
plt.figure(figsize=(8, 6))
plt.boxplot(all_test_accuracies, tick_labels=['Test Accuracy'], showmeans=True)
plt.title('Test Accuracy Distribution Across Different Seeds')
plt.ylabel('Accuracy (%)')
plt.grid(True)
plt.show()

# Plot box plots for Precision, Recall, F1-Score across Different Seeds
plt.figure(figsize=(8, 6))
metrics_data = [all_test_precisions, all_test_recalls, all_test_f1s]
plt.boxplot(metrics_data, tick_labels=['Precision', 'Recall', 'F1-Score'], showmeans=True)
plt.title('Performance Metrics Distribution Across Different Seeds')
plt.ylabel('Score')
plt.grid(True)
plt.show()