In [1]:
# !pip install timm==1.0.9
# !pip install albumentations==1.4.14
# !pip install torcheval==0.0.7
# !pip install pandas==2.2.2
# !pip install numpy==1.26.4

In [1]:
import sys, os, time, copy, gc
import torch
from torch import nn
from torch.utils.data import DataLoader
from pathlib import Path

import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import multiprocessing as mp

from torcheval.metrics.functional import binary_auroc, multiclass_auroc

from sklearn.model_selection import StratifiedGroupKFold

import hashlib
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split

from PIL import Image
import torch.optim as optim

from collections import defaultdict




sys.path.append('../src')
from utils import set_seed, visualize_augmentations_positive, print_trainable_parameters
from models import setup_model
from training import fetch_scheduler, train_one_epoch, valid_one_epoch, run_training, get_nth_test_step
from models import ISICModel, ISICModelEdgnet, setup_model
from datasets import ISICDatasetSamplerW, ISICDatasetSampler, ISICDatasetSimple, ISICDatasetSamplerMulticlass, prepare_loaders
from augmentations import get_augmentations

  check_for_updates()


In [2]:
# Set up device and random seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

Using device: cuda
GPU: NVIDIA RTX A6000
Number of GPUs: 1


In [3]:
original_data_path = "../data/original"
original_root = Path('../data/original')

data_artifacts = "../data/artifacts"
os.makedirs(data_artifacts, exist_ok=True)

In [4]:
# Set the HDF5 file path
TRAIN_HDF5_FILE_PATH = original_root / 'train-image.hdf5'

train_path = original_root / 'train-metadata.csv'
df_train = pd.read_csv(train_path)
df_train["path"] = '../data/original/train-image/image/' + df_train['isic_id'] + ".jpg"
original_positive_cases = df_train['target'].sum()
original_total_cases = len(df_train)
original_positive_ratio = original_positive_cases / original_total_cases

print(f"Number of positive cases: {original_positive_cases}")
print(f"Number of negative cases: {original_total_cases - original_positive_cases}")
print(f"Ratio of negative to positive cases: {(original_total_cases - original_positive_cases) / original_positive_cases:.2f}:1")

  df_train = pd.read_csv(train_path)


Number of positive cases: 393
Number of negative cases: 400666
Ratio of negative to positive cases: 1019.51:1


In [5]:
MODEL_NAME = "EVA" # "EDGENEXT"


CONFIG = {
    "seed": 42 if MODEL_NAME == 'EVA' else 1997,
    "epochs": 500,
    "img_size": 336 if MODEL_NAME == 'EVA' else 256,
    "train_batch_size": 32,
    "valid_batch_size": 64,
    "learning_rate": 1e-4,
    "scheduler": 'CosineAnnealingLR',
    "min_lr": 1e-6,
    "T_max": 2000,
    "weight_decay": 1e-6,
    "fold" : 0,
    "n_fold": 5,
    "n_accumulate": 1,
    "group_col": 'patient_id',
    "device": device
}

model_name = "eva02_small_patch14_336.mim_in22k_ft_in1k" if MODEL_NAME == 'EVA' else "edgenext_base.in21k_ft_in1k"
checkpoint_path = None


if MODEL_NAME == 'EVA':
    ISICModelPrep = ISICModel
else:
    ISICModelPrep = ISICModelEdgnet

In [6]:
data_transforms = get_augmentations(CONFIG)

  self.__pydantic_validator__.validate_python(data, self_instance=self)


In [7]:
def criterion(outputs, targets):
    return nn.BCELoss()(outputs, targets)

In [8]:
# synthetic_custom_data = f"../data/artifacts/syntetic_custom_base_{CONFIG['seed']}"
# os.makedirs(synthetic_custom_data, exist_ok=True)

