In [1]:
 # Imports for data manipulation, processing, data loading, visualization and other utility imports
import torch
import torchvision.transforms as T
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from sklearn.metrics import accuracy_score
import numpy as np
from sklearn.decomposition import PCA
from torchvision import transforms
import sklearn
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import cv2
from sklearn.model_selection import StratifiedKFold
import shutil
import os
from torch.nn import functional as F



In [None]:
import os

# Defining the folder paths for both healthy and sick images
base_path = "/content/drive/MyDrive/Dataset/newdb"
healthy_path = os.path.join(base_path, "Healthy")
sick_path = os.path.join(base_path, "Unhealthy")
subfolders = [healthy_path, sick_path]


In [9]:
#Reading the images
# Get all image file paths from Healthy and Sick directories
healthy_images = [os.path.join(healthy_path, fname) for fname in os.listdir(healthy_path) if fname.lower().endswith(('jpg'))]
sick_images = [os.path.join(sick_path, fname) for fname in os.listdir(sick_path) if fname.lower().endswith(('jpg'))]

# Combine them into a single list with labels
img_paths = healthy_images + sick_images
img_labels = ['Healthy'] * len(healthy_images) + ['Unhealthy'] * len(sick_images)

# Print some examples to confirm
print("First 5 Healthy Images:", healthy_images[:5])
print("First 5 Sick Images:", sick_images[:5])
print("Number of images:", len(img_paths))


First 5 Healthy Images: ['/content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0092_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0094_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0055_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0042_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0074_anterior.jpg']
First 5 Sick Images: ['/content/drive/MyDrive/Dataset/newdbBalanced/Unhealthy/IIR0296_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Unhealthy/IIR0326_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Unhealthy/IIR0328_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Unhealthy/IIR0307_anterior.jpg', '/content/drive/MyDrive/Dataset/newdbBalanced/Unhealthy/IIR0299_anterior.jpg']
Number of images: 142


In [10]:
##### EXPERIMENT PARAMETERS #####
# Constants to control grid size
N_IMGS = len(img_paths)  # Total number of valid images
N_ROW_IMGS = 4  # Number of rows in the grid
N_COL_IMGS = (N_IMGS + N_ROW_IMGS - 1) // N_ROW_IMGS  # Calculate columns dynamically



# Define the number of patches along the two dimensions to split each image into
patch_h = 32
patch_w = 32
# Define the size of the model to use for extracting the features of each patch
model_size = 's' # options: s, b, l, g
use_registers = False
use_extended_dinov2 = True

patch_multiplier = 7  # Multiplier for patch count (to get 224x224 from 32x32 patches)

# Define the final image size (this will be calculated dynamically based on patches)
final_size = patch_h * patch_multiplier  # final image size will be 32 * 7 = 224
print(f"Final image size: {final_size}x{final_size}")
print(N_IMGS)


Final image size: 224x224
142


In [None]:
#Preprocessing and Colorization
from PIL import Image
#define the output directory for transformed images
output_dir = "/content/drive/MyDrive/Dataset/DINOv2_images"
os.makedirs(output_dir, exist_ok=True)  # Create folder if it doesn't exist

#define DINOv2-compatible transformations
dinov2_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),  # Convert image to tensor (C, H, W)
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # Normalize using ImageNet values
        std=[0.229, 0.224, 0.225]
    )
])

#initialize tensor to store transformed images for DINOv2
imgs_tensor = torch.zeros(len(img_paths), 3, 224, 224)  # (N, C, H, W)

#function to process image (Colorize + DINOv2 Transform) and save it as ONE file
def process_image(image_path, index):
    # Open and convert to grayscale
    img = Image.open(image_path).convert("L")
    img_np = np.array(img)

    #apply JET colormap
    img_color = cv2.applyColorMap(img_np, cv2.COLORMAP_JET)
    colorized_img = Image.fromarray(img_color)  # Convert to PIL format

    #apply DINOv2 transformations
    img_tensor = dinov2_transform(colorized_img)
    imgs_tensor[index] = img_tensor  # Store the transformed tensor

    #convert the tensor back to a displayable image
    img_numpy = img_tensor.permute(1, 2, 0).mul(torch.tensor([0.229, 0.224, 0.225])).add(torch.tensor([0.485, 0.456, 0.406]))  # Undo normalization
    img_numpy = torch.clamp(img_numpy, 0, 1).cpu().numpy()  # Ensure values are in range [0,1]
    img_resized = Image.fromarray((img_numpy * 255).astype(np.uint8))  # Convert to 8-bit image

    #generate new filename and save the transformed image in the DINOv2 folder
    base_name = os.path.basename(image_path).split('.')[0]
    dinov2_path = os.path.join(output_dir, f"{base_name}_DINOv2.jpg")
    img_resized.save(dinov2_path)

    return dinov2_path  # Return the saved image path

#process all images
for i, img_path in enumerate(img_paths):
    print(f"Processing {img_path}...")
    dinov2_image_path = process_image(img_path, i)

print("All images have been **colorized, transformed, and saved** in 'DINOv2_images'.")
print("Final dataset tensor shape:", imgs_tensor.shape)  # Should be (N, 3, 224, 224)

#extra step to check ** (optional)
for path in os.listdir(output_dir):
    full_path = os.path.join(output_dir, path)
    img = Image.open(full_path)
    if img.size != (224, 224):
        print(f"Not resized: {full_path} | Size: {img.size}")







Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0092_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0094_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0055_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0042_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0074_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0066_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0083_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0077_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0045_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0082_anterior.jpg...
Processing /content/drive/MyDrive/Dataset/newdbBalanced/Healthy/IIR0103_anterior.jpg...
Processing /content/drive/MyDriv

In [None]:
# #Show the images selected
# # List to store valid image paths
# img_paths = []

# # Iterate through subfolders and image numbers
# for subfolder in subfolders:
#     folder_path = os.path.join(base_path, subfolder)
#     for i in range(1, 329):  # Iterate from 1 to 372
#         # Construct the image filename
#         image_name = f"IIR{i:04d}_anterior.jpg"  # Format number with leading zeros
#         image_path = os.path.join(folder_path, image_name)

#         # Check if the file exists
#         if os.path.isfile(image_path):
#             img_paths.append(image_path)


# # Function to display the images
# def plot_original_images():
#     plt.figure(figsize=(2 * N_COL_IMGS, 2 * N_ROW_IMGS), dpi=120)
#     for i in range(N_IMGS-1):
#         plt.subplot(N_ROW_IMGS, N_COL_IMGS, i + 1)
#         plt.xticks([])
#         plt.yticks([])
#         # Read and display the image from the local path
#         plt.imshow(mpimg.imread(img_paths[i]))
#         plt.title(f"Image {i+1}")  # Optional: Add titles for clarity
#     plt.tight_layout()
#     plt.show()
#     plt.close()

# # Call the function to display the images
# plot_original_images()





In [None]:

