In [1]:

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
import torch.nn.functional as F
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
#from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from nmc.models import *
from nmc.datasets import * 
from nmc.augmentations import get_train_augmentation, get_val_augmentation
from nmc.losses import get_loss
from nmc.schedulers import get_scheduler
from nmc.optimizers import get_optimizer
from nmc.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from tools.val import evaluate_epi
from nmc.utils.episodic_utils import * 
from scipy.cluster import hierarchy
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
from torch.optim import lr_scheduler
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mutual_info_score
from scipy.cluster import hierarchy
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.data import Subset
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import cv2

# Swin Transformer 모델 정의 (7클래스 분류)
from torchvision.models import swin_t

In [2]:
with open('../configs/NMC.yaml') as f:
    cfg = yaml.load(f, Loader=yaml.SafeLoader)
print(cfg)
fix_seeds(3407)
setup_cudnn()
gpu = setup_ddp()
save_dir = Path(cfg['SAVE_DIR'])
save_dir.mkdir(exist_ok=True)
cleanup_ddp()

{'DEVICE': 'cuda:0', 'SAVE_DIR': 'output', 'MODEL': {'NAME': 'EfficientNetV2MModelMulti', 'BACKBONE': 'EfficientNetV2', 'PRETRAINED': '/workspace/jhmoon/nmc_2024/checkpoints/pretrained/tf_efficientnetv2_m_weights.pth', 'UNFREEZE': 'full', 'VERSION': "384_32_loss'"}, 'DATASET': {'NAME': 'NMCDataset', 'ROOT': '/data/nmc/processed_image', 'TRAIN_RATIO': 0.7, 'VALID_RATIO': 0.15, 'TEST_RATIO': 0.15}, 'TRAIN': {'IMAGE_SIZE': [384, 384], 'BATCH_SIZE': 32, 'EPOCHS': 100, 'EVAL_INTERVAL': 1, 'AMP': False, 'DDP': False}, 'LOSS': {'NAME': 'BCEWithLogitsLoss', 'CLS_WEIGHTS': False}, 'OPTIMIZER': {'NAME': 'adamw', 'LR': 0.1, 'WEIGHT_DECAY': 0.01}, 'SCHEDULER': {'NAME': 'warmuppolylr', 'POWER': 0.9, 'WARMUP': 10, 'WARMUP_RATIO': 0.1}, 'EVAL': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.NMC.pth', 'IMAGE_SIZE': [384, 384]}, 'TEST': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.NMC.pth', 'FILE': 'assests/ade', 'IMAGE_SIZE': [384, 384], 'OVERLAY': True}}


In [3]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0