# tsp = StratifiedGroupKFold(2, shuffle=True, random_state=CONFIG['seed'])
# metrics_ev_df = []
# test_forecast = []
# val_forecast = []
# for fold_n, (train_index, val_index) in enumerate(tsp.split(df_train, y=df_train.target, groups=df_train[CONFIG["group_col"]])):
#     fold_df_train = df_train.iloc[train_index].reset_index(drop=True)
#     fold_df_valid = df_train.iloc[val_index].reset_index(drop=True)
#     synthetic_custom_data_pr = os.path.join(synthetic_custom_data, str(fold_n))
#     os.makedirs(synthetic_custom_data_pr, exist_ok=True)

#     for fn in fold_df_train[fold_df_train.target==1].isic_id.values:
#         if fn not in images_to_include:
#             continue
#         img = Image.open(os.path.join('../data/original/train-image/image', fn + ".jpg"))
#         img.save(os.path.join(synthetic_custom_data_pr, fn + ".png"))
    

In [9]:
folder_name = f"../models/oof_{{MODEL_NAME.lower()}}_base"
os.makedirs(folder_name, exist_ok=True)

In [10]:
def get_metrics(drop_path_rate, drop_rate, models_folder, model_maker):
    tsp = StratifiedGroupKFold(5, shuffle=True, random_state=CONFIG['seed'])
    results_list = []
    fold_df_valid_list = []
    for fold_n, (train_index, val_index) in enumerate(tsp.split(df_train, y=df_train.target, groups=df_train[CONFIG["group_col"]])):
        fold_df_train = df_train.iloc[train_index].reset_index(drop=True)
        fold_df_valid = df_train.iloc[val_index].reset_index(drop=True)
        set_seed(CONFIG['seed'])
        model = setup_model(model_name, drop_path_rate=drop_path_rate, drop_rate=drop_rate, model_maker=model_maker)
        print_trainable_parameters(model)

        train_loader, valid_loader = prepare_loaders(fold_df_train, fold_df_valid, CONFIG, data_transforms)
    
        optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], 
                           weight_decay=CONFIG['weight_decay'])
        scheduler = fetch_scheduler(optimizer, CONFIG)
    
        model, history = run_training(
            train_loader, valid_loader,
            model, optimizer, scheduler,
            device=CONFIG['device'],
            num_epochs=CONFIG['epochs'],
            CONFIG=CONFIG, 
            tolerance_max=20,
            test_every_nth_step=lambda x: 5,
            seed=CONFIG['seed'])
        torch.save(model.state_dict(), os.path.join(models_folder, f"model__{fold_n}"))
        results_list.append(np.max(history['Valid Kaggle metric']))

        val_epoch_loss, val_epoch_auroc, val_epoch_custom_metric, tmp_predictions_all, tmp_targets_all = valid_one_epoch(
            model, 
            valid_loader, 
            device=CONFIG['device'], 
            epoch=1, 
            optimizer=optimizer, 
            criterion=criterion, 
            use_custom_score=True,
            metric_function=binary_auroc, 
            num_classes=1,
            return_preds=True)

        fold_df_valid['tmp_targets_all'] = tmp_targets_all
        fold_df_valid['tmp_predictions_all'] = tmp_predictions_all
        fold_df_valid['fold_n'] = fold_n
        fold_df_valid_list.append(fold_df_valid)
    fold_df_valid_list = pd.concat(fold_df_valid_list).reset_index(drop=True)
    return results_list, fold_df_valid_list

In [None]:
base_metrics, oof_forecasts = get_metrics(drop_path_rate=0, drop_rate=0, models_folder=folder_name, model_maker=ISICModelPrep)
oof_forecasts.to_parquet(f'../data/artifacts/oof_forecasts_{MODEL_NAME.lower()}_base.parquet')

trainable params: 21744385 || all params: 21744385 || trainable%: 100.00


100%|██████████| 19/19 [00:08<00:00,  2.14it/s, Epoch=1, LR=0.0001, Train_Auroc=0.582, Train_Loss=0.754]





100%|██████████| 19/19 [00:06<00:00,  2.72it/s, Epoch=2, LR=9.99e-5, Train_Auroc=0.751, Train_Loss=0.676]





