## Some Instruction before you run this notebook.
1. Change the line 29 of dataloader_clustering module from wsi_id = wsi_file[:12] to wsi_id = wsi_folder
2. Edit the second cell configurations according to your paths
3. Use patch level features for this. (Fivecrops or Patch Level Averaged) 

In [None]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
# print(torch.version)
# print(torch.version.cuda)
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from dataloader_clustering import WSIDataset
from eval_patch_features.logistic import eval_linear
from eval_patch_features.ann import eval_ANN
from eval_patch_features.knn import eval_knn
from eval_patch_features.protonet import eval_protonet
from eval_patch_features.metrics import get_eval_metrics, print_metrics
from utility import calculate_metric_averages, average_confusion_matrices, write_data_in_excel, build_probs_df
import warnings
warnings.filterwarnings("ignore")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Configurations

In [None]:
# configs
VECTOR_DIM = 1280
HIDDEN_DIM = 768 
FM_MODEL = "baseline"
RUNS_RESULT = "average"
ANN_RUNS = 20
CLUSTERING_METHOD = 'kmeans'
NUM_CLUSTERS = 2
NUM_PATCHES_PER_CLUSTER = 0
BATCH_SIZE = 4
K_FOLDS_PATH = r"E:\Aamir Gulzar\dataset\paip_data\labels\TrainTest_paip.csv"
DATA_PATH = f"E:\Aamir Gulzar\dataset\paip_data\{FM_MODEL}_FiveCrop_Features"
MODEL_SAVE_PATH = f"E:\KSA Project\KSAproject_pipeline1\WSI_Classification\Clustering\PAIP-IV\{FM_MODEL}_{NUM_CLUSTERS}Cluster_Classifiers"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
OUTPUT_SAVE_PATH = r"E:\KSA Project\KSAproject_pipeline1\WSI_Classification\Clustering\PAIP-IV"
os.makedirs(OUTPUT_SAVE_PATH, exist_ok=True)
# create a excel sheet in the output folder to save the results
EVAL_METRICS_EXCEL = os.path.join(OUTPUT_SAVE_PATH, "PAIP-IV_2cluster_eval_metrics.xlsx")
PROBS_ALL_EXCEL = os.path.join(OUTPUT_SAVE_PATH, "PAIP-IV_2cluster_probs_all.xlsx")


### Trainer Function

In [None]:
def train_and_evaluate(fold,train_loader, test_loader, model_type='lin'):
    all_train_feats, all_train_labels, all_test_feats, all_test_labels = [], [], [], []
    all_test_ids = []
    # Prepare training and testing data
    for features, label, _ in train_loader:
        all_train_feats.append(features)
        all_train_labels.append(label)
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)
        # Store as single WSI IDs from the batch 
        if isinstance(wsi_id, (list, tuple)):
            all_test_ids.extend(wsi_id)
        else:
            all_test_ids.append(wsi_id)
    # Convert lists to tensors
    global train_feats, train_labels, val_feats, val_labels, test_feats, test_labels
    train_feats = torch.cat(all_train_feats)
    train_labels = torch.cat([labels.clone().detach() for labels in all_train_labels])
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    
    # Select the model based on the input argument
    if model_type == 'lin':
        eval_metrics, eval_dump = eval_linear(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=None,  # Optionally, use a separate validation set
            valid_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            max_iter=350,
            save_path = MODEL_SAVE_PATH,
            verbose=False,
        )
    elif model_type == 'ann':
        eval_metrics, eval_dump = eval_ANN(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=None,
            valid_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            input_dim=VECTOR_DIM * NUM_CLUSTERS,
            hidden_dim = HIDDEN_DIM,
            model_save_path = MODEL_SAVE_PATH,
            max_iter=350,
            num_runs=ANN_RUNS,  # Run the function 5 times
            runs_results=RUNS_RESULT,  # Choose "average" or "best"
            metric_weights={"bacc": 0.5, "auroc": 0.5},  # Prioritize balanced accuracy
            verbose=False,
        )
    elif model_type == 'knn':
        eval_metrics, eval_dump = eval_knn(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=None,
            val_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            n_neighbors=5,
            normalize_feats=True,
            model_save_path = MODEL_SAVE_PATH,
            verbose=False
        )
    elif model_type == 'proto':
        eval_metrics, eval_dump = eval_protonet(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=None,
            val_labels=None,
            test_feats=test_feats,
            test_labels=test_labels,
            normalize_feats=True,
            model_save_path = MODEL_SAVE_PATH
        )
        
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return eval_metrics, eval_dump, all_test_ids

### K Fold

In [None]:
from collections import Counter
from typing import List
import pandas as pd

def count_classes(dataset):
    """
    Helper function to count class occurrences in a dataset.
    """
    labels = []
    for _, label, _ in DataLoader(dataset, batch_size=1, shuffle=False):
        labels.append(label.item() if isinstance(label, torch.Tensor) else label)
    return Counter(labels)