In [4]:
def get_train_augmentation(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_test_transform(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


In [5]:
class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        
        # 데이터셋에서 레이블 추출
        if hasattr(dataset, 'labels'):
            self.labels = dataset.labels
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        elif hasattr(dataset, 'targets'):
            self.labels = dataset.targets
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        else:
            try:
                self.labels = [sample[1] for sample in dataset]
                if isinstance(self.labels[0], np.ndarray):
                    self.labels = torch.from_numpy(np.array(self.labels))
                else:
                    self.labels = torch.tensor(self.labels)
            except:
                raise ValueError("Cannot access labels from dataset")
        
        self.n_classes = self.labels.shape[1] if len(self.labels.shape) > 1 else len(torch.unique(self.labels))
        self.samples_per_class = batch_size // self.n_classes
        
        # 클래스별 인덱스 저장
        self.class_indices = []
        for i in range(self.n_classes):
            if len(self.labels.shape) > 1:
                idx = torch.where(self.labels[:, i] == 1)[0]
            else:
                idx = torch.where(self.labels == i)[0]
            self.class_indices.append(idx)
        
        self.n_batches = len(self.dataset) // batch_size
        if len(self.dataset) % batch_size != 0:
            self.n_batches += 1
    
    def __iter__(self):
        for _ in range(self.n_batches):
            batch_indices = []
            for class_idx in range(self.n_classes):
                class_samples = self.class_indices[class_idx]
                if len(class_samples) == 0:
                    continue
                
                # 랜덤 선택
                selected = class_samples[torch.randint(len(class_samples), 
                                                     (self.samples_per_class,))]
                batch_indices.extend(selected.tolist())
            
            # 배치 크기에 맞게 자르기
            if len(batch_indices) > self.batch_size:
                batch_indices = batch_indices[:self.batch_size]
            
            # 중요: 리스트로 yield
            yield batch_indices
    
    def __len__(self):
        return self.n_batches

In [6]:
start = time.time()
best_mf1 = 0.0
device = torch.device(cfg['DEVICE'])
print("device : ", device)
num_workers = mp.cpu_count()
train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']

image_size = [256,256]
image_dir = Path(dataset_cfg['ROOT']) / 'train_images'
train_transform = get_train_augmentation(image_size)
val_test_transform = get_val_test_transform(image_size)
batch_size = 32


dataset = eval(dataset_cfg['NAME'])(
    dataset_cfg['ROOT'] + '/cropped_images_1424x1648',
    dataset_cfg['TRAIN_RATIO'],
    dataset_cfg['VALID_RATIO'],
    dataset_cfg['TEST_RATIO'],
    transform=None
)
trainset, valset, testset = dataset.get_splits()
trainset.transform = train_transform
valset.transform = val_test_transform
testset.transform = val_test_transform



# DataLoader 수정
trainloader = DataLoader(
    trainset, 
    batch_sampler=BalancedBatchSampler(trainset, batch_size=batch_size),
    num_workers=num_workers,
    pin_memory=True
)
# trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, drop_last=True, pin_memory=True)
valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)
testloader = DataLoader(testset, batch_size=1, num_workers=1, pin_memory=True)

device :  cuda:0
/data/nmc/processed_image/cropped_images_1424x1648
(0,)               1935
(3,)                542
(1, 2, 3)           532
(1, 2)              286
(2,)                233
(1, 2, 3, 4)        190
(2, 3)              163
(1,)                155
(4,)                 47
(1, 3)               31
(1, 2, 4)            27
(3, 4)               22
(1, 2, 3, 4, 5)      11
(5,)                  9
(2, 3, 4)             9
(1, 4)                9
(1, 2, 3, 5)          8
(1, 2, 5)             7
(2, 4)                7
(1, 2, 3, 5, 6)       5
(1, 2, 3, 6)          4
(1, 3, 4)             2
(1, 3, 6)             1
(6,)                  1
(1, 2, 6)             1
(1, 2, 3, 4, 6)       1
Name: label, dtype: int64
train size: 4238
(0,)               415
(3,)               118
(1, 2, 3)          113
(1, 2)              65
(2,)                50
(1, 2, 3, 4)        46
(2, 3)              35
(1,)                32
(4,)                13
(1, 3)               7
(5,)                 4
(3, 4)      

In [22]:
# Model definition (changed to binary classification)
model = models.efficientnet_v2_m(pretrained=True)
num_ftrs = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.BatchNorm1d(num_ftrs),
    nn.Linear(num_ftrs, 7)
)
model = model.to(device)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [8]:
# model for swin S
model = models.swin_s(pretrained=True)
num_ftrs = model.head.in_features
model.head = nn.Sequential(
    nn.BatchNorm1d(num_ftrs),
    nn.Linear(num_ftrs, 7)
)
model = model.to(device)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [23]:
# L2 regularization
weight_decay = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=weight_decay)
criterion = nn.BCEWithLogitsLoss()
scaler = GradScaler(enabled=train_cfg['AMP'])
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)


In [24]:

