In [None]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from uni import get_encoder
from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe
from uni.downstream.eval_patch_features.fewshot import eval_knn, eval_fewshot
from uni.downstream.eval_patch_features.protonet import ProtoNet, prototype_topk_vote
from uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics
from uni.downstream.utils import concat_images
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Define directories for train, validation, and test splits
train =  torch.load(r"E:\KSA Project\dataset\uni_fivecrop_features\train\TCGA-3L-AA1B_nonMSIH\TCGA-3L-AA1B_nonMSIH_0.pt")
valid = torch.load(r"E:\KSA Project\dataset\uni_fivecrop_features\valid\TCGA-A6-5661_MSIH\TCGA-A6-5661_MSIH_0.pt")
test = torch.load(r"E:\KSA Project\dataset\uni_fivecrop_features\test\TCGA-A6-2686_MSIH\TCGA-A6-2686_MSIH_0.pt")
# print shapes 
print(train[0].shape)
print(valid[1].shape)  
print(test[2].shape)

torch.Size([1024])
torch.Size([1024])
torch.Size([1024])


## Load Pretrained tissue type linear classifier

In [5]:
# Load the complete saved model and class instance
model_save_path = r"E:\KSA Project\new_feature_extraction_approach\ann_tissue_classifier_uni.pth"
tissue_classifier = torch.load(model_save_path)
# Set the model to evaluation mode if needed
tissue_classifier.eval()

Sequential(
  (0): Linear(in_features=1024, out_features=224, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=224, out_features=128, bias=True)
  (4): ReLU()
  (5): Dropout(p=0.3, inplace=False)
  (6): Linear(in_features=128, out_features=9, bias=True)
)

### Read FiveCrop Patch Level Features and Concatenate all the same tissue type features at WSI level and save these features

In [33]:
def load_features_and_labels(save_dir,save_feat, device='cpu'):
    wsi_count = 0
    all_wsi_features = []
    all_wsi_labels = []
    # Loop through each WSI folder
    for wsi_folder in os.listdir(save_dir):
        wsi_folder_path = os.path.join(save_dir, wsi_folder)
        if os.path.isdir(wsi_folder_path):  # Ensure it is a directory
            
            # Initialize lists for each tissue type (T1 to T9)
            T0_feats, T1_feats, T2_feats, T3_feats, T4_feats, T5_feats, T6_feats, T7_feats, T8_feats = ([] for _ in range(9))
            # Loop through each patch file (.pt) inside the WSI folder
            patch_count = 0
            for file_name in os.listdir(wsi_folder_path):
                if file_name.endswith('.pt'):
                    file_path = os.path.join(wsi_folder_path, file_name)
                    # Load the patch features (which contain 5 crops)
                    patch = torch.load(file_path, map_location=device)
                    # Process each crop for the patch
                    for i, patch_crop in enumerate(patch):
                        patch_crop = patch_crop.unsqueeze(0)  # Add batch dimension
                        # Simulate the tissue type classification by assigning a random type
                        p_crop_tissue_type = torch.argmax(tissue_classifier(patch_crop),dim=1)
                        patch_crop = patch_crop.squeeze(0)  # Add batch dimension for model input
                        # p_crop_tissue_type = random.randint(0, 8)
                        # Add the crop to the appropriate tissue type list
                        if p_crop_tissue_type == 0:
                            T0_feats.append(patch_crop)
                        elif p_crop_tissue_type == 1:
                            T1_feats.append(patch_crop)
                        elif p_crop_tissue_type == 2:
                            T2_feats.append(patch_crop)
                        elif p_crop_tissue_type == 3:
                            T3_feats.append(patch_crop)
                        elif p_crop_tissue_type == 4:
                            T4_feats.append(patch_crop)
                        elif p_crop_tissue_type == 5:
                            T5_feats.append(patch_crop)
                        elif p_crop_tissue_type == 6:
                            T6_feats.append(patch_crop)
                        elif p_crop_tissue_type == 7:
                            T7_feats.append(patch_crop)
                        elif p_crop_tissue_type == 8:
                            T8_feats.append(patch_crop)

            # Function to safely average features if the list is not empty, else return a zero vector
            def average_features(feat_list, feat_dim):
                if feat_list:
                    return torch.stack(feat_list).mean(dim=0).to(device)
                else:
                    return torch.zeros(feat_dim, device=device)  # Return a zero vector of the same feature dimension

            # Assuming feature dimension from the first patch crop (modify as needed)
            # feat_dim = patch[0].shape[0] if patch else 512  # Default to 512 if no patches
            feat_dim = 1024

            # Compute the average features for each tissue type, ensuring non-empty lists
            T0 = average_features(T0_feats, feat_dim)
            T1 = average_features(T1_feats, feat_dim)
            T2 = average_features(T2_feats, feat_dim)
            T3 = average_features(T3_feats, feat_dim)
            T4 = average_features(T4_feats, feat_dim)
            T5 = average_features(T5_feats, feat_dim)
            T6 = average_features(T6_feats, feat_dim)
            T7 = average_features(T7_feats, feat_dim)
            T8 = average_features(T8_feats, feat_dim)

            # Print the lengths of the tissue type lists for the current WSI
            # print(f"Lengths of features: T1={len(T0_feats)}, T2={len(T1_feats)}, T3={len(T2_feats)}, T4={len(T3_feats)}, T5={len(T4_feats)}, T6={len(T5_feats)}, T7={len(T6_feats)}, T8={len(T7_feats)}, T9={len(T8_feats)}")

            # Concatenate the feature vectors of all tissue types into a single list for the WSI
            WSI_feature_list = torch.stack([T0, T1, T2, T3, T4, T5, T6, T7, T8])
            # Print the length of the final feature vector for this WSI
            # print(f"Length of feature vector for WSI {wsi_folder}: {len(WSI_feature_list)}")
            # Save the WSI feature list (optional)
            torch.save(WSI_feature_list, os.path.join(save_feat, f'{wsi_folder}.pt'))
    