#Splitting the Dataset (into training, validation and testing set) using StratifiedKFolds
#imports and configuration
import os
import torch
import numpy as np
from sklearn.model_selection import StratifiedKFold, train_test_split
import shutil


#number of k folds to use
n_splits = 5
#folder paths for saving split data and transformed images
output_folder = "/content/drive/MyDrive/Dataset/Split_Folds"
dinov2_images_dir = "/content/drive/MyDrive/Dataset/DINOv2_images"
os.makedirs(output_folder, exist_ok=True)

#check if the folds already exist
folds = []
found_all_folds = True

#try loading all 5 pre-saved fold files
for fold in range(1, n_splits + 1):
    fold_path = os.path.join(output_folder, f"dinov2_fold_{fold}.pt")
    if os.path.exists(fold_path):
        print(f"Fold {fold} found. Loading from file.")
        fold_data = torch.load(fold_path, weights_only=False)
        folds.append(fold_data)
    else:
        print(f"Fold {fold} not found. Will re-run data splitting.")
        found_all_folds = False
        break

#print the class distribution of the existing folds
if found_all_folds:
    print("\nLoaded Fold Stats:")

    for i, fold in enumerate(folds):
        print(f"\nFold {i + 1} - Dataset Breakdown:")

        for split_name, labels in {
            "Train": fold["y_train"],
            "Validate-Test": fold["y_test"],
            "Final Test": fold["y_final"]
        }.items():
            total = len(labels)
            healthy = sum(1 for label in labels if label == "Healthy")
            unhealthy = sum(1 for label in labels if label == "Unhealthy")
            print(f"  {split_name}: {total} images ({healthy} Healthy, {unhealthy} Unhealthy)")


#if any of the folds isn't foundm run stratified k-fold splitting
if not found_all_folds:
    print("\nRe-running StratifiedKFold splitting and saving folds...")

    labels = np.array(img_labels)   # labels as NumPy array
    all_indices = np.arange(len(imgs_tensor)) #original indices for tracking

    #split into 80% train/test and 20% final test and track the original indices
    X_train_test, X_final, y_train_test, y_final, train_test_indices, final_indices = train_test_split(
        imgs_tensor, labels, all_indices, test_size=0.20, stratify=labels, shuffle=True, random_state=42
    )

    print(f"\nAfter split: {X_train_test.shape[0]} train/test, {X_final.shape[0]} final test")

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    for fold, (train_index, test_index) in enumerate(skf.split(X_train_test, y_train_test)):
        X_train, X_test = X_train_test[train_index], X_train_test[test_index]
        y_train, y_test = y_train_test[train_index], y_train_test[test_index]

        # create the fold dictionary
        fold_data = {
            "X_train": X_train, "y_train": y_train,
            "X_test": X_test, "y_test": y_test,
            "X_final": X_final, "y_final": y_final, #final set shared across all folds
        }
        folds.append(fold_data)

        #save the newly created fold
        fold_path = os.path.join(output_folder, f"dinov2_fold_{fold + 1}.pt")
        torch.save(fold_data, fold_path)
        print(f"Saved fold {fold + 1} -> {fold_path}")

        split_images_root = os.path.join(output_folder, "SPLITS", f"fold_{fold + 1}")
        os.makedirs(split_images_root, exist_ok=True)

        #map split names to indices and labels
        split_mapping = {
            "Train": (train_index, y_train),
            "Validate-Test": (test_index, y_test),
            "Final Test": (final_indices, y_final)
        }

        print(f"\n Fold {fold + 1} - Dataset Breakdown:")

        for split_name, (indices, y_labels) in split_mapping.items():
            split_folder = os.path.join(split_images_root, split_name)
            os.makedirs(split_folder, exist_ok=True)

            #print split summary
            total = len(indices)
            healthy = sum(1 for label in y_labels if label == "Healthy")
            unhealthy = sum(1 for label in y_labels if label == "Unhealthy")
            print(f"  {split_name}: {total} images ({healthy} Healthy, {unhealthy} Unhealthy)")

            #copy image files into the appropriate split folder
            for idx in indices:
                original_index = idx if split_name == "Final Test" else train_test_indices[idx]

                orig_img_path = img_paths[original_index]
                base_name = os.path.basename(orig_img_path).split('.')[0]
                transformed_name = f"{base_name}_DINOv2.jpg"
                transformed_path = os.path.join(dinov2_images_dir, transformed_name)

                if os.path.exists(transformed_path):
                    shutil.copy(transformed_path, os.path.join(split_folder, transformed_name))
                else:
                    print(f"Missing transformed image: {transformed_path}")



Fold 1 not found. Will re-run data splitting.

Re-running StratifiedKFold splitting and saving folds...

After split: 113 train/test, 29 final test
Saved fold 1 -> /content/drive/MyDrive/Dataset/Split_FoldsBalanced/dinov2_fold_1.pt

 Fold 1 - Dataset Breakdown:
  Train: 90 images (45 Healthy, 45 Unhealthy)
Missing transformed image: /content/drive/MyDrive/Dataset/DINOv2_imagesBalanced/IIR0022_anterior_DINOv2.jpg
Missing transformed image: /content/drive/MyDrive/Dataset/DINOv2_imagesBalanced/IIR0002_anterior_DINOv2.jpg
Missing transformed image: /content/drive/MyDrive/Dataset/DINOv2_imagesBalanced/IIR0093_anterior_DINOv2.jpg
Missing transformed image: /content/drive/MyDrive/Dataset/DINOv2_imagesBalanced/IIR0098_anterior_DINOv2.jpg
Missing transformed image: /content/drive/MyDrive/Dataset/DINOv2_imagesBalanced/IIR0065_anterior_DINOv2.jpg
Missing transformed image: /content/drive/MyDrive/Dataset/DINOv2_imagesBalanced/IIR0061_anterior_DINOv2.jpg
Missing transformed image: /content/drive/My

In [7]:
#Mount the drive content
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
path='/content/drive/MyDrive/neco_on_dinov2r_vit14_model.ckpt'
!ls /content/drive/MyDrive/neco_on_dinov2r_vit14_model.ckpt 

/content/drive/MyDrive/neco_on_dinov2r_vit14_model.ckpt


In [13]:
# Instructions:
# From https://github.com/vpariza/NeCo?tab=readme-ov-file download the student Dinov2 ViT-S/14 from  https://1drv.ms/u/c/67fac29a77adbae6/EWvXdau9r6NIr-vIc_xDlxAB1sDrljoaPR_A3JhIEeE8dw?e=pOXEXG
# upload to google drive
# and find the path to the file on google drive
path = "/content/drive/MyDrive/neco_on_dinov2r_vit14_model.ckpt"
!ls /content/drive/MyDrive/neco_on_dinov2r_vit14_model.ckpt

/content/drive/MyDrive/neco_on_dinov2r_vit14_model.ckpt


In [14]:
#load the small model
#import torch #Import the torch library
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
model.load_state_=dict(torch.load(path), strict=False)


Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:00<00:00, 289MB/s]