def train_epoch(model, dataloader, criterion, optimizer, scaler, device):
    model.train()
    total_loss = 0
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast(enabled=scaler is not None):
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels.float())
        
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [25]:
def evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            # 각 클래스에 대해 시그모이드 적용 후 임계값 처리
            preds = (torch.sigmoid(outputs) > 0.5).int()
            
            # 배치 단위로 예측값과 라벨 저장
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    # 배치 데이터를 하나의 배열로 결합
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    # 멀티라벨 F1 score 계산
    f1 = f1_score(all_labels, all_preds, average='samples')  # or 'micro', 'macro', 'weighted'
    
    return f1

In [26]:
def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, scaler, device, epochs):
    best_f1 = 0.0
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device)
        val_f1 = evaluate(model, val_loader, device)
        
        print(f"Training Loss: {train_loss:.4f}")
        print(f"Validation F1 Score: {val_f1:.4f}")
        
        scheduler.step(val_f1)
        
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), 'model/best_model_nmc.pth')
            print("New best model saved!")
        
        early_stopping(val_f1)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
        
        print()
    
    return best_f1

In [27]:
# Main execution code
# 정규화, lr스케쥴링, 데이터 증강, 조기종료, 배치정규화
epochs = 100
best_f1 = train_and_evaluate(model, trainloader, valloader, criterion, optimizer, scaler, device, epochs)

print(f"Training completed. Best F1 Score: {best_f1:.4f}")

# Final evaluation on test set
model.load_state_dict(torch.load('model/best_model_nmc.pth'))
test_f1 = evaluate(model, testloader, device)
print(f"Test F1 Score: {test_f1:.4f}")

Epoch 1/100


Training:   0%|          | 0/133 [00:00<?, ?it/s]

Training: 100%|██████████| 133/133 [00:53<00:00,  2.49it/s]
Evaluating: 100%|██████████| 914/914 [00:39<00:00, 23.24it/s]


Training Loss: 0.4608
Validation F1 Score: 0.6501
New best model saved!

Epoch 2/100


Training: 100%|██████████| 133/133 [00:58<00:00,  2.27it/s]
Evaluating: 100%|██████████| 914/914 [00:39<00:00, 23.38it/s]


Training Loss: 0.2707
Validation F1 Score: 0.6586
New best model saved!

Epoch 3/100


Training: 100%|██████████| 133/133 [00:53<00:00,  2.51it/s]
Evaluating: 100%|██████████| 914/914 [00:40<00:00, 22.58it/s]


Training Loss: 0.2093
Validation F1 Score: 0.7346
New best model saved!

Epoch 4/100


Training: 100%|██████████| 133/133 [00:44<00:00,  2.96it/s]
Evaluating: 100%|██████████| 914/914 [00:52<00:00, 17.37it/s]


Training Loss: 0.1732
Validation F1 Score: 0.7418
New best model saved!

Epoch 5/100


Training: 100%|██████████| 133/133 [00:44<00:00,  2.97it/s]
Evaluating: 100%|██████████| 914/914 [00:43<00:00, 21.17it/s]


Training Loss: 0.1467
Validation F1 Score: 0.7641
New best model saved!

Epoch 6/100


Training: 100%|██████████| 133/133 [00:52<00:00,  2.53it/s]
Evaluating: 100%|██████████| 914/914 [00:38<00:00, 23.92it/s]


Training Loss: 0.1343
Validation F1 Score: 0.7205

Epoch 7/100


Training: 100%|██████████| 133/133 [00:55<00:00,  2.41it/s]
Evaluating: 100%|██████████| 914/914 [00:38<00:00, 23.64it/s]


Training Loss: 0.1038
Validation F1 Score: 0.7235

Epoch 8/100


Training: 100%|██████████| 133/133 [00:42<00:00,  3.13it/s]
Evaluating: 100%|██████████| 914/914 [00:50<00:00, 18.15it/s]


Training Loss: 0.0937
Validation F1 Score: 0.7694
New best model saved!

Epoch 9/100


