In [None]:
repo_dir = "Repos"   # Set this to be where your github repos are located.
%load_ext autoreload
%autoreload 2

# Update the load path so python can find modules for the model
import sys
from pathlib import Path
sys.path.insert(0, str(Path.home() / repo_dir / "eye-ai-ml"))

In [None]:
# Prerequisites
import json
import os
from eye_ai.eye_ai import EyeAI

import pandas as pd
from pathlib import Path, PurePath
import logging

from deriva_ml import DatasetBag, Workflow, ExecutionConfiguration, DatasetVersion
from deriva_ml import MLVocab as vc
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)

In [None]:
from deriva.core.utils.globus_auth_utils import GlobusNativeLogin
catalog_id = "eye-ai" #@param
host = 'www.eye-ai.org'


gnl = GlobusNativeLogin(host=host)
if gnl.is_logged_in([host]):
    print("You are already logged in.")
else:
    gnl.login([host], no_local_server=True, no_browser=True, refresh_tokens=True, update_bdbag_keychain=True)
    print("Login Successful")

In [None]:
cache_dir = '/data'
working_dir = '/data'
EA = EyeAI(hostname = host, catalog_id = catalog_id, cache_dir= cache_dir, working_dir=working_dir)

In [None]:
# RID of source dataset, if any.
# RID of source dataset, if any.
datasets = [
                  '4-N9XE', 
                  '4-NAPT', 
                  '4-NBG6', 
                  '4-NC9J', 
                  '4-ND2Y', 
                  '2-39FY', 
                  '2-277M']

to_be_download = []
for dataset in datasets:
    ds_dict = {
        'rid': dataset,
        'materialize':True,
        'version':EA.dataset_version(dataset_rid=dataset),
    }
    to_be_download.append(ds_dict)

EA.add_term(vc.workflow_type, "RETFound Model Train", description="A workflow to train RETFound model")

# Workflow instance
workflow_instance = EA.add_workflow(Workflow(
    name="RETFound Model train - 200 images",
    url="https://github.com/informatics-isi-edu/eye-ai-exec/blob/main/notebooks/RETFound_Huy/RETFOUND_DATA_200.ipynb",
    workflow_type="RETFound Model Train",
))
# Configuration instance.

# Set to False if you only need the metadata from the bag, and not the assets.
download_assets = True

config = ExecutionConfiguration(
    # Comment out the following line if you don't need the assets.
    datasets=to_be_download if download_assets else [],
    workflow=workflow_instance,
    description="Instance of training RETFound model - 200 images")

# Initialize execution
execution = EA.create_execution(config)

In [None]:
print(execution)

In [None]:
ds_bag_0 = execution.datasets[0]
ds_bag_1 = execution.datasets[1]
ds_bag_2 = execution.datasets[2]
ds_bag_3 = execution.datasets[3]
ds_bag_4 = execution.datasets[4]

ds_bag_val = execution.datasets[5]
ds_bag_test = execution.datasets[6]

In [None]:
ds_bag_list = [ds_bag_0, ds_bag_1, ds_bag_2, ds_bag_3, ds_bag_4,]

In [None]:
val_excluded_df = pd.read_csv("valid_no_optic_disc_image_ids.csv")
val_excluded = val_excluded_df["ID"].tolist()

train_excluded_df = pd.read_csv("train_no_optic_disc_image_ids.csv")
train_excluded = train_excluded_df["ID"].tolist()

test_included_df = pd.read_csv("Graded_Test_Dataset_2-277M_With_Demographics_CDR_Diagnosis_Image_Quality_Model_Diagnosis_Predicitons_with_Jiun_Do_June8_2024_with_Catalog_model_predictions.csv")
test_included = test_included_df["Image_cd"].tolist()

In [None]:
output_dir = execution._working_dir
output_dir

In [None]:
validation_image_path_cropped, validation_csv_cropped = EA.create_cropped_images(ds_bag_val,
                                                                                 output_dir = output_dir ,
                                                                                 crop_to_eye=True,
                                                                                exclude_list= val_excluded)

validation_image_path_uncropped, validation_csv_uncropped = EA.create_cropped_images(ds_bag_val,
                                                                                 output_dir = output_dir,
                                                                                 crop_to_eye=False,
                                                                                    exclude_list= val_excluded)

test_image_path_cropped, test_csv_cropped = EA.create_cropped_images(ds_bag_test,
                                                                     output_dir = output_dir,
                                                                     crop_to_eye=True,
                                                                     include_only_list= test_included)

test_image_path_uncropped, test_csv_uncropped = EA.create_cropped_images(ds_bag_test,
                                                                         output_dir = output_dir ,
                                                                         crop_to_eye=False,
                                                                         include_only_list = test_included)

In [None]:
best_hyper_parameters_json_path = "best_hyperparameters_exluding_no_optic_disc_images_june_24_2024.json"

In [None]:
best_hyper_parameters_json_path

In [None]:
# crete asset path
asset_path_models = execution.execution_asset_path("Diagnosis_Model")
asset_path_output = execution.execution_asset_path("Model_Prediction")
asset_path_logs = execution.execution_asset_path("Training_Log")

In [None]:
asset_path_models

In [None]:
asset_path_output

In [None]:
asset_path_logs

In [None]:
output_dir

