# 1. Setup & Imports

## 1.1 Setup

In [None]:
import os
import wandb
from dotenv import load_dotenv

load_dotenv()
wandb_api_key = os.getenv("WANDB_API_KEY")

wandb.login(key=wandb_api_key)

## 1.2 Imports

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='pydantic')

import gc
import math
import heapq
import shutil
import glob
import random
from types import SimpleNamespace
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import cv2
from sklearn.metrics import roc_auc_score, confusion_matrix
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
import timm
import ttach as tta
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import get_cosine_schedule_with_warmup

cv2.setNumThreads(0)
print(os.cpu_count())

## 1.3 Configuration

In [3]:
class CFG:
    seed = 855
    n_folds = 5
    epochs = 25
    virtual_epochs = 25
    warmup_multiplier = 2
    batch_size = 32
    accum_iter = 1
    num_workers = 4
    persistent_workers=True
    lr = 0.0001
    weight_decay = 0.05
    alpha = 0.5
    T = 2.0
    drop_path_rate = 0.2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    project_name = 'PlantPathology2020'
    exp_name = 'Student9_convnext_small'

In [None]:
def set_seed(seed, deterministic=False):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) # cpu
    torch.cuda.manual_seed(seed) # gpu
    torch.cuda.manual_seed_all(seed) # Î©ÄÌã∞ gpu
    if deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.use_deterministic_algorithms(False)


set_seed(CFG.seed)

device = CFG.device
print(device)

# 2. Data Pipeline

## 2.1 Load Raw Data

In [None]:
img_dir = '../data/images/'

hard_cols = ['healthy', 'multiple_diseases', 'rust', 'scab']
soft_cols = ['healthy_pred', 'multiple_diseases_pred', 'rust_pred', 'scab_pred']

train_df = pd.read_csv('../data/datasets/train_reborn_02.csv')
train_df = train_df.reset_index(drop=True)
oof_df_01 = pd.read_csv('../data/datasets/oof_preds_Student5_EfficientNetB6_reborn.csv')
oof_df_02 = pd.read_csv('../data/oof_preds_Student8_ResNest101e.csv')

test_df = pd.read_csv('../data/test.csv')
submission = pd.read_csv('../data/sample_submission.csv')

train_df.head()

In [None]:
train_df['fold'].value_counts()

In [None]:
display(oof_df_01.head())
display(oof_df_02.head())

In [None]:
oof_df_01['image_id'] = train_df['image_id']
oof_df_01 = oof_df_01[['image_id', 'healthy', 'multiple_diseases', 'rust', 'scab']]
oof_df_01.head()

In [None]:
oof_df = oof_df_01.copy()
oof_df[hard_cols] = oof_df_01[hard_cols] * 0.5 + oof_df_02[hard_cols] * 0.5
oof_df.head()

In [None]:
oof_df.columns = ['image_id', 'healthy_pred', 'multiple_diseases_pred', 'rust_pred', 'scab_pred']
train_df = train_df.merge(oof_df, on='image_id', how='left')
train_df.head()

## 2.2 Load Images

In [None]:
all_images = {}
# all_img_ids = np.concatenate([train_df['image_id'].values, test_df['image_id'].values])
all_img_ids = np.concatenate([train_df['image_id'].tolist(), test_df['image_id'].tolist()])

print("Loading all images into RAM once...")
for img_id in tqdm(all_img_ids, desc='Loading Images...' ,leave=False):
    img = cv2.imread(img_dir + img_id + '.jpg')
    img = cv2.resize(img, (650, 450))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img.setflags(write=False)
    all_images[img_id] = img

print("All Images on Ram")

## 2.3 Custom Dataset Class

In [11]:
class ImageDataset(Dataset):
    def __init__(self, df, hard_cols=hard_cols, soft_cols=soft_cols, transform=None, is_test=False):
        super().__init__()
        self.df = df
        self.transform = transform
        self.is_test = is_test
        self.hard_cols = hard_cols
        self.soft_cols = soft_cols

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_id = self.df.iloc[idx, 0]
        image = all_images[img_id].copy()
        # img_path = self.img_dir + img_id + '.jpg'
        # image = cv2.imread(img_path)
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # image = cv2.resize(image, (650, 450))

        if self.transform is not None:
            image = self.transform(image=image)['image']

        if self.is_test:
            return image
        else:
            soft_labels = self.df.iloc[idx][self.soft_cols].values.astype(np.float32)
            hard_labels = self.df.iloc[idx][self.hard_cols].values.astype(np.float32)
            return image, torch.tensor(soft_labels), torch.tensor(hard_labels)

