In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
import time
import numpy as np
import cv2
from PIL import Image
import os
from matplotlib import pyplot as plt
from resotrmer import Restormer_Denoise
from models_DnCNN import DnCNN_Denoiser
from denoise_classical import GaussianBlur, MedianBlur
import shutil
from copy import deepcopy
import pickle
import seaborn as sns
import pandas as pd
import traceback

#Turn all the randomisation off to ensure the results of every execution is the same 
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

### Define variables

In [None]:
IMG_SIZE = 224
FOLDER_PATH = "Images/100"

restomer = Restormer_Denoise("blind")
dncnn = DnCNN_Denoiser()

denoise_methods = {
    "None": lambda x:x,
    "Restormer": restomer.denoise_image,
    "Gaussian_Blur": GaussianBlur,
    "Median_Blur": MedianBlur,
    "DnCNN": dncnn.denoise_image
}

classification_models = ["CNN", "KNN", "SVM", "Random Forest"]

models_denoising_accuracies = {model: {denoise_method: -1 for denoise_method in denoise_methods.keys()}  # -1 represent not yet calculated
                               for model in classification_models}
models_denoising_f1_scores = {model: {denoise_method: -1 for denoise_method in denoise_methods.keys()}  # -1 represent not yet calculated
                               for model in classification_models}
models_denoising_classification_times = {model: {denoise_method: -1 for denoise_method in denoise_methods.keys()}  # -1 represent not yet calculated
                               for model in classification_models}
models_denoising_confusion_metrics = {model: {denoise_method: -1 for denoise_method in denoise_methods.keys()}  # -1 represent not yet calculated
                               for model in classification_models}

