## Some Instruction before you run this notebook.
1. Change line 17 of dataloader module from wsi_id = os.path.splitext(wsi_file)[0] to wsi_id = wsi_file[:12]
2. Edit the second cell configurations according to your paths
3. If you have fivecrops level features (not averaged yet) you can avearge them first using cell-3 code. 

In [13]:
import torch
import torchvision
from torch.utils.data import DataLoader
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from dataloader import WSIDataset
from eval_patch_features.logistic import eval_linear
from eval_patch_features.ann import eval_ANN
from eval_patch_features.knn import eval_knn
from eval_patch_features.protonet import eval_protonet
from eval_patch_features.metrics import get_eval_metrics, print_metrics
from utility import calculate_metric_averages, average_confusion_matrices, write_data_in_excel, build_probs_df
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Configurations

In [None]:
from openpyxl import load_workbook
from openpyxl.utils.dataframe import dataframe_to_rows
# configs
VECTOR_DIM = 1280  # size of input feature vector
HIDDEN_DIM = 768   # size of ANN hidden layer
BATCH_SIZE = 8
FM_MODEL = "virchow2"
RUNS_RESULT = "average"
ANN_RUNS = 20
K_FOLDS_PATH = r"E:\KSA Project\dataset\splits\kfolds.csv"
DATA_PATH = r"E:\KSA Project\dataset\virchow2_features\all_data"
MODEL_SAVE_PATH = f"E:\KSA Project\KSAproject_pipeline1\WSI_Classification\Averaging\TCGA-CV\{FM_MODEL}"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
OUTPUT_SAVE_PATH = r"E:\KSA Project\KSAproject_pipeline1\WSI_Classification\Averaging\TCGA-CV"
os.makedirs(OUTPUT_SAVE_PATH, exist_ok=True)
# create a excel sheet in the output folder to save the results
EVAL_METRICS_EXCEL = os.path.join(OUTPUT_SAVE_PATH, "TCGA-CV_avg_eval_metrics.xlsx")
PROBS_ALL_EXCEL = os.path.join(OUTPUT_SAVE_PATH, "TCGA-CV_avg_probs_all.xlsx")

In [15]:
# data = r"E:\Aamir Gulzar\dataset\CAIMAN_Fivecrop_NormFeatures"
# data_save = r"E:\Aamir Gulzar\dataset\CAIMAN_Features"
# # load the data
# for wsi in os.listdir(data):
#     wsi_data = []
#     for patch in os.listdir(j_(data, wsi)):
#         patch_data = torch.load(j_(data, wsi, patch))
#         # check if the loaded feature vector is five crop then average it first then append to the wsi_data
#         if patch_data.shape[0] > 1:
#             patch_data = patch_data.mean(dim=0)
#         wsi_data.append(patch_data)
#     wsi_data = torch.stack(wsi_data).mean(dim=0)
#     save_path = j_(data_save, wsi + ".pt")
#     torch.save(wsi_data, save_path)
# print("done")

## Trainer Function