Training: 100%|██████████| 133/133 [00:43<00:00,  3.07it/s]
Evaluating: 100%|██████████| 914/914 [00:47<00:00, 19.44it/s]


Training Loss: 0.0900
Validation F1 Score: 0.7438

Epoch 10/100


Training: 100%|██████████| 133/133 [00:48<00:00,  2.74it/s]
Evaluating: 100%|██████████| 914/914 [00:38<00:00, 23.55it/s]


Training Loss: 0.0909
Validation F1 Score: 0.7955
New best model saved!

Epoch 11/100


Training: 100%|██████████| 133/133 [00:58<00:00,  2.27it/s]
Evaluating: 100%|██████████| 914/914 [00:39<00:00, 23.09it/s]


Training Loss: 0.0750
Validation F1 Score: 0.7873

Epoch 12/100


Training: 100%|██████████| 133/133 [00:44<00:00,  2.98it/s]
Evaluating: 100%|██████████| 914/914 [00:49<00:00, 18.64it/s]


Training Loss: 0.0658
Validation F1 Score: 0.8021
New best model saved!

Epoch 13/100


Training: 100%|██████████| 133/133 [00:43<00:00,  3.03it/s]
Evaluating: 100%|██████████| 914/914 [00:51<00:00, 17.70it/s]


Training Loss: 0.0597
Validation F1 Score: 0.8080
New best model saved!

Epoch 14/100


Training: 100%|██████████| 133/133 [00:43<00:00,  3.04it/s]
Evaluating: 100%|██████████| 914/914 [00:38<00:00, 23.64it/s]


Training Loss: 0.0535
Validation F1 Score: 0.8173
New best model saved!

Epoch 15/100


Training: 100%|██████████| 133/133 [01:00<00:00,  2.21it/s]
Evaluating: 100%|██████████| 914/914 [00:38<00:00, 23.46it/s]


Training Loss: 0.0458
Validation F1 Score: 0.8055

Epoch 16/100


Training: 100%|██████████| 133/133 [00:53<00:00,  2.46it/s]
Evaluating: 100%|██████████| 914/914 [00:40<00:00, 22.39it/s]


Training Loss: 0.0464
Validation F1 Score: 0.8253
New best model saved!

Epoch 17/100


Training: 100%|██████████| 133/133 [00:44<00:00,  3.02it/s]
Evaluating: 100%|██████████| 914/914 [00:48<00:00, 18.66it/s]


Training Loss: 0.0484
Validation F1 Score: 0.8062

Epoch 18/100


Training: 100%|██████████| 133/133 [00:43<00:00,  3.03it/s]
Evaluating: 100%|██████████| 914/914 [00:39<00:00, 23.13it/s]


Training Loss: 0.0399
Validation F1 Score: 0.8117

Epoch 19/100


Training: 100%|██████████| 133/133 [00:59<00:00,  2.24it/s]
Evaluating: 100%|██████████| 914/914 [00:37<00:00, 24.51it/s]


Training Loss: 0.0426
Validation F1 Score: 0.8031

Epoch 20/100


Training: 100%|██████████| 133/133 [00:50<00:00,  2.64it/s]
Evaluating: 100%|██████████| 914/914 [00:43<00:00, 21.20it/s]


Training Loss: 0.0419
Validation F1 Score: 0.8097

Epoch 21/100


Training: 100%|██████████| 133/133 [00:43<00:00,  3.02it/s]
Evaluating: 100%|██████████| 914/914 [00:51<00:00, 17.84it/s]


Training Loss: 0.0337
Validation F1 Score: 0.7746

Epoch 22/100


Training: 100%|██████████| 133/133 [00:42<00:00,  3.10it/s]
Evaluating: 100%|██████████| 914/914 [00:37<00:00, 24.23it/s]


Training Loss: 0.0323
Validation F1 Score: 0.8139
Epoch 00022: reducing learning rate of group 0 to 1.0000e-05.

