## Initial Setup:
This step initializes the necessary configuration.

In [None]:
!pip install torch torchvision torchaudio 

In [None]:
!cd Repos && git clone --branch eye-ai-compatible https://github.com/huynguyentran/RETFound_MAE.git 
!cd Repos/RETFound_MAE && pip install -r requirements.txt
!cd Repos/RETFound_MAE && git pull

In [None]:
 # Set this to be where your github repos are located.
%reload_ext autoreload
%autoreload 2

# Update the load path so python can find modules for the model
import sys
from pathlib import Path
sys.path.insert(0, "Repos/eye-ai-ml")
sys.path.insert(0, "Repos/RETFound_MAE")

In [None]:
# Prerequisites
import json
import os
from eye_ai.eye_ai import EyeAI

import pandas as pd
from pathlib import Path, PurePath
import logging
import torch

from deriva_ml import DatasetBag, Workflow, ExecutionConfiguration, DatasetVersion
from deriva_ml import MLVocab as vc
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)

In [None]:
from deriva.core.utils.globus_auth_utils import GlobusNativeLogin
catalog_id = "eye-ai" #@param
host = 'www.eye-ai.org'


gnl = GlobusNativeLogin(host=host)
if gnl.is_logged_in([host]):
    print("You are already logged in.")
else:
    gnl.login([host], no_local_server=True, no_browser=True, refresh_tokens=True, update_bdbag_keychain=True)
    print("Login Successful")

In [None]:
cache_dir = '/data'
working_dir = '/data'
EA = EyeAI(hostname = host, catalog_id = catalog_id, cache_dir= cache_dir, working_dir=working_dir)

## Downloading Dataset:
Downloading the datasets. We will work with three datasets: 2-A5T0 (train), 2-A5T2 (val), and 2-A5T4 (test). The dataset order when extracting is always set in the list provided when downloading. Additionally, this code will always download the latest version of the datasets.

In [None]:
datasets = [
        '2-A5T0',
        '2-A5T2',
        '2-A5T4',
    ]

to_be_download = []
for dataset in datasets:
    ds_dict = {
        'rid': dataset,
        'materialize':True,
        'version':EA.dataset_version(dataset_rid=dataset),
    }
    to_be_download.append(ds_dict)

workflow_instance = EA.add_workflow(Workflow(
    name="RETFound Model train",
    url="https://github.com/informatics-isi-edu/eye-ai-exec/blob/main/notebooks/RETFound_Huy/RETFOUND_DATA_TEMPLATE.ipynb",
    workflow_type="RETFound Model Train",
))

download_assets = True

config = ExecutionConfiguration(
    # Comment out the following line if you don't need the assets.
    datasets=to_be_download  if download_assets else [],
    assets = ['4-S3KR',  
             #4-S3KP,
             ],  #RETFound pre-trained weight.You should always has at least one when training.
    workflow=workflow_instance,
    description="Instance of training RETFound model")

# Initialize execution
execution = EA.create_execution(config)

In [None]:
print(execution)

## Preprocessing:
Crop the images and move them to the designated folder for training, validation, and testing.          

In [None]:
ds_bag_train = execution.datasets[0]
ds_bag_val = execution.datasets[1]
ds_bag_test = execution.datasets[2]

retfound_pretrained_weight = execution.asset_paths[0]

In [None]:
output_dir = execution._working_dir

In [None]:
ds_bag_train_dict = {"ds_bag": ds_bag_train}
ds_bag_val_dict = {"ds_bag": ds_bag_val}
ds_bag_test_dict = {"ds_bag": ds_bag_test}

In [None]:
"""
If the following function returns an error, it means that it has not been updated in Eye-AI.
Instead, your dataset directory should follow the format below for the pipeline to work.

├── data folder
    ├──train
        ├──class_a
        ├──class_b
        ├──class_c
    ├──val
        ├──class_a
        ├──class_b
        ├──class_c
    ├──test
        ├──class_a
        ├──class_b
        ├──class_c
"""
dataset_dir = EA.create_retfound_image_directory(ds_bag_train_dict =  ds_bag_train_dict, 
                                ds_bag_val_dict = ds_bag_val_dict, 
                                ds_bag_test_dict =  ds_bag_test_dict, 
                                output_dir =output_dir, 
                                crop_to_eye = True)[0]

In [None]:
asset_path_models = execution.execution_asset_path("Diagnosis_Model")
asset_path_output = execution.execution_asset_path("Model_Prediction")
asset_path_logs = execution.execution_asset_path("Training_Log")

In [None]:
from datetime import datetime
current_date = datetime.now().strftime("%b_%d_%Y") 
print(current_date)

In [None]:
RETFound_output = "./RETFound_output/task"
os.makedirs(RETFound_output, exist_ok= True)

## Train and Evaluate:

In [None]:
from main_finetune import main, get_args_parser 
with execution.execute() as exec:
    args_list = [
        "--model", "RETFound_mae",
        "--savemodel",
        "--global_pool",
        "--batch_size", "16",
        "--world_size", "1",
        "--epochs", "100",
        "--blr", "5e-3", "--layer_decay", "0.65",
        "--weight_decay", "0.05", "--drop_path", "0.2",
        "--nb_classes", "2",
        "--data_path", str(dataset_dir),
        "--input_size", "224",
        "--task", str(RETFound_output),
        "--output_dir", str(asset_path_output),
        "--finetune", str(retfound_pretrained_weight),
    ]

    args = get_args_parser().parse_args(args_list)
    criterion = torch.nn.CrossEntropyLoss()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    main(args, criterion)

## Evaluate Only:
If you already have a RETFound model, provide its path here to evaluate it directly.

In [None]:
with execution.execute() as exec:
    path_to_model = "path/to/model.pth"
    args_list = [
        "--model", "RETFound_mae",
        "--eval",
        "--savemodel",
        "--global_pool",
        "--batch_size", "16",
        "--world_size", "1",
        "--epochs", "100",
        "--blr", "5e-3", "--layer_decay", "0.65",
        "--weight_decay", "0.05", "--drop_path", "0.2",
        "--nb_classes", "2",
        "--data_path", str(dataset_dir),
        "--input_size", "224",
        "--task", str(dataset_dir),
        "--output_dir", str(asset_path_output),
        "--resume", path_to_model,
    ]

    args = get_args_parser().parse_args(args_list)
    criterion = torch.nn.CrossEntropyLoss()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    main(args, criterion)

## Upload results:

In [None]:
execution.upload_execution_outputs(clean_folder=True)