In [None]:
print(model)

DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

In [15]:
#feature extraction using dinov2
#tensor = torch.randn(N_IMGS, 3, 224, 224) # 16*14 =224

#use preprocessed image tensor (already normalized, resized to 224x224)
#shape must be [N, 3 channels, 224, 224]
tensor=imgs_tensor
display(tensor.shape)

#load dinov2 from torchhub
#depending on configuration, load DINOv2 ViT model with or without registers
if use_registers:
  model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vit{}14_reg'.format(model_size))
else:
  model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vit{}14'.format(model_size))
#forward pass: get dictionary of intermediate features
features_dict = model.forward_features(tensor)
#display available keys in the feature dictionary
display(features_dict.keys())
#show the shape of each important feature tensor
display(features_dict['x_norm_clstoken'].shape)
display(features_dict['x_norm_regtokens'].shape)
display(features_dict['x_norm_patchtokens'].shape)
display(features_dict['x_prenorm'].shape)


#print the final name of the model used to see if its with registers
model_name = f'dinov2_vit{model_size}14' + ('_reg' if use_registers else '')
print("Loaded model:", model_name)


torch.Size([142, 3, 224, 224])

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


dict_keys(['x_norm_clstoken', 'x_norm_regtokens', 'x_norm_patchtokens', 'x_prenorm', 'masks'])

torch.Size([142, 384])

torch.Size([142, 0, 384])

torch.Size([142, 256, 384])

torch.Size([142, 257, 384])

Loaded model: dinov2_vits14


In [None]:

# select feature dimensions based on the model used
feat_dims = {
    's': 384,    # ViT-S/14 small
    'b': 768,    # ViT-B/14 base
    'l': 1024,   # ViT-L/14 large
    'g': 1536,   # ViT-G/14 giant
}
feat_dim = feat_dims[model_size]  # model_size = 's', 'b', 'l', or 'g'

# Transform (only normalize because images are already resized & colorized)
dinov2_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

# load the dinov2 model with or without registers
if use_registers:
    model = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_size}14_reg')
else:
    model = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_size}14')

if not use_extended_dinov2:
    model.load_state_dict(torch.load(path), strict=False)

#load the transformed image paths
dinov2_img_paths = [
    os.path.join("/content/drive/MyDrive/Dataset/DINOv2_images", os.path.basename(p).split('.')[0] + "_DINOv2.jpg")
    for p in img_paths
]

#load and normalize images
N_IMGS = len(dinov2_img_paths)
imgs_tensor = torch.zeros(N_IMGS, 3, 224, 224)

for i, path in enumerate(dinov2_img_paths):
    img = Image.open(path).convert("RGB")
    img_tensor = dinov2_transform(img)

    #check for safety
    if img_tensor.shape != (3, 224, 224):
        print(f"Shape mismatch: {path} | Got: {img_tensor.shape}")
    else:
        print(f"Loaded and normalized: {path}")

    imgs_tensor[i] = img_tensor

#move to device and extract [CLS] features
with torch.no_grad():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    imgs_tensor = imgs_tensor.to(device)

    features_dict = model.forward_features(imgs_tensor)
    cls_features = features_dict['x_norm_clstoken'].cpu()  # This is your [CLS] token

#check
print(f" Extracted [CLS] features | Shape: {cls_features.shape}")


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0092_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0094_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0055_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0042_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0074_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0066_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0083_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0077_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_imagesBALANCED/IIR0045_anterior_DINOv2.jpg
Loaded and normalized: /content/drive/MyDrive/Dataset/DINOv2_ima

In [None]:
# !nvidia-smi


In [None]:
# # Apply PCA on the features of each patch across all the images in our set
# features = features.reshape(N_IMGS * patch_h * patch_w, feat_dim)
# print('Features Shape: ',features.shape)
# pca = PCA(n_components=3)
# pca.fit(features)
# pca_features = pca.transform(features)

In [None]:
# # visualize PCA components for finding a proper threshold
# plt.subplot(1, 3, 1)
# plt.hist(pca_features[:, 0])
# plt.subplot(1, 3, 2)
# plt.hist(pca_features[:, 1])
# plt.subplot(1, 3, 3)
# plt.hist(pca_features[:, 2])
# plt.show()
# plt.close()

In [None]:
# minmax_scale = lambda pca_features: (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())

In [None]:
# # visualize PCA components for finding a proper threshold
# plt.subplot(1, 3, 1)
# plt.hist(minmax_scale(pca_features[:, 0]))
# plt.subplot(1, 3, 2)
# plt.hist(minmax_scale(pca_features[:, 1]))
# plt.subplot(1, 3, 3)
# plt.hist(minmax_scale(pca_features[:, 2]))
# plt.show()
# plt.close()

In [None]:
# # The threshold to be used for separating the foreground from the background,
# # using the major PCA component (i.e., the first one at position 0) of each
# # patch
# PCA_THRESHOLD = 10
# PCA_THRESHOLD_RATIO = 0.6

In [None]:
# # plot the first pca component
# # Min - Max normalization of first PCA component for all images
# pca_features_norm = pca_features.copy()
# pca_features_norm[:, 0] = minmax_scale(pca_features_norm[:, 0])

# plt.figure(figsize=(2 * (N_IMGS / N_ROW_IMGS), 2 * (N_IMGS / N_COL_IMGS)), dpi=80)
# for i in range(N_IMGS):
#     # plot the first pca component for eacg image
#     plt.subplot(N_ROW_IMGS, N_COL_IMGS, i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.imshow(pca_features_norm[i * patch_h * patch_w: (i+1) * patch_h * patch_w, 0].reshape(patch_h, patch_w))
# plt.show()
# plt.close()

In [None]:

# ######## APPROACH 1 ########
# # Threshold with an absolute value

# # segment foreground from background using the first component
# # Background identified by the components below the threshold
# # pca_features_bg_mask = pca_features[:, 0] < PCA_THRESHOLD
# # # Foreground identified by the components above the threshold
# # pca_features_fg_mask = ~pca_features_bg_mask

# ######## APPROACH 2 ########
# # Threshold with an absolute value from the normalized values

# # segment foreground from background using the first component
# # Background identified by the components below the threshold
# pca_features_bg_mask = pca_features_norm[:, 0] < PCA_THRESHOLD_RATIO
# # Foreground identified by the components above the threshold
# pca_features_fg_mask = ~pca_features_bg_mask

In [None]:
# # plot the background using the pca_features_bg
# plt.figure(figsize=(2 * (N_IMGS / N_ROW_IMGS), 2 * (N_IMGS / N_COL_IMGS)), dpi=80)
# for i in range(N_IMGS):
#     plt.subplot(N_ROW_IMGS, N_COL_IMGS, i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.imshow(pca_features_bg_mask[i * patch_h * patch_w: (i+1) * patch_h * patch_w].reshape(patch_h, patch_w))
# plt.show()

