# 1. Setup & Imports

## 1.1 Setup

In [None]:
import os
import wandb
from dotenv import load_dotenv

load_dotenv()
wandb_api_key = os.getenv("WANDB_API_KEY")

wandb.login(key=wandb_api_key)

## 1.2 Imports

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='pydantic')

import gc
import math
import shutil
import glob
import random
from types import SimpleNamespace
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import cv2
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.swa_utils import update_bn
from torch.utils.data import Dataset, DataLoader

import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import TQDMProgressBar

from torchmetrics import ConfusionMatrix, AUROC
import timm
import ttach as tta
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import get_cosine_schedule_with_warmup

import seaborn as sns
import matplotlib.pyplot as plt

cv2.setNumThreads(0)
print(os.cpu_count())
torch.set_float32_matmul_precision('medium') # L4

## 1.3 Configuration

In [3]:
class CFG:
    model_arch = 'convnext_small.fb_in22k_ft_in1k'
    is_bn = False
    seed = 855
    top_k = 3
    n_folds = 5
    epochs = 25
    virtual_epochs = 25
    warmup_multiplier = 2
    batch_size = 32
    accum_iter = 1
    num_workers = 4
    persistent_workers=True
    lr = 0.0001
    weight_decay = 0.05
    alpha = 0.5
    T = 2.0
    drop_path_rate = 0.2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    project_name = 'PlantPathology2020'
    exp_name = 'Student9_convnext_small_inf_edited'

In [None]:
def set_seed(seed, deterministic=False):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) # cpu
    torch.cuda.manual_seed(seed) # gpu
    torch.cuda.manual_seed_all(seed) # 멀티 gpu
    if deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.use_deterministic_algorithms(False)


set_seed(CFG.seed)

device = CFG.device
print(device)

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(CFG.seed)

# 2. Data Pipeline

## 2.1 Data Loading

In [None]:
class DataModule:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.img_dir = self.data_dir + 'images/'
        
    def prepare_data(self):
        self.train_df = pd.read_csv(self.data_dir + 'datasets/train_reborn_02.csv')
        self.train_df = self.train_df.reset_index(drop=True)
        self.test_df = pd.read_csv(self.data_dir + 'test.csv')
        self.submission = pd.read_csv(self.data_dir + 'sample_submission.csv')
        
        oof_df_01 = pd.read_csv(self.data_dir + 'datasets/oof_preds_Student5_EfficientNetB6_reborn.csv')
        oof_df_02 = pd.read_csv(self.data_dir + 'oof_preds_Student8_ResNest101e.csv')        
        oof_df_01['image_id'] = self.train_df['image_id']
        oof_df_01 = oof_df_01[['image_id', 'healthy', 'multiple_diseases', 'rust', 'scab']]
        
        self.oof_df = oof_df_01.copy()
        hard_cols = ['healthy', 'multiple_diseases', 'rust', 'scab']
        self.oof_df[hard_cols] = oof_df_01[hard_cols] * 0.5 + oof_df_02[hard_cols] * 0.5
        self.oof_df.columns = ['image_id', 'healthy_pred', 'multiple_diseases_pred', 'rust_pred', 'scab_pred']
        self.train_df = self.train_df.merge(self.oof_df, on='image_id', how='left')
        return self.train_df, self.test_df, self.submission

    
train_df, test_df, submission = DataModule('../data/').prepare_data()

hard_cols = ['healthy', 'multiple_diseases', 'rust', 'scab']
soft_cols = ['healthy_pred', 'multiple_diseases_pred', 'rust_pred', 'scab_pred']

train_df.head()

## 2.2 Load Images

In [None]:
all_images = {}
all_img_ids = np.concatenate([train_df['image_id'].tolist(), test_df['image_id'].tolist()])

print("Loading all images into RAM once...")
for img_id in tqdm(all_img_ids, desc='Loading Images...' ,leave=False):
    img = cv2.imread('../data/images/' + img_id + '.jpg')
    img = cv2.resize(img, (650, 450))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img.setflags(write=False)
    all_images[img_id] = img

