## 1. Setup and Imports

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder

In [2]:
# RETFound encoder
!cd RETFound_MAE
!pip install -r requirements.txt

import sys
sys.path.append('D:/AI_Project_BME/Multimodal-Classification/RETFound_MAE')
from models_vit import RETFound_mae

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu124


## 2. Load Images and Labels

Use a custom Dataset class to parse the OCT image directory and match each image with its diagnosis label.

- Total number of images: 28696 
- Total number of instances for classification: 173

In [7]:
def find_b_scans_directory(root_path):
    for dirpath, _, filenames in os.walk(root_path):
        if any(fname.lower().endswith(('.jpg', '.png')) for fname in filenames):
            return dirpath
    return None  

class OCTDataset(Dataset):
    def __init__(self, root_dir, label_path):
        self.root_dir = root_dir
        self.image_paths = []
        self.labels = []
        self.loaded_volume_ids = set() 
        self.expected_volume_ids = set() 

        if not os.path.isfile(label_path):
            print("Not a valid label path")
            return

        labels_df = pd.read_excel(label_path)
        le = LabelEncoder()
        labels_df['label'] = le.fit_transform(labels_df["Diagnosis Label"])
        self.class_mapping = le.classes_
        self.labels_df = labels_df

        # Compute expected volume_ids from label file
        self.expected_volume_ids = set(
            labels_df.apply(
                lambda row: f"{int(row['Patient Number'])}_{row['Laterality']}_{int(row['Diagnosis Date'])}",
                axis=1
            )
        )

        for patient_id in os.listdir(root_dir):
            if labels_df['Patient Number'].isin([int(patient_id)]).any():
                patient_df = labels_df[labels_df['Patient Number'] == int(patient_id)]
                patient_path = os.path.join(root_dir, patient_id)
                if not os.path.isdir(patient_path):
                    continue
                for eye in ["L", "R"]:
                    if patient_df['Laterality'].isin([eye]).any():
                        eye_path = os.path.join(patient_path, eye)
                        eye_df = patient_df[patient_df['Laterality'] == eye]
                        if not os.path.isdir(eye_path):
                            continue
                        for scan_date in os.listdir(eye_path):
                            if eye_df['Diagnosis Date'].isin([int(scan_date)]).any():
                                scan_date_path = os.path.join(eye_path, scan_date)
                                scan_date_df = eye_df[eye_df['Diagnosis Date'] == int(scan_date)]
                                if not os.path.isdir(scan_date_path):
                                    continue
                                b_scans_path = find_b_scans_directory(scan_date_path)
                                if b_scans_path and os.path.isdir(b_scans_path):
                                    volume_id = f"{int(patient_id)}_{eye}_{scan_date}"
                                    self.loaded_volume_ids.add(volume_id) 
                                    for img_name in os.listdir(b_scans_path):
                                        img_path = os.path.join(b_scans_path, img_name)
                                        if img_path.endswith(".jpg") or img_path.endswith(".png"):
                                            self.image_paths.append(img_path)
                                            self.labels.append(scan_date_df['label'].iloc[0])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        return self.image_paths[idx], self.labels[idx]

    def get_label_map(self):
        return self.class_mapping

    def report_missing_volumes(self):
        """Print and save volumes that are in label file but not loaded due to missing images"""
        missing = self.expected_volume_ids - self.loaded_volume_ids
        print(f" Total expected volumes in label file: {len(self.expected_volume_ids)}")
        print(f" Volumes successfully loaded from image folders: {len(self.loaded_volume_ids)}")
        print(f" Missing volumes: {len(missing)}")
        for vid in list(sorted(missing)):
            print(" -", vid)

In [8]:
root_dir="D:/cleaning_GUI_annotated_Data/Cirrus_OCT_Imaging_Data"
label_path="D:/AI_Project_BME/annotations_modified.xlsx"

dataset = OCTDataset(root_dir, label_path)
dataset.report_missing_volumes()

 Total expected volumes in label file: 173
 Volumes successfully loaded from image folders: 173
 Missing volumes: 0


In [9]:
print(len(dataset))

28696


## 3. Verify Loaded Image & Label Consistency

Manually inspect a few samples to ensure images are correctly matched to labels.

In [12]:
# Reload the label file 
labels_df = pd.read_excel(label_path)
le = LabelEncoder()
labels_df['label'] = le.fit_transform(labels_df["Diagnosis Label"])
label_map = le.classes_

mismatch_count = 0
N = 100 # Number of images to verify