100%|██████████| 19/19 [00:07<00:00,  2.71it/s, Epoch=3, LR=9.98e-5, Train_Auroc=0.825, Train_Loss=0.546]





100%|██████████| 19/19 [00:07<00:00,  2.68it/s, Epoch=4, LR=9.96e-5, Train_Auroc=0.853, Train_Loss=0.523]





100%|██████████| 19/19 [00:07<00:00,  2.71it/s, Epoch=5, LR=9.94e-5, Train_Auroc=0.86, Train_Loss=0.551] 
100%|██████████| 1112/1112 [04:03<00:00,  4.57it/s, Epoch=5, LR=9.94e-5, Valid_Auroc=0.526, Valid_Loss=0.583]
  _warn_get_lr_called_within_step(self)


Validation AUROC Improved (-inf ---> 0.10594036364080138)



100%|██████████| 19/19 [00:07<00:00,  2.57it/s, Epoch=6, LR=9.92e-5, Train_Auroc=0.877, Train_Loss=0.466]





100%|██████████| 19/19 [00:07<00:00,  2.66it/s, Epoch=7, LR=9.89e-5, Train_Auroc=0.862, Train_Loss=0.491]





100%|██████████| 19/19 [00:07<00:00,  2.66it/s, Epoch=8, LR=9.86e-5, Train_Auroc=0.88, Train_Loss=0.461] 





100%|██████████| 19/19 [00:07<00:00,  2.68it/s, Epoch=9, LR=9.82e-5, Train_Auroc=0.884, Train_Loss=0.444]





100%|██████████| 19/19 [00:07<00:00,  2.62it/s, Epoch=10, LR=9.78e-5, Train_Auroc=0.892, Train_Loss=0.448]
100%|██████████| 1214/1214 [04:26<00:00,  4.56it/s, Epoch=45, LR=5.82e-5, Valid_Auroc=0.521, Valid_Loss=0.581]
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:07<00:00,  2.57it/s, Epoch=46, LR=5.67e-5, Train_Auroc=0.924, Train_Loss=0.408]





100%|██████████| 20/20 [00:07<00:00,  2.53it/s, Epoch=47, LR=5.52e-5, Train_Auroc=0.931, Train_Loss=0.357]





100%|██████████| 20/20 [00:07<00:00,  2.65it/s, Epoch=48, LR=5.36e-5, Train_Auroc=0.935, Train_Loss=0.34] 





100%|██████████| 20/20 [00:07<00:00,  2.56it/s, Epoch=49, LR=5.21e-5, Train_Auroc=0.936, Train_Loss=0.343]





100%|██████████| 20/20 [00:07<00:00,  2.60it/s, Epoch=50, LR=5.05e-5, Train_Auroc=0.94, Train_Loss=0.324] 
100%|██████████| 1214/1214 [04:26<00:00,  4.56it/s, Epoch=50, LR=5.05e-5, Valid_Auroc=0.521, Valid_Loss=0.207]
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:07<00:00,  2.63it/s, Epoch=51, LR=4.89e-5, Train_Auroc=0.918, Train_Loss=0.381]





100%|██████████| 20/20 [00:07<00:00,  2.59it/s, Epoch=52, LR=4.74e-5, Train_Auroc=0.942, Train_Loss=0.337]





100%|██████████| 20/20 [00:07<00:00,  2.55it/s, Epoch=53, LR=4.58e-5, Train_Auroc=0.931, Train_Loss=0.348]





100%|██████████| 20/20 [00:07<00:00,  2.55it/s, Epoch=54, LR=4.43e-5, Train_Auroc=0.947, Train_Loss=0.304]





100%|██████████| 20/20 [00:07<00:00,  2.64it/s, Epoch=55, LR=4.28e-5, Train_Auroc=0.938, Train_Loss=0.344]
100%|██████████| 1214/1214 [04:26<00:00,  4.56it/s, Epoch=55, LR=4.28e-5, Valid_Auroc=0.52, Valid_Loss=0.204] 
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:07<00:00,  2.61it/s, Epoch=56, LR=4.12e-5, Train_Auroc=0.942, Train_Loss=0.316]