Epoch 23/100


Training: 100%|██████████| 133/133 [00:58<00:00,  2.29it/s]
Evaluating: 100%|██████████| 914/914 [00:39<00:00, 23.36it/s]


Training Loss: 0.0271
Validation F1 Score: 0.8184

Epoch 24/100


Training: 100%|██████████| 133/133 [00:57<00:00,  2.32it/s]
Evaluating: 100%|██████████| 914/914 [00:39<00:00, 23.43it/s]


Training Loss: 0.0231
Validation F1 Score: 0.8178

Epoch 25/100


Training: 100%|██████████| 133/133 [00:44<00:00,  2.96it/s]
Evaluating: 100%|██████████| 914/914 [00:48<00:00, 18.84it/s]


Training Loss: 0.0220
Validation F1 Score: 0.8257
New best model saved!

Epoch 26/100


Training: 100%|██████████| 133/133 [00:44<00:00,  2.98it/s]
Evaluating: 100%|██████████| 914/914 [00:51<00:00, 17.65it/s]


Training Loss: 0.0206
Validation F1 Score: 0.8217
Early stopping triggered
Training completed. Best F1 Score: 0.8257


Evaluating: 100%|██████████| 902/902 [00:50<00:00, 17.85it/s]

Test F1 Score: 0.8321





In [28]:
def evaluate(model, dataloader, device, num_classes):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            # 각 클래스에 대해 시그모이드 적용 후 임계값 처리
            preds = (torch.sigmoid(outputs) > 0.5).int()
            
            # 배치 단위로 예측값과 라벨 저장
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    # 배치 데이터를 하나의 배열로 결합
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    # 전체 F1 score 계산
    overall_f1 = f1_score(all_labels, all_preds, average='samples')
    
    # 각 클래스별 F1 score 계산
    class_f1_scores = f1_score(all_labels, all_preds, average=None)
    
    # 각 클래스별 정밀도(Precision)와 재현율(Recall) 계산
    class_precision = precision_score(all_labels, all_preds, average=None)
    class_recall = recall_score(all_labels, all_preds, average=None)
    
    # 결과를 딕셔너리로 정리
    results = {
        'overall_f1': overall_f1,
        'class_f1_scores': class_f1_scores,
        'class_precision': class_precision,
        'class_recall': class_recall
    }
    
    # 각 클래스별 메트릭 출력
    print("\nPer-class Performance Metrics:")
    print("-" * 50)
    for i in range(num_classes):
        print(f"Class {i}:")
        print(f"  F1-Score: {class_f1_scores[i]:.4f}")
        print(f"  Precision: {class_precision[i]:.4f}")
        print(f"  Recall: {class_recall[i]:.4f}")
    print("-" * 50)
    print(f"Overall F1-Score: {overall_f1:.4f}")
    
    return results

In [29]:
model.load_state_dict(torch.load('model/best_model_nmc.pth'))
test_f1 = evaluate(model, testloader, device, 7)

Evaluating: 100%|██████████| 902/902 [00:38<00:00, 23.31it/s]



Per-class Performance Metrics:
--------------------------------------------------
Class 0:
  F1-Score: 0.8984
  Precision: 0.9129
  Recall: 0.8843
Class 1:
  F1-Score: 0.8358
  Precision: 0.8485
  Recall: 0.8235
Class 2:
  F1-Score: 0.8671
  Precision: 0.8726
  Recall: 0.8616
Class 3:
  F1-Score: 0.8563
  Precision: 0.8726
  Recall: 0.8405
Class 4:
  F1-Score: 0.6102
  Precision: 0.7500
  Recall: 0.5143
Class 5:
  F1-Score: 0.6667
  Precision: 0.8333
  Recall: 0.5556
Class 6:
  F1-Score: 0.5000
  Precision: 1.0000
  Recall: 0.3333
--------------------------------------------------
Overall F1-Score: 0.8321