print("All Images on Ram")

## 2.3 Custom Dataset

In [8]:
class ImageDataset(Dataset):
    def __init__(self, df, hard_cols=hard_cols, soft_cols=soft_cols, transform=None, is_test=False):
        super().__init__()
        self.df = df
        self.transform = transform
        self.is_test = is_test
        self.hard_cols = hard_cols
        self.soft_cols = soft_cols

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_id = self.df.iloc[idx, 0]
        image = all_images[img_id].copy()

        if self.transform is not None:
            image = self.transform(image=image)['image']

        if self.is_test:
            return image
        else:
            soft_labels = self.df.iloc[idx][self.soft_cols].values.astype(np.float32)
            hard_labels = self.df.iloc[idx][self.hard_cols].values.astype(np.float32)
            return image, torch.tensor(soft_labels), torch.tensor(hard_labels)

## 2.4 DataModule

In [9]:
class PlantDataModule(pl.LightningDataModule):
    """
    데이터 로딩, 전처리 및 학습/검증 세트 분할을 관리하는 클래스입니다.
    
    K-Fold 인덱스에 따라 데이터를 학습용과 검증용으로 분리하며, 
    stage 인자에 따라 불필요한 데이터 로딩을 방지하여 메모리 효율성을 최적화합니다. 
    이미지 증강(Augmentation) 로직을 내부적으로 포함하여 데이터와 모델 사이의 인터페이스를 명확히 정의합니다.
    """
    def __init__(self, train_df, test_df, cfg, fold_idx, inference_mode=False):
        super().__init__()
        self.train_df = train_df
        self.test_df = test_df
        self.cfg = cfg
        self.fold_idx = fold_idx
        self.inference_mode = inference_mode

        self.transform_train = A.Compose([
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0, p=1.0),
                A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2, p=1.0)
            ], p=1.0),

            A.OneOf([
                A.MotionBlur(blur_limit=3, p=1.0),
                A.MedianBlur(blur_limit=3, p=1.0),
                A.GaussianBlur(blur_limit=3, p=1.0),
            ], p=0.5),

            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            
            A.Affine(
                scale=(0.8, 1.2),
                translate_percent=0.2,
                rotate=20,
                interpolation=cv2.INTER_CUBIC, # 보간
                border_mode=cv2.BORDER_REFLECT_101, # 테두리 반사 채우기
                p=1.0
            ),

            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

        self.transform_test = A.Compose([
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])
            
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train = self.train_df[self.train_df['fold']!=self.fold_idx].reset_index(drop=True).copy()
            self.valid = self.train_df[self.train_df['fold']==self.fold_idx].copy()
            self.dataset_train = ImageDataset(self.train, transform=self.transform_train)
            self.dataset_valid = ImageDataset(self.valid, transform=self.transform_test)
            print(f'[Fit] Train: {len(self.train)}, Valid: {len(self.valid)}')

        elif stage == 'test':
            self.valid = self.train_df[self.train_df['fold']==self.fold_idx].copy()
            self.dataset_valid = ImageDataset(self.valid, transform=self.transform_test)
            self.dataset_test = ImageDataset(self.test_df, transform=self.transform_test, is_test=True)
            print(f'[Test] Valid(OOF): {len(self.valid)}, Test: {len(self.test_df)}')

        elif stage == 'predict':
            self.dataset_test = ImageDataset(self.test_df, transform=self.transform_test, is_test=True)
    
    def train_dataloader(self):
        loader_train = DataLoader(self.dataset_train, batch_size=self.cfg.batch_size, shuffle=True,
                                worker_init_fn=seed_worker, generator=g, num_workers=self.cfg.num_workers, 
                                persistent_workers=True, pin_memory=True)
        return loader_train
    
    def val_dataloader(self):
        user_persistent = not self.inference_mode
        loader_valid = DataLoader(self.dataset_valid, batch_size=self.cfg.batch_size*4, shuffle=False,
                                worker_init_fn=seed_worker, generator=g, num_workers=self.cfg.num_workers, 
                                persistent_workers=user_persistent, pin_memory=True)
        return loader_valid
    
    def predict_dataloader(self):
        loader_test = DataLoader(self.dataset_test, batch_size=self.cfg.batch_size*4, shuffle=False,
                                worker_init_fn=seed_worker, generator=g, num_workers=self.cfg.num_workers, 
                                persistent_workers=False, pin_memory=True)
        return loader_test
    
    def test_dataloader(self):
        return self.predict_dataloader()