100%|██████████| 20/20 [00:07<00:00,  2.60it/s, Epoch=57, LR=3.97e-5, Train_Auroc=0.931, Train_Loss=0.347]





100%|██████████| 20/20 [00:07<00:00,  2.62it/s, Epoch=58, LR=3.82e-5, Train_Auroc=0.933, Train_Loss=0.333]





100%|██████████| 20/20 [00:07<00:00,  2.65it/s, Epoch=59, LR=3.67e-5, Train_Auroc=0.939, Train_Loss=0.319]





100%|██████████| 20/20 [00:07<00:00,  2.56it/s, Epoch=60, LR=3.52e-5, Train_Auroc=0.952, Train_Loss=0.295]
100%|██████████| 1214/1214 [04:27<00:00,  4.54it/s, Epoch=60, LR=3.52e-5, Valid_Auroc=0.521, Valid_Loss=0.308]





  _warn_get_lr_called_within_step(self)
100%|██████████| 20/20 [00:07<00:00,  2.52it/s, Epoch=61, LR=3.37e-5, Train_Auroc=0.945, Train_Loss=0.302]





100%|██████████| 20/20 [00:07<00:00,  2.58it/s, Epoch=62, LR=3.23e-5, Train_Auroc=0.947, Train_Loss=0.3]  





100%|██████████| 20/20 [00:07<00:00,  2.55it/s, Epoch=63, LR=3.08e-5, Train_Auroc=0.942, Train_Loss=0.315]





100%|██████████| 20/20 [00:07<00:00,  2.61it/s, Epoch=64, LR=2.94e-5, Train_Auroc=0.945, Train_Loss=0.315]





100%|██████████| 20/20 [00:07<00:00,  2.55it/s, Epoch=65, LR=2.8e-5, Train_Auroc=0.958, Train_Loss=0.268] 
100%|██████████| 1214/1214 [04:26<00:00,  4.55it/s, Epoch=65, LR=2.8e-5, Valid_Auroc=0.521, Valid_Loss=0.343]
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:07<00:00,  2.58it/s, Epoch=66, LR=2.67e-5, Train_Auroc=0.953, Train_Loss=0.295]





100%|██████████| 20/20 [00:07<00:00,  2.65it/s, Epoch=67, LR=2.53e-5, Train_Auroc=0.948, Train_Loss=0.294]





100%|██████████| 20/20 [00:07<00:00,  2.55it/s, Epoch=68, LR=2.4e-5, Train_Auroc=0.948, Train_Loss=0.287] 





100%|██████████| 20/20 [00:07<00:00,  2.59it/s, Epoch=69, LR=2.27e-5, Train_Auroc=0.96, Train_Loss=0.267] 





100%|██████████| 20/20 [00:07<00:00,  2.58it/s, Epoch=70, LR=2.14e-5, Train_Auroc=0.956, Train_Loss=0.264]
100%|██████████| 1214/1214 [04:26<00:00,  4.56it/s, Epoch=70, LR=2.14e-5, Valid_Auroc=0.521, Valid_Loss=0.403]





  _warn_get_lr_called_within_step(self)
100%|██████████| 20/20 [00:07<00:00,  2.57it/s, Epoch=71, LR=2.02e-5, Train_Auroc=0.954, Train_Loss=0.278]





100%|██████████| 20/20 [00:07<00:00,  2.63it/s, Epoch=72, LR=1.89e-5, Train_Auroc=0.957, Train_Loss=0.273]





100%|██████████| 20/20 [00:07<00:00,  2.59it/s, Epoch=73, LR=1.78e-5, Train_Auroc=0.971, Train_Loss=0.245]





100%|██████████| 20/20 [00:07<00:00,  2.56it/s, Epoch=74, LR=1.66e-5, Train_Auroc=0.959, Train_Loss=0.27] 