In [None]:
import shutil
def create_retfound_ds(output, train_dir, val_dir, test_dir, ds_bag_name, crop):
    ds_bag_out_path = os.path.join(output, f"{ds_bag_name}_RETFound") if not crop else  os.path.join(output, f"{ds_bag_name}_RETFound_cropped") 
    os.makedirs(ds_bag_out_path, exist_ok= True)
    
    for subdir in ['train', 'val', 'test']:
        subdir_path = os.path.join(ds_bag_out_path, subdir)
        os.makedirs(subdir_path, exist_ok= True)
        

    dirs = [(train_dir, 'train'), (val_dir, 'val'), (test_dir, 'test')]
    
    for source_dir, subdir in dirs:
        for class_dir in os.listdir(source_dir):
            class_path = os.path.join(source_dir, class_dir)
            target_class_dir = os.path.join(ds_bag_out_path, subdir, class_dir)
            os.makedirs(target_class_dir, exist_ok= True)
            for file_name in os.listdir(class_path):
                source_file = os.path.join(class_path, file_name)
                target_file = os.path.join(target_class_dir, file_name)
                shutil.copy(source_file, target_file)
    return ds_bag_out_path

In [None]:
import subprocess
repo_path = os.path.expanduser("~/Repos/RETFound_MAE")
with execution.execute() as exec:
    for index, ds_bag in enumerate(ds_bag_list):
        if index < 4:
            continue
        image_path_ds_bag_path_cropped, csv_ds_bag_cropped = EA.create_cropped_images(
                                                   ds_bag, 
                                                   output_dir, 
                                                   crop_to_eye=True,
                                                    exclude_list= train_excluded,
                                                   )
        image_path_ds_bag_path_uncropped, csv_ds_bag_uncropped = EA.create_cropped_images(
                                                   ds_bag, 
                                                   output_dir, 
                                                   crop_to_eye=False,
                                                    exclude_list= train_excluded,
                                                 )
       
        print("Dataset: ", ds_bag.dataset_rid)
        # retfound_ds_bag_path_uncropped= create_retfound_ds(output= output_dir, 
        #                                                    train_dir = image_path_ds_bag_path_uncropped, 
        #                                                    val_dir = validation_image_path_uncropped,  
        #                                                    test_dir = test_image_path_uncropped, 
        #                                                    ds_bag_name =ds_bag.dataset_rid, crop = False)
        retfound_ds_bag_path_cropped= create_retfound_ds(output= output_dir, 
                                                         train_dir = image_path_ds_bag_path_cropped, 
                                                         val_dir = validation_image_path_cropped,  
                                                         test_dir = test_image_path_cropped, 
                                                         ds_bag_name =ds_bag.dataset_rid, crop = True)

        

        retfound_out_uncropped = output_dir / f"{ds_bag.dataset_rid}/RETFound_task/Uncropped_"
        os.makedirs(retfound_out_uncropped, exist_ok=True)
        
        retfound_out_cropped = output_dir / f"{ds_bag.dataset_rid}/RETFound_task/Cropped_"
        os.makedirs(retfound_out_cropped, exist_ok=True)
        retfound_output_dirs = [
           retfound_out_uncropped,
            retfound_out_cropped
        ]
        
        data_paths = [
            # retfound_ds_bag_path_uncropped,
            retfound_ds_bag_path_cropped
        ]
        
        for data_path, retfound_output_dir in zip(data_paths, retfound_output_dirs):
            os.makedirs(retfound_output_dir, exist_ok=True)
            
            command = [
                "torchrun",
                "--nproc_per_node=1", "--master_port=48798", "main_finetune.py",
                "--batch_size", "16",
                "--world_size", "1",
                "--model", "vit_large_patch16",
                "--epochs", "50",
                "--blr", "5e-3", "--layer_decay", "0.65",
                "--weight_decay", "0.05", "--drop_path", "0.2",
                "--nb_classes", "5",
                "--data_path", data_path,
                "--task", retfound_output_dir,
                "--finetune", "RETFound_cfp_weights.pth",
                "--input_size", "224"
            ]
                
            # Run the command inside the RETFound_MAE repository
            subprocess.run(command, check=True, cwd=repo_path)
            
        for data_path in data_paths:
            if os.path.exists(data_path):
                shutil.rmtree(data_path)
                print(f"Deleted folder: {data_path}")
            else:
                print(f"Folder does not exist: {data_path}")

In [None]:
import os
import shutil

for ds_bag in ds_bag_list:
    source_dir = output_dir / ds_bag.dataset_rid /  "RETFound_task"
    if not source_dir.exists():
        print(f"Skipping: {source_dir} does not exist.")
        continue
        
    for item in os.listdir(source_dir):
        item_path = Path(source_dir)  / item
        suffix = "uncropped" if "Uncropped" in item else "cropped"

        if item.endswith(".pth"):
            new_file_name = f"{ds_bag.dataset_rid}_{suffix}.pth"
            destination_path = asset_path_models / new_file_name  
            shutil.move(item_path, destination_path)
        elif  "test" in item and (item.endswith(".csv") or item.endswith(".jpg")):
            new_file_name = f"{ds_bag.dataset_rid}_{suffix}_metrics_test.csv" if item.endswith(".csv") else f"{ds_bag.dataset_rid}_{suffix}_conf_matrix.jpg"
            destination_path = asset_path_output / new_file_name 
            shutil.move(item_path, destination_path)
        elif "val" in item:
              new_file_name = f"{ds_bag.dataset_rid}_{suffix}_metrics_val.csv"
              destination_path = asset_path_logs / new_file_name  
              shutil.move(item_path, destination_path)
        elif item_path.is_dir():  
              for sub_item in item_path.iterdir():
                   if "roc_" in sub_item.name:
                        new_file_name = f"{ds_bag.dataset_rid}_{suffix}_{sub_item.name}"
                        destination_path = asset_path_output / new_file_name
                        shutil.move(sub_item, destination_path)
        
     
        

In [None]:
execution.upload_execution_outputs(clean_folder=True)