for i in range(N):
    img_path = dataset.image_paths[i] # get image path
    label_from_dataset = dataset.labels[i] # get label

    # Extract "patient_id, eye, scan_date" from img_path
    parts = img_path.split(os.sep)
    try:
        scan_date = int(parts[-8])      # e.g., 20080818
        eye = parts[-9]                 # e.g., L
        patient_id = int(parts[-10])    # e.g., 000003162
    except (IndexError, ValueError):
        print(f"[!] Skipping due to unexpected path format: {img_path}")
        continue

    # Match against label file
    match_df = labels_df[
        (labels_df["Patient Number"] == patient_id) &
        (labels_df["Laterality"] == eye) &
        (labels_df["Diagnosis Date"] == scan_date)
    ]

    if match_df.empty:
        print(f"[!] No match found for: Patient {patient_id}, Eye {eye}, Date {scan_date}")
        mismatch_count += 1
        continue

    label_from_excel = match_df['label'].iloc[0]
    label_name_from_excel = label_map[label_from_excel]
    label_name_from_dataset = label_map[label_from_dataset]

    if label_from_dataset != label_from_excel:
        print(f"[✗] Label mismatch: {img_path}")
        print(f"    Dataset label: {label_from_dataset} ({label_name_from_dataset})")
        print(f"    Excel label:   {label_from_excel} ({label_name_from_excel})")
        mismatch_count += 1
    else:
        print(f"[✓] Label match: {img_path} -> {label_from_dataset} ({label_name_from_dataset})")

print(f"\n Verification complete: Compared {N} samples, found {mismatch_count} mismatches.")

[✓] Label match: D:/cleaning_GUI_annotated_Data/Cirrus_OCT_Imaging_Data\000003162\L\20080818\124239\OPT\Carl_Zeiss_Meditec\200X1024X200\Original\B-Scans\000003162_L_20080818_124239_200X1024X200_ORG_IMG_JPG_001.jpg -> 1 (GA)
[✓] Label match: D:/cleaning_GUI_annotated_Data/Cirrus_OCT_Imaging_Data\000003162\L\20080818\124239\OPT\Carl_Zeiss_Meditec\200X1024X200\Original\B-Scans\000003162_L_20080818_124239_200X1024X200_ORG_IMG_JPG_002.jpg -> 1 (GA)
[✓] Label match: D:/cleaning_GUI_annotated_Data/Cirrus_OCT_Imaging_Data\000003162\L\20080818\124239\OPT\Carl_Zeiss_Meditec\200X1024X200\Original\B-Scans\000003162_L_20080818_124239_200X1024X200_ORG_IMG_JPG_003.jpg -> 1 (GA)
[✓] Label match: D:/cleaning_GUI_annotated_Data/Cirrus_OCT_Imaging_Data\000003162\L\20080818\124239\OPT\Carl_Zeiss_Meditec\200X1024X200\Original\B-Scans\000003162_L_20080818_124239_200X1024X200_ORG_IMG_JPG_004.jpg -> 1 (GA)
[✓] Label match: D:/cleaning_GUI_annotated_Data/Cirrus_OCT_Imaging_Data\000003162\L\20080818\124239\OPT\

## 4. Load RETFound Encoder & Pretrained Weights

Initialize the ViT-Large encoder from RETFound and load the pretrained weights (natureOCT MAE).

In [13]:
# Load RETFound encoder (ViT-Large) & Pre-trained weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RETFound_mae(img_size=224, num_classes=0).to(device)

checkpoint = torch.load("D:/AI_Project_BME/Multimodal-Classification/RETFound_MAE/RETFound_mae_natureOCT.pth", map_location=device, weights_only=False)
state_dict = {k: v for k, v in checkpoint['model'].items() if not k.startswith("decoder") and "mask_token" not in k}
model.load_state_dict(state_dict, strict=True)
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Id

## 5. Define RETFound Transform & Image Encoding Function

Define a function that applies the RETFound and outputs a 1024-dim encoding vector for each image.

In [14]:
retfound_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])


def get_image_encoding(image_path, model, transform, device):
    image = Image.open(image_path).convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model(img_tensor)
    return features.squeeze(0).cpu().numpy()  # return [1024]

## 6. Run Encoding & Save Encoded Features

- `images_encoding.npy`: Contains the 1024-dimensional RETFound encoding for each individual OCT image.
- `diagnosis_labels.npy`: Contains the integer-encoded disease classification labels for each image. These are the targets used to train the image-level classifier.


In [15]:
all_features, all_labels = [], []
for img_path, label in tqdm(dataset):
    encoding = get_image_encoding(img_path, model, retfound_transform, device)
    all_features.append(encoding)
    all_labels.append(label)

# Save encoding features & diagnosis labels to disk
np.save("D:/AI_Project_BME/Multimodal-Classification/outputs/images_encoding.npy", np.stack(all_features))
np.save("D:/AI_Project_BME/Multimodal-Classification/outputs/diagnosis_labels.npy", np.array(all_labels))