100%|██████████| 20/20 [00:07<00:00,  2.55it/s, Epoch=75, LR=1.55e-5, Train_Auroc=0.967, Train_Loss=0.241]
100%|██████████| 1214/1214 [04:26<00:00,  4.56it/s, Epoch=75, LR=1.55e-5, Valid_Auroc=0.521, Valid_Loss=0.334]
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:07<00:00,  2.56it/s, Epoch=76, LR=1.44e-5, Train_Auroc=0.962, Train_Loss=0.246]





100%|██████████| 20/20 [00:07<00:00,  2.58it/s, Epoch=77, LR=1.34e-5, Train_Auroc=0.966, Train_Loss=0.23] 





100%|██████████| 20/20 [00:07<00:00,  2.65it/s, Epoch=78, LR=1.24e-5, Train_Auroc=0.972, Train_Loss=0.217]





100%|██████████| 20/20 [00:07<00:00,  2.64it/s, Epoch=79, LR=1.14e-5, Train_Auroc=0.968, Train_Loss=0.238]





100%|██████████| 20/20 [00:07<00:00,  2.54it/s, Epoch=80, LR=1.05e-5, Train_Auroc=0.965, Train_Loss=0.241]
100%|██████████| 1214/1214 [04:26<00:00,  4.56it/s, Epoch=80, LR=1.05e-5, Valid_Auroc=0.521, Valid_Loss=0.379]
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:07<00:00,  2.64it/s, Epoch=81, LR=9.56e-6, Train_Auroc=0.97, Train_Loss=0.221] 





100%|██████████| 20/20 [00:07<00:00,  2.53it/s, Epoch=82, LR=8.71e-6, Train_Auroc=0.961, Train_Loss=0.256]





100%|██████████| 20/20 [00:07<00:00,  2.54it/s, Epoch=83, LR=7.89e-6, Train_Auroc=0.96, Train_Loss=0.254] 





100%|██████████| 20/20 [00:07<00:00,  2.64it/s, Epoch=84, LR=7.12e-6, Train_Auroc=0.98, Train_Loss=0.191] 





100%|██████████| 20/20 [00:07<00:00,  2.58it/s, Epoch=85, LR=6.4e-6, Train_Auroc=0.971, Train_Loss=0.211] 
100%|██████████| 1214/1214 [04:48<00:00,  4.20it/s, Epoch=85, LR=6.4e-6, Valid_Auroc=0.521, Valid_Loss=0.266]





  _warn_get_lr_called_within_step(self)
100%|██████████| 20/20 [00:08<00:00,  2.44it/s, Epoch=86, LR=5.71e-6, Train_Auroc=0.971, Train_Loss=0.225]





100%|██████████| 20/20 [00:08<00:00,  2.30it/s, Epoch=87, LR=5.07e-6, Train_Auroc=0.97, Train_Loss=0.207] 





100%|██████████| 20/20 [00:08<00:00,  2.44it/s, Epoch=88, LR=4.48e-6, Train_Auroc=0.971, Train_Loss=0.224]





100%|██████████| 20/20 [00:08<00:00,  2.26it/s, Epoch=89, LR=3.93e-6, Train_Auroc=0.974, Train_Loss=0.198]





100%|██████████| 20/20 [00:08<00:00,  2.46it/s, Epoch=90, LR=3.42e-6, Train_Auroc=0.972, Train_Loss=0.207]
100%|██████████| 1214/1214 [05:02<00:00,  4.01it/s, Epoch=90, LR=3.42e-6, Valid_Auroc=0.521, Valid_Loss=0.36] 
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:07<00:00,  2.52it/s, Epoch=91, LR=2.97e-6, Train_Auroc=0.971, Train_Loss=0.221]





100%|██████████| 20/20 [00:08<00:00,  2.28it/s, Epoch=92, LR=2.56e-6, Train_Auroc=0.978, Train_Loss=0.2]  





100%|██████████| 20/20 [00:08<00:00,  2.38it/s, Epoch=93, LR=2.19e-6, Train_Auroc=0.984, Train_Loss=0.163]





100%|██████████| 20/20 [00:08<00:00,  2.29it/s, Epoch=94, LR=1.88e-6, Train_Auroc=0.96, Train_Loss=0.251] 