# Define directories for train, validation, and test splits
train_dir = r"E:\Aamir Gulzar\dataset\\uni_fivecrop_features\\train"
valid_dir = r"E:\Aamir Gulzar\dataset\\uni_fivecrop_features\\valid"
test_dir = r"E:\Aamir Gulzar\dataset\\uni_fivecrop_features\\test"

save_train_feat = r"E:\Aamir Gulzar\dataset\\uni_tissue_features\\train"
save_valid_feat = r"E:\Aamir Gulzar\dataset\\uni_tissue_features\\valid"
save_test_feat = r"E:\Aamir Gulzar\dataset\\uni_tissue_features\\test"

# Load features and labels for each split
load_features_and_labels(train_dir,save_train_feat )
load_features_and_labels(valid_dir, save_valid_feat)
load_features_and_labels(test_dir, save_test_feat)


#### Load Tissue Classes Saved Features and select two best performing tissue types for linear models training
#### Load all tissue classes features of each WSI and train linear models.

In this approach read all the five crop level patches of all patches and make ten lists of features of each WSI according to tissue classes

## Reload saved feature for KFolds Evaluation

In [8]:
import pandas as pd
import torch
import os

# Load the Excel file with patient IDs for each fold
folds_df = pd.read_csv("E:\KSA Project\dataset\splits\kfolds.csv")

# Convert the patient IDs for each fold to a list (truncate to first 12 characters)
fold1_ids = folds_df['Fold1'].dropna().apply(lambda x: x[:12]).tolist()
fold2_ids = folds_df['Fold2'].dropna().apply(lambda x: x[:12]).tolist()
fold3_ids = folds_df['Fold3'].dropna().apply(lambda x: x[:12]).tolist()
fold4_ids = folds_df['Fold4'].dropna().apply(lambda x: x[:12]).tolist()

folds = [
    (fold1_ids, "Fold 1"),
    (fold2_ids, "Fold 2"),
    (fold3_ids, "Fold 3"),
    (fold4_ids, "Fold 4")
]

