# RADDINO Patch Feature Extraction Pipeline

This notebook implements a medical image feature extraction pipeline designed to process chest X-rays from the SIIM-ACR Pneumothorax dataset. The core purpose is to extract deep learning features from these images using a pre-trained RADDINO model.

## Key Components
- Uses PyTorch and PyTorch Lightning for the deep learning framework
- Employs MONAI (Medical Open Network for AI) for medical imaging-specific data handling
- Implements efficient data processing with parallel execution and persistent caching
- Configures a feature extraction model that outputs embeddings to a specified directory

## Workflow
The workflow follows a standard machine learning pipeline:
1. Data loading from CSV files containing image paths
2. Data transformation using specialized medical imaging preprocessing
3. Dataset and dataloader configuration with persistent caching for performance
4. Model initialization with appropriate parameters
5. Validation through visual spot-checking of processed images
6. Feature extraction execution using PyTorch Lightning's prediction mode

The extracted features are saved to disk and can be used for downstream tasks such as classification, clustering, or further analysis. The notebook is optimized for performance with GPU acceleration.

In [None]:
import os
import itertools
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from typing import List
from typing_extensions import override

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import lightning as pl
from lightning.pytorch.callbacks import RichProgressBar

import subprocess

from tqdm import tqdm

import monai as mn
from transforms.Transform4RADDINO import Transform4RADDINO
from models.RADDINO import Extractor_patch

SEED = 5566
pl.seed_everything(SEED)
torch.set_float32_matmul_precision('medium')

In [None]:
def get_data_dict_part(df_part):
    "Important! Modify this function"

    BASE_PATH = '/MODIFY_THIS_PATH/' # modify
    
    data_dict = list()
    for i in tqdm(range(len(df_part)), desc="Processing part"):
        row = df_part.iloc[i]

        data_dict.append({
            'img':BASE_PATH +'/'+ row["ImagePath"],
            "paths": BASE_PATH +'/'+ row["ImagePath"]
        })
    
    return data_dict

def get_data_dict(df, num_cores=32):
    parts = np.array_split(df, num_cores)
    func = partial(get_data_dict_part)
    
    with ProcessPoolExecutor(num_cores) as executor:
        data_dicts = executor.map(func, parts)
    
    return list(itertools.chain(*data_dicts))

### Set parameters

In [None]:
# IMPORTANT BEFORE PROCEEDING --> DO YOU WANT TO DELETE CACHE???
DELETE_CACHE = True

INPUT = 'input_example.csv'

TEST_NAME = '' 
MONAI_CACHE_DIR = f'./cache/{TEST_NAME}' 
IMG_SIZE = 518
BATCH_SIZE = 16
PRECISION = 'bf16-mixed' 
OUTPUT_FOLDER = './features_RADDINO/'

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = '5' ## set the GPU#

In [None]:
if DELETE_CACHE:
    if os.path.exists(MONAI_CACHE_DIR):
        subprocess.call(['rm', '-rf', f'{MONAI_CACHE_DIR}'])
        print(f"MONAI's {MONAI_CACHE_DIR} cache directory removed successfully!")
    else:
        print(f"MONAI's {MONAI_CACHE_DIR} cache directory does not exist!")

### Read input file

In [None]:
df = pd.read_csv(INPUT).iloc[0:3]

In [None]:
df

In [None]:
# get dict for datasets

eval_dict = get_data_dict(df)

### Model setup

In [None]:
# define transforms

eval_transforms = Transform4RADDINO(IMG_SIZE).predict

# define datasets

eval_ds = mn.data.PersistentDataset(data=eval_dict, transform=eval_transforms, cache_dir=f"{MONAI_CACHE_DIR}")

# define data loader

eval_dl = DataLoader(eval_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=16, drop_last=False, persistent_workers=True)

# instantiate the model

os.makedirs(OUTPUT_FOLDER, exist_ok=True)
model = Extractor_patch(BATCH_SIZE=BATCH_SIZE, OUTPUT_DIR=OUTPUT_FOLDER)

### Spot check the pre-processed images

In [None]:
# SPOT CHECK
test_ds=mn.data.Dataset(data=eval_dict, transform=eval_transforms)

for _ in range(3):
    random_i = np.random.randint(0, len(test_ds))
    for data_ in test_ds[random_i:random_i+1]:
        
        print(f"{data_['paths']}")
        plt.imshow(np.array(data_['img'])[0,:,:], cmap='gray')
        plt.show()

In [None]:
data_['img'].shape

### Define Callbacks

In [None]:
progress_bar = RichProgressBar()

### Evaluation

In [None]:
# instantiate trainer

trainer = pl.Trainer(callbacks=[progress_bar], inference_mode=True)

In [None]:
# evaluate the model

_ = trainer.predict(model, dataloaders=eval_dl)