## 2.4 Augmentations

In [None]:
# mean_fill_value = [103, 131, 82]

transform_train = A.Compose([
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0, p=1.0),
        A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2, p=1.0)
    ], p=1.0),

    A.OneOf([
        A.MotionBlur(blur_limit=3, p=1.0),
        A.MedianBlur(blur_limit=3, p=1.0),
        A.GaussianBlur(blur_limit=3, p=1.0),
    ], p=0.5),

    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    
    A.Affine(
        scale=(0.8, 1.2),
        translate_percent=0.2,
        rotate=20,
        interpolation=cv2.INTER_CUBIC, # Î≥¥Í∞Ñ
        border_mode=cv2.BORDER_REFLECT_101, # ÌÖåÎëêÎ¶¨ Î∞òÏÇ¨ Ï±ÑÏö∞Í∏∞
        p=1.0
    ),

    # A.CoarseDropout(
    #     num_holes_range=(4, 8),       # min_holes=4, max_holes=8 ÎåÄÏ≤¥
    #     hole_height_range=(8, 16),    # min_height=8, max_height=16 ÎåÄÏ≤¥
    #     hole_width_range=(8, 16),     # min_width=8, max_width=16 ÎåÄÏ≤¥
    #     fill=mean_fill_value,
    #     p=0.5
    # ),

    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

transform_test = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

transform_tta = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        # tta.Rotate90(angles=[0, 90, 180, 270]),
    ]
)

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(CFG.seed)

In [None]:
# ÌÖåÏä§Ìä∏ Îç∞Ïù¥ÌÑ∞ÏÖã
batch_size = CFG.batch_size

dataset_test = ImageDataset(test_df, transform=transform_test, is_test=True)
loader_test = DataLoader(dataset_test, batch_size=batch_size*4,
                         shuffle=False, worker_init_fn=seed_worker, generator=g, num_workers=CFG.num_workers, persistent_workers=True, pin_memory=True)

# 3. Model Architecture

In [None]:
def get_model():
    model = timm.create_model(
            'convnext_small.fb_in22k_ft_in1k',
            pretrained=True,
            drop_path_rate=CFG.drop_path_rate,
            num_classes=4,
        )
    model = model.to(device)
    return model

model = get_model()
config = timm.data.resolve_model_data_config(model)

print(f"Mean: {config['mean']}")
print(f"Std: {config['std']}")

In [None]:
torch.cuda.empty_cache()
gc.collect()

# 4. Training Components

## 4.1 Metrics & Trackers

In [None]:
class AvgMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class MetricHandler:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds_list = []
        self.actual_list = []

    def update(self, preds, actual):
        self.preds_list.extend(preds)
        self.actual_list.extend(actual)

    def compute_roc_auc(self):
        return roc_auc_score(self.actual_list, self.preds_list)

    def print_confusion_matrix(self):
        y_true = np.argmax(np.array(self.actual_list), axis=1)
        y_pred = np.argmax(np.array(self.preds_list), axis=1)
        cm = confusion_matrix(y_true, y_pred)

        class_total = cm.sum(axis=1)
        class_correct = cm.diagonal()
        class_acc = np.divide(class_correct, class_total, out=np.zeros_like(class_correct, dtype=float), where=class_total!=0)
        class_names = ['Healthy', 'Multiple', 'Rust', 'Scab']

        print('\n' + '=' * 55)
        print(f"{'True \\ Pred':<12} | {'Healthy':^7} {'Multiple':^8} {'Rust':^7} {'Scab':^7}")
        print('-' * 55)
        for i, label in enumerate(class_names):
            print(f"{label:<12} | {cm[i][0]:^7} {cm[i][1]:^8} {cm[i][2]:^7} {cm[i][3]:^7}")
        print('-' * 55)

        print('üéØ Class-wise Accuracy:')
        for i, name in enumerate(class_names):
            acc_percent = class_acc[i] * 100
            print(f' > {name:<9} : {acc_percent:6.2f}%  ({class_correct[i]:3d} / {class_total[i]:3d})')
        print('=' * 55 + '\n')

## 4.2 Checkpoint