# Cross-validation function
def run_k_fold_cross_validation(save_dir: str, folds: List[List[str]], model_type: str = 'linear'):
    results_per_fold = []

    num_folds = len(folds)

    for i in range(1):
        # Define test and validation folds
        train_ids = folds[i]
        test_ids = folds[i + 1]  # The next fold in sequence is used as validation

        # Create datasets and loaders
        train_dataset = WSIDataset(save_dir, train_ids)
        train_dataset.apply_clustering(clustering_algorithm=CLUSTERING_METHOD, num_clusters=NUM_CLUSTERS, num_selected_patches=NUM_PATCHES_PER_CLUSTER)
        test_dataset = WSIDataset(save_dir, test_ids)
        test_dataset.apply_clustering(clustering_algorithm=CLUSTERING_METHOD,num_clusters=NUM_CLUSTERS, num_selected_patches=NUM_PATCHES_PER_CLUSTER)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        # Train and evaluate
        print(f"Running Fold {i + 1} with model {model_type}...")
        eval_metrics, eval_dump, all_test_ids = train_and_evaluate(i,train_loader,test_loader, model_type=model_type)
        print_metrics(eval_metrics)
        result = {
            **eval_metrics,
            **eval_dump,
            "wsi_ids":all_test_ids ,  # You already have this in train_and_evaluate
            "fold": i + 1
        }
        results_per_fold.append(result)

    return results_per_fold

### Main Runner Function

In [None]:
# Example usage:
folds_df = pd.read_csv(K_FOLDS_PATH)
# Define your folds
fold1_ids = folds_df['Fold1'].dropna().tolist()
fold2_ids = folds_df['Fold2'].dropna().tolist()
folds = [fold1_ids, fold2_ids]
# Run k-fold cross-validation with different models
model_types = ['lin','ann','knn','proto']
# model_types = ['ann','protonet']
metric_indices = {
    'acc': 0,          # 'lin_acc' corresponds to index 0
    'bacc': 1,         # 'lin_bacc' corresponds to index 1
    'macro_f1': 2,     # 'lin_macro_f1' corresponds to index 2
    'weighted_f1': 3,  # 'lin_weighted_f1' corresponds to index 3
    'auroc': 4         # 'lin_auroc' corresponds to index 4
}

eval_metrics__for_excel = []
probs_all_for_excel = None
for model in model_types:
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_fold_cross_validation(DATA_PATH, folds, model_type=model)
    model_df = build_probs_df(k_folds_results,model_name=model)
    # === Merge predictions across models ===
    if probs_all_for_excel is None:
        probs_all_for_excel = model_df
    else:
        probs_all_for_excel=pd.merge(probs_all_for_excel,model_df,on=["Fold", "WSI_ID", "Target"],how="outer")

    # === Average metrics (only pass metric parts of result dicts)
    average_results = calculate_metric_averages(
        [{k: v for k, v in result.items() if k in [f"{model}_{m}" for m in metric_indices.keys()]}
        for result in k_folds_results],
        metric_indices,
        model_prefix=model
    )
    # === Confusion matrices
    confusion_matrices = [np.array(result[f"{model}_conf_matrix"]) for result in k_folds_results if f"{model}_conf_matrix" in result]
    
    avg_conf_matrix = average_confusion_matrices(confusion_matrices)
    print("\n\n Average results for all folds:")
    for metric, value in average_results.items():
        print(f"{metric}: {value:.4f}")
     # Append per metric rows for each fold + average
    for metric in metric_indices.keys():
        row = [f"{model}_{metric}"]
        for result in k_folds_results:
            row.append(result.get(f"{model}_{metric}", 'N/A'))
        row.append(average_results.get(f"{model}_{metric}", 'N/A'))
        eval_metrics__for_excel.append(row)

    # Append confusion matrix as string (per fold)
    row = [f"{model}_conf_matrix"]
    for result in k_folds_results:
        row.append(str(result.get(f"{model}_conf_matrix", "N/A")))
    row.append(str(avg_conf_matrix))
    eval_metrics__for_excel.append(row)
    
eval_metrics_df = pd.DataFrame(eval_metrics__for_excel, 
                        columns=["Metric","Fold1","AvgFolds"])
write_data_in_excel(EVAL_METRICS_EXCEL, eval_metrics_df, FM_MODEL)
write_data_in_excel(PROBS_ALL_EXCEL, probs_all_for_excel, FM_MODEL)



 ********* Training with model: linear********* 


Running Fold 1 with model linear...
lin_acc: 0.6774
lin_bacc: 0.5387
lin_macro_f1: 0.5387
lin_weighted_f1: 0.6774
lin_auroc: 0.5595
lin_conf_matrix: [[19  5]
 [ 5  2]]


 ********* Training with model: ann********* 


Running Fold 1 with model ann...
acc: 0.6129
bacc: 0.6488
macro_f1: 0.5773
weighted_f1: 0.6446
auroc: 0.6190
conf_matrix: [[14 10]
 [ 2  5]]


 ********* Training with model: knn********* 


Running Fold 1 with model knn...
knn_acc: 0.4839
knn_bacc: 0.4137
knn_macro_f1: 0.4095
knn_weighted_f1: 0.5244
knn_auroc: 0.4643
knn_conf_matrix: [[13 11]
 [ 5  2]]


 ********* Training with model: protonet********* 


Running Fold 1 with model protonet...
proto_acc: 0.6774
proto_bacc: 0.5387
proto_macro_f1: 0.5387
proto_weighted_f1: 0.6774
proto_auroc: 0.7262
proto_conf_matrix: [[19  5]
 [ 5  2]]