def load_features_and_labels(save_dir, fold_ids):
    all_wsi_feature_list = []
    all_wsi_label_list = []
    # Loop through each WSI folder
    for wsi_file in os.listdir(save_dir):
        wsi_path = os.path.join(save_dir, wsi_file)

        # Extract the WSI ID and truncate it to the first 12 characters
        wsi_id = wsi_file[:12]  # Extract first 12 characters

        # Check if the WSI ID is in the fold's ID list
        if wsi_id not in fold_ids:
            continue  # Skip if the WSI is not in the current fold

        if wsi_path.endswith('.pt'):
            file_path = os.path.join(wsi_file, wsi_path)

            # Load the features
            wsi_features = torch.load(file_path)
            # if you are interested to get results of only two best performing tissue types (MUC & TUM) of each WSI then use below line otherwise comment it to use all tissue types featureas
            wsi_features = torch.cat([wsi_features[4], wsi_features[8]])
            wsi_features = wsi_features.view(-1)

            # Append WSI feature to the list
            all_wsi_feature_list.append(wsi_features)

            # Determine label based on WSI folder name
            if '_nonMSI' in wsi_file:
                all_wsi_label_list.append(0)
            elif '_MSI' in wsi_file:
                all_wsi_label_list.append(1)

    # Stack all WSI features and labels
    features = torch.stack(all_wsi_feature_list)
    labels = torch.tensor(all_wsi_label_list)
    
    return features, labels

def print_metrics(metrics):
    for key, value in metrics.items():
        if isinstance(value, (int, float)):  # Check if the value is a number
            print(f"{key}: {value:.4f}")  # Format numbers with 4 decimal places
        else:
            print(f"{key}: {value}")  # For non-numeric values, just print them directly

save_dir = r"E:\KSA Project\dataset\uni_tissue_features\all_data"

def average_metrics(results_per_fold):
    average_results = {}

    # Iterate over each fold's result
    for fold_result in results_per_fold:
        for key, value in fold_result.items():
            if isinstance(value, (int, float)):  # If it's a numeric value, sum it up
                if key not in average_results:
                    average_results[key] = 0
                average_results[key] += value
            elif isinstance(value, dict):  # If it's a dictionary, recursively handle it
                if key not in average_results:
                    average_results[key] = {}
                for sub_key, sub_value in value.items():
                    if isinstance(sub_value, (int, float)):
                        if sub_key not in average_results[key]:
                            average_results[key][sub_key] = 0
                        average_results[key][sub_key] += sub_value

    # Divide by the number of folds to get the average
    num_folds = len(results_per_fold)
    for key, value in average_results.items():
        if isinstance(value, (int, float)):  # Average numeric values
            average_results[key] /= num_folds
        elif isinstance(value, dict):  # Average values in nested dictionaries
            for sub_key, sub_value in value.items():
                average_results[key][sub_key] /= num_folds

    return average_results

  folds_df = pd.read_csv("E:\KSA Project\dataset\splits\kfolds.csv")


### Linear Model

In [9]:
# Function to run cross-validation
def run_k_fold_cross_validation():
    results_per_fold = []

    for i, (test_ids, fold_name) in enumerate(folds):
        # Prepare training and validation sets (all folds except the test fold)
        train_features = []
        train_labels = []
        
        for j, (train_ids, _) in enumerate(folds):
            if i == j:
                continue  # Skip the test fold
            fold_features, fold_labels = load_features_and_labels(save_dir, train_ids)
            train_features.append(fold_features)
            train_labels.append(fold_labels)

        # Concatenate training features and labels from all training folds
        train_features = torch.cat(train_features)
        train_labels = torch.cat(train_labels)

        # Load test data (current fold as the test set)
        test_features, test_labels = load_features_and_labels(save_dir, test_ids)

        print(f"Running for {fold_name} as test set")
        print(f"Training on {train_features.shape[0]} samples, Testing on {test_features.shape[0]} samples")

        # Train and evaluate the model on the current fold
        from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe

        linprobe_eval_metrics, linprobe_dump = eval_linear_probe(
            train_feats=train_features,
            train_labels=train_labels,
            valid_feats=None,  # Use test set as validation for simplicity
            valid_labels=None,
            test_feats=test_features,
            test_labels=test_labels,
            max_iter=50,
            verbose=True,
        )

        print(f"Results for {fold_name}:")
        print_metrics(linprobe_eval_metrics)
        results_per_fold.append(linprobe_eval_metrics)

    return results_per_fold

