In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, models
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
import time
import numpy as np
import cv2
from PIL import Image
import os
from matplotlib import pyplot as plt
from resotrmer import Restormer_Denoise

#Turn all the randomisation off to ensure the results of every execution is the same 
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Global variables

In [None]:
IMG_SIZE = 224
FOLDER_PATH = "Images/100"

### Denoising Strategy

In [None]:
restomer = Restormer_Denoise("blind")

def denoise_none(img):
    return img

denoise_methods = {
    "None": denoise_none,
    "Restormer": restomer.denoise_image,
    # "Gaussian_Blur": g_blur,
    # "Median_Blur": m_blur,
    # "DnCNN": ...
}

def build_transform(denoise_method: str) -> transforms.Compose:
    denoise_fn = denoise_methods.get(denoise_method, denoise_none)

    return transforms.Compose([
        transforms.Lambda(denoise_fn),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

### Define Image Dataset structure and image transform

In [None]:
class ImageDataSet(Dataset):
    def __init__(self, image_names, transform):
        self.file_names = []
        self.labels = []
        for numeric_label, names in enumerate(image_names):
            self.labels.extend([numeric_label]*len(names))
            self.file_names.extend(names)

        self.transform = transform

    def __getitem__(self, index):
        img_name = self.file_names[index]
        img = Image.open(img_name).convert('RGB')
        img = self.transform(img)
        label = self.labels[index]
        return img, label
    
    def __len__(self):
        return len(self.file_names)

# labels = ["B_Cells", "CD4+_T_Cells", "DCIS_1", "DCIS_2", "Invasive_Tumor", "Prolif_Invasive_Tumor"]
labels = ["Immune_Cells", "Non_Invasive_Tumor", "Invasive_Tumor_Set"]
le = LabelEncoder()
numeric_labels = le.fit_transform(labels)
image_names = []
for _ in numeric_labels:
    image_names.append([])

for (dir_path, dir_names, file_names) in os.walk(FOLDER_PATH):
    parent_folder = os.path.basename(dir_path)
    if parent_folder in labels: # Read the subset of dataset to reduce training time 
        for file in file_names:
            image = cv2.imread(os.path.join(dir_path, file))
            if image.shape[0] < 100 and image.shape[1] < 100: #skip the small image, it doesn't give much info
                continue
            numeric_label = le.transform([parent_folder])[0]
            image_names[numeric_label].append(os.path.join(dir_path, file))


denoising_datasets = {key : ImageDataSet(image_names, build_transform(key)) for key in denoise_methods.keys()}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #check if the computer has GPU

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(image_names))
model = model.to(device)

denoising_models = {key : model for key in denoise_methods.keys()}

In [None]:
# hyper-parameters setting
num_epochs = 100
patience = 10 #for early stopping
batch_size = 32
learning_rate = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def split_dataset(dataset: ImageDataSet) -> list[DataLoader]:
    train_idx, temp_idx = train_test_split(list(range(len(dataset))), test_size=0.3, random_state=0)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=0)

    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)
    test_subset = Subset(dataset, test_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

### Model training

In [None]:
best_val_loss = float('inf')
epoch_no_improvement = 0
best_model_parameters = None
loss_curves = {key : [] for key in denoise_methods.keys()}

for denoise, model in denoising_models.items():
    print(f"Start training {denoise} model.")
    train_loader, val_loader, _ = split_dataset(denoising_datasets.get(denoise))
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        avg_val_loss = val_loss/len(val_loader)
        val_accuracy = 100 * correct / total
        print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")
        loss_curves[denoise].append(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_parameters = model.state_dict()
            epoch_no_improvement = 0
        else:
            epoch_no_improvement += 1
            if epoch_no_improvement == patience:
                print(f"No improvement for {patience} epoches. Early stopping.")
                break

    if best_model_parameters is not None:
        model.load_state_dict(best_model_parameters)

In [None]:
for denoise, model in denoising_models:
    torch.save(model.state_dict(), f'denoise_models/{denoise}.pth')

### Evaluation

In [None]:
accuracies = {{key : -1 for key in denoise_methods.keys()}}
f1_scores = {{key : -1 for key in denoise_methods.keys()}}
confusion_metrics = {{key : None for key in denoise_methods.keys()}}

for denoise, model in denoising_models.items():
    _, _, test_loader = split_dataset(denoising_datasets.get(denoise))
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    
    accuracies[denoise] = accuracy
    f1_scores[denoise] = f1
    confusion_metrics = cm
    
    print(f"{denoise} accuracy: {accuracy}")
    print(f"{denoise} f1 score: {f1}")

In [None]:
fig, axes = plt.subplots(1, 3)
for denoise, cm in confusion_metrics:
    ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_).plot(ax=axes[0])
    axes[0].set_title(f"{denoise} CNN Confusion Matrix")

plt.tight_layout()
plt.show()