# Baseline Model for Diabetic Retinopathy Image Classification

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import sys
import joblib
from os.path import join
import pandas as pd

fp_code_folder = "../../"
fp_project_folder = join(fp_code_folder, "../")
sys.path.append(fp_code_folder)

fp_checkpoint_folder = join(fp_project_folder, "checkpoints")

from src.configs.image_config import fp_data_dfs_file, col_info
from src.preprocessing.pytorch_preprocessing import get_pytorch_split_dict, get_baseline_image_dl
from src.filepath_manager import FilePath

from src.models.resnet.resnet_model import ResNet18
from src.models.resnet.resnet_trainer import train_resnet
from src.models.resnet.resnet_evaluator import evaluate_resnet
from src.pytorch_training.misc import set_seed_pytorch
from src.models.resnet_bilateral.inference import get_pred_perf_df

fp = FilePath(model_name="resnet_baseline", fp_checkpoint_folder=fp_checkpoint_folder)

seed_no = 2024

batch_size = 64
eval_batch_size = 64*4

2025-02-19 11:49:39.951053: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-19 11:49:39.962213: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-19 11:49:39.965602: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-19 11:49:39.975005: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Load Data

In [10]:
data_dfs = joblib.load(fp_data_dfs_file)

In [11]:
set_seed_pytorch(seed_no)
data_dls = get_pytorch_split_dict(
    data_dict=data_dfs, col_info=col_info, 
    batch_size=batch_size, eval_batch_size=eval_batch_size, shuffle_train=True,
    dl_func=get_baseline_image_dl
)

## Train Model

In [None]:
resnet_model = ResNet18(num_classes=col_info["num_classes"])
resnet_history = train_resnet(
    model=resnet_model, **data_dls, 
    fp_model=fp.fp_model_file, fp_history=fp.fp_history_file, # Where to store trained model and history of training
    max_epochs=500, 
    lr=0.001, weight_decay=0.0001,  # Training parameters
    patience=10, metric_to_monitor = "acc", maximise=True, # For early stopping
    verbose=True,
    seed = seed_no
)

## Evaluate Model

In [8]:
resnet_model = torch.load(fp.fp_model_file)

In [None]:
pred_df, perf_df = evaluate_resnet(
    model=resnet_model, col_info=col_info, dl=data_dls["test_dl"], verbose=True
)
pred_df.to_csv(fp.fp_prediction_file)
perf_df.to_csv(fp.fp_performance_file)

  0%|          | 0/28 [00:00<?, ?it/s]

In [2]:
pred_df = pd.read_csv(fp.fp_prediction_file)
get_pred_perf_df(pred_df, col_info)

Unnamed: 0,Acc,Class 0 Acc,Class 1 Acc,Class 2 Acc,Class 3 Acc,Class 4 Acc,Average Class Acc,Acc W/O Class 0
0,0.761877,0.979432,0.0,0.376786,0.280374,0.307692,0.388857,0.254319


In [14]:
perf_df = pd.read_csv(fp.fp_performance_file, index_col=0)
perf_df

Unnamed: 0,Acc,Class 0 Acc,Class 1 Acc,Class 2 Acc,Class 3 Acc,Class 4 Acc
0,0.761877,0.979432,0.0,0.376786,0.280374,0.307692