# Run k-fold cross-validation
all_fold_results = run_k_fold_cross_validation()
average_results = average_metrics(all_fold_results)
print("\nAverage Results across all folds:")
print_metrics(average_results)

Running for Fold 1 as test set
Training on 305 samples, Testing on 100 samples
Linear Probe Evaluation: Train shape torch.Size([305, 2048])
Linear Probe Evaluation: Test shape torch.Size([100, 2048])
Linear Probe Evaluation (Train Time): Best cost = 40.960
Linear Probe Evaluation (Train Time): Using only train set for evaluation. Train Shape:  torch.Size([305, 2048])
(Before Training) Loss: 0.693
(After Training) Loss: 0.036
Linear Probe Evaluation (Test Time): Test Shape torch.Size([100, 2048])
Confusion Matrix:
[[80  5]
 [ 8  7]]
Linear Probe Evaluation: Time taken 0.70
Results for Fold 1:
lin_acc: 0.8700
lin_bacc: 0.7039
lin_kappa: 0.4444
lin_weighted_f1: 0.8639
lin_report: {'0': {'precision': 0.9090909090909091, 'recall': 0.9411764705882353, 'f1-score': 0.9248554913294798, 'support': 85.0}, '1': {'precision': 0.5833333333333334, 'recall': 0.4666666666666667, 'f1-score': 0.5185185185185185, 'support': 15.0}, 'accuracy': 0.87, 'macro avg': {'precision': 0.7462121212121212, 'recall': 

### ANN Model

In [10]:
# Function to run cross-validation
def run_k_fold_cross_validation():
    results_per_fold = []

    for i, (test_ids, fold_name) in enumerate(folds):
        # Prepare training and validation sets (all folds except the test fold)
        train_features = []
        train_labels = []
        
        for j, (train_ids, _) in enumerate(folds):
            if i == j:
                continue  # Skip the test fold
            fold_features, fold_labels = load_features_and_labels(save_dir, train_ids)
            train_features.append(fold_features)
            train_labels.append(fold_labels)

        # Concatenate training features and labels from all training folds
        train_features = torch.cat(train_features)
        train_labels = torch.cat(train_labels)

        # Load test data (current fold as the test set)
        test_features, test_labels = load_features_and_labels(save_dir, test_ids)

        print(f"Running for {fold_name} as test set")
        print(f"Training on {train_features.shape[0]} samples, Testing on {test_features.shape[0]} samples")

        # Train and evaluate the model on the current fold
        from uni.downstream.eval_patch_features.ann import eval_ANN_probe

        linprobe_eval_metrics, linprobe_dump = eval_ANN_probe(
            train_feats = train_features,
            train_labels = train_labels,
            valid_feats =None ,
            valid_labels = None,
            test_feats = test_features,
            test_labels =test_labels,
            input_dim=2048,
            max_iter = 120,
            verbose= True,
        )

        print(f"Results for {fold_name}:")
        print_metrics(linprobe_eval_metrics)
        results_per_fold.append(linprobe_eval_metrics)

    return results_per_fold

# Run k-fold cross-validation
all_fold_results = run_k_fold_cross_validation()
average_results = average_metrics(all_fold_results)
print("\nAverage Results across all folds:")
print_metrics(average_results)

Running for Fold 1 as test set
Training on 305 samples, Testing on 100 samples
ANN Probe Evaluation: Train shape torch.Size([305, 2048])
ANN Probe Evaluation: Test shape torch.Size([100, 2048])
ANN Probe Evaluation (Train Time): Best cost = 40.960
ANN Probe Evaluation (Train Time): Using only train set for evaluation. Train Shape:  torch.Size([305, 2048])
(After Training) Loss: 27.647
ANN Probe Evaluation (Test Time): Test Shape torch.Size([100, 2048])
Confusion Matrix:
[[68 17]
 [ 2 13]]
ANN Probe Evaluation: Time taken 0.86
Results for Fold 1:
ann_acc: 0.8100
ann_bacc: 0.8333
ann_kappa: 0.4722
ann_weighted_f1: 0.8325
ann_report: {'0': {'precision': 0.9714285714285714, 'recall': 0.8, 'f1-score': 0.8774193548387097, 'support': 85.0}, '1': {'precision': 0.43333333333333335, 'recall': 0.8666666666666667, 'f1-score': 0.5777777777777777, 'support': 15.0}, 'accuracy': 0.81, 'macro avg': {'precision': 0.7023809523809523, 'recall': 0.8333333333333334, 'f1-score': 0.7275985663082437, 'support'

### KNN and ProtoNet

In [11]:
# Function to run cross-validation
def run_k_fold_cross_validation():
    knn_results_per_fold = []
    protonet_results_per_fold = []

    for i, (test_ids, fold_name) in enumerate(folds):
        # Prepare training and validation sets (all folds except the test fold)
        train_features = []
        train_labels = []
        
        for j, (train_ids, _) in enumerate(folds):
            if i == j:
                continue  # Skip the test fold
            fold_features, fold_labels = load_features_and_labels(save_dir, train_ids)
            train_features.append(fold_features)
            train_labels.append(fold_labels)

        # Concatenate training features and labels from all training folds
        train_features = torch.cat(train_features)
        train_labels = torch.cat(train_labels)

        # Load test data (current fold as the test set)
        test_features, test_labels = load_features_and_labels(save_dir, test_ids)

        print(f"Running for {fold_name} as test set")
        print(f"Training on {train_features.shape[0]} samples, Testing on {test_features.shape[0]} samples")

        from uni.downstream.eval_patch_features.fewshot import eval_knn

        knn_eval_metrics, knn_dump, proto_eval_metrics, proto_dump = eval_knn(
            train_feats = train_features,
            train_labels =train_labels,
            test_feats = test_features,
            test_labels =test_labels,
            center_feats = True,
            normalize_feats = True,
            n_neighbors = 5
        )

        print(f"Results for {fold_name}:")
        print_metrics(knn_eval_metrics)
        print_metrics(proto_eval_metrics)
        knn_results_per_fold.append(knn_eval_metrics)
        protonet_results_per_fold.append(proto_eval_metrics)

    return knn_results_per_fold, protonet_results_per_fold

# Run k-fold cross-validation
knn_all_fold_results, protonet_all_fold_results = run_k_fold_cross_validation()
knn_average_results = average_metrics(knn_all_fold_results)
proto_average_results = average_metrics(protonet_all_fold_results)
print("\nKNN Average Results across all folds:")
print_metrics(knn_average_results)
print("\nProtoNet Average Results across all folds:")
print_metrics(proto_average_results)



Running for Fold 1 as test set
Training on 305 samples, Testing on 100 samples
Results for Fold 1:
knn5_acc: 0.8400
knn5_bacc: 0.5490
knn5_kappa: 0.1351
knn5_weighted_f1: 0.8044
knn5_report: {'0': {'precision': 0.8631578947368421, 'recall': 0.9647058823529412, 'f1-score': 0.9111111111111111, 'support': 85.0}, '1': {'precision': 0.4, 'recall': 0.13333333333333333, 'f1-score': 0.2, 'support': 15.0}, 'accuracy': 0.84, 'macro avg': {'precision': 0.631578947368421, 'recall': 0.5490196078431373, 'f1-score': 0.5555555555555556, 'support': 100.0}, 'weighted avg': {'precision': 0.7936842105263158, 'recall': 0.84, 'f1-score': 0.8044444444444444, 'support': 100.0}}
proto_acc: 0.8100
proto_bacc: 0.7510
proto_kappa: 0.4025
proto_weighted_f1: 0.8266
proto_report: {'0': {'precision': 0.9342105263157895, 'recall': 0.8352941176470589, 'f1-score': 0.8819875776397516, 'support': 85.0}, '1': {'precision': 0.4166666666666667, 'recall': 0.6666666666666666, 'f1-score': 0.5128205128205128, 'support': 15.0}, '

### 