# 1. Setup & Imports

## 1.1 Setup

In [1]:
import os
import wandb
from dotenv import load_dotenv

load_dotenv()
wandb_api_key = os.getenv("WANDB_API_KEY")

wandb.login(key=wandb_api_key)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /teamspace/studios/this_studio/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mparaise[0m ([33mparaise-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## 1.2 Imports

In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='pydantic')

import gc
import math
import shutil
import glob
import random
from types import SimpleNamespace
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import cv2
from sklearn.metrics import roc_auc_score

import torch
from torch.optim.swa_utils import update_bn
from torch.utils.data import Dataset, DataLoader

import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import TQDMProgressBar, Callback

from torchmetrics import ConfusionMatrix, AUROC
import timm
import ttach as tta
import albumentations as A
from albumentations.pytorch import ToTensorV2

import seaborn as sns
import matplotlib.pyplot as plt

cv2.setNumThreads(0)
print(os.cpu_count())
torch.set_float32_matmul_precision('high') # L4

4


In [3]:
timm.__version__

'1.0.22'

## 1.3 Configuration

In [22]:
class CFG:
    model_arch = 'resnest101e'
    is_bn = True
    seed = 938
    lr = 0.0001
    weight_decay = 0.05
    alpha = 0.7
    weak_alpha = 0.3
    strong_alpha = 0.7
    T = 1.25
    drop_path_rate = 0.2
    top_k = 3
    n_folds = 5
    epochs = 24
    batch_size = 32
    accum_iter = 1
    num_workers = 4
    persistent_workers=True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    project_name = 'PlantPathology2020'
    exp_name = 's14_resnest'

In [None]:
def set_seed(seed, deterministic=False):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) # cpu
    torch.cuda.manual_seed(seed) # gpu
    torch.cuda.manual_seed_all(seed) # 멀티 gpu
    if deterministic:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.use_deterministic_algorithms(False)

set_seed(CFG.seed)
device = CFG.device
print(device)

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(CFG.seed)

In [None]:
import ipynbname
current_file = ipynbname.name() + ".ipynb"
print(current_file)

# 2. Data Pipeline

## 2.1 Data Loading

### Data & Knowledge Source
- Teacher Logits:

    - Teacher 1: ConvNeXt-Small

    - Teacher 2: ConvNeXt-Small variant.

In [None]:
class DataModule:
    def __init__(self, data_dir, logit_path1, logit_path2):
        self.data_dir = data_dir
        self.img_dir = self.data_dir + 'images/'
        self.logit_path1 = logit_path1
        self.logit_path2 = logit_path2
        
    def prepare_data(self):
        train_df = pd.read_csv(self.data_dir + 'datasets/train_reborn_02.csv')
        train_df = train_df.reset_index(drop=True)
        test_df = pd.read_csv(self.data_dir + 'test.csv')
        submission = pd.read_csv(self.data_dir + 'sample_submission.csv')
        teacher_logit1 = np.load(self.logit_path1)
        teacher_logit2 = np.load(self.logit_path2)
        teacher_logit = teacher_logit1* 0.5 + teacher_logit2 * 0.5
        print(f"Logit1 Range: {teacher_logit1.min():.2f} ~ {teacher_logit1.max():.2f}")
        print(f"Logit2 Range: {teacher_logit2.min():.2f} ~ {teacher_logit2.max():.2f}")
        return train_df, teacher_logit, test_df, submission

    
train_df, teacher_logit, test_df, submission = DataModule(data_dir='../data/',
 logit_path1='../data/models/s10_convnext_small_T_scheduler/oof_ogit_s10_convnext_small_T_scheduler.npy',
 logit_path2='../data/models/s12_conv_teacher_s10_v2/oof_logit_s12_conv_teacher_s10_v2.npy').prepare_data()

hard_cols = ['healthy', 'multiple_diseases', 'rust', 'scab']

train_df.head()

## 2.2 Load Images

In [None]:
all_images = {}
all_img_ids = np.concatenate([train_df['image_id'].tolist(), test_df['image_id'].tolist()])

print("Loading all images into RAM once...")
for img_id in tqdm(all_img_ids, desc='Loading Images...' ,leave=False):
    img = cv2.imread('../data/images/' + img_id + '.jpg')
    img = cv2.resize(img, (650, 450))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img.setflags(write=False)
    all_images[img_id] = img

print("All Images on Ram")

## 2.3 Custom Dataset

In [9]:
class ImageDataset(Dataset):
    def __init__(self, df, hard_cols=hard_cols, teacher_logit=None, transform=None, is_test=False):
        super().__init__()
        self.df = df
        self.transform = transform
        self.is_test = is_test
        self.hard_cols = hard_cols
        self.teacher_logit = teacher_logit

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_id = self.df.iloc[idx, 0]
        image = all_images[img_id].copy()

        if self.transform is not None:
            image = self.transform(image=image)['image']

        if self.is_test:
            return image
        else:
            hard_labels = self.df.iloc[idx][self.hard_cols].values.astype(np.float32)
            oof_logit = self.teacher_logit[idx]
            return image, torch.tensor(oof_logit), torch.tensor(hard_labels)

## 2.4 DataModule

In [10]:
class PlantDataModule(pl.LightningDataModule):
    """
    데이터 로딩, 전처리 및 학습/검증 세트 분할을 관리하는 클래스입니다.
    
    K-Fold 인덱스에 따라 데이터를 학습용과 검증용으로 분리하며, 
    stage 인자에 따라 불필요한 데이터 로딩을 방지하여 메모리 효율성을 최적화합니다. 
    이미지 증강(Augmentation) 로직을 내부적으로 포함하여 데이터와 모델 사이의 인터페이스를 명확히 정의합니다.
    """
    def __init__(self, train_df, teacher_logit, test_df, cfg, fold_idx, inference_mode=False):
        super().__init__()
        self.train_df = train_df
        self.teacher_logit = teacher_logit
        self.test_df = test_df
        self.cfg = cfg
        self.fold_idx = fold_idx
        self.inference_mode = inference_mode

        self.transform_train = A.Compose([
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0, p=1.0),
                A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2, p=1.0)
            ], p=1.0),

            A.OneOf([
                A.MotionBlur(blur_limit=5, p=1.0),
                A.MedianBlur(blur_limit=5, p=1.0),
                A.GaussianBlur(blur_limit=(3, 5), p=1.0),
            ], p=0.5),
            
            A.OneOf([
                A.Affine(
                    scale=(0.8, 1.2),       
                    translate_percent=0.2,  
                    rotate=45,              # 회전 각도 확대
                    shear=20,               # 전단 변환(기울기) 추가
                    interpolation=cv2.INTER_CUBIC,
                    border_mode=cv2.BORDER_REFLECT_101, 
                    p=1.0
                ),
                A.Perspective(scale=(0.05, 0.1), p=1.0), # 원근 변환
            ], p=0.8),
            
            # 질감/노이즈
            A.OneOf([
                A.ISONoise(p=1.0),
                A.GaussNoise(p=1.0),
            ], p=0.3),

            A.CoarseDropout(
                num_holes_range=(8, 16),
                hole_height_range=(8, 16),
                hole_width_range=(8, 16),
                fill=[103, 131, 82],
                p=0.5
            ),

            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),

            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

        self.transform_test = A.Compose([
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            train_mask = (self.train_df['fold']!=self.fold_idx).values
            val_mask = (self.train_df['fold']==self.fold_idx).values

            self.train = self.train_df[train_mask].reset_index(drop=True)
            self.valid = self.train_df[val_mask].reset_index(drop=True)
            train_logit = self.teacher_logit[train_mask]
            val_logit = self.teacher_logit[val_mask]

            self.dataset_train = ImageDataset(self.train, teacher_logit=train_logit, transform=self.transform_train)
            self.dataset_valid = ImageDataset(self.valid, teacher_logit=val_logit, transform=self.transform_test)
            print(f'[Fit] Train: {len(self.train)}, Valid: {len(self.valid)}')

        elif stage == 'test':
            val_mask = (self.train_df['fold'] == self.fold_idx).values
            self.valid = self.train_df[val_mask].reset_index(drop=True)
            val_logit = self.teacher_logit[val_mask]
            
            self.dataset_valid = ImageDataset(self.valid, teacher_logit=val_logit, transform=self.transform_test)
            self.dataset_test = ImageDataset(self.test_df, transform=self.transform_test, is_test=True)
            print(f'[Test] Valid(OOF): {len(self.valid)}, Test: {len(self.test_df)}')

        elif stage == 'predict':
            self.dataset_test = ImageDataset(self.test_df, transform=self.transform_test, is_test=True)
    
    def train_dataloader(self):
        loader_train = DataLoader(self.dataset_train, batch_size=self.cfg.batch_size, shuffle=True,
                                worker_init_fn=seed_worker, generator=g, num_workers=self.cfg.num_workers, 
                                persistent_workers=True, pin_memory=True)
        return loader_train
    
    def val_dataloader(self):
        user_persistent = not self.inference_mode
        loader_valid = DataLoader(self.dataset_valid, batch_size=self.cfg.batch_size*4, shuffle=False,
                                worker_init_fn=seed_worker, generator=g, num_workers=self.cfg.num_workers, 
                                persistent_workers=user_persistent, pin_memory=True)
        return loader_valid
    
    def predict_dataloader(self):
        loader_test = DataLoader(self.dataset_test, batch_size=self.cfg.batch_size*4, shuffle=False,
                                worker_init_fn=seed_worker, generator=g, num_workers=self.cfg.num_workers, 
                                persistent_workers=False, pin_memory=True)
        return loader_test
    
    def test_dataloader(self):
        return self.predict_dataloader()

# 3. Model Architecture

In [12]:
class FoldAlphaCallback(Callback):
    """
    폴드별로 Knowledge Distillation의 Alpha 값을 다르게 조정하는 콜백
    - 성능이 낮은 폴드: GT 비중(Alpha)을 높여 Teacher 의존도 낮춤
    - 성능이 높은 폴드: Teacher 비중(1-Alpha)을 높여 정규화 효과 극대화
    """
    def __init__(self, current_fold, weak_folds=[0, 1], weak_alpha=0.8, strong_alpha=0.5):
        super().__init__()
        self.current_fold = current_fold
        self.weak_folds = weak_folds
        self.weak_alpha = weak_alpha
        self.strong_alpha = strong_alpha

    def on_train_start(self, trainer, pl_module):
        # 학습 시작 시점에 폴드에 맞는 Alpha 설정
        if self.current_fold in self.weak_folds:
            pl_module.current_alpha = self.weak_alpha
            strategy = "GT Focus (Weak Fold)"
        else:
            pl_module.current_alpha = self.strong_alpha
            strategy = "Regularization (Strong Fold)"
            
        # 로그 출력 (한 번만)
        if trainer.global_rank == 0:
            print(f"\n[FoldAlphaCallback] Fold {self.current_fold+1}: "
                  f"Alpha set to {pl_module.current_alpha} ({strategy})")

In [14]:
class PlantDiseaseModule(pl.LightningModule):
    """
    모델의 순전파, 손실 함수 계산, 최적화 알고리즘 및 메트릭 측정을 캡슐화합니다.
    특히 훈련 단계에서는 Soft Label Mixing(Knowledge Distillation 원리 적용)을 통해 
    라벨 노이즈에 대한 강건성을 확보하며, 추론 단계에서는 TTA(Test Time Augmentation)를 
    통합하여 예측의 불확실성을 줄이고 일반화 성능을 향상시킵니다.
    """
    def __init__(self, config, steps_per_epoch=None):
        super().__init__()
        if isinstance(config, type):
            config = {k: v for k, v in config.__dict__.items() if not k.startswith('__')}
        self.save_hyperparameters(config)
        self.current_alpha = self.hparams.alpha
        self.model = timm.create_model(
            self.hparams.model_arch,
            pretrained=True,
            drop_path_rate=self.hparams.drop_path_rate,
            num_classes=4
            )
        self.criterion = torch.nn.CrossEntropyLoss()
        
        # TTA
        self.tta_transforms = tta.Compose([
                tta.HorizontalFlip(),
                tta.VerticalFlip(),
            ])
        
        # metrics
        self.valid_auc = AUROC(task='multiclass', num_classes=4)
        self.valid_cm = ConfusionMatrix(task='multiclass', num_classes=4)
        self.best_score = 0.0

        self.top_k_scores = []  # (score, epoch) 튜플을 저장할 리스트
        self.top_k = self.hparams.top_k

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=8, T_mult=2, eta_min=1e-6)
        
        scheduler_config = {
            'scheduler' : scheduler,
            'interval' : 'epoch',
            'frequency' : 1
        }
        
        return [optimizer], [scheduler_config]
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        T = self.hparams.T
        alpha = self.current_alpha

        image, logit_from_oof, hard_labels = batch
        outputs = self.model(image)

        # 하드라벨 손실
        loss_hard = self.criterion(outputs, hard_labels)

        # Teacher
        teacher_probs = torch.softmax(logit_from_oof / T, dim=1)
        # Student
        student_log_probs = torch.nn.functional.log_softmax(outputs / T, dim=1)

        kl_loss = torch.nn.functional.kl_div(
            student_log_probs, 
            teacher_probs, 
            reduction='batchmean'
        )
        loss_soft = kl_loss * (T**2)
        # total loss
        loss = alpha * loss_hard + (1 - alpha) * loss_soft
        # log
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('hard_loss', loss_hard, on_step=False, on_epoch=True, logger=True)
        self.log('soft_loss', loss_soft, on_step=False, on_epoch=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        image, _, hard_labels = batch
        outputs = self.model(image)
        loss = self.criterion(outputs, hard_labels)
        
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(outputs, dim=1)
        targets = torch.argmax(hard_labels, dim=1)

        self.valid_cm(preds, targets)
        self.valid_auc(probs, targets)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_roc_auc', self.valid_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:
            return
                  
        score = self.trainer.callback_metrics.get('val_roc_auc')
        current_epoch = self.current_epoch
        
        if score is not None:
            current_score = score.item()            
            self.top_k_scores.append((current_score, current_epoch))
            self.top_k_scores.sort(key=lambda x: x[0], reverse=True)
            self.top_k_scores = self.top_k_scores[:self.top_k]
            is_in_top_k = (current_score, current_epoch) in self.top_k_scores
            
            if is_in_top_k and isinstance(self.logger, WandbLogger):
                rank = self.top_k_scores.index((current_score, current_epoch)) + 1
                top_k_str = ", ".join([f"(Ep {e}: {s:.4f})" for s, e in self.top_k_scores])
                self.print(f"Current Top-{self.top_k}: {top_k_str}")
                
                cm = self.valid_cm.compute().cpu().numpy()
                columns = ['Healthy', 'Multiple', 'Rust', 'Scab']
                plt.figure(figsize=(10, 8))
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                            xticklabels=columns, yticklabels=columns,
                            annot_kws={"size": 14})
                            
                plt.ylabel('True Label', fontsize=12)
                plt.xlabel('Predicted Label', fontsize=12)
                plt.title(f'Confusion Matrix (Epoch {current_epoch})', fontsize=14)
             
                self.logger.experiment.log({
                    "val/confusion_matrix": wandb.Image(plt, caption=f"Epoch {current_epoch}"),
                    "global_step": self.global_step
                })
                plt.close()

        self.valid_cm.reset()

    def on_train_epoch_end(self):
        score = self.trainer.callback_metrics.get('val_roc_auc')
        train_loss = self.trainer.callback_metrics.get('train_loss_epoch')
        val_loss = self.trainer.callback_metrics.get('val_loss')

        current_epoch = self.current_epoch
        t_loss_str = f"{train_loss:.4f}" if train_loss is not None else "N/A"
        v_loss_str = f"{val_loss:.4f}" if val_loss is not None else "N/A"
        roc_str = f"{score:.4f}" if score is not None else "N/A"
        self.print(f"\n(Epoch {current_epoch}) Train Loss: {t_loss_str} | Val Loss: {v_loss_str} | ROC AUC: {roc_str}")        

    def on_predict_start(self):
        self.tta_model = tta.ClassificationTTAWrapper(
            self.model, 
            self.tta_transforms, 
            merge_mode='mean'
        )

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        if isinstance(batch, (list, tuple)):
            x = batch[0]
        else:
            x = batch
        
        outputs = self.tta_model(x)
        return outputs

In [None]:
torch.cuda.empty_cache()
gc.collect()

# 4. Training Components

## 4.1 Metrics & Utils

In [16]:
class MetricHandler:
    def __init__(self):
        self.reset()

    def reset(self):
        self.preds_list = []
        self.actual_list = []

    def update(self, preds, actual):
        self.preds_list.extend(preds)
        self.actual_list.extend(actual)

    def compute_roc_auc(self):
        return roc_auc_score(self.actual_list, self.preds_list)
    
    
class BackupHandler:
    def __init__(self, local_dir, backup_dir=None, active=True):
        self.local_dir = local_dir
        self.backup_dir = backup_dir
        self.active = active and (backup_dir is not None)

        if self.active and self.backup_dir is not None:
            os.makedirs(self.backup_dir, exist_ok=True)
            print(f'Backup Active : {self.local_dir} -> {self.backup_dir}')

    def backup(self, filename):
        if not self.active or self.backup_dir is None:
            return

        src_path = os.path.join(self.local_dir, filename)
        dst_path = os.path.join(self.backup_dir, filename)
        
        if os.path.exists(src_path):
            shutil.copy(src_path, dst_path)

    def save_file(self, data, filename, logit=False):
        local_path = os.path.join(self.local_dir, filename)

        if logit:
            np.save(local_path, data)
            print(f'Logit saved at {local_path}')

        else:
            data.to_csv(local_path, index=False)
            print(f'CSV saved at {local_path}')

        self.backup(filename)

## 4.2 Experiment Ochestrator

In [17]:
class ExperimentRunner:
    """
    K-Fold 교차 검증 및 전체 실험 프로세스를 지휘하는 오케스트레이터 클래스입니다.
    
    환경 설정(Kaggle, Colab, Local)에 따른 경로 자동화부터 WandB 로깅, 체크포인트 저장, 
    K-Fold 학습 루프 제어 및 최종 추론(OOF 및 Test)까지의 전체 워크플로우를 담당합니다.
    실험이 종료될 때마다 명시적인 메모리 정리(GC, CUDA Cache)를 수행하여 
    리소스 사용을 최적화하고 연속적인 실험 안정성을 보장합니다.
    """
    def __init__(self, config, train_df, teacher_logit, test_df):
        super().__init__()
        self.config = config
        self.train_df = train_df
        self.teacher_logit = teacher_logit
        self.test_df = test_df
        self.paths = self._setup_env()
        self.backup_handler = BackupHandler(local_dir=self.paths.local_path , backup_dir=self.paths.drive_path, active=False)
        
    def _setup_env(self):
        is_kaggle = os.path.exists('/kaggle/') 
        is_colab = os.path.exists('/content/drive/Mydrive') and not is_kaggle

        if is_kaggle:
            print("Environment: Kaggle")
            drive_path = None
            local_path = '/kaggle/working/'
        elif is_colab:
            print("Environment: Google Colab")
            drive_path = f'/content/drive/MyDrive/Kaggle_Save/{CFG.exp_name}/'
            local_path = '/content/models/'
        else:
            print("Environment: Local")
            drive_path = None
            local_path = f'../data/models/{CFG.exp_name}/'
        
        print(f"Save Path: {local_path}")
        return SimpleNamespace(local_path=local_path, drive_path=drive_path)    
    
    def run(self):
        for fold in range(self.config.n_folds):
            print('='*30, f'FOLD {fold+1}', '='*30)
            
            wandb_logger = WandbLogger(
                project=self.config.project_name,
                group=self.config.exp_name,
                settings=wandb.Settings(program=current_file),
                name=f"Fold_{fold+1}",
                job_type="train",
                save_code=True,
                config={k: v for k, v in self.config.__dict__.items() if not k.startswith('__')}
            )
            
            train_len = len(self.train_df[self.train_df['fold'] != fold])
            steps_per_epoch = math.ceil(train_len / self.config.batch_size / self.config.accum_iter)

            datamodule = PlantDataModule(train_df=self.train_df, teacher_logit=self.teacher_logit, test_df=self.test_df, cfg=self.config, fold_idx=fold)
            model = PlantDiseaseModule(self.config, steps_per_epoch=steps_per_epoch)
            
            ckpt_callback = pl.callbacks.ModelCheckpoint(
                monitor='val_roc_auc',
                mode='max',
                save_top_k=self.config.top_k,
                save_weights_only=True,
                save_last=False,
                dirpath=self.paths.local_path,
                filename=f'Fold{fold+1}-Ep{{epoch:02d}}-{{val_roc_auc:.4f}}',
                auto_insert_metric_name=False,
            )
            progress_bar = TQDMProgressBar(refresh_rate=1)
            alpha_callback = FoldAlphaCallback(
                current_fold=fold,
                weak_folds=[0, 1],
                weak_alpha=self.config.weak_alpha, 
                strong_alpha=self.config.strong_alpha
            )
            
            trainer = pl.Trainer(
                max_epochs=self.config.epochs,
                accelerator='auto',
                precision='16-mixed',
                accumulate_grad_batches=self.config.accum_iter,
                callbacks=[ckpt_callback, progress_bar, alpha_callback],
                logger=wandb_logger,
                log_every_n_steps=10
            )

            trainer.fit(model, datamodule=datamodule)
            
            print(f'\n Top-{ckpt_callback.save_top_k} Models in this Fold:')
            for path, score in ckpt_callback.best_k_models.items():
                model_name = os.path.basename(path)
                print(f'> {model_name}')
                
            wandb.finish()
            
            # 메모리 정리
            del datamodule, trainer, model
            torch.cuda.empty_cache()
            gc.collect()

    def _load_averaged_model(self, fold):
        save_path = os.path.join(self.paths.local_path, f'best_score_model_{fold+1}.pth')
        model = PlantDiseaseModule(self.config)

        if os.path.exists(save_path):
            print(f'Found existing averaged model for Fold {fold+1}. Loading directly...')
            state_dict = torch.load(save_path, map_location=device)
            model.load_state_dict(state_dict)
        else:
            print(f'Merging Top-K Models for Fold {fold+1} ...')
            score_pattern = os.path.join(self.paths.local_path, f'Fold{fold+1}-Ep*.ckpt')
            score_files = glob.glob(score_pattern)
            print(f'Found {len(score_files)} score models : {[os.path.basename(f) for f in score_files]}')
            
            first_state = torch.load(score_files[0], map_location='cpu')['state_dict']
            avg_state_dict = {}
            for k, v in first_state.items():
                if v.is_floating_point():
                    avg_state_dict[k] = v.float() # Float32로 변환하여 초기화
                else:
                    avg_state_dict[k] = v 
            
            if len(score_files) > 1:
                for path in score_files[1:]:
                    state_dict = torch.load(path, map_location='cpu')['state_dict']
                    for key in avg_state_dict:
                        if avg_state_dict[key].is_floating_point():
                            avg_state_dict[key] += state_dict[key].float()
                        else:
                            pass
                for key in avg_state_dict:
                    if avg_state_dict[key].is_floating_point():
                        avg_state_dict[key] = avg_state_dict[key] / len(score_files)
            
            model.load_state_dict(avg_state_dict)
            
            for remove_path in score_files:
                if os.path.exists(remove_path):
                    os.remove(remove_path)
            
            save_path = os.path.join(self.paths.local_path, f'best_score_model_{fold+1}.pth')
            torch.save(model.state_dict(), save_path)
            print('Save Avg Model : ', save_path)
        
        if self.config.is_bn:
            print('Update BN stats ... ')
            model = model.to(self.config.device)
            model.train()

            train_subset = self.train_df[self.train_df['fold'] != fold].reset_index(drop=True)
            transform_test = A.Compose([
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ])
            dataset_bn = ImageDataset(train_subset, transform=transform_test, is_test=True)
            loader_bn = DataLoader(
                dataset_bn, 
                batch_size=self.config.batch_size, 
                shuffle=True,
                num_workers=self.config.num_workers,
                pin_memory=True,
                drop_last=True
            )
            
            update_bn(loader_bn, model, device=self.config.device)
            model.eval()
            torch.save(model.state_dict(), save_path)
        else:
            print('Skipping BN update for LayerNorm')
            model = model.to(self.config.device)
        return model

    def find_optimal_temperature(self, logits, labels):
        """
        OOF Logits와 정답을 이용해 NLL을 최소화하는 T 값 탐색
        """
        # 정답 라벨 처리 (One-hot -> Index)
        if labels.ndim > 1:
            labels = np.argmax(labels, axis=1)
        
        logits_tensor = torch.tensor(logits, dtype=torch.float32)
        labels_tensor = torch.tensor(labels, dtype=torch.long)
        
        # NLL Loss
        t_candidates = np.arange(0.5, 2.6, 0.1)
        best_t = min(t_candidates, key=lambda t: torch.nn.CrossEntropyLoss()(logits_tensor / t, labels_tensor).item())
        print(f"    > Best T: {best_t:.1f}")
        return best_t

    # weight average
    def run_inference(self):
        oof_preds = np.zeros((len(self.train_df), 4))
        oof_preds_og = np.zeros((len(self.train_df), 4))
        oof_preds_logit = np.zeros((len(self.train_df), 4))

        final_preds = np.zeros((len(self.test_df), 4))
        final_preds_og = np.zeros((len(self.test_df), 4))

        for fold in range(self.config.n_folds):
            print(f'=== Inference Fold {fold+1} ===')
            # 가중치 평균 모델
            avg_model = self._load_averaged_model(fold)

            # 추론용 데이터 모듈 설정
            infer_module = PlantDataModule(train_df=self.train_df, teacher_logit=self.teacher_logit, test_df=self.test_df, cfg=self.config, fold_idx=fold, inference_mode=True)
            infer_module.setup(stage='test')
            
            # trainer 생성
            progress_bar = TQDMProgressBar(refresh_rate=1)
            infer_trainer = pl.Trainer(
                accelerator='auto',
                precision='16-mixed',
                logger=False,
                enable_checkpointing=False,
                callbacks=[progress_bar]
            )

            # 검증셋 인덱스 및 정답 라벨
            valid_indices = self.train_df[self.train_df['fold'] == fold].index.values
            valid_labels = self.train_df.iloc[valid_indices][['healthy', 'multiple_diseases', 'rust', 'scab']].values

            # OOF 추론
            oof_list = infer_trainer.predict(avg_model, dataloaders=infer_module.val_dataloader())
            current_oof_logits = torch.cat(oof_list).cpu().numpy()
            print(f"Max: {current_oof_logits.max()}, Min: {current_oof_logits.min()}")

            # 최적 Calibration T 찾기
            optimal_t = self.find_optimal_temperature(current_oof_logits, valid_labels)

            # OOF 보정 및 확률 변환
            calibrated_oof_probs = torch.softmax(torch.tensor(current_oof_logits) / optimal_t, dim=1).numpy()
            og_oof_probs = torch.softmax(torch.tensor(current_oof_logits), dim=1).numpy()
            oof_preds[valid_indices] = calibrated_oof_probs
            oof_preds_og[valid_indices] = og_oof_probs
            oof_preds_logit[valid_indices] = current_oof_logits

            # Test 추론 및 보정
            sub_list = infer_trainer.predict(avg_model, dataloaders=infer_module.predict_dataloader())
            current_test_logits = torch.cat(sub_list).cpu().numpy()
            calibrated_test_probs = torch.softmax(torch.tensor(current_test_logits) / optimal_t, dim=1).numpy()
            test_probs = torch.softmax(torch.tensor(current_test_logits), dim=1).numpy()
            final_preds += calibrated_test_probs
            final_preds_og += test_probs

            # 메모리 정리
            del avg_model, infer_trainer, infer_module
            torch.cuda.empty_cache()
            gc.collect()

        final_preds /= self.config.n_folds
        final_preds_og /= self.config.n_folds

        # 최종 메트릭 계산
        metric_handler = MetricHandler()
        metric_handler.update(oof_preds, self.train_df[hard_cols].values)
        oof_roc = metric_handler.compute_roc_auc()
        print(f'\n>>> Final OOF ROC AUC : {oof_roc:.5f}')

        return oof_preds, oof_preds_og, oof_preds_logit, final_preds, final_preds_og

In [None]:
torch.cuda.empty_cache()
gc.collect()

# 5. Training Execution

In [None]:
runner = ExperimentRunner(config=CFG, train_df=train_df, teacher_logit=teacher_logit, test_df=test_df)

In [None]:
%%time
runner.run()

# 6. Inference & Save

In [None]:
%%time
oof_preds, oof_preds_og, oof_preds_logit, final_preds, final_preds_og = runner.run_inference()

In [None]:
result_oof = train_df[['image_id']].copy()
result_oof[hard_cols] = oof_preds
runner.backup_handler.save_file(result_oof, f'oof_preds_{CFG.exp_name}.csv')
result_oof[hard_cols] = oof_preds_og
runner.backup_handler.save_file(result_oof, f'oof_preds_og_{CFG.exp_name}.csv')
runner.backup_handler.save_file(oof_preds_logit, f'oof_logit_{CFG.exp_name}.npy', logit=True)

result_sub = submission[['image_id']].copy()
result_sub[hard_cols] = final_preds
runner.backup_handler.save_file(result_sub, f'submission_{CFG.exp_name}.csv')
result_sub[hard_cols] = final_preds_og
runner.backup_handler.save_file(result_sub, f'submission_og_{CFG.exp_name}.csv')

In [None]:
display(result_oof.head())
display(result_sub.head())

In [None]:
sub_file1 = os.path.join(runner.backup_handler.local_dir, f'submission_{CFG.exp_name}.csv')
sub_file2 = os.path.join(runner.backup_handler.local_dir, f'submission_og_{CFG.exp_name}.csv')
print(sub_file1)

## 6.1 Submission

In [None]:
os.environ['KAGGLE_CONFIG_DIR'] = "/teamspace/studios/this_studio/"
!kaggle competitions submit -c plant-pathology-2020-fgvc7 -f {sub_file1} -m "{CFG.exp_name}_cali"
!kaggle competitions submit -c plant-pathology-2020-fgvc7 -f {sub_file2} -m "{CFG.exp_name}_og"