100%|██████████| 20/20 [00:08<00:00,  2.45it/s, Epoch=95, LR=1.61e-6, Train_Auroc=0.983, Train_Loss=0.174]
100%|██████████| 1214/1214 [05:01<00:00,  4.03it/s, Epoch=95, LR=1.61e-6, Valid_Auroc=0.521, Valid_Loss=0.289]





  _warn_get_lr_called_within_step(self)
100%|██████████| 20/20 [00:08<00:00,  2.48it/s, Epoch=96, LR=1.39e-6, Train_Auroc=0.973, Train_Loss=0.209]





100%|██████████| 20/20 [00:08<00:00,  2.34it/s, Epoch=97, LR=1.22e-6, Train_Auroc=0.97, Train_Loss=0.221] 





100%|██████████| 20/20 [00:08<00:00,  2.40it/s, Epoch=98, LR=1.1e-6, Train_Auroc=0.972, Train_Loss=0.211] 





100%|██████████| 20/20 [00:08<00:00,  2.26it/s, Epoch=99, LR=1.02e-6, Train_Auroc=0.973, Train_Loss=0.213]





100%|██████████| 20/20 [00:07<00:00,  2.51it/s, Epoch=100, LR=1e-6, Train_Auroc=0.971, Train_Loss=0.213]   
100%|██████████| 1214/1214 [05:03<00:00,  4.00it/s, Epoch=100, LR=1e-6, Valid_Auroc=0.521, Valid_Loss=0.304]





  _warn_get_lr_called_within_step(self)
100%|██████████| 20/20 [00:08<00:00,  2.42it/s, Epoch=101, LR=1.02e-6, Train_Auroc=0.967, Train_Loss=0.229]





100%|██████████| 20/20 [00:08<00:00,  2.32it/s, Epoch=102, LR=1.1e-6, Train_Auroc=0.979, Train_Loss=0.186] 





100%|██████████| 20/20 [00:08<00:00,  2.39it/s, Epoch=103, LR=1.22e-6, Train_Auroc=0.965, Train_Loss=0.228]





100%|██████████| 20/20 [00:08<00:00,  2.33it/s, Epoch=104, LR=1.39e-6, Train_Auroc=0.977, Train_Loss=0.202]





100%|██████████| 20/20 [00:07<00:00,  2.53it/s, Epoch=105, LR=1.61e-6, Train_Auroc=0.98, Train_Loss=0.183] 
100%|██████████| 1214/1214 [05:02<00:00,  4.01it/s, Epoch=105, LR=1.61e-6, Valid_Auroc=0.521, Valid_Loss=0.293]





  _warn_get_lr_called_within_step(self)
100%|██████████| 20/20 [00:08<00:00,  2.49it/s, Epoch=106, LR=1.88e-6, Train_Auroc=0.977, Train_Loss=0.189]





100%|██████████| 20/20 [00:08<00:00,  2.34it/s, Epoch=107, LR=2.19e-6, Train_Auroc=0.979, Train_Loss=0.191]





100%|██████████| 20/20 [00:08<00:00,  2.42it/s, Epoch=108, LR=2.56e-6, Train_Auroc=0.974, Train_Loss=0.205]





100%|██████████| 20/20 [00:07<00:00,  2.51it/s, Epoch=109, LR=2.97e-6, Train_Auroc=0.984, Train_Loss=0.178]





100%|██████████| 20/20 [00:08<00:00,  2.32it/s, Epoch=110, LR=3.42e-6, Train_Auroc=0.973, Train_Loss=0.205]
100%|██████████| 1214/1214 [05:03<00:00,  4.01it/s, Epoch=110, LR=3.42e-6, Valid_Auroc=0.521, Valid_Loss=0.387]
  _warn_get_lr_called_within_step(self)





100%|██████████| 20/20 [00:08<00:00,  2.47it/s, Epoch=111, LR=3.93e-6, Train_Auroc=0.979, Train_Loss=0.194]