In [None]:
class ModelCheckpoint:
    def __init__(self, output_dir, mode='max', top_k=3):
        self.output_dir = output_dir
        self.mode = mode
        self.top_k = top_k
        self.best_scores = [] # heap : [(compare_score, save_path), ...]
        os.makedirs(self.output_dir, exist_ok=True)

    def update(self, model, score, filename):
        save_path = os.path.join(self.output_dir, filename)
        compare_score = score if self.mode == 'max' else -score
        save_flag = False
        
        if len(self.best_scores) < self.top_k:
            save_flag = True
        elif compare_score > self.best_scores[0][0]:
            save_flag = True
            _, remove_path = heapq.heappop(self.best_scores)
            if os.path.exists(remove_path):
                os.remove(remove_path)
        
        if save_flag:
            heapq.heappush(self.best_scores, (compare_score, save_path))
            torch.save(model.state_dict(), save_path)
            print(f'Top-{self.top_k} Model Saved : {filename} (score: {score:.4f})')
            return True

        return False


class BackupHandler:
    def __init__(self, local_dir, backup_dir=None, active=True):
        self.local_dir = local_dir
        self.backup_dir = backup_dir
        self.active = active and (backup_dir is not None)

        if self.active and self.backup_dir is not None:
            os.makedirs(self.backup_dir, exist_ok=True)
            print(f'Backup Active : {self.local_dir} -> {self.backup_dir}')

    def backup(self, filename):
        if not self.active or self.backup_dir is None:
            return

        src_path = os.path.join(self.local_dir, filename)
        dst_path = os.path.join(self.backup_dir, filename)
        
        if os.path.exists(src_path):
            shutil.copy(src_path, dst_path)

    def save_csv(self, df, filename):
        local_path = os.path.join(self.local_dir, filename)
        df.to_csv(local_path, index=False)
        print(f'CSV saved at {local_path}')

        if self.active and self.backup_dir is not None:
            backup_path = os.path.join(self.backup_dir, filename)
            df.to_csv(backup_path, index=False)
            print(f'CSV saved at {backup_path}')

## 4.3 Trainer Engine