# 3. Model Architecture

In [10]:
class PlantDiseaseModule(pl.LightningModule):
    """
    모델의 순전파, 손실 함수 계산, 최적화 알고리즘 및 메트릭 측정을 캡슐화합니다.
    특히 훈련 단계에서는 Soft Label Mixing(Knowledge Distillation 원리 적용)을 통해 
    라벨 노이즈에 대한 강건성을 확보하며, 추론 단계에서는 TTA(Test Time Augmentation)를 
    통합하여 예측의 불확실성을 줄이고 일반화 성능을 향상시킵니다.
    """
    def __init__(self, config, steps_per_epoch=None):
        super().__init__()
        if isinstance(config, type):
            config = {k: v for k, v in config.__dict__.items() if not k.startswith('__')}
        self.save_hyperparameters(config)
        self.steps_per_epoch = steps_per_epoch
        self.model = timm.create_model(
            self.hparams.model_arch,
            pretrained=True,
            drop_path_rate=self.hparams.drop_path_rate,
            num_classes=4
            )
        self.criterion = nn.CrossEntropyLoss()
        self.steps_per_epoch = steps_per_epoch
        
        # TTA
        transforms = tta.Compose([
                tta.HorizontalFlip(),
                tta.VerticalFlip(),
            ])
        
        self.tta_model = tta.ClassificationTTAWrapper(self.model, transforms, merge_mode='mean')
        
        # metrics
        self.valid_auc = AUROC(task='multiclass', num_classes=4)
        self.valid_cm = ConfusionMatrix(task='multiclass', num_classes=4)
        self.best_score = 0.0

        self.top_k_scores = []  # (score, epoch) 튜플을 저장할 리스트
        self.top_k = self.hparams.top_k

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
    
        total_steps = self.steps_per_epoch * self.hparams.virtual_epochs
        warmup_steps = self.steps_per_epoch * self.hparams.warmup_multiplier
            
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        scheduler_config = {
            'scheduler' : scheduler,
            'interval' : 'step',
            'frequency' : 1
        }
        
        return [optimizer], [scheduler_config]
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        image, soft_labels, hard_labels = batch
        
        T = self.hparams.T
        if T > 1.0:
            epsilon = 1e-6
            logits_from_oof = torch.log(soft_labels + epsilon)
            soft_labels = torch.softmax(logits_from_oof / T, dim=1)

        label = self.hparams.alpha * hard_labels + (1 - self.hparams.alpha) * soft_labels
        outputs = self.model(image)
        logits = outputs / T if T > 1 else outputs
        loss = self.criterion(logits, label)
        
        if T > 1.0:
            loss = loss * (T ** 2)

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        image, _, hard_labels = batch
        outputs = self.model(image)
        loss = self.criterion(outputs, hard_labels)
        
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(outputs, dim=1)
        targets = torch.argmax(hard_labels, dim=1)

        self.valid_cm(preds, targets)
        self.valid_auc(probs, targets)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_roc_auc', self.valid_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:
            return
                  
        score = self.trainer.callback_metrics.get('val_roc_auc')
        train_loss = self.trainer.callback_metrics.get('train_loss')
        val_loss = self.trainer.callback_metrics.get('val_loss')

        current_epoch = self.current_epoch
        t_loss_str = f"{train_loss:.4f}" if train_loss is not None else "N/A"
        v_loss_str = f"{val_loss:.4f}" if val_loss is not None else "N/A"
        roc_str = f"{score:.4f}" if score is not None else "N/A"
        self.print(f"\n(Epoch {current_epoch}) Train Loss: {t_loss_str} | Val Loss: {v_loss_str} | ROC AUC: {roc_str}")
        
        if score is not None:
            current_score = score.item()            
            self.top_k_scores.append((current_score, current_epoch))
            self.top_k_scores.sort(key=lambda x: x[0], reverse=True)
            self.top_k_scores = self.top_k_scores[:self.top_k]
            is_in_top_k = (current_score, current_epoch) in self.top_k_scores
            
            if is_in_top_k and isinstance(self.logger, WandbLogger):
                rank = self.top_k_scores.index((current_score, current_epoch)) + 1
                self.print(f'New Top-K Score! (Rank {rank})')
                top_k_str = ", ".join([f"(Ep {e}: {s:.4f})" for s, e in self.top_k_scores])
                self.print(f"Current Top-{self.top_k}: {top_k_str}")
                
                plt.figure(figsize=(10, 8))
                cm = self.valid_cm.compute().cpu().numpy()
                columns = ['Healthy', 'Multiple', 'Rust', 'Scab']

                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                            xticklabels=columns, yticklabels=columns,
                            annot_kws={"size": 12}) # 글자 크기 키움
                            
                plt.ylabel('True Label', fontsize=12)
                plt.xlabel('Predicted Label', fontsize=12)
                plt.title(f'Confusion Matrix (Epoch {current_epoch})', fontsize=14)

                log_key = f"Confusion_Matrix_Ep{current_epoch}"                
                self.logger.experiment.log({
                    log_key: wandb.Image(plt),
                    "global_step": self.global_step
                })
                plt.close()
                self.print(f"Confusion Matrix saved to WandB key: {log_key}")
        self.valid_cm.reset()
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        if isinstance(batch, (list, tuple)):
            x = batch[0]
        else:
            x = batch
        
        # outputs = self.tta_model(x)
        outputs = self.model(x)
        preds = torch.softmax(outputs, dim=1)
        return preds

