In [None]:

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
import torch.nn.functional as F
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
#from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from nmc.models import *
from nmc.datasets import * 
from nmc.augmentations import get_train_augmentation, get_val_augmentation
from nmc.losses import get_loss
from nmc.schedulers import get_scheduler
from nmc.optimizers import get_optimizer
from nmc.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from tools.val import evaluate_epi
from nmc.utils.episodic_utils import * 
from scipy.cluster import hierarchy
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
from torch.optim import lr_scheduler
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mutual_info_score
from scipy.cluster import hierarchy
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.data import Subset
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import cv2

# Swin Transformer 모델 정의 (7클래스 분류)
from torchvision.models import swin_t

In [None]:
with open('../configs/NMC.yaml') as f:
    cfg = yaml.load(f, Loader=yaml.SafeLoader)
print(cfg)
fix_seeds(3407)
setup_cudnn()
gpu = setup_ddp()
save_dir = Path(cfg['SAVE_DIR'])
save_dir.mkdir(exist_ok=True)
cleanup_ddp()

In [None]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0

In [None]:
def get_train_augmentation(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_test_transform(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


In [None]:
class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        
        # 데이터셋에서 레이블 추출
        if hasattr(dataset, 'labels'):
            self.labels = dataset.labels
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        elif hasattr(dataset, 'targets'):
            self.labels = dataset.targets
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        else:
            try:
                self.labels = [sample[1] for sample in dataset]
                if isinstance(self.labels[0], np.ndarray):
                    self.labels = torch.from_numpy(np.array(self.labels))
                else:
                    self.labels = torch.tensor(self.labels)
            except:
                raise ValueError("Cannot access labels from dataset")
        
        self.n_classes = self.labels.shape[1] if len(self.labels.shape) > 1 else len(torch.unique(self.labels))
        self.samples_per_class = batch_size // self.n_classes
        
        # 클래스별 인덱스 저장
        self.class_indices = []
        for i in range(self.n_classes):
            if len(self.labels.shape) > 1:
                idx = torch.where(self.labels[:, i] == 1)[0]
            else:
                idx = torch.where(self.labels == i)[0]
            self.class_indices.append(idx)
        
        self.n_batches = len(self.dataset) // batch_size
        if len(self.dataset) % batch_size != 0:
            self.n_batches += 1
    
    def __iter__(self):
        for _ in range(self.n_batches):
            batch_indices = []
            for class_idx in range(self.n_classes):
                class_samples = self.class_indices[class_idx]
                if len(class_samples) == 0:
                    continue
                
                # 랜덤 선택
                selected = class_samples[torch.randint(len(class_samples), 
                                                     (self.samples_per_class,))]
                batch_indices.extend(selected.tolist())
            
            # 배치 크기에 맞게 자르기
            if len(batch_indices) > self.batch_size:
                batch_indices = batch_indices[:self.batch_size]
            
            # 중요: 리스트로 yield
            yield batch_indices
    
    def __len__(self):
        return self.n_batches

In [None]:
start = time.time()
best_mf1 = 0.0
device = torch.device(cfg['DEVICE'])
print("device : ", device)
num_workers = mp.cpu_count()
train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']

image_size = [256,256]
image_dir = Path(dataset_cfg['ROOT']) / 'train_images'
train_transform = get_train_augmentation(image_size)
val_test_transform = get_val_test_transform(image_size)
batch_size = 32


dataset = eval(dataset_cfg['NAME'])(
    dataset_cfg['ROOT'] + '/cropped_images_1424x1648',
    dataset_cfg['TRAIN_RATIO'],
    dataset_cfg['VALID_RATIO'],
    dataset_cfg['TEST_RATIO'],
    transform=None
)
trainset, valset, testset = dataset.get_splits()
trainset.transform = train_transform
valset.transform = val_test_transform
testset.transform = val_test_transform



# DataLoader 수정
trainloader = DataLoader(
    trainset, 
    batch_sampler=BalancedBatchSampler(trainset, batch_size=batch_size),
    num_workers=num_workers,
    pin_memory=True
)
# trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, drop_last=True, pin_memory=True)
valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)
testloader = DataLoader(testset, batch_size=1, num_workers=1, pin_memory=True)

In [None]:
def get_label_order(dataset):
    # 데이터프레임에서 직접 라벨 정보 가져오기
    labels = np.array([label for label in dataset.dataframe['label'].values])
    
    # 멀티라벨을 one-hot 형태로 변환
    label_matrix = np.zeros((len(labels), 7))  # 7개 클래스
    for i, label_list in enumerate(labels):
        label_matrix[i, label_list] = 1
    
    # 각 라벨의 빈도수 계산
    label_counts = np.sum(label_matrix, axis=0)
    
    # 빈도수 기준 내림차순 정렬
    order = np.argsort(-label_counts)
    
    print("Label frequencies:")
    for idx, count in enumerate(label_counts[order]):
        print(f"Label {order[idx]}: {count} samples")
    
    return order

In [None]:
class ClassifierChainModel(nn.Module):
    def __init__(self, base_model, num_labels, label_order):
        super().__init__()
        self.num_labels = num_labels
        self.label_order = label_order
        
        # 기본 특성 추출기
        self.feature_extractor = base_model
        num_ftrs = base_model.classifier[1].in_features
        self.feature_extractor.classifier = nn.Identity()
        
        # 각 라벨별 분류기 체인
        self.chains = nn.ModuleList([
            nn.Sequential(
                nn.Linear(num_ftrs + i, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 1)
            ) for i in range(num_labels)
        ])
        
    def forward(self, x):
        # 특성 추출
        features = self.feature_extractor(x)  # [batch_size, num_ftrs]
        batch_size = features.size(0)
        predictions = []
        
        # 체인을 따라 순차적으로 예측
        for i in range(self.num_labels):
            if i == 0:
                chain_input = features
            else:
                # 이전 예측들을 concatenate
                prev_preds = torch.cat(predictions, dim=1)  # [batch_size, i]
                chain_input = torch.cat([features, prev_preds], dim=1)
            
            pred = self.chains[i](chain_input)  # [batch_size, 1]
            predictions.append(pred)
        
        # 모든 예측을 concatenate
        final_output = torch.cat(predictions, dim=1)  # [batch_size, num_labels]
        return final_output

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, scaler, device):
    model.train()
    total_loss = 0
    
    for images, labels in tqdm(dataloader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast(enabled=scaler is not None):
            outputs = model(images)  
            # 라벨 순서에 따라 loss 계산
            labels_reordered = labels[:, model.label_order]
            loss = criterion(outputs, labels_reordered)
        
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            
            outputs = model(images)
            preds = (torch.sigmoid(outputs) > 0.5).int()
            
            # 예측값을 원래 라벨 순서로 되돌리기
            inv_order = np.argsort(model.label_order)
            preds = preds[:, inv_order]
            
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    f1 = f1_score(all_labels, all_preds, average='samples')
    
    return f1

In [None]:
def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, scaler, device, epochs):
    best_f1 = 0.0
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device)
        val_f1 = evaluate(model, val_loader, device)
        
        print(f"Training Loss: {train_loss:.4f}")
        print(f"Validation F1 Score: {val_f1:.4f}")
        
        scheduler.step(val_f1)
        
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), 'model/best_model_nmc_chain.pth')
            print("New best model saved!")
        
        early_stopping(val_f1)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
        
        print()
    
    return best_f1