In [None]:
# # plot the background using the pca_features_bg
# plt.figure(figsize=(2 * (N_IMGS / N_ROW_IMGS), 2 * (N_IMGS / N_COL_IMGS)), dpi=80)
# for i in range(N_IMGS):
#     plt.subplot(N_ROW_IMGS, N_COL_IMGS, i+1)
#     plt.xticks([])
#     plt.yticks([])
#     mask = pca_features_bg_mask[i * patch_h * patch_w: (i+1) * patch_h * patch_w].reshape(patch_h, patch_w)
#     mask = np.repeat(np.expand_dims(mask,axis=-1),3,axis=-1)
#     c = np.zeros((patch_h, patch_w,3))
#     c[:,:,0] = 1.0

#     img = np.where(mask, np.zeros((patch_h, patch_w,3)), c)
#     plt.imshow(img)
# plt.show()

In [None]:
# # PCA for only foreground patches
# pca.fit(features[pca_features_fg_mask])
# pca_features_rem = pca.transform(features[pca_features_fg_mask])

# for i in range(3):
#     # pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].min()) / (pca_features_rem[:, i].max() - pca_features_rem[:, i].min())
#     # transform using mean and std, I personally, this transformation seems to give a better visualization
#     pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].mean()) / (pca_features_rem[:, i].std() ** 2) + 0.5

# pca_features_rgb = pca_features.copy()
# pca_features_rgb[pca_features_bg_mask] = 0
# pca_features_rgb[pca_features_fg_mask] = pca_features_rem

# pca_features_rgb = pca_features_rgb.reshape(N_IMGS, patch_h, patch_w, 3)
# plt.figure(figsize=(2 * (N_IMGS / N_ROW_IMGS), 2 * (N_IMGS / N_COL_IMGS)), dpi=80)
# for i in range(N_IMGS):
#     plt.subplot(N_ROW_IMGS, N_COL_IMGS, i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.imshow(pca_features_rgb[i][..., ::-1])
# plt.savefig('features.png')
# plt.show()
# plt.close()

In [None]:
#Pure Linear HEAD https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb?fbclid=IwY2xjawIq_NxleHRuA2FlbQIxMAABHaJ21bD3gYnAURGck-pM06xz8sNKEeC8JOTbOeSpow2dO99Qww8he1Fz1Q_aem_ZWCgSLLZnJld1lOzhsJ3Qw
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, f1_score
import numpy as np
import os
from sklearn.preprocessing import StandardScaler
import glob


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#labeling the healthy and unhealthy class to numerical values
label_mapping = {"Healthy": 0, "Unhealthy": 1}
folds_folder = "/content/drive/MyDrive/Dataset/Split_Folds"

#freeze the entire model
for param in model.parameters():
    param.requires_grad = False

#define and attach linear head
num_features = model.norm.weight.shape[0]
model.fc = nn.Linear(num_features, 1).to(device)

