# Train IC

In [4]:

%load_ext autoreload
%autoreload 2

import sys
from os.path import join
import joblib
import torch
import pandas as pd

fp_code_folder = "../../"
fp_project_folder = join(fp_code_folder, "../")
sys.path.append(fp_code_folder)

fp_checkpoint_folder = join(fp_project_folder, "checkpoints")

from src.models.resnet_bilateral.inference import get_pred_perf_df
from src.configs.image_config import fp_data_dfs_file, col_info
from src.preprocessing.pytorch_preprocessing import get_pytorch_split_dict, get_ic_image_dl
from src.filepath_manager import FilePath

from src.pytorch_training.misc import set_seed_pytorch
from src.models.resnet_bilateral.ic_trainer import train_resnet_ic
from src.models.resnet_bilateral.ic_inference import evaluate_resnet_ic

fp = FilePath(model_name="resnet_bilateral_ic", fp_checkpoint_folder=fp_checkpoint_folder)
fp_ae_model_file = '../../../checkpoints/models/resnet_bilateral_ae/model.pt'

batch_size = 64
eval_batch_size = 64*4

seed_no = 2024

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load Data

In [2]:
data_dfs = joblib.load(fp_data_dfs_file)

In [3]:
set_seed_pytorch(seed_no)
data_dls = get_pytorch_split_dict(
    data_dict=data_dfs, col_info=col_info, 
    batch_size=batch_size, eval_batch_size=eval_batch_size, shuffle_train=True,
    dl_func=get_ic_image_dl
)

## Train IC

In [6]:
resnet_bilateral = torch.load(fp_ae_model_file)

In [None]:
ae_history = train_resnet_ic(
    model=resnet_bilateral, **data_dls, 
    fp_model=fp.fp_model_file, fp_history=fp.fp_history_file, # Where to store trained model and history of training
    max_epochs=500, 
    lr=0.001, weight_decay=0.0001,  # Training parameters, # Smaller LR?
    patience=10, metric_to_monitor = "bce", maximise=False, # For early stopping
    verbose=True,
    seed = seed_no
) 

## Evaluate

In [4]:
resnet_bilateral = torch.load(fp.fp_model_file)

In [5]:
pred_df, perf_df = evaluate_resnet_ic(
    model=resnet_bilateral, dl=data_dls["test_dl"], col_info=col_info, verbose=True
)
pred_df.to_csv(fp.fp_prediction_file)
perf_df.to_csv(fp.fp_performance_file)

  0%|          | 0/28 [00:00<?, ?it/s]

In [5]:
pred_df = pd.read_csv(fp.fp_prediction_file, index_col=0)
get_pred_perf_df(pred_df, col_info)

Unnamed: 0,Acc,Class 0 Acc,Class 1 Acc,Class 2 Acc,Class 3 Acc,Class 4 Acc,Average Class Acc,Acc W/O Class 0
0,0.731068,0.95146,0.037037,0.308929,0.383178,0.012821,0.338685,0.216891


In [8]:
perf_df = pd.read_csv(fp.fp_performance_file, index_col=0)
perf_df

Unnamed: 0,Acc,Class 0 Acc,Class 1 Acc,Class 2 Acc,Class 3 Acc,Class 4 Acc,Average Class Acc
0,0.731068,0.95146,0.037037,0.308929,0.383178,0.012821,0.338685