In [None]:
# Main execution code
# 정규화, lr스케쥴링, 데이터 증강, 조기종료, 배치정규화
# 라벨 순서 결정
label_order = get_label_order(trainset)

# 모델 초기화
base_model = models.efficientnet_v2_m(pretrained=True)
model = ClassifierChainModel(base_model, num_labels=7, label_order=label_order)
model = model.to(device)

# L2 regularization
weight_decay = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=weight_decay)
criterion = nn.BCEWithLogitsLoss()
scaler = GradScaler(enabled=train_cfg['AMP'])
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)

print('--')
epochs = 100
best_f1 = train_and_evaluate(model, trainloader, valloader, criterion, optimizer, scaler, device, epochs)

print(f"Training completed. Best F1 Score: {best_f1:.4f}")

# Final evaluation on test set
model.load_state_dict(torch.load('model/best_model_nmc_chain.pth'))
test_f1 = evaluate(model, testloader, device)
print(f"Test F1 Score: {test_f1:.4f}")

In [None]:
def evaluate(model, dataloader, device, num_classes):
   model.eval()
   all_preds = []
   all_labels = []
   
   with torch.no_grad():
       for images, labels in tqdm(dataloader, desc="Evaluating"):
           images = images.to(device)
           labels = labels.to(device)
           
           outputs = model(images)
           preds = (torch.sigmoid(outputs) > 0.5).int()
           
           # 예측값을 원래 라벨 순서로 되돌리기 
           inv_order = np.argsort(model.label_order)
           preds = preds[:, inv_order]
           
           all_preds.append(preds.cpu().numpy())
           all_labels.append(labels.cpu().numpy())
   
   all_preds = np.vstack(all_preds)
   all_labels = np.vstack(all_labels)
   
   # 전체 F1 score 계산
   overall_f1 = f1_score(all_labels, all_preds, average='samples')
   
   # 각 클래스별 F1 score 계산
   class_f1_scores = f1_score(all_labels, all_preds, average=None)
   
   # 각 클래스별 정밀도(Precision)와 재현율(Recall) 계산
   class_precision = precision_score(all_labels, all_preds, average=None)
   class_recall = recall_score(all_labels, all_preds, average=None)
   
   # 결과를 딕셔너리로 정리
   results = {
       'overall_f1': overall_f1,
       'class_f1_scores': class_f1_scores,
       'class_precision': class_precision,
       'class_recall': class_recall
   }
   
   # 각 클래스별 메트릭 출력 (원래 라벨 순서로)
   print("\nPer-class Performance Metrics:")
   print("-" * 50)
   for i in range(num_classes):
       orig_idx = model.label_order[i]  # 원래 라벨 인덱스
       print(f"Class {orig_idx} (순서: {i}):")
       print(f"  F1-Score: {class_f1_scores[i]:.4f}")
       print(f"  Precision: {class_precision[i]:.4f}")
       print(f"  Recall: {class_recall[i]:.4f}")
   print("-" * 50)
   print(f"Overall F1-Score: {overall_f1:.4f}")
   
   return results

In [None]:
model.load_state_dict(torch.load('model/best_model_nmc_chain.pth'))
test_f1 = evaluate(model, testloader, device, 7)