In [None]:
class Trainer:
    def __init__(self, model, loader_train, loader_valid, fold, config, local_model_dir):
        self.model = model
        self.loader_train = loader_train
        self.loader_valid = loader_valid
        self.fold = fold
        self.config = config
        self.alpha = config.alpha
        # self.T = config.T if hasattr(config, 'T') else 1.0
        self.local_model_dir = local_model_dir
        os.makedirs(self.local_model_dir, exist_ok=True)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.loss_function = nn.CrossEntropyLoss()
        self.scaler = GradScaler('cuda')
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr, weight_decay=self.config.weight_decay)
        self.accum_iter = config.accum_iter

        steps_per_epoch = math.ceil(len(loader_train.dataset) / self.config.batch_size)
        total_steps = (steps_per_epoch // self.accum_iter) * config.virtual_epochs
        warmup_steps = (steps_per_epoch // self.accum_iter) * config.warmup_multiplier if hasattr(config, 'warmup_multiplier') else 2
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        self.metric_handler = MetricHandler()
        # self.loss_checkpoint = ModelCheckpoint(output_dir=self.local_model_dir, mode='min')
        self.score_checkpoint = ModelCheckpoint(output_dir=self.local_model_dir, mode='max')


    def _train_one_epoch(self, epoch):
        self.model.train()
        epoch_train_loss = AvgMeter()
        pbar = tqdm(self.loader_train, desc=f'Train {self.fold+1} Ep {epoch+1}', leave=False)
        
        T = self.config.T
        self.optimizer.zero_grad()
        # batch[0]=image, batch[1]=soft label, batch[2]=hard label
        for i, batch in enumerate(pbar):
            image = batch[0].to(self.device)
            soft_labels = batch[1].to(self.device)
            hard_labels = batch[2].to(self.device)
            
            if T > 1.0:
                epsilon = 1e-6
                logits_from_oof = torch.log(soft_labels + epsilon)
                soft_labels = torch.softmax(logits_from_oof / T, dim=1)

            label = self.alpha * hard_labels + (1 - self.alpha) * soft_labels
            
            with autocast('cuda'):
                outputs = self.model(image)
                logits = outputs / T if T > 1 else outputs
                loss = self.loss_function(logits, label) 
                if T > 1.0:
                    loss = loss * (T ** 2)
                loss = loss / self.accum_iter

            self.scaler.scale(loss).backward()
            if (i + 1) % self.accum_iter == 0 or (i + 1) == len(self.loader_train):
                scale_before = self.scaler.get_scale()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                scale_after = self.scaler.get_scale()
                
                if scale_after >= scale_before:
                    self.scheduler.step()
                    
                self.optimizer.zero_grad()

            epoch_train_loss.update(loss.item() * self.accum_iter, n=image.size(0))
            pbar.set_postfix({'train_loss': epoch_train_loss.avg})

        return epoch_train_loss.avg


    @torch.no_grad()
    def _val_one_epoch(self, epoch):
        self.model.eval()
        epoch_val_loss = AvgMeter()
        epoch_val_conf = AvgMeter()
        self.metric_handler.reset()

        for image, soft_labels, hard_labels in tqdm(self.loader_valid, desc=f'Val {self.fold+1} Ep {epoch+1}', leave=False):
            image = image.to(self.device)
            hard_labels = hard_labels.to(self.device)
            
            with autocast('cuda'):
                outputs = self.model(image)
                loss = self.loss_function(outputs, hard_labels)
            
            probs = torch.softmax(outputs.cpu(), dim=1)
            max_probs = probs.max(dim=1)[0]
            epoch_val_conf.update(max_probs.mean().item(), n=image.size(0))

            epoch_val_loss.update(loss.item(), n=image.size(0))
            preds = probs.numpy()
            # preds = torch.softmax(outputs.cpu(), dim=1).numpy()
            self.metric_handler.update(preds, hard_labels.cpu().numpy())

        val_loss = epoch_val_loss.avg
        roc_auc = self.metric_handler.compute_roc_auc()
        val_conf = epoch_val_conf.avg

        return val_loss, roc_auc, val_conf

    def fit(self, epochs):
        for epoch in range(epochs):
            set_seed(self.config.seed + epoch)

            train_loss = self._train_one_epoch(epoch)
            val_loss, roc_auc, val_conf = self._val_one_epoch(epoch)

            print(f'EPOCH [{epoch+1}/{epochs}] || TRAIN LOSS : {train_loss:.4f} || VAL LOSS : {val_loss:.4f} / ROC AUC : {roc_auc:.4f} | Avg Conf: {val_conf:.4f} ')

            current_lr = self.optimizer.param_groups[0]['lr']
            wandb.log({
                "train/loss": train_loss,
                "val/loss": val_loss,
                "val/auc": roc_auc,
                "learning_rate": current_lr,
                "epoch": epoch + 1
            })

            # best_loss_name = f'best_loss_model_{self.fold+1}_loss_{val_loss}.pth'
            best_score_name = f'best_score_model_{self.fold+1}_ep{epoch+1}_roc_{roc_auc:.4f}.pth'

            # self.loss_checkpoint.update(self.model, val_loss, best_loss_name)

            if self.score_checkpoint.update(self.model, roc_auc, best_score_name):
                self.metric_handler.print_confusion_matrix()
                if wandb.run is not None:
                    wandb.run.summary["best_auc"] = max(self.score_checkpoint.best_scores)[0]

        return max(self.score_checkpoint.best_scores)[0],# -max(self.loss_checkpoint.best_scores)

## 4.4 Predictor

In [None]:
class Predictor:
    def __init__(self, device, model_arch=None,  tta_transform=None):
        self.model_arch = model_arch
        self.device = device
        self.tta_transform = tta_transform
        if self.tta_transform and (model_arch is not None):
            self.model = tta.ClassificationTTAWrapper(self.model_arch, self.tta_transform, merge_mode='mean')
        
    def load_weights(self, weight_path):
        state_dict = torch.load(weight_path, map_location=self.device)
        self.model_arch.load_state_dict(state_dict)
        self.model = self.model_arch.to(self.device)
        if self.tta_transform:
            self.model = tta.ClassificationTTAWrapper(self.model, self.tta_transform, merge_mode='mean')
        
        self.model.eval()

    @torch.no_grad()
    def predict(self, loader):
        preds_list = []
        for batch in tqdm(loader, leave=False):
            if isinstance(batch, (list, tuple)):
                image = batch[0]
            else:
                image = batch
            
            image = image.to(self.device)

            with autocast('cuda'):
                outputs = self.model(image)

            preds = torch.softmax(outputs.cpu(), dim=1).numpy()
            preds_list.append(preds)

        final_preds = np.concatenate(preds_list, axis=0)

        return final_preds

In [None]:
torch.cuda.empty_cache()
gc.collect()

## 4.5 Experiment Ochestrator

In [None]:
class ExperimentRunner:
    def __init__(self, config, train_df, test_df):
        self.config = config
        self.train_df = train_df
        self.test_df = test_df
        self.paths = self._setup_env()
        self.backup_handler = BackupHandler(local_dir=self.paths.local_path , backup_dir=self.paths.drive_path, active=False)

    def _setup_env(self):
        is_kaggle = os.path.exists('/kaggle/') 
        is_colab = os.path.exists('/content/drive/Mydrive') and not is_kaggle

        if is_kaggle:
            print("Environment: Kaggle")
            drive_path = None
            local_path = '/kaggle/working/'
        elif is_colab:
            print("Environment: Google Colab")
            drive_path = f'/content/drive/MyDrive/Kaggle_Save/{CFG.exp_name}/'
            local_path = '/content/models/'
        else:
            print("Environment: Local")
            drive_path = None
            local_path = f'../data/models/{CFG.exp_name}/'
        
        print(f"Save Path: {local_path}")
        return SimpleNamespace(local_path=local_path, drive_path=drive_path)

    def _run_fold(self, fold, loader_train, loader_valid):
        model = get_model()      
        trainer = Trainer(model, loader_train, loader_valid, fold, self.config, self.paths.local_path)
        best_score = trainer.fit(self.config.epochs)
        return trainer, best_score

    def run_experiment(self):
        for fold in range(self.config.n_folds):
            print('='*30, f'FOLD {fold+1}', '='*30)

            # wanb run Ï¥àÍ∏∞Ìôî
            if wandb.run is not None:
                wandb.finish()

            wandb.init(
                project=self.config.project_name,
                group=self.config.exp_name,
                name=f"Fold_{fold+1}",
                job_type="train",
                config={k: v for k, v in self.config.__dict__.items() if not k.startswith('__')}
            )
            
            # Îç∞Ïù¥ÌÑ∞ Î°úÎçî Ï§ÄÎπÑ
            train = self.train_df[self.train_df['fold']!=fold].reset_index(drop=True).copy()
            valid = self.train_df[self.train_df['fold']==fold].copy()

            valid_indices = valid.index.values
            valid = valid.reset_index(drop=True)

            print(f'train size : {len(train)}')
            print(f'valid size : {len(valid)}')

            dataset_train = ImageDataset(train, transform=transform_train)
            dataset_valid = ImageDataset(valid, transform=transform_test)

            loader_train = DataLoader(dataset_train, batch_size=self.config.batch_size, shuffle=True,
                                    worker_init_fn=seed_worker, generator=g, num_workers=self.config.num_workers, persistent_workers=True, pin_memory=True)
            loader_valid = DataLoader(dataset_valid, batch_size=self.config.batch_size*4, shuffle=False,
                                    worker_init_fn=seed_worker, generator=g, num_workers=self.config.num_workers, persistent_workers=True, pin_memory=True)


            trainer, best_score = self._run_fold(fold=fold, loader_train=loader_train, loader_valid=loader_valid)

            # backup_handler.backup(trainer.best_loss_name)
            # backup_handler.backup(trainer.best_score_name)

            # Î°úÍπÖ Ï¢ÖÎ£å
            wandb.finish()
            
            # Î©îÎ™®Î¶¨ Ï†ïÎ¶¨
            del loader_train, loader_valid, trainer, dataset_train, dataset_valid
            torch.cuda.empty_cache()
            gc.collect()
            
    def _load_averaged_model(self, fold, model_arch):
        save_path = os.path.join(self.paths.local_path, f'best_score_model_{fold+1}.pth')
        model = model_arch.to(self.config.device)

        if os.path.exists(save_path):
            print(f"‚úÖ Found existing averaged model for Fold {fold+1}. Loading directly...")
            state_dict = torch.load(save_path, map_location=self.config.device)
            model.load_state_dict(state_dict)
        else:
            print(f'Merging Top-K Models for Fold {fold+1} ...')
            score_pattern = os.path.join(self.paths.local_path, f'best_score_model_{fold+1}_*.pth')
            score_files = glob.glob(score_pattern)
            print(f'Found {len(score_files)} score models : {[os.path.basename(f) for f in score_files]}')
            
            avg_state_dict = torch.load(score_files[0], map_location=self.config.device)
            if len(score_files) > 1:
                for path in score_files[1:]:
                    state_dict = torch.load(path, map_location=self.config.device)
                    for key in avg_state_dict:
                        avg_state_dict[key] += state_dict[key]
                for key in avg_state_dict:
                    if avg_state_dict[key].is_floating_point():
                        avg_state_dict[key] = avg_state_dict[key] / len(score_files)
                    else:
                        avg_state_dict[key] = avg_state_dict[key] // len(score_files)
            
            model = model_arch.to(self.config.device)
            model.load_state_dict(avg_state_dict)
            
            for remove_path in score_files:
                if os.path.exists(remove_path):
                    os.remove(remove_path)
            
            save_path = os.path.join(self.paths.local_path, f'best_score_model_{fold+1}.pth')
            torch.save(model.state_dict(), save_path)
            print('Save Avg Model : ', save_path)
        
        print('Update BN stats ... ')
        model.train()
        train_subset = self.train_df[self.train_df['fold'] != fold]
        dataset_bn = ImageDataset(train_subset, transform=transform_test)
        loader_bn = DataLoader(dataset=dataset_bn, batch_size=self.config.batch_size, shuffle=False,
                                worker_init_fn=seed_worker, generator=g, num_workers=self.config.num_workers, persistent_workers=True, pin_memory=True)
        
        with torch.no_grad():
            for batch in tqdm(loader_bn, desc=f'Update BN stats for Fold {fold+1}', leave=False):
                model(batch[0].to(self.config.device))
                
        model.eval()
        return model
        
    def run_inference(self):
        oof_preds = np.zeros((len(self.train_df), 4))
        final_preds = np.zeros((len(self.test_df), 4))

        for fold in range(self.config.n_folds):
            print(f"=== Inference Fold {fold+1} ===")
            avg_model = self._load_averaged_model(fold, get_model())
            
            predictor = Predictor(device=self.config.device, model_arch=avg_model, tta_transform=transform_tta)
            # sub_predictor = Predictor(model_arch=get_model(), device=device, tta_transform=transform_tta)

            valid = self.train_df[self.train_df['fold']==fold].copy()
            valid_indices = valid.index.values
            dataset_valid = ImageDataset(valid, transform=transform_test)
            loader_valid = DataLoader(dataset_valid, batch_size=self.config.batch_size*4, shuffle=False,
                                    worker_init_fn=seed_worker, generator=g, num_workers=self.config.num_workers, persistent_workers=True, pin_memory=True)

            # oof_predictor.load_weights(score_model_path)
            print(f"OOF Inference ... ")
            oof_temp = predictor.predict(loader_valid)
            oof_preds[valid_indices] = oof_temp

            # sub_predictor.load_weights(score_model_path)
            print(f"Test Inference ... ")
            sub_temp = predictor.predict(loader_test)
            final_preds += (sub_temp / self.config.n_folds)

            del predictor, loader_valid, dataset_valid
            torch.cuda.empty_cache()
            gc.collect()

        metric_handler = MetricHandler()
        metric_handler.update(oof_preds, self.train_df[hard_cols].values)
        oof_roc = metric_handler.compute_roc_auc()
        print(f'OOF ROC AUC : {oof_roc:.4f}')

        return oof_preds, final_preds

# 5. Training Execution

In [None]:
runner = ExperimentRunner(config=CFG, train_df=train_df, test_df=test_df)

In [None]:
%%time
runner.run_experiment()

# 6. Inference & Save

In [None]:
models_list = glob.glob(runner.paths.local_path + "/*.pth")
print([m.split('/')[-1] for m in models_list ])

In [None]:
%%time
oof_preds, final_preds = runner.run_inference()

In [None]:
result_oof = train_df[['image_id']].copy()
result_oof[hard_cols] = oof_preds
runner.backup_handler.save_csv(result_oof, f'oof_preds_{CFG.exp_name}.csv')

result_sub = submission[['image_id']].copy()
result_sub[hard_cols] = final_preds
runner.backup_handler.save_csv(result_sub, f'submission_{CFG.exp_name}.csv')

In [None]:
display(result_oof.head())
display(result_sub.head())