100%|██████████| 20/20 [00:07<00:00,  2.56it/s, Epoch=112, LR=4.48e-6, Train_Auroc=0.983, Train_Loss=0.181]





100%|██████████| 20/20 [00:07<00:00,  2.50it/s, Epoch=113, LR=5.07e-6, Train_Auroc=0.972, Train_Loss=0.21] 





  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
1

# Train with synthetic data

In [None]:
def get_metrics_synth(drop_path_rate, drop_rate, models_folder, model_maker, synth_path = "../data/artifacts/syntetic_base_folds_final"):
    tsp = StratifiedGroupKFold(5, shuffle=True, random_state=CONFIG['seed'])
    results_list = []
    fold_df_valid_list = []
    for fold_n, (train_index, val_index) in enumerate(tsp.split(df_train, y=df_train.target, groups=df_train[CONFIG["group_col"]])):
        fold_df_train = df_train.iloc[train_index].reset_index(drop=True)
        fold_df_valid = df_train.iloc[val_index].reset_index(drop=True)
        set_seed(CONFIG['seed'])
        model = setup_model(model_name, drop_path_rate=drop_path_rate, drop_rate=drop_rate, model_maker=model_maker)
        print_trainable_parameters(model)

        synthetic_Df = pd.concat([
            pd.DataFrame({
                "path": glob.glob(f"{synth_path}/{CONFIG['seed']}/{fold_n}/hr/**.png")
            }),
            pd.DataFrame({
                "path": glob.glob(f"{synth_path}/{CONFIG['seed']}/{fold_n}/lr/**.png")
            })
        ])
        
        synthetic_Df['weight'] = 1
        synthetic_Df['target'] = synthetic_Df.path.apply(lambda x: int(x.split('___')[1].split('.')[0]))

        synthetic_Df = pd.concat([
            synthetic_Df, fold_df_train[['path', 'target', 'weight']].reset_index(drop=True)
        ]).reset_index(drop=True)

        
        train_loader, valid_loader = prepare_loaders(synthetic_Df, fold_df_valid, CONFIG, data_transforms)
    
        optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], 
                           weight_decay=CONFIG['weight_decay'])
        scheduler = fetch_scheduler(optimizer, CONFIG)
    
        model, history = run_training(
            train_loader, valid_loader,
            model, optimizer, scheduler,
            device=CONFIG['device'],
            num_epochs=CONFIG['epochs'],
            CONFIG=CONFIG, 
            tolerance_max=20,
            test_every_nth_step=lambda x: 5,
            seed=CONFIG['seed'])
        torch.save(model.state_dict(), os.path.join(models_folder, f"model__{fold_n}"))
        results_list.append(np.max(history['Valid Kaggle metric']))

        val_epoch_loss, val_epoch_auroc, val_epoch_custom_metric, tmp_predictions_all, tmp_targets_all = valid_one_epoch(
            model, 
            valid_loader, 
            device=CONFIG['device'], 
            epoch=1, 
            optimizer=optimizer, 
            criterion=criterion, 
            use_custom_score=True,
            metric_function=binary_auroc, 
            num_classes=1,
            return_preds=True)

        fold_df_valid['tmp_targets_all'] = tmp_targets_all
        fold_df_valid['tmp_predictions_all'] = tmp_predictions_all
        fold_df_valid['fold_n'] = fold_n
        fold_df_valid_list.append(fold_df_valid)
    fold_df_valid_list = pd.concat(fold_df_valid_list).reset_index(drop=True)
    return results_list, fold_df_valid_list

In [None]:
folder_name = f"../models/oof_{{MODEL_NAME.lower()}}_base__synth"
os.makedirs(folder_name, exist_ok=True)

In [None]:
base_metrics, oof_forecasts = get_metrics(drop_path_rate=0, drop_rate=0, models_folder=folder_name, model_maker=ISICModelPrep)
oof_forecasts.to_parquet(f'../data/artifacts/oof_forecasts_{MODEL_NAME.lower()}_base__synth.parquet')