# In-Context Learning for Ophthalmology

## Import libraries and define useful things

In [None]:
import sys

import numpy as np
import pandas as pd
from PIL import Image as PIL_Image
# import h5py
# import typing

import json

import os, glob

from collections import OrderedDict

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score, f1_score
from sklearn.preprocessing import OneHotEncoder

import matplotlib
from matplotlib import pyplot as plt

# from tqdm import tqdm
from tqdm.notebook import tqdm

from scipy.special import softmax
# from scipy.special import expit
from scipy.spatial import distance

In [None]:
# dr_stages = ["No Diabetic Retinopathy (Normal)", "Diabetic Retinopathy (DR)"]
dr_stages = ["Normal", "Diabetic Retinopathy (DR)"]
onset_level = 1

# IDRiD 
img_dir_tr = '../../EVAL_DATASETS/IDRiD/grading/OriginalImages/training/'
# full_path_list_tr = sorted(glob.glob(img_dir_tr + '*' + '.jpg', recursive=False))
# print(f'Number of files in {img_dir_tr}\t{len(full_path_list_tr)}', flush=True)

csv_file_tr = '../../EVAL_DATASETS/IDRiD/grading/Groundtruths/IDRiD_TrainingLabels.csv'
df_metadata_tr = pd.read_csv(csv_file_tr, low_memory=False)
df_metadata_tr = df_metadata_tr[['Image name', 'Retinopathy grade', 'Risk of macular edema ']]
label_text = []
file_paths = []
file_paths_224 = []
split = []
for idx, row in df_metadata_tr.iterrows():
    if int(row['Retinopathy grade']) < onset_level: 
        label_bin = 0
    else:
        label_bin = 1
    label_text.append(dr_stages[label_bin])
    file_paths.append(img_dir_tr + str(row['Image name']) + '.jpg')
    file_paths_224.append(img_dir_tr + 'idrid_224/' + str(row['Image name']) + '.png')
    split.append('train')
df_metadata_tr['label_text'] = label_text
df_metadata_tr['file_path'] = file_paths
df_metadata_tr['file_path_224'] = file_paths_224
df_metadata_tr['split'] = split
print(f'Metadata shape : {df_metadata_tr.shape}')
print(df_metadata_tr.columns)

img_dir_te = '../../EVAL_DATASETS/IDRiD/grading/OriginalImages/test/'
# full_path_list_te = sorted(glob.glob(img_dir_te + '*' + '.jpg', recursive=False))
# print(f'Number of files in {img_dir_te}\t{len(full_path_list_te)}', flush=True)

csv_file_te = '../../EVAL_DATASETS/IDRiD/grading/Groundtruths/IDRiD_TestLabels.csv'
df_metadata_te = pd.read_csv(csv_file_te, low_memory=False)
label_text = []
file_paths = []
file_paths_224 = []
split = []
for idx, row in df_metadata_te.iterrows():
    if int(row['Retinopathy grade']) < onset_level: 
        label_bin = 0
    else:
        label_bin = 1
    label_text.append(dr_stages[label_bin])
    file_paths.append(img_dir_te + str(row['Image name']) + '.jpg')
    file_paths_224.append(img_dir_te + 'idrid_224/' + str(row['Image name']) + '.png')
    split.append('test')
df_metadata_te['label_text'] = label_text
df_metadata_te['file_path'] = file_paths
df_metadata_te['file_path_224'] = file_paths_224
df_metadata_te['split'] = split
print(f'Metadata shape : {df_metadata_te.shape}')
print(df_metadata_te.columns)

df_metadata = pd.concat([df_metadata_tr, df_metadata_te], axis=0)
print(f'Metadata shape : {df_metadata.shape}')
print(df_metadata.columns)

del df_metadata_tr, df_metadata_te, file_paths, file_paths_224, label_text, split

## Preprare RETFound and extract feature embeddings

In [None]:
sys.path.insert(1, '../RETFound_MAE/')

import torch
import models_vit
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision import transforms as T
# from torchvision.transforms import v2 as T

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(f'Max float : {sys.float_info.max}')
print(torch.__version__)
print(f'Cuda available : {torch.cuda.is_available()}')
print(f'Number of GPUs : {torch.cuda.device_count()}')
print(f'CUDA Version : {torch.version.cuda}')
print(f'timm Version : {timm.__version__}')