def build_transform(denoise_method: str) -> transforms.Compose:
    denoise_fn = denoise_methods.get(denoise_method, lambda x:x)

    return transforms.Compose([
        transforms.Lambda(denoise_fn),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

### Preprocessing: Merge labels(**Only** execute the following cell if you haven't merge the folders)

In [None]:
# each list argument contains some subfolders, this function merge all images in subfolders into one big folder
# list1: immune cells class
# list2: non-invasive tumor
# list3: invasive tumor
def merge_folder(list1, list2, list3):
    # define folder paths
    source_path = './Images/100/'

    immune_cells = list1

    non_invasive_tumor = list2

    invasive_tumor_cell = list3

    folder_dests = ["Immune_Cells", "Non_Invasive_Tumor", "Invasive_Tumor_Set"]

    # make empty folder
    for i in folder_dests:
        dest_path = source_path + i
        os.makedirs(dest_path, exist_ok=True)
        print(f"{dest_path} is created.")

    # start copying file process
    source_categories = [immune_cells, non_invasive_tumor, invasive_tumor_cell]
    extensions = '*.png'

    count = 1
    cate_count = 0
    for source_dirs in source_categories:
        for src in source_dirs:
            pattern = os.path.join(source_path+src, extensions)
            print(pattern)
            for img_path in glob.glob(pattern):
                print(img_path)
                filename = os.path.basename(img_path)
                dest_path = source_path + folder_dests[cate_count]
                file_dest_path = os.path.join(dest_path, filename)
                # copying images from sub-folder to big folder, (overwirte if same images exist)
                shutil.copy2(img_path, file_dest_path)
                print(f'{count}: Copied {img_path} -> {file_dest_path}')
                count += 1
        cate_count += 1
    return 0

merge_folder(['B_Cells','CD4+_T_Cells'], ['DCIS_1', 'DCIS_2'], ['Invasive_Tumor', 'Prolif_Invasive_Tumor'])

### Define Image Dataset structure and image transform

In [None]:
class ImageDataSet(Dataset):
    def __init__(self, image_names, transform):
        self.file_names = []
        self.labels = []
        for numeric_label, names in enumerate(image_names):
            self.labels.extend([numeric_label]*len(names))
            self.file_names.extend(names)

        self.transform = transform

    def __getitem__(self, index):
        img_name = self.file_names[index]
        img = Image.open(img_name).convert('RGB')
        img = self.transform(img)
        label = self.labels[index]
        return img, label
    
    def __len__(self):
        return len(self.file_names)

# labels = ["B_Cells", "CD4+_T_Cells", "DCIS_1", "DCIS_2", "Invasive_Tumor", "Prolif_Invasive_Tumor"]
labels = ["Immune_Cells", "Non_Invasive_Tumor", "Invasive_Tumor_Set"]
le = LabelEncoder()
numeric_labels = le.fit_transform(labels)
image_names = []
for _ in numeric_labels:
    image_names.append([])

for (dir_path, dir_names, file_names) in os.walk(FOLDER_PATH):
    parent_folder = os.path.basename(dir_path)
    if parent_folder in labels: # Read the subset of dataset to reduce training time 
        for file in file_names:
            image = cv2.imread(os.path.join(dir_path, file))
            if image.shape[0] < 100 and image.shape[1] < 100: #skip the small image, it doesn't give much info
                continue
            numeric_label = le.transform([parent_folder])[0]
            image_names[numeric_label].append(os.path.join(dir_path, file))


denoising_datasets = {key : ImageDataSet(image_names, build_transform(key)) for key in denoise_methods.keys()}

## CNN model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #check if the computer has GPU

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(image_names))
model = model.to(device)

denoising_cnn_models = {key : deepcopy(model).to(device) for key in denoise_methods.keys()}

In [None]:
# hyper-parameters setting
num_epochs = 100
patience = 10 #for early stopping
batch_size = 128
learning_rate = 0.001

In [None]:
def split_dataset(dataset: ImageDataSet):
    train_idx, temp_idx = train_test_split(list(range(len(dataset))), test_size=0.3, random_state=0)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=0)

    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)
    test_subset = Subset(dataset, test_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader

### CNN Model training

In [None]:
training_loss_curves = {key : [] for key in denoise_methods.keys()}
val_loss_curves = {key : [] for key in denoise_methods.keys()}

for denoise, model in denoising_cnn_models.items():
    print(f"Start training {denoise} model.")
    best_val_loss = float('inf')
    epoch_no_improvement = 0
    best_model_parameters = None
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_loader, val_loader, _ = split_dataset(denoising_datasets.get(denoise))
    try:
        scaler = torch.amp.GradScaler()
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            running_loss = 0.0

            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()

                with torch.amp.autocast("cuda"):
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                # outputs = model(images)
                # loss = criterion(outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                # loss.backward()
                # optimizer.step()
                
                running_loss += loss.item()
            
            training_loss = running_loss/len(train_loader)
            training_loss_curves[denoise].append(training_loss)
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {training_loss:.4f}")
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            avg_val_loss = val_loss/len(val_loader)
            val_accuracy = 100 * correct / total
            print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")
            val_loss_curves[denoise].append(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model_parameters = model.state_dict()
                epoch_no_improvement = 0
            else:
                epoch_no_improvement += 1
                if epoch_no_improvement == patience:
                    print(f"No improvement for {patience} epoches. Early stopping.")
                    break

        if best_model_parameters is not None:
            model.load_state_dict(best_model_parameters)
    except Exception as e:
        traceback.print_exc()
        # torch.save(model.state_dict(), f'CNN_{denoise}.pth')

In [None]:
for denoise, model in denoising_cnn_models.items():
    torch.save(model.state_dict(), f'denoised_models/CNN_{denoise}.pth')

### CNN models Evaluation

In [None]:
for denoise, model in denoising_cnn_models.items():
    _, _, test_loader = split_dataset(denoising_datasets.get(denoise))
    model.eval()
    y_true = []
    y_pred = []

    start = time.time()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    end = time.time()

    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    elapsed_time = end - start
    
    models_denoising_accuracies["CNN"][denoise] = accuracy
    models_denoising_f1_scores["CNN"][denoise] = f1
    models_denoising_confusion_metrics["CNN"][denoise] = cm
    models_denoising_classification_times["CNN"][denoise] = elapsed_time
    
    print(f"{denoise} accuracy: {accuracy}")
    print(f"{denoise} f1 score: {f1}")
    print(f"{denoise} classification time: {elapsed_time}")

### Feature extractor for traditional machine learning methods

In [None]:
def datasets_feature_extractor(model, dataset):
    model.eval()
    feature_extractor = nn.Sequential(*list(model.children())[:-1]) # remove the last layer
    feature_extractor.eval()
    feature_extractor.to(device)

    train_features = []
    train_labels = []
    test_features = []
    test_labels = []
    train_loader, _, test_loader = split_dataset(dataset)

    with torch.no_grad():
        for images, labels in train_loader:
            images = images.to(device)
            output = feature_extractor(images).squeeze()
            train_features.append(output.cpu().numpy())
            train_labels.append(labels.cpu().numpy())

    X_train = np.vstack(train_features)
    y_train = np.hstack(train_labels)

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            output = feature_extractor(images).squeeze()
            test_features.append(output.cpu().numpy())
            test_labels.append(labels.cpu().numpy())

    X_test = np.vstack(test_features)
    y_test = np.hstack(test_labels)

    return X_train, y_train, X_test, y_test

In [None]:
for denoise_method, dataset in denoising_datasets:
    X_train, y_train, X_test, y_test = datasets_feature_extractor(model, dataset)

    # SVM
    svm = SVC()
    start = time.time()
    svm.fit(X_train, y_train)
    end = time.time()

    y_pred = svm.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    cm = confusion_matrix(y_test, y_pred)
    elapsed_time = end - start

    models_denoising_accuracies["SVM"][denoise_method] = accuracy
    models_denoising_f1_scores["SVM"][denoise_method] = f1
    models_denoising_confusion_metrics["SVM"][denoise_method] = cm
    models_denoising_classification_times["SVM"][denoise_method] = elapsed_time

    print(f"SVM {denoise_method} accuracy: {accuracy}")
    print(f"SVM {denoise_method} f1 score: {f1}")
    print(f"SVM {denoise_method} classification time: {elapsed_time}")

    #RF
    rf = RandomForestClassifier()
    start = time.time()
    rf.fit(X_train, y_train)
    end = time.time()

    y_pred = rf.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    cm = confusion_matrix(y_test, y_pred)
    elapsed_time = end - start

    models_denoising_accuracies["Random Forest"][denoise_method] = accuracy
    models_denoising_f1_scores["Random Forest"][denoise_method] = f1
    models_denoising_confusion_metrics["Random Forest"][denoise_method] = cm
    models_denoising_classification_times["Random Forest"][denoise_method] = elapsed_time

    print(f"Random Forest {denoise_method} accuracy: {accuracy}")
    print(f"Random Forest {denoise_method} f1 score: {f1}")
    print(f"Random Forest {denoise_method} classification time: {elapsed_time}")

    #Find best k for KNN
    knn_models = []
    knn_accuracies = []
    knn_f1_scores = []
    knn_confusion_metrics = []
    knn_classification_times = []
    for k in range(1, 32, 2): #k = 1 to 31
        knn = KNeighborsClassifier(n_neighbors=k)
        start = time.time()
        knn.fit(X_train, y_train)
        end = time.time()

        y_pred = knn.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred)
        cm = confusion_matrix(y_test, y_pred)
        elapsed_time = end - start

        knn_models.append(knn)
        knn_accuracies.append(accuracy)
        knn_f1_scores.append(f1)
        knn_confusion_metrics.append(cm)
        knn_classification_times.append(elapsed_time)

    knn_accuracies = np.array(knn_accuracies)
    max_idx = np.argmax(knn_accuracies)
    best_k = 2*max_idx+1

    print(f"Best {best_k}NN {denoise_method} accuracy: {knn_accuracies[max_idx]}")
    print(f"Best {best_k}NN {denoise_method} f1 score: {knn_f1_scores[max_idx]}")
    print(f"Best {best_k}NN {denoise_method} classification time: {knn_classification_times[max_idx]}")

    accuracy = float(knn_accuracies[max_idx])
    f1 = knn_f1_scores[max_idx]
    cm = knn_confusion_metrics[max_idx]
    elapsed_time = knn_classification_times[max_idx]

    models_denoising_accuracies["KNN"][denoise_method] = accuracy
    models_denoising_f1_scores["KNN"][denoise_method] = f1
    models_denoising_confusion_metrics["KNN"][denoise_method] = cm
    models_denoising_classification_times["KNN"][denoise_method] = elapsed_time

    with open(f"denoised_models/SVM_{denoise_method}.pkl", "wb") as f:
        pickle.dump(svm, f)
    with open(f"denoised_models/RF_{denoise_method}.pkl", "wb") as f:
        pickle.dump(rf, f)
    with open(f"denoised_models/{best_k}NN_{denoise_method}.pkl", "wb") as f:
        pickle.dump(knn_models[max_idx], f)

## Visualisation

### Accuracy

In [None]:
df = pd.DataFrame(models_denoising_accuracies).T.reset_index().melt(id_vars='index', var_name='Denoising', value_name='Accuracy')
df.columns = ['Model', 'Denoising', 'Accuracy']
plt.figure(figsize=(8, 6))
sns.barplot(data=df, x='Model', y='Accuracy', hue='Denoising')

plt.title('Accuracy by Model and Denoising Methods')
plt.ylabel('Accuracy')
plt.xlabel('Model')
plt.legend(title='Denoising Methods')
plt.tight_layout()
plt.show()

In [None]:
df = pd.DataFrame(models_denoising_f1_scores).T.reset_index().melt(id_vars='index', var_name='Denoising', value_name='Accuracy')
df.columns = ['Model', 'Denoising', 'F1 Score']
plt.figure(figsize=(8, 6))
sns.barplot(data=df, x='Model', y='F1 Score', hue='Denoising')

plt.title('F1 Score by Model and Denoising Methods')
plt.ylabel('F1 Score')
plt.xlabel('Model')
plt.legend(title='Denoising Methods')
plt.tight_layout()
plt.show()

In [None]:
df = pd.DataFrame(models_denoising_classification_times).T.reset_index().melt(id_vars='index', var_name='Denoising', value_name='Accuracy')
df.columns = ['Model', 'Denoising', 'Classification Time']
plt.figure(figsize=(8, 6))
sns.barplot(data=df, x='Model', y='Classification Time', hue='Denoising')

plt.title('Classification Time by Model and Denoising Methods')
plt.ylabel('Classification Time')
plt.xlabel('Model')
plt.legend(title='Denoising Methods')
plt.tight_layout()
plt.show()

### Validation Loss Curve

In [None]:
plt.figure()
for denoise, val_losses in val_loss_curves:
    plt.plot(val_losses, label=denoise)
df = pd.DataFrame(val_loss_curves)

plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Curve For each Denoiser')
plt.legend()
plt.grid(True)
plt.show()

### Confusion Matrix(How?)