In [16]:
def train_and_evaluate(fold,train_loader,val_loader, test_loader, model_type='linear'):
    all_train_feats, all_train_labels,all_val_feats,all_val_labels, all_test_feats, all_test_labels = [], [], [], [], [], []
    all_test_ids = []
    # Prepare training and testing data
    for features, label, _ in train_loader:
        all_train_feats.append(features)
        all_train_labels.append(label)
    for features, label, _ in val_loader:
        all_val_feats.append(features)
        all_val_labels.append(label)
    for features, label, wsi_id in test_loader:
        all_test_feats.append(features)
        all_test_labels.append(label)
        # Store as single WSI IDs from the batch 
        if isinstance(wsi_id, (list, tuple)):
            all_test_ids.extend(wsi_id)
        else:
            all_test_ids.append(wsi_id)

    # Convert lists to tensors
    global train_feats, train_labels, val_feats, val_labels, test_feats, test_labels
    train_feats = torch.cat(all_train_feats)
    train_labels = torch.cat([labels.clone().detach() for labels in all_train_labels])
    val_feats = torch.cat(all_val_feats)
    val_labels = torch.cat([labels.clone().detach() for labels in all_val_labels])
    test_feats = torch.cat(all_test_feats)
    test_labels = torch.cat([labels.clone().detach() for labels in all_test_labels])
    
    # Select the model based on the input argument
    if model_type == 'lin':
        eval_metrics, eval_dump = eval_linear(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=val_feats,  # Optionally, use a separate validation set
            valid_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            max_iter=350,
            save_path = MODEL_SAVE_PATH,
            verbose=False,
        )
    elif model_type == 'ann':
        eval_metrics, eval_dump = eval_ANN(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            valid_feats=val_feats,
            valid_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            input_dim=VECTOR_DIM,
            hidden_dim = HIDDEN_DIM,
            max_iter=350,
            model_save_path = MODEL_SAVE_PATH,
            num_runs=ANN_RUNS,  # Run the function 5 times
            runs_results=RUNS_RESULT,  # Choose "average" or "best"
            metric_weights={"bacc": 0.5, "auroc": 0.5},  # Prioritize balanced accuracy
            verbose=False,
        )
    elif model_type == 'knn':
        eval_metrics, eval_dump = eval_knn(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=val_feats,
            val_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            n_neighbors=5,
            normalize_feats=True,
            model_save_path = MODEL_SAVE_PATH,
            verbose=False
        )
    elif model_type == 'proto':
        eval_metrics, eval_dump = eval_protonet(
            fold=fold,
            train_feats=train_feats,
            train_labels=train_labels,
            val_feats=val_feats,
            val_labels=val_labels,
            test_feats=test_feats,
            test_labels=test_labels,
            normalize_feats=True,
            model_save_path = MODEL_SAVE_PATH,
        )
        
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return eval_metrics, eval_dump, all_test_ids

### K-Folds

In [17]:
from typing import List
def run_k_fold_cross_validation(save_dir: str, folds: List[List[str]], model_type: str = 'lin'):
    results_per_fold = []
    num_folds = len(folds)

    for i in range(num_folds):
        # Define test and validation folds
        test_ids = folds[i]
        val_ids = folds[(i + 1) % num_folds]  # The next fold in sequence is used as validation

        # Use remaining folds as training
        train_ids = []
        for j in range(num_folds):
            if j != i and j != (i + 1) % num_folds:
                train_ids.extend(folds[j])
        print(f"Running Fold {i + 1} with model {model_type}...")
        # Create datasets and loaders
        train_dataset = WSIDataset(save_dir, train_ids)
        val_dataset = WSIDataset(save_dir, val_ids)
        test_dataset = WSIDataset(save_dir, test_ids)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        eval_metrics, eval_dump, all_test_ids = train_and_evaluate(i,train_loader,val_loader, test_loader, model_type=model_type)
        print_metrics(eval_metrics)
        result = {
            **eval_metrics,
            **eval_dump,
            "wsi_ids":all_test_ids ,  # You already have this in train_and_evaluate
            "fold": i + 1
        }
        results_per_fold.append(result)
    return results_per_fold

### Main Runner Function

In [18]:
# Example usage:
folds_df = pd.read_csv(K_FOLDS_PATH)
# Define your folds
fold1_ids = folds_df['Fold1'].dropna().apply(lambda x: x[:12]).tolist()
fold2_ids = folds_df['Fold2'].dropna().apply(lambda x: x[:12]).tolist()
fold3_ids = folds_df['Fold3'].dropna().apply(lambda x: x[:12]).tolist()
fold4_ids = folds_df['Fold4'].dropna().apply(lambda x: x[:12]).tolist()
folds = [fold1_ids, fold2_ids, fold3_ids, fold4_ids]