100%|█████████████████████████████████████████████████████| 28696/28696 [27:45<00:00, 17.23it/s]


## 7. Aggregate Volume-wise Encodings
Aggregate the image-level encodings using **mean pooling** to form a single 1024-dimensional feature vector for each volume.

- `volume_features.npy`: Contains the mean-pooled RETFound encodings for each OCT volume (i.e., each patient-eye-date group). Shape: `[#volumes, 1024]`
- `volume_labels.npy`: Contains the integer-encoded disease classification labels for each volume. These labels are the same as in the original dataset but aggregated to match the volume level.

In [16]:
from collections import defaultdict

# Group features by volume
volume_features = defaultdict(list)
volume_labels = dict()

for img_path, feature, label in zip(dataset.image_paths, all_features, all_labels):
    parts = img_path.split(os.sep)
    try:
        patient_id = parts[-10]
        eye = parts[-9]
        scan_date = parts[-8]
        volume_id = f"{patient_id}_{eye}_{scan_date}"
    except IndexError:
        print(f"[!] Failed to parse path: {img_path}")
        continue

    volume_features[volume_id].append(feature)
    volume_labels[volume_id] = label  # All images from the same volume share the same label

# Aggregate (mean-pooling)
aggregated_features = []
aggregated_labels = []

for vid, feats in volume_features.items():
    pooled = np.mean(feats, axis=0)  # [num_images, 1024] → [1024]
    aggregated_features.append(pooled)
    aggregated_labels.append(volume_labels[vid])

# Save volume-wise features
np.save("D:/AI_Project_BME/Multimodal-Classification/outputs/volume_features.npy", np.stack(aggregated_features))
np.save("D:/AI_Project_BME/Multimodal-Classification/outputs/volume_labels.npy", np.array(aggregated_labels))

print(f" Saved {len(aggregated_features)} volume-level encodings.")

 Saved 173 volume-level encodings.


## 8. Test

In [31]:
import numpy as np
image_encodings = np.load("D:/AI_Project_BME/Multimodal-Classification/outputs/images_encoding.npy")           # shape: [N, 1024]
diagnosis_labels = np.load("D:/AI_Project_BME/Multimodal-Classification/outputs/diagnosis_labels.npy")         # shape: [N]

# === Print shapes ===
print("Image-Level Encodings:")
print(f"  Shape: {image_encodings.shape}")
print(f"  First vector (truncated): {image_encodings[0][:5]} ...")

print("Diagnosis Labels:")
print(f"  Shape: {diagnosis_labels.shape}")
print(f"  First 10 labels: {diagnosis_labels[:10]}")

# === Label statistics ===
unique_labels, counts = np.unique(diagnosis_labels, return_counts=True)
print(f"Found {len(unique_labels)} unique classes.")
for label, count in zip(unique_labels, counts):
    print(f"  Label {label}: {count} samples")

# === Display label name mapping  ===
try:
    from sklearn.preprocessing import LabelEncoder
    import pandas as pd

    label_path = "D:/AI_Project_BME/annotations_modified.xlsx"
    labels_df = pd.read_excel(label_path)
    le = LabelEncoder()
    le.fit(labels_df["Diagnosis Label"])
    label_map = le.classes_ 

    print("Label Mapping:")
    for i, name in enumerate(label_map):
        print(f"  {i} → {name}")
except Exception as e:
    print("Could not recover label names. Skipping label map.")

Image-Level Encodings:
  Shape: (28696, 1024)
  First vector (truncated): [ 0.53131574  0.00161412 -0.18659723 -0.40004224 -1.0517174 ] ...
Diagnosis Labels:
  Shape: (28696,)
  First 10 labels: [1 1 1 1 1 1 1 1 1 1]
Found 6 unique classes.
  Label 0: 4952 samples
  Label 1: 2592 samples
  Label 2: 5592 samples
  Label 3: 1056 samples
  Label 4: 4392 samples
  Label 5: 10112 samples
Label Mapping:
  0 → Early AMD
  1 → GA
  2 → Int AMD
  3 → Not AMD
  4 → Scar
  5 → Wet


In [32]:
volume_features = np.load("D:/AI_Project_BME/Multimodal-Classification/outputs/volume_features.npy")
volume_labels = np.load("D:/AI_Project_BME/Multimodal-Classification/outputs/volume_labels.npy")
print("Volumn-Level Encodings shape:", volume_features.shape)  #  (num_volumes, 1024)
print("Volumn-level Labels shape:  ", volume_labels.shape)   #  (num_volumes,)

Volumn-Level Encodings shape: (173, 1024)
Volumn-level Labels shape:   (173,)