def prepare_model(chkpt_dir, arch='vit_large_patch16'):
    # build model
    model = models_vit.__dict__[arch](
        img_size=224,
        num_classes=5,
        drop_path_rate=0,
        global_pool=True,
    )
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    return model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device : {device}')

chkpt_dir = '../../Projects/RETFound_MAE/RETFound_cfp_weights.pth'
vision_encoder = prepare_model(chkpt_dir, 'vit_large_patch16')

# device = torch.device('cuda')
vision_encoder.to(device)
print('Vision encoder model loaded.')

transforms = T.Compose([
    T.ToTensor(), 
    T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 
])


class IDRiD_ImageDataset(Dataset):
    def __init__(self, metadata, target_column='Retinopathy grade', 
                 transforms=None, target_transforms=None
                ):
        self.metadata = metadata 
        self.target_column = target_column        
        self.transforms = transforms
        self.target_transforms = target_transforms

    def __len__(self):
        return self.metadata.shape[0]

    def __getitem__(self, idx):

        filepath = self.metadata.iloc[idx]['file_path_224']
        with PIL_Image.open(filepath) as img:

            if len(img.size) < 3: # if single channel, convert to RGB
                img = img.convert(mode='RGB')
            
            if self.transforms:
                img = self.transforms(img)

        if int(self.metadata.iloc[idx][self.target_column]) < onset_level: 
            label_bin = 0
        else:
            label_bin = 1
        
        return img, label_bin #, os.path.basename(filepath), self.metadata.iloc[idx]['split']
    
    # def get_labels(self):
    #     # return as series for ImbalancedDatasetSampler to read into a Pandas dataframe
    #     return self.metadata[self.target_column]


# In[5]:


num_workers = 8
batch_size = 32 # 32 // n_views #52 # 128 # 96 # 128 # 208 # 164 # 112 # 75 # 22*4

# Note that shuffle is mutually exclusive with Sampler
# shuffle_dict = {'train': False, 'test': False} #, 'test': False}

idrid_dataset = IDRiD_ImageDataset(df_metadata, transforms=transforms, target_transforms=None)

dataloader = DataLoader(idrid_dataset, batch_size=batch_size,
                        shuffle=False, sampler=None, # samplers[split], 
                        num_workers=num_workers, pin_memory=True)

In [None]:
def extract_features(vision_encoder, dataloader):
    
    out_data = OrderedDict()
    out_data['features'] = []
    out_data['labels'] = []
    
    vision_encoder.eval()

    with torch.no_grad():
        for inputs, labels in tqdm(iter(dataloader)):
                    
            inputs = inputs.to(device)
            
            # a dictionary of features from various read-out layers
            # {readout_layer_name : features}
            # with torch.autocast(device_type='cuda', dtype=torch.float16): #torch.cuda.amp.autocast():
                # with torch.inference_mode(mode=True):
            # outputs = model(inputs)
            outputs = vision_encoder.forward_features(inputs)
            outputs = torch.squeeze(outputs)
            # for readout_layername, features in outputs.items():
            outputs = np.squeeze(outputs.cpu().detach().numpy())
            out_data['features'].append(outputs)
            out_data['labels'].append(labels)
            # break # only 1 readout layer name!!
    
    
    # list to numpy array
    out_data['features'] = np.concatenate(out_data['features'], axis=0) 
    out_data['labels'] = np.concatenate(out_data['labels'], axis=0) 
        
    print(f'Features : {out_data["features"].shape}') 
    print(f'Labels : {out_data["labels"].shape}, Unique labels : {np.unique(out_data["labels"], return_counts=True)}') 
    
    return out_data

out_data = extract_features(vision_encoder, dataloader)

X, y = out_data['features'], np.asarray(out_data['labels'], dtype=np.int32)
    
with open(f'IDRiD_Binary_Features.npy', 'wb') as handle:
    # pickle.dump(out_data, handle, protocol=4)
    np.save(handle, out_data['features'])
    np.save(handle, out_data['labels'])

del out_data

In [None]:
print(f'{np.unique(y, return_counts=True)[1]/np.sum(np.unique(y, return_counts=True)[1])}')