In [None]:
torch.cuda.empty_cache()
gc.collect()

# 4. Training Components

## 4.1 Metrics

In [12]:
class MetricHandler:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds_list = []
        self.actual_list = []

    def update(self, preds, actual):
        self.preds_list.extend(preds)
        self.actual_list.extend(actual)

    def compute_roc_auc(self):
        return roc_auc_score(self.actual_list, self.preds_list)
    
    
class BackupHandler:
    def __init__(self, local_dir, backup_dir=None, active=True):
        self.local_dir = local_dir
        self.backup_dir = backup_dir
        self.active = active and (backup_dir is not None)

        if self.active and self.backup_dir is not None:
            os.makedirs(self.backup_dir, exist_ok=True)
            print(f'Backup Active : {self.local_dir} -> {self.backup_dir}')

    def backup(self, filename):
        if not self.active or self.backup_dir is None:
            return

        src_path = os.path.join(self.local_dir, filename)
        dst_path = os.path.join(self.backup_dir, filename)
        
        if os.path.exists(src_path):
            shutil.copy(src_path, dst_path)

    def save_csv(self, df, filename):
        local_path = os.path.join(self.local_dir, filename)
        df.to_csv(local_path, index=False)
        print(f'CSV saved at {local_path}')

        if self.active and self.backup_dir is not None:
            backup_path = os.path.join(self.backup_dir, filename)
            df.to_csv(backup_path, index=False)
            print(f'CSV saved at {backup_path}')

## 4.2 Experiment Ochestrator

