In [1]:
import os
import time
import random
import copy

from scipy import ndimage
import optuna, math
import timm
import torch
import cv2
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler  # Mixed Precision용

from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split, StratifiedKFold

from collections import Counter
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import wandb
from datetime import datetime

# 한글 폰트 설정 (시각화용)
plt.rcParams['font.family'] = ['DejaVu Sans']

  warn(


In [2]:
# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

In [3]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data config
data_path = '../data/'

# model config
# model_name = 'tf_efficientnetv2_b3' # 'resnet50' 'efficientnet-b0', ...
# model_name = 'swin_base_patch4_window12_384_in22k'
model_name = 'convnext_base_384_in22ft1k'
# model_name = 'convnextv2_base.fcmae_ft_in22k_in1k_384'
# model_name = 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k' # openclip
# model_name = 'vit_base_patch16_384.augreg_in1k' # augreg
# model_name = 'eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k' # eva-02 멀티모달
# model_name = 'eva02_large_patch14_448.mim_in22k_ft_in1k' #448 테스트용
# model_name = 'vit_base_patch14_reg4_dinov2.lvd142m' # dinov2 reg4

# model_name = 'eva02_large_patch14_448.mim_in22k_ft_in1k' #448 테스트용

# training config
img_size = 384
LR = 2e-4
EPOCHS = 10
BATCH_SIZE = 24
num_workers = 12
EMA = True  # Exponential Moving Average 사용 여부

In [4]:
# 5-Fold 앙상블 모델 준비
ensemble_models = []
for i in range(5):  # fold 개수만큼
    fold_model = timm.create_model(model_name, pretrained=False, num_classes=17).to(device)  # pretrained=False로 변경
    
    # fold별 저장된 파일 로드
    checkpoint = torch.load(f'models/fold_{i+1}_best.pth')  # fold별 파일 경로
    fold_model.load_state_dict(checkpoint)
    fold_model.eval()
    
    ensemble_models.append(fold_model)
    print(f"✓ Fold {i+1} model loaded from models/fold_{i+1}_best.pth")

print(f"Using ensemble of all {len(ensemble_models)} fold models for inference")

✓ Fold 1 model loaded from models/fold_1_best.pth
✓ Fold 2 model loaded from models/fold_2_best.pth
✓ Fold 3 model loaded from models/fold_3_best.pth
✓ Fold 4 model loaded from models/fold_4_best.pth
✓ Fold 5 model loaded from models/fold_5_best.pth
Using ensemble of all 5 fold models for inference


In [5]:
# Temperature Scaling 클래스 정의
class TemperatureScaling(nn.Module):
    def __init__(self):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)
    
    def forward(self, logits):
        return logits / self.temperature

In [6]:
# 기본 TTA 변형들 (모든 클래스 공통)
essential_tta_transforms = [
    # 원본
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
    # 90도 회전들
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.Rotate(limit=[90, 90], p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.Rotate(limit=[180, 180], p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.Rotate(limit=[-90, -90], p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
    # 밝기 개선
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.RandomBrightnessContrast(brightness_limit=[0.3, 0.3], contrast_limit=[0.3, 0.3], p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
]

# 취약 클래스(3,4,7,14) 전용 강화 TTA
enhanced_tta_transforms = essential_tta_transforms + [
    # 문서 특화 변형들 추가
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),  # 대비 향상
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1.0),  # 선명화
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
    A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=0),
        A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4, p=1.0),
        A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]),
]

print(f"Basic TTA transforms: {len(essential_tta_transforms)}")
print(f"Enhanced TTA transforms: {len(enhanced_tta_transforms)}")

Basic TTA transforms: 5
Enhanced TTA transforms: 8


In [7]:
def assess_image_quality(img):
    """
    이미지 품질을 평가하여 전처리 필요성 판단
    Args:
        img: RGB numpy array (H, W, 3)
    Returns:
        quality_score: 0~1 (낮을수록 오염됨)
        metrics: 개별 품질 지표들
    """
    # RGB to GRAY 변환
    if len(img.shape) == 3:
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    else:
        gray = img
    
    # 1. 블러 정도 측정 (Laplacian variance)
    blur_score = cv2.Laplacian(gray, cv2.CV_64F).var()
    blur_normalized = min(blur_score / 500.0, 1.0)  # 500 이상이면 선명
    
    # 2. 대비 측정 (표준편차)
    contrast_score = gray.std()
    contrast_normalized = min(contrast_score / 50.0, 1.0)  # 50 이상이면 대비 좋음
    
    # 3. 밝기 분포 (히스토그램 엔트로피)
    hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
    hist_norm = hist / hist.sum()
    hist_entropy = -np.sum(hist_norm * np.log(hist_norm + 1e-7))
    brightness_normalized = min(hist_entropy / 6.0, 1.0)  # 엔트로피 정규화
    
    # 4. 전체 품질 점수 계산 (가중 평균)
    quality_score = (blur_normalized * 0.4 + 
                    contrast_normalized * 0.4 + 
                    brightness_normalized * 0.2)
    
    metrics = {
        'blur': blur_normalized,
        'contrast': contrast_normalized, 
        'brightness': brightness_normalized,
        'overall': quality_score
    }
    
    return quality_score, metrics