#loop through each fold
for fold in range(1, 6):
    print(f"\n Fold {fold}")

    fold_path = os.path.join(folds_folder, f"dinov2_fold_{fold}.pt")
    data = torch.load(fold_path, weights_only=False)
    final_test_folder = os.path.join("/content/drive/MyDrive/Dataset/Split_Folds/SPLITS", f"fold_{fold}", "Final Test")
    final_filenames = sorted([os.path.basename(f) for f in glob.glob(os.path.join(final_test_folder, "*"))])

    #prepare list to collect misclassified samples
    misclassified_final_samples = []

    def encode_labels(y):
        return torch.tensor([label_mapping[label] for label in y], dtype=torch.float32)

    X_train = data["X_train"].to(device)
    y_train = encode_labels(data["y_train"]).to(device)
    X_test = data["X_test"].to(device)
    y_test = encode_labels(data["y_test"]).to(device)
    X_final = data["X_final"].to(device)
    y_final = encode_labels(data["y_final"]).to(device)

    #feature extraction with debugging
    try:
        print("Extracting DINOv2 features...")
        with torch.no_grad():
            scaler = StandardScaler()

            f_train = model.forward_features(X_train)["x_norm_clstoken"]
            f_test = model.forward_features(X_test)["x_norm_clstoken"]
            f_final = model.forward_features(X_final)["x_norm_clstoken"]

            f_train = scaler.fit_transform(f_train.cpu())
            f_test = scaler.transform(f_test.cpu())
            f_final = scaler.transform(f_final.cpu())

            f_train = torch.tensor(f_train, device=device).float()
            f_test = torch.tensor(f_test, device=device).float()
            f_final = torch.tensor(f_final, device=device).float()

        print("Feature extraction successful.")
    except Exception as e:
        print(f"Feature extraction failed: {e}")
        continue

    #prepare data loaders of each set
    train_loader = DataLoader(TensorDataset(f_train, y_train), batch_size=32, shuffle=True)
    test_loader = DataLoader(TensorDataset(f_test, y_test), batch_size=32)
    final_loader = DataLoader(TensorDataset(f_final, y_final), batch_size=32)

    #loss & optimizer
    pos_weight = torch.tensor([1.35], device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.AdamW(model.fc.parameters(), lr=1e-3)

    #training loop
    n_epochs = 50
    for epoch in range(n_epochs):

        #train
        model.fc.train()
        train_loss = 0.0
        train_preds_raw, train_labels_list = [], []

        for features, labels in train_loader:
            optimizer.zero_grad()
            outputs = model.fc(features).squeeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_preds_raw.extend(outputs.detach().cpu().numpy())
            train_labels_list.extend(labels.cpu().numpy())

        #validate
        model.fc.eval()
        val_preds_raw, val_labels_list = [], []

        with torch.no_grad():
            for features, labels in test_loader:
                outputs = model.fc(features).squeeze(1)
                val_preds_raw.extend(outputs.cpu().numpy())
                val_labels_list.extend(labels.cpu().numpy())

        fpr, tpr, thresholds = roc_curve(val_labels_list, val_preds_raw)
        high_sens_mask = tpr >= 0.90
        candidate_thresholds = thresholds[high_sens_mask]

        if candidate_thresholds.size > 0:
            f1_scores = []
            for i, thresh in enumerate(candidate_thresholds):
                preds = (np.array(val_preds_raw) > thresh).astype(int)
                if (tpr[high_sens_mask][i] >= 0.90):
                    f1_scores.append(f1_score(val_labels_list, preds))
                else:
                    f1_scores.append(0)
            best_thresh = candidate_thresholds[np.argmax(f1_scores)]
        else:
            print("No threshold satisfies sensitivity ≥ 0.90 — using 0.5 fallback.")
            best_thresh = 0.5

        #metrics: Training
        train_preds = (np.array(train_preds_raw) > best_thresh).astype(int)
        train_acc = accuracy_score(train_labels_list, train_preds)
        tn, fp, fn, tp = confusion_matrix(train_labels_list, train_preds).ravel()
        train_sens = tp / (tp + fn) if (tp + fn) > 0 else 0
        train_spec = tn / (tn + fp) if (tn + fp) > 0 else 0
        train_f1 = f1_score(train_labels_list, train_preds)

        #metrics: Validation
        val_preds = (np.array(val_preds_raw) > best_thresh).astype(int)
        val_acc = accuracy_score(val_labels_list, val_preds)
        tn, fp, fn, tp = confusion_matrix(val_labels_list, val_preds).ravel()
        val_sens = tp / (tp + fn) if (tp + fn) > 0 else 0
        val_spec = tn / (tn + fp) if (tn + fp) > 0 else 0
        val_f1 = f1_score(val_labels_list, val_preds)

             #validation loss
        validation_loss = 0.0
        with torch.no_grad():
            for features, labels in test_loader:
                outputs = model.fc(features).squeeze(1)
                loss = criterion(outputs, labels)
                validation_loss += loss.item()
        validation_loss /= len(test_loader)


        #metrics: Final Test
        final_preds_raw, final_labels = [], []
        test_loss = 0.0
        with torch.no_grad():
            for features, labels in final_loader:
                outputs = model.fc(features).squeeze(1)
                final_preds_raw.extend(outputs.cpu().numpy())
                final_labels.extend(labels.cpu().numpy())
                loss = criterion(outputs, labels)
                test_loss += loss.item()
        test_loss /= len(final_loader)

        final_preds = (np.array(final_preds_raw) > best_thresh).astype(int)
        final_acc = accuracy_score(final_labels, final_preds)
        tn, fp, fn, tp = confusion_matrix(final_labels, final_preds).ravel()
        final_sens = tp / (tp + fn) if (tp + fn) > 0 else 0
        final_spec = tn / (tn + fp) if (tn + fp) > 0 else 0
        final_f1 = f1_score(final_labels, final_preds)
        #debug
        #print(f"Fold {fold}: {len(final_preds)} predictions vs {len(final_filenames)} filenames")
        #identify and print misclassified final test images
        for global_idx, (pred, label) in enumerate(zip(final_preds, final_labels)):
         if pred != label:
            filename = final_filenames[global_idx]
            misclassified_final_samples.append((fold, epoch+1, filename, int(label), int(pred)))


        #epoch summary
        print(f"Epoch {epoch+1}/{n_epochs} | "
      f"Train | Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | Sens: {train_sens:.4f} | Spec: {train_spec:.4f} | F1: {train_f1:.4f} | "
      f"Validate | Acc: {val_acc:.4f} | Sens: {val_sens:.4f} | Spec: {val_spec:.4f} | F1: {val_f1:.4f} || "
      f"Final Test | Acc: {final_acc:.4f} | Sens: {final_sens:.4f} | Spec: {final_spec:.4f} | F1: {final_f1:.4f} | "
      f"TP: {tp} | FP: {fp} | FN: {fn} | TN: {tn} | "
      f"Validation Loss: {validation_loss:.4f} | Test Loss: {test_loss:.4f}")


      #after all epochs of the fold finished ===
        #print(f"\nMisclassified images for Fold {fold}:")
        for fold_id, epoch_id, filename, true_label, pred_label in misclassified_final_samples:
         if epoch_id == n_epochs:  # only print misclassifications from final epoch
            print(f"[Fold {fold_id}] {filename} | True: {true_label} → Predicted: {pred_label}")

#save misclassified Final Test images for this fold ===
output_path = f"/content/drive/MyDrive/Dataset/misclassified_fold_{fold}.txt"
with open(output_path, "w") as f:
  for fold_id, epoch_id, filename, true_label, pred_label in misclassified_final_samples:
    f.write(f"Fold {fold_id}, Epoch {epoch_id}, File: {filename}, True Label: {true_label}, Predicted: {pred_label}\n")





 Fold 1
Extracting DINOv2 features...
Feature extraction successful.
Epoch 1/50 | Train | Loss: 2.2939 | Acc: 0.5667 | Sens: 0.8889 | Spec: 0.2444 | F1: 0.6723 | Validate | Acc: 0.6522 | Sens: 0.8333 | Spec: 0.4545 | F1: 0.7143 || Final Test | Acc: 0.5172 | Sens: 0.8571 | Spec: 0.2000 | F1: 0.6316 | TP: 12 | FP: 12 | FN: 2 | TN: 3 | Validation Loss: 0.8059 | Test Loss: 0.8023
Epoch 2/50 | Train | Loss: 2.0155 | Acc: 0.6444 | Sens: 0.9778 | Spec: 0.3111 | F1: 0.7333 | Validate | Acc: 0.6957 | Sens: 0.9167 | Spec: 0.4545 | F1: 0.7586 || Final Test | Acc: 0.5172 | Sens: 0.7857 | Spec: 0.2667 | F1: 0.6111 | TP: 11 | FP: 11 | FN: 3 | TN: 4 | Validation Loss: 0.7539 | Test Loss: 0.7628
Epoch 3/50 | Train | Loss: 1.8671 | Acc: 0.6667 | Sens: 0.9556 | Spec: 0.3778 | F1: 0.7414 | Validate | Acc: 0.6957 | Sens: 0.9167 | Spec: 0.4545 | F1: 0.7586 || Final Test | Acc: 0.5862 | Sens: 0.8571 | Spec: 0.3333 | F1: 0.6667 | TP: 12 | FP: 10 | FN: 2 | TN: 5 | Validation Loss: 0.7155 | Test Loss: 0.7361


In [None]:
# Classification using MLP head (no block unfreezing + sensitivity-prioritized threshold)
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, f1_score
import os
import glob

# Custom dataset class with optional augmentation
class AugmentedTensorDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.y[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

# Ensure model is loaded
assert 'model' in globals(), "DINOv2 model not loaded!"

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Freeze all ViT parameters (no unfreezing)
for param in model.parameters():
    param.requires_grad = False

# Define MLP head (trainable by default)
num_features = model.norm.weight.shape[0]
model.fc = nn.Sequential(
    nn.Linear(num_features, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 1)
).to(device)

folds_folder = "/content/drive/MyDrive/Dataset/Split_FoldsBalanced"
label_mapping = {"Healthy": 0, "Unhealthy": 1}
n_epochs = 50
batch_size = 32
learning_rate = 2e-4
weight_decay = 1e-4

# Loop through each fold
for fold in range(1, 6):
    print(f"\nTraining on Fold {fold}...")
    final_test_folder = os.path.join("/content/drive/MyDrive/Dataset/Split_Folds/SPLITSBalanced", f"fold_{fold}", "Final Test")
    final_filenames = sorted([os.path.basename(f) for f in glob.glob(os.path.join(final_test_folder, "*"))])

    fold_path = os.path.join(folds_folder, f"dinov2_fold_{fold}.pt")
    data = torch.load(fold_path, map_location=device, weights_only=False)

    X_train = data["X_train"].float().to(device)
    y_train = torch.tensor([label_mapping[l] for l in data["y_train"]], dtype=torch.float32, device=device)
    X_test = data["X_test"].float().to(device)
    y_test = torch.tensor([label_mapping[l] for l in data["y_test"]], dtype=torch.float32, device=device)
    X_final = data["X_final"].float().to(device)
    y_final = torch.tensor([label_mapping[l] for l in data["y_final"]], dtype=torch.float32, device=device)

    pos_weight = torch.tensor([len(y_train) / (2 * y_train.sum())], dtype=torch.float32, device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.AdamW(model.fc.parameters(), lr=learning_rate, weight_decay=weight_decay)

    mean = X_train.mean(dim=(0, 2, 3), keepdim=True)
    std = X_train.std(dim=(0, 2, 3), keepdim=True) + 1e-6

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(45),
        transforms.ColorJitter(brightness=0.6, contrast=0.6),
        transforms.RandomErasing(p=0.7, scale=(0.02, 0.25)),
        transforms.Normalize(mean=mean.squeeze(), std=std.squeeze()),
    ])
    test_transform = transforms.Compose([
        transforms.Normalize(mean=mean.squeeze(), std=std.squeeze())
    ])

    train_dataset = AugmentedTensorDataset(X_train, y_train, train_transform)
    test_dataset = AugmentedTensorDataset(X_test, y_test, test_transform)
    final_dataset = AugmentedTensorDataset(X_final, y_final, test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    final_loader = DataLoader(final_dataset, batch_size=batch_size)

    misclassified_final_samples = []

    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        train_preds_raw, train_labels = [], []

        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = model.fc(model.forward_features(batch_X)["x_norm_clstoken"]).squeeze(1)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_preds_raw.extend(outputs.detach().cpu().numpy())
            train_labels.extend(batch_y.cpu().numpy())

        model.eval()
        val_preds_raw, val_labels = [], []
        val_loss = 0
        with torch.no_grad():
            for batch_X, batch_y in test_loader:
                outputs = model.fc(model.forward_features(batch_X)["x_norm_clstoken"]).squeeze(1)
                loss = criterion(outputs, batch_y)
                val_loss += loss.item()
                val_preds_raw.extend(outputs.cpu().numpy())
                val_labels.extend(batch_y.cpu().numpy())

        # Threshold selection: Sens ≥ 0.90, then max F1
        fpr, tpr, thresholds = roc_curve(val_labels, val_preds_raw)
        high_sens_mask = tpr >= 0.90
        candidate_thresholds = thresholds[high_sens_mask]

        if candidate_thresholds.size > 0:
            f1_scores = []
            for i, thresh in enumerate(candidate_thresholds):
                preds = (np.array(val_preds_raw) > thresh).astype(int)
                if (tpr[high_sens_mask][i] >= 0.90):
                    f1_scores.append(f1_score(val_labels, preds))
                else:
                    f1_scores.append(0)
            best_threshold = candidate_thresholds[np.argmax(f1_scores)]
        else:
            print("No threshold satisfies sensitivity ≥ 0.90 — using 0.5 fallback.")
            best_threshold = 0.5

        # Train metrics
        train_preds = (np.array(train_preds_raw) > best_threshold).astype(int)
        train_acc = accuracy_score(train_labels, train_preds)
        tn, fp, fn, tp = confusion_matrix(train_labels, train_preds).ravel()
        train_sens = tp / (tp + fn) if (tp + fn) > 0 else 0
        train_spec = tn / (tn + fp) if (tn + fp) > 0 else 0
        train_f1 = f1_score(train_labels, train_preds)

        # Validation metrics
        val_preds = (np.array(val_preds_raw) > best_threshold).astype(int)
        val_f1 = f1_score(val_labels, val_preds)
        val_acc = accuracy_score(val_labels, val_preds)
        tn, fp, fn, tp = confusion_matrix(val_labels, val_preds).ravel()
        val_sens = tp / (tp + fn) if (tp + fn) > 0 else 0
        val_spec = tn / (tn + fp) if (tn + fp) > 0 else 0

        # Final test metrics
        final_preds_raw, final_labels = [], []
        final_loss = 0
        with torch.no_grad():
            for batch_X, batch_y in final_loader:
                outputs = model.fc(model.forward_features(batch_X)["x_norm_clstoken"]).squeeze(1)
                loss = criterion(outputs, batch_y)
                final_loss += loss.item()
                final_preds_raw.extend(outputs.cpu().numpy())
                final_labels.extend(batch_y.cpu().numpy())

        final_preds = (np.array(final_preds_raw) > best_threshold).astype(int)
        final_f1 = f1_score(final_labels, final_preds)
        final_acc = accuracy_score(final_labels, final_preds)
        tn_f, fp_f, fn_f, tp_f = confusion_matrix(final_labels, final_preds).ravel()
        final_sens = tp_f / (tp_f + fn_f) if (tp_f + fn_f) > 0 else 0
        final_spec = tn_f / (tn_f + fp_f) if (tn_f + fp_f) > 0 else 0

        print(f"Epoch {epoch+1}/{n_epochs} |  Train | Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | Sens: {train_sens:.4f} | Spec: {train_spec:.4f} | F1: {train_f1:.4f} | "
        f"Validate | Loss : {val_loss:.4f} | Acc: {val_acc:.4f} | Sens: {val_sens:.4f} | Spec: {val_spec:.4f} | F1: {val_f1:.4f} || "
        f"Final Test | Loss : {final_loss:.4f} | Acc: {final_acc:.4f} | Sens: {final_sens:.4f} | Spec: {final_spec:.4f} | F1: {final_f1:.4f} | "
       f"TP: {tp_f} | FP: {fp_f} | FN: {fn_f} | TN: {tn_f}")

        if epoch == n_epochs - 1:
            for global_idx, (pred, label) in enumerate(zip(final_preds, final_labels)):
                if pred != label:
                    filename = final_filenames[global_idx]
                    misclassified_final_samples.append((fold, epoch+1, filename, int(label), int(pred)))

    print(f"\nMisclassified images for Fold {fold}:")
    for fold_id, epoch_id, filename, true_label, pred_label in misclassified_final_samples:
        print(f"[Fold {fold_id}] {filename} | True: {true_label} → Predicted: {pred_label}")

In [None]:


#Classification using MLP head using Youden's Index and unfreezing block 11
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, f1_score
import os
import glob

#custom dataset class with optional augmentation
class AugmentedTensorDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X #input features
        self.y = y #labels 1 or 0
        self.transform = transform #optional data augm

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.y[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

#ensure model is loaded
assert 'model' in globals(), "DINOv2 model not loaded!"
#set the device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define final_filenames for tracking misclassified images
final_test_folder = os.path.join("/content/drive/MyDrive/Dataset/Split_FoldsBalanced/SPLITS", f"fold_{fold}", "Final Test")
final_filenames = sorted([os.path.basename(f) for f in glob.glob(os.path.join(final_test_folder, "*"))])

#freeze all model parameters except for final transformer block and classfication head
for param in model.parameters():
    param.requires_grad = False
for name, param in model.named_parameters():
    if "block.11" in name or "fc" in name:
        param.requires_grad = True

#MLP HEAD
num_features = model.norm.weight.shape[0]
model.fc = nn.Sequential(
    nn.Linear(num_features, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 1)
).to(device)


folds_folder = "/content/drive/MyDrive/Dataset/Split_FoldsBalanced"
label_mapping = {"Healthy": 0, "Unhealthy": 1} #map classes
n_epochs = 50
batch_size = 32
learning_rate_fc = 2e-4
learning_rate_vit = 5e-7
weight_decay = 1e-4

#loop over each fold
for fold in range(1, 6):
    print(f"\nTraining on Fold {fold}...")
 #load fold data
    fold_path = os.path.join(folds_folder, f"dinov2_fold_{fold}.pt")
    data = torch.load(fold_path, map_location=device, weights_only=False)
#extract features and labels and move to device
    X_train = data["X_train"].float().to(device)
    y_train = torch.tensor([label_mapping[l] for l in data["y_train"]], dtype=torch.float32, device=device)
    X_test = data["X_test"].float().to(device)
    y_test = torch.tensor([label_mapping[l] for l in data["y_test"]], dtype=torch.float32, device=device)
    X_final = data["X_final"].float().to(device)
    y_final = torch.tensor([label_mapping[l] for l in data["y_final"]], dtype=torch.float32, device=device)
    #handle class imbalance by using a positive class weight
    pos_weight = torch.tensor([len(y_train) / (2 * y_train.sum())], dtype=torch.float32, device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    #optimizer for mlp and last block
    optimizer = optim.AdamW([
        {'params': model.fc.parameters(), 'lr': learning_rate_fc},
        {'params': model.blocks[11:].parameters(), 'lr': learning_rate_vit},
    ], weight_decay=weight_decay)

    #learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    #compute mean and std for normalization
    mean = X_train.mean(dim=(0, 2, 3), keepdim=True)
    std = X_train.std(dim=(0, 2, 3), keepdim=True) + 1e-6

    #define data augmentation
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(45),
        transforms.ColorJitter(brightness=0.6, contrast=0.6),
        transforms.RandomErasing(p=0.7, scale=(0.02, 0.25)),
        transforms.Normalize(mean=mean.squeeze(), std=std.squeeze()),
    ])
    test_transform = transforms.Compose([transforms.Normalize(mean=mean.squeeze(), std=std.squeeze())])

    #create datasets and data loaders
    train_dataset = AugmentedTensorDataset(X_train, y_train, train_transform)
    test_dataset = AugmentedTensorDataset(X_test, y_test, test_transform)
    final_dataset = AugmentedTensorDataset(X_final, y_final, test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    final_loader = DataLoader(final_dataset, batch_size=batch_size, shuffle=False)
    #create a list to track misclassified samples
    misclassified_final_samples = []

    #training loop
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        train_preds_raw, train_labels = [], []

        #TRAINING
        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            #forward pass through vit and mlp
            outputs = model.fc(model.forward_features(batch_X)["x_norm_clstoken"]).squeeze(1)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_preds_raw.extend(outputs.detach().cpu().numpy())
            train_labels.extend(batch_y.cpu().numpy())

        #optimal threshold using roc curve
        fpr, tpr, thresholds = roc_curve(train_labels, train_preds_raw)
        best_threshold = thresholds[np.argmax(tpr - fpr)]
        train_preds = (np.array(train_preds_raw) > best_threshold).astype(int)
        train_f1 = f1_score(train_labels, train_preds)
        train_acc = accuracy_score(train_labels, train_preds)
        tn, fp, fn, tp = confusion_matrix(train_labels, train_preds).ravel()
        train_sens = tp / (tp + fn) if (tp + fn) > 0 else 0
        train_spec = tn / (tn + fp) if (tn + fp) > 0 else 0

        #VALIDATION
        model.eval()
        val_preds_raw, val_labels = [], []
        val_loss = 0
        with torch.no_grad():
            for batch_X, batch_y in test_loader:
                outputs = model.fc(model.forward_features(batch_X)["x_norm_clstoken"]).squeeze(1)
                loss = criterion(outputs, batch_y)
                val_loss += loss.item()
                val_preds_raw.extend(outputs.cpu().numpy())
                val_labels.extend(batch_y.cpu().numpy())

        val_preds = (np.array(val_preds_raw) > best_threshold).astype(int)
        val_f1 = f1_score(val_labels, val_preds)
        val_acc = accuracy_score(val_labels, val_preds)
        tn, fp, fn, tp = confusion_matrix(val_labels, val_preds).ravel()
        val_sens = tp / (tp + fn) if (tp + fn) > 0 else 0
        val_spec = tn / (tn + fp) if (tn + fp) > 0 else 0

        #FINAL TESTING
        final_preds_raw, final_labels = [], []
        final_loss = 0
        with torch.no_grad():
            for batch_X, batch_y in final_loader:
                outputs = model.fc(model.forward_features(batch_X)["x_norm_clstoken"]).squeeze(1)
                loss = criterion(outputs, batch_y)
                final_loss += loss.item()
                final_preds_raw.extend(outputs.cpu().numpy())
                final_labels.extend(batch_y.cpu().numpy())

        final_preds = (np.array(final_preds_raw) > best_threshold).astype(int)
        final_f1 = f1_score(final_labels, final_preds)
        final_acc = accuracy_score(final_labels, final_preds)
        tn_f, fp_f, fn_f, tp_f = confusion_matrix(final_labels, final_preds).ravel()
        final_sens = tp_f / (tp_f + fn_f) if (tp_f + fn_f) > 0 else 0
        final_spec = tn_f / (tn_f + fp_f) if (tn_f + fp_f) > 0 else 0


        print(f"Epoch {epoch+1}/{n_epochs} |  Train | Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | Sens: {train_sens:.4f} | Spec: {train_spec:.4f} | F1: {train_f1:.4f} | "
      f"Validate | Loss : {val_loss:.4f} | Acc: {val_acc:.4f} | Sens: {val_sens:.4f} | Spec: {val_spec:.4f} | F1: {val_f1:.4f} || "
      f"Final Test | Loss : {final_loss:.4f} | Acc: {final_acc:.4f} | Sens: {final_sens:.4f} | Spec: {final_spec:.4f} | F1: {final_f1:.4f} | "
      f"TP: {tp_f} | FP: {fp_f} | FN: {fn_f} | TN: {tn_f}")


  #       #track misclassified images in final test
  #   if epoch == n_epochs - 1:  # only track misclassified images after the final epoch
  #       for global_idx, (pred, label) in enumerate(zip(final_preds, final_labels)):
  #           if pred != label:
  #               filename = final_filenames[global_idx]
  #               misclassified_final_samples.append((fold, epoch+1, filename, int(label), int(pred)))
  #       scheduler.step()
  #  #after all epochs for this fold, print and save the misclassified images
  #   print(f"\nMisclassified images for Fold {fold}:")
  #   for fold_id, epoch_id, filename, true_label, pred_label in misclassified_final_samples:
  #      if epoch_id == n_epochs:  # Only print misclassifications from final epoch
  #       print(f"[Fold {fold_id}] {filename} | True: {true_label} → Predicted: {pred_label}")







Training on Fold 1...
Epoch 1/50 |  Train | Loss: 2.2133 | Acc: 0.5556 | Sens: 0.3556 | Spec: 0.7556 | F1: 0.4444 | Validate | Loss : 0.7279 | Acc: 0.4783 | Sens: 0.0000 | Spec: 1.0000 | F1: 0.0000 || Final Test | Loss : 0.6926 | Acc: 0.5172 | Sens: 0.0000 | Spec: 1.0000 | F1: 0.0000 | TP: 0 | FP: 0 | FN: 14 | TN: 15
Epoch 2/50 |  Train | Loss: 2.2972 | Acc: 0.5000 | Sens: 0.0000 | Spec: 1.0000 | F1: 0.0000 | Validate | Loss : 0.7012 | Acc: 0.4783 | Sens: 0.0000 | Spec: 1.0000 | F1: 0.0000 || Final Test | Loss : 0.6793 | Acc: 0.5172 | Sens: 0.0000 | Spec: 1.0000 | F1: 0.0000 | TP: 0 | FP: 0 | FN: 14 | TN: 15
Epoch 3/50 |  Train | Loss: 2.0384 | Acc: 0.6000 | Sens: 0.7111 | Spec: 0.4889 | F1: 0.6400 | Validate | Loss : 0.6719 | Acc: 0.4783 | Sens: 0.5000 | Spec: 0.4545 | F1: 0.5000 || Final Test | Loss : 0.6636 | Acc: 0.5862 | Sens: 0.7857 | Spec: 0.4000 | F1: 0.6471 | TP: 11 | FP: 9 | FN: 3 | TN: 6
Epoch 4/50 |  Train | Loss: 1.8987 | Acc: 0.6778 | Sens: 0.5778 | Spec: 0.7778 | F1: 0.

In [None]:
# import torch

# # Flatten images for easier comparison
# X_train_tensor = fold_data["X_train"].reshape(fold_data["X_train"].shape[0], -1)  # (N, Flattened)
# X_test_tensor = fold_data["X_test"].reshape(fold_data["X_test"].shape[0], -1)  # (M, Flattened)

# #  Use broadcasting for efficient comparison
# tolerance = 1e-6  # Allow small differences due to floating point precision
# matches = (torch.abs(X_test_tensor.unsqueeze(1) - X_train_tensor) < tolerance).all(dim=2)

# #  Count duplicates
# overlapping_samples = matches.any(dim=1).sum().item()

# #  Print the number of overlapping images
# if overlapping_samples > 0:
#     print(f"⚠ WARNING: {overlapping_samples} test images are nearly identical to training images! Possible data leakage detected!")
# else:
#     print(" No identical test images found in the training set. No data leakage detected.")


In [None]:
# test_duplicates = matches.any(dim=1).sum().item()
# print(f"Test Images Detected as Duplicates: {test_duplicates}")

# # If duplicates exist, print a few to inspect manually
# duplicate_indices = torch.where(matches.any(dim=1))[0].cpu().numpy()
# for idx in duplicate_indices[:5]:  # Show first few duplicates
#     print(f"Test Image {idx} matches a training image.")


In [None]:
# import torch

# # Find training indices where the matches occur
# matching_train_indices = torch.where(matches)[1].cpu().numpy()  # Get training image indices

# # Print which training images match each duplicated test image
# for test_idx, train_idx in zip(duplicate_indices, matching_train_indices):
#     print(f"Test Image {test_idx} is identical to Training Image {train_idx}.")


In [None]:
# CHECK FOR DUPLICATES IN ORIGINAL DATASET

# Define a transformation to resize images and convert to tensor
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize to 224x224 (DINOv2 input size)
#     transforms.ToTensor(),  # Convert to tensor
# ])

# # Function to load all images from a folder
# def load_images_from_folder(folder):
#     images = []
#     image_filenames = []

#     for filename in os.listdir(folder):
#         file_path = os.path.join(folder, filename)
#         if filename.lower().endswith(('.png', '.jpg', '.jpeg')):  # Ensure it's an image
#             try:
#                 img = Image.open(file_path).convert("RGB")  # Ensure 3-channel RGB
#                 img_tensor = transform(img)  # Apply transformations
#                 images.append(img_tensor)
#                 image_filenames.append(filename)  # Store filenames for debugging
#             except Exception as e:
#                 print(f"Error loading {filename}: {e}")

#     return images, image_filenames

# # Load images from both classes
# healthy_images, healthy_filenames = load_images_from_folder(healthy_path)
# sick_images, sick_filenames = load_images_from_folder(sick_path)

# # Combine into a full dataset
# full_images = healthy_images + sick_images
# full_filenames = healthy_filenames + sick_filenames  # Keep filenames for debugging

# print(f"Total images loaded: {len(full_images)}")
# #  Step 2: Convert Images to Hashes for Duplicate Detection
# # Now that we have loaded all images as tensors, we can check for duplicates by hashing each image.

# # python
# # Copy

# def hash_image(image_tensor):
#     """Generate a hash for an image tensor to detect duplicates."""
#     return hashlib.md5(image_tensor.numpy().tobytes()).hexdigest()

# # Generate hashes for all images
# image_hashes = [hash_image(img) for img in full_images]

# # Find duplicates
# hash_counts = {}
# duplicates = []

# for i, hash_value in enumerate(image_hashes):
#     if hash_value in hash_counts:
#         duplicates.append((full_filenames[i], full_filenames[hash_counts[hash_value]]))  # Store duplicate filenames
#     else:
#         hash_counts[hash_value] = i

# # Print duplicate results
# if duplicates:
#     print(f"⚠ Found {len(duplicates)} duplicate images!")
#     for dup1, dup2 in duplicates:
#         print(f"Duplicate: {dup1} and {dup2}")
# else:
#     print(" No duplicates found before splitting.")
# def hash_image(image_tensor):
#     """Generate a hash for an image tensor to detect duplicates."""
#     return hashlib.md5(image_tensor.numpy().tobytes()).hexdigest()

# # Generate hashes for all images
# image_hashes = [hash_image(img) for img in full_images]

# # Find duplicates
# hash_counts = {}
# duplicates = []

# for i, hash_value in enumerate(image_hashes):
#     if hash_value in hash_counts:
#         duplicates.append((full_filenames[i], full_filenames[hash_counts[hash_value]]))  # Store duplicate filenames
#     else:
#         hash_counts[hash_value] = i

# # Print duplicate results
# if duplicates:
#     print(f"⚠ Found {len(duplicates)} duplicate images!")
#     for dup1, dup2 in duplicates:
#         print(f"Duplicate: {dup1} and {dup2}")
# else:
#     print(" No duplicates found before splitting.")