In [13]:
class ExperimentRunner:
    """
    K-Fold 교차 검증 및 전체 실험 프로세스를 지휘하는 오케스트레이터 클래스입니다.
    
    환경 설정(Kaggle, Colab, Local)에 따른 경로 자동화부터 WandB 로깅, 체크포인트 저장, 
    K-Fold 학습 루프 제어 및 최종 추론(OOF 및 Test)까지의 전체 워크플로우를 담당합니다.
    실험이 종료될 때마다 명시적인 메모리 정리(GC, CUDA Cache)를 수행하여 
    리소스 사용을 최적화하고 연속적인 실험 안정성을 보장합니다.
    """
    def __init__(self, config, train_df, test_df):
        super().__init__()
        self.config = config
        self.train_df = train_df
        self.test_df = test_df
        self.paths = self._setup_env()
        self.backup_handler = BackupHandler(local_dir=self.paths.local_path , backup_dir=self.paths.drive_path, active=False)
        
    def _setup_env(self):
        is_kaggle = os.path.exists('/kaggle/') 
        is_colab = os.path.exists('/content/drive/Mydrive') and not is_kaggle

        if is_kaggle:
            print("Environment: Kaggle")
            drive_path = None
            local_path = '/kaggle/working/'
        elif is_colab:
            print("Environment: Google Colab")
            drive_path = f'/content/drive/MyDrive/Kaggle_Save/{CFG.exp_name}/'
            local_path = '/content/models/'
        else:
            print("Environment: Local")
            drive_path = None
            local_path = f'../data/models/{CFG.exp_name}/'
        
        print(f"Save Path: {local_path}")
        return SimpleNamespace(local_path=local_path, drive_path=drive_path)    
    
    def run(self):
        for fold in range(self.config.n_folds):
            print('='*30, f'FOLD {fold+1}', '='*30)
            
            wandb_logger = WandbLogger(
                project=self.config.project_name,
                group=self.config.exp_name,
                name=f"Fold_{fold+1}",
                job_type="train",
                config={k: v for k, v in self.config.__dict__.items() if not k.startswith('__')}
            )
            
            train_len = len(self.train_df[self.train_df['fold'] != fold])
            steps_per_epoch = math.ceil(train_len / self.config.batch_size / self.config.accum_iter)

            datamodule = PlantDataModule(train_df=self.train_df, test_df=self.test_df, cfg=self.config, fold_idx=fold)
            model = PlantDiseaseModule(self.config, steps_per_epoch=steps_per_epoch)
            
            ckpt_callback = pl.callbacks.ModelCheckpoint(
                monitor='val_roc_auc',
                mode='max',
                save_top_k=self.config.top_k,
                save_last=False,
                dirpath=self.paths.local_path,
                filename=f'Fold{fold+1}-Ep{{epoch:02d}}-{{val_roc_auc:.4f}}',
                auto_insert_metric_name=False,
            )
            progress_bar = TQDMProgressBar(refresh_rate=1)
            
            trainer = pl.Trainer(
                max_epochs=self.config.epochs,
                accelerator='auto',
                precision='16-mixed',
                accumulate_grad_batches=self.config.accum_iter,
                callbacks=[ckpt_callback, progress_bar],
                logger=wandb_logger,
                log_every_n_steps=10
            )

            trainer.fit(model, datamodule=datamodule)
            
            print(f'\n Top-{ckpt_callback.save_top_k} Models in this Fold:')
            for path, score in ckpt_callback.best_k_models.items():
                model_name = os.path.basename(path)
                print(f'> {model_name}')
                
            wandb.finish()
            
            # 메모리 정리
            del datamodule, trainer, model
            torch.cuda.empty_cache()
            gc.collect()

    def _load_averaged_model(self, fold):
        save_path = os.path.join(self.paths.local_path, f'best_score_model_{fold+1}.pth')
        model = PlantDiseaseModule(self.config)

        if os.path.exists(save_path):
            print(f'Found existing averaged model for Fold {fold+1}. Loading directly...')
            state_dict = torch.load(save_path, map_location=device)
            model.load_state_dict(state_dict)
        else:
            print(f'Merging Top-K Models for Fold {fold+1} ...')
            score_pattern = os.path.join(self.paths.local_path, f'Fold{fold+1}-Ep*.ckpt')
            score_files = glob.glob(score_pattern)
            print(f'Found {len(score_files)} score models : {[os.path.basename(f) for f in score_files]}')
            
            first_state = torch.load(score_files[0], map_location='cpu')['state_dict']
            avg_state_dict = {}
            for k, v in first_state.items():
                if v.is_floating_point():
                    avg_state_dict[k] = v.float() # Float32로 변환하여 초기화
                else:
                    avg_state_dict[k] = v 
            
            if len(score_files) > 1:
                for path in score_files[1:]:
                    state_dict = torch.load(path, map_location='cpu')['state_dict']
                    for key in avg_state_dict:
                        avg_state_dict[key] += state_dict[key].float()
                for key in avg_state_dict:
                    if avg_state_dict[key].is_floating_point():
                        avg_state_dict[key] = avg_state_dict[key] / len(score_files)
            
            model.load_state_dict(avg_state_dict)
            
            # for remove_path in score_files:
            #     if os.path.exists(remove_path):
            #         os.remove(remove_path)
            
            save_path = os.path.join(self.paths.local_path, f'best_score_model_{fold+1}.pth')
            torch.save(model.state_dict(), save_path)
            print('Save Avg Model : ', save_path)
        
        if self.config.is_bn:
            print('Update BN stats ... ')
            model = model.to(self.config.device)
            model.train()

            train_subset = self.train_df[self.train_df['fold'] != fold].reset_index(drop=True)
            transform_test = A.Compose([
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ])
            dataset_bn = ImageDataset(train_subset, transform=transform_test)
            loader_bn = DataLoader(
                dataset_bn, 
                batch_size=self.config.batch_size, 
                shuffle=True,
                num_workers=self.config.num_workers,
                pin_memory=True
            )
            
            update_bn(loader_bn, model, device=self.config.device)
            model.eval()
            torch.save(model.state_dict(), save_path)
        else:
            print('Skipping BN update for LayerNorm')
            model = model.to(self.config.device)
        return model
        
    def run_inference(self):
        oof_preds = np.zeros((len(self.train_df), 4))
        final_preds = np.zeros((len(self.test_df), 4))

        for fold in range(self.config.n_folds):
            print(f'=== Inference Fold {fold+1} ===')
            avg_model = self._load_averaged_model(fold)
            
            infer_module = PlantDataModule(self.train_df, self.test_df, self.config, fold_idx=fold, inference_mode=True)
            infer_module.setup(stage='test')
            
            progress_bar = TQDMProgressBar(refresh_rate=1)
            infer_trainer = pl.Trainer(
                accelerator='auto',
                logger=False,
                enable_checkpointing=False,
                callbacks=[progress_bar]
            )
            
            # OOF Inference
            print(f'OOF Inference ... ')
            oof_list = infer_trainer.predict(avg_model, dataloaders=infer_module.val_dataloader())
            oof_temp = torch.cat(oof_list).cpu().numpy()
            valid_indices = infer_module.valid.index.values
            oof_preds[valid_indices] = oof_temp

            # Test Inference
            print(f'Test Inference ... ')
            sub_list = infer_trainer.predict(avg_model, dataloaders=infer_module.predict_dataloader())
            sub_temp = torch.cat(sub_list).cpu().numpy()
            final_preds += (sub_temp / self.config.n_folds)

            del infer_trainer, infer_module, avg_model
            torch.cuda.empty_cache()
            gc.collect()

        metric_handler = MetricHandler()
        metric_handler.update(oof_preds, self.train_df[hard_cols].values)
        oof_roc = metric_handler.compute_roc_auc()
        print(f'OOF ROC AUC : {oof_roc:.4f}')

        return oof_preds, final_preds

In [None]:
torch.cuda.empty_cache()
gc.collect()

# 5. Training Execution

In [None]:
runner = ExperimentRunner(config=CFG, train_df=train_df, test_df=test_df)

In [None]:
%%time
runner.run()

# 6. Inference & Save

In [None]:
%%time
oof_preds, final_preds = runner.run_inference()

In [None]:
result_oof = train_df[['image_id']].copy()
result_oof[hard_cols] = oof_preds
runner.backup_handler.save_csv(result_oof, f'oof_preds_{CFG.exp_name}.csv')

result_sub = submission[['image_id']].copy()
result_sub[hard_cols] = final_preds
runner.backup_handler.save_csv(result_sub, f'submission_{CFG.exp_name}.csv')

In [None]:
display(result_oof.head())
display(result_sub.head())