# Run k-fold cross-validation with different models
model_types = ['lin','ann','knn','proto']
metric_indices = {
    'acc': 0,          
    'bacc': 1,      
    'macro_f1': 2,     
    'weighted_f1': 3, 
    'auroc': 4      
}
eval_metrics__for_excel = []
probs_all_for_excel = None
for model in model_types:
    predictions_list = []
    print(f"\n\n ********* Training with model: {model}********* \n\n")
    k_folds_results = run_k_fold_cross_validation(DATA_PATH, folds, model_type=model)
    model_df = build_probs_df(k_folds_results,model_name=model)
    # === Merge predictions across models ===
    if probs_all_for_excel is None:
        probs_all_for_excel = model_df
    else:
        probs_all_for_excel=pd.merge(probs_all_for_excel,model_df,on=["Fold", "WSI_ID", "Target"],how="outer")

    # === Average metrics (only pass metric parts of result dicts)
    average_results = calculate_metric_averages(
        [{k: v for k, v in result.items() if k in [f"{model}_{m}" for m in metric_indices.keys()]}
        for result in k_folds_results],
        metric_indices,
        model_prefix=model
    )
    # === Confusion matrices
    confusion_matrices = [np.array(result[f"{model}_conf_matrix"]) for result in k_folds_results if f"{model}_conf_matrix" in result]
    
    avg_conf_matrix = average_confusion_matrices(confusion_matrices)
    print("\n\n Average results for all folds:")
    for metric, value in average_results.items():
        print(f"{metric}: {value:.4f}")
     # Append per metric rows for each fold + average
    for metric in metric_indices.keys():
        row = [f"{model}_{metric}"]
        for result in k_folds_results:
            row.append(result.get(f"{model}_{metric}", 'N/A'))
        row.append(average_results.get(f"{model}_{metric}", 'N/A'))
        eval_metrics__for_excel.append(row)

    # Append confusion matrix as string (per fold)
    row = [f"{model}_conf_matrix"]
    for result in k_folds_results:
        row.append(str(result.get(f"{model}_conf_matrix", "N/A")))
    row.append(str(avg_conf_matrix))
    eval_metrics__for_excel.append(row)
    
eval_metrics_df = pd.DataFrame(eval_metrics__for_excel, 
                        columns=["Metric","Fold1","Fold2","Fold3","Fold4","AvgFolds"])
write_data_in_excel(EVAL_METRICS_EXCEL, eval_metrics_df, FM_MODEL)
write_data_in_excel(PROBS_ALL_EXCEL, probs_all_for_excel,FM_MODEL)



 ********* Training with model: lin********* 


Running Fold 1 with model lin...
lin_acc: 0.8500
lin_bacc: 0.6098
lin_macro_f1: 0.6315
lin_weighted_f1: 0.8301
lin_auroc: 0.9184
lin_conf_matrix: [[81  4]
 [11  4]]
Running Fold 2 with model lin...
lin_acc: 0.9149
lin_bacc: 0.7569
lin_macro_f1: 0.7941
lin_weighted_f1: 0.9082
lin_auroc: 0.9098
lin_conf_matrix: [[79  2]
 [ 6  7]]
Running Fold 3 with model lin...
lin_acc: 0.8739
lin_bacc: 0.6695
lin_macro_f1: 0.7139
lin_weighted_f1: 0.8507
lin_auroc: 0.8489
lin_conf_matrix: [[90  1]
 [13  7]]
Running Fold 4 with model lin...
lin_acc: 0.9000
lin_bacc: 0.6808
lin_macro_f1: 0.7222
lin_weighted_f1: 0.8867
lin_auroc: 0.8479
lin_conf_matrix: [[85  2]
 [ 8  5]]


 Average results for all folds:
lin_acc: 0.8847
lin_bacc: 0.6793
lin_macro_f1: 0.7154
lin_weighted_f1: 0.8689
lin_auroc: 0.8813


 ********* Training with model: ann********* 


Running Fold 1 with model ann...
ann_acc: 0.7970
ann_bacc: 0.7886
ann_macro_f1: 0.7005
ann_weighted_f1: 0.8189