In [8]:
def adaptive_preprocessing(img):
    """
    품질 점수에 따른 적응형 전처리
    Args:
        img: RGB numpy array (H, W, 3)
    Returns:
        processed_img: 전처리된 RGB numpy array
    """
    quality_score, metrics = assess_image_quality(img)
    
    # 임계값 설정
    CLEAN_THRESHOLD = 0.6    # 깨끗한 이미지
    MODERATE_THRESHOLD = 0.3  # 중간 오염
    
    if quality_score > CLEAN_THRESHOLD:
        # 정상 이미지 - 원본 그대로 반환
        return img
    
    elif quality_score > MODERATE_THRESHOLD:
        # 중간 오염 - 가벼운 복원
        processed_img = img.copy()
        
        # 가벼운 블러 제거
        if metrics['blur'] < 0.5:
            processed_img = cv2.bilateralFilter(processed_img, 5, 50, 50)
        
        # 대비 개선
        if metrics['contrast'] < 0.5:
            processed_img = cv2.convertScaleAbs(processed_img, alpha=1.2, beta=10)
        
        return np.clip(processed_img, 0, 255).astype(np.uint8)
    
    else:
        # 심각한 오염 - 강력한 복원
        processed_img = img.copy()
        
        # 노이즈 제거 (Non-local Means Denoising)
        processed_img = cv2.fastNlMeansDenoisingColored(processed_img, None, 10, 10, 7, 21)
        
        # 선명화 (Unsharp Mask)
        if metrics['blur'] < 0.3:
            gaussian = cv2.GaussianBlur(processed_img, (0, 0), 2.0)
            processed_img = cv2.addWeighted(processed_img, 1.5, gaussian, -0.5, 0)
        
        # 대비 향상 (CLAHE)
        if metrics['contrast'] < 0.3:
            lab = cv2.cvtColor(processed_img, cv2.COLOR_RGB2LAB)
            lab[:,:,0] = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)).apply(lab[:,:,0])
            processed_img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
        
        return np.clip(processed_img, 0, 255).astype(np.uint8)

In [9]:
# 테스트 함수
def test_preprocessing_functions():
    """구현한 함수들의 기본 동작 테스트"""
    print("Testing adaptive preprocessing functions...")
    
    # 더미 이미지로 테스트
    test_img = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
    
    try:
        quality_score, metrics = assess_image_quality(test_img)
        print(f"Quality assessment: {quality_score:.3f}")
        print(f"Metrics: {metrics}")
        
        processed_img = adaptive_preprocessing(test_img)
        print(f"Preprocessing completed. Input shape: {test_img.shape}, Output shape: {processed_img.shape}")
        print("✅ All functions working correctly")
        return True
        
    except Exception as e:
        print(f"❌ Error in functions: {e}")
        return False

In [10]:
# 함수 테스트 실행
test_preprocessing_functions()

Testing adaptive preprocessing functions...
Quality assessment: 0.971
Metrics: {'blur': 1.0, 'contrast': 0.9860672067540966, 'brightness': 0.8811639944712321, 'overall': 0.9706596815958852}
Preprocessing completed. Input shape: (384, 384, 3), Output shape: (384, 384, 3)
✅ All functions working correctly


True

In [11]:
# TTA 추론을 위한 Dataset 클래스
class TTAImageDataset(Dataset):
    def __init__(self, data, path, transforms):
        if isinstance(data, str):
            self.df = pd.read_csv(data).values
        else:
            self.df = data.values
        self.path = path
        self.transforms = transforms  # 여러 transform을 리스트로 받음

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(Image.open(os.path.join(self.path, name)).convert('RGB'))
        
        # ★ 적응형 전처리 적용
        img = adaptive_preprocessing(img)
        
        # 모든 transform을 적용한 결과를 리스트로 반환
        augmented_images = []
        for transform in self.transforms:
            aug_img = transform(image=img)['image']
            augmented_images.append(aug_img)
        
        return augmented_images, target

In [None]:
# TTA Dataset 생성
tta_dataset = TTAImageDataset(
    "../data/sample_submission.csv",
    "../data/test/",
    enhanced_tta_transforms
)

# TTA DataLoader (배치 크기를 줄여서 메모리 절약)
tta_loader = DataLoader(
    tta_dataset,
    batch_size=96,  # TTA는 메모리를 많이 사용하므로 배치 크기 줄임
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

print(f"TTA Dataset size: {len(tta_dataset)}")

TTA Dataset size: 3140


In [None]:
def enhanced_ensemble_tta_inference(models, loader, weak_classes=[3,4,7,14], confidence_threshold=0.40):
    """취약 클래스 대응 강화 앙상블 TTA 추론"""
    temp_scaling = TemperatureScaling().to(device)
    all_predictions = []
    all_confidences = []
    weak_class_count = 0
    enhanced_count = 0
    
    for batch_idx, (images_list, _) in enumerate(tqdm(loader, desc="Enhanced Ensemble TTA")):
        batch_size = images_list[0].size(0)
        batch_predictions = []
        batch_confidences = []
        
        for i in range(batch_size):
            # 각 이미지별로 개별 처리
            single_images = [img_tensor[i:i+1] for img_tensor in images_list]  # 각 TTA에서 i번째 이미지
            
            # 1차: 기본 TTA (5개 변형) 앙상블
            basic_probs = torch.zeros(1, 17).to(device)
            for model in models:
                with torch.no_grad():
                    for j, images in enumerate(single_images[:5]):  # 기본 TTA만 (처음 5개)
                        images = images.to(device)
                        preds = model(images)
                        scaled_preds = temp_scaling(preds)
                        probs = torch.softmax(scaled_preds, dim=1)
                        basic_probs += probs / (len(models) * 5)
            
            basic_pred = torch.argmax(basic_probs, dim=1).item()
            basic_conf = torch.max(basic_probs, dim=1)[0].item()
            
            # 2차: 취약 클래스 + 낮은 신뢰도인 경우 강화 TTA 적용
            if basic_pred in weak_classes and basic_conf < confidence_threshold:
                weak_class_count += 1
                enhanced_count += 1
                
                # 강화 TTA (8개 변형) 앙상블
                enhanced_probs = torch.zeros(1, 17).to(device)
                for model in models:
                    with torch.no_grad():
                        for images in single_images:  # 모든 TTA 변형 (8개)
                            images = images.to(device)
                            preds = model(images)
                            scaled_preds = temp_scaling(preds)
                            probs = torch.softmax(scaled_preds, dim=1)
                            enhanced_probs += probs / (len(models) * len(single_images))
                
                enhanced_pred = torch.argmax(enhanced_probs, dim=1).item()
                enhanced_conf = torch.max(enhanced_probs, dim=1)[0].item()
                
                # 더 높은 신뢰도 선택
                if enhanced_conf > basic_conf + 0.1:  # 10% 이상 향상 시 채택
                    final_pred = enhanced_pred
                    final_conf = enhanced_conf
                else:
                    final_pred = basic_pred
                    final_conf = basic_conf
            else:
                # 정상 클래스 또는 높은 신뢰도 → 기본 TTA 결과 사용
                final_pred = basic_pred
                final_conf = basic_conf
                if basic_pred in weak_classes:
                    weak_class_count += 1
            
            batch_predictions.append(final_pred)
            batch_confidences.append(final_conf)
        
        all_predictions.extend(batch_predictions)
        all_confidences.extend(batch_confidences)
    
    print(f"취약 클래스 예측: {weak_class_count}개")
    print(f"강화 TTA 적용: {enhanced_count}개")
    print(f"평균 신뢰도: {np.mean(all_confidences):.3f}")
    
    return all_predictions

In [14]:
# 앙상블 TTA 실행
print("Starting Ensemble TTA inference...")
tta_predictions = enhanced_ensemble_tta_inference(
    models=ensemble_models, 
    loader=tta_loader, 
    weak_classes=[3,4,7,14],
    confidence_threshold=0.40
)

Starting Ensemble TTA inference...


Enhanced Ensemble TTA:   0%|          | 0/33 [00:00<?, ?it/s]

Enhanced Ensemble TTA: 100%|██████████| 33/33 [15:24<00:00, 28.02s/it]

취약 클래스 예측: 692개
강화 TTA 적용: 174개
평균 신뢰도: 0.644





In [15]:
# TTA 결과로 submission 파일 생성
tta_pred_df = pd.DataFrame(tta_dataset.df, columns=['ID', 'target'])
tta_pred_df['target'] = tta_predictions

In [16]:
# 기존 submission과 동일한 순서인지 확인
sample_submission_df = pd.read_csv("../data/sample_submission.csv")
assert (sample_submission_df['ID'] == tta_pred_df['ID']).all()

In [17]:
# TTA 결과 저장
tta_pred_df.to_csv("../submission/choice.csv", index=False)
print("TTA predictions saved")

print("TTA Prediction sample:")

TTA predictions saved
TTA Prediction sample:


In [18]:
tta_pred_df.head()

Unnamed: 0,ID,target
0,0008fdb22ddce0ce.jpg,2
1,00091bffdffd83de.jpg,12
2,00396fbc1f6cc21d.jpg,5
3,00471f8038d9c4b6.jpg,12
4,00901f504008d884.jpg,2
