In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.v2 as transforms_v2  # CutMix를 위한 v2 transforms 추가
import sys
import os
import torch
import time
import random
import numpy as np
import wandb
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
from tools.tool import EarlyStopping
from models.resnet import resnet18, resnet34, resnet50

wandb.login(key="ef091b9abcea3186341ddf8995d62bde62d7469e")
wandb.init(project="PBL-2", name="resnet18_cutmix)")  # CutMix 적용 실험임을 명시

# WandB 설정
config = {
    "model": "resnet18",
    "batch_size": 128,
    "num_epochs": 100,
    "learning_rate": 0.001,
    "optimizer": "Adam",
    "seed": 2025,
    "deterministic": False,
    "patience": 10,  # early stopping patience
    "train_ratio": 0.8,
    "val_ratio": 0.1,
    "test_ratio": 0.1,
    "cutmix_alpha": 1.0,  # CutMix 알파 파라미터 추가
    "cutmix_prob": 0.5    # CutMix 적용 확률 추가
}
wandb.config.update(config)

# CIFAR-100 데이터셋 로드
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

full_trainset = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR100(
    root='./data', train=False, download=True, transform=transform_test)

# Stratified 분할을 위한 준비 (train, validation 나누기)
# 모든 라벨을 추출
targets = np.array(full_trainset.targets)

# StratifiedShuffleSplit을 사용하여 8:1:1 비율로 분할
# 먼저 train과 validation을 나눔 (full_trainset에서 8:2)
train_val_split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=config["seed"])
train_idx, temp_idx = next(train_val_split.split(np.zeros(len(targets)), targets))

# 그 다음 validation과 test를 나눔 (temp에서 1:1, 전체로 보면 1:1)
val_test_targets = targets[temp_idx]
val_test_split = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=config["seed"])
val_idx_temp, test_idx_temp = next(val_test_split.split(np.zeros(len(val_test_targets)), val_test_targets))

# 원래 인덱스로 매핑
val_idx = temp_idx[val_idx_temp]
test_idx = temp_idx[test_idx_temp]

# Subset 생성
trainset = Subset(full_trainset, train_idx)
valset = Subset(full_trainset, val_idx)
testset_split = Subset(full_trainset, test_idx)  # 원래 테스트셋 대신 stratified split에서 나온 테스트셋 사용

# DataLoader 생성
trainloader = DataLoader(trainset, batch_size=config["batch_size"], shuffle=True, num_workers=16)
valloader = DataLoader(valset, batch_size=config["batch_size"], shuffle=False, num_workers=16)
testloader = DataLoader(testset_split, batch_size=config["batch_size"], shuffle=False, num_workers=16)

print(f"Train set size: {len(trainset)}")
print(f"Validation set size: {len(valset)}")
print(f"Test set size: {len(testset_split)}")

# 추가: CutMix 변환 정의
cutmix = transforms_v2.CutMix(alpha=config["cutmix_alpha"], num_classes=100)  # CIFAR-100은 100개 클래스

# CutMix용 손실 함수 정의 (원-핫 인코딩된 레이블 처리)
def cutmix_criterion(outputs, targets):
    """
    CutMix로 혼합된 레이블을 처리하기 위한 손실 함수
    outputs: 모델 출력
    targets: CutMix로 생성된 원-핫 인코딩 레이블
    """
    return torch.nn.functional.cross_entropy(outputs, targets)

def train(model, trainloader, criterion, optimizer, device, epoch):
    """
    학습 함수 (CutMix 적용)
    """
    model.train()   # 모델을 학습 모드로 설정
    start_time = time.time()  # 시간 측정 시작
    running_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # CutMix 확률적 적용
        if random.random() < config["cutmix_prob"]:
            inputs, labels = cutmix(inputs, labels)
            # 이 경우 labels은 원-핫 인코딩 형태로 변환됨
            use_cutmix = True
        else:
            use_cutmix = False
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        
        # CutMix 적용 여부에 따라 손실 함수 선택
        if use_cutmix:
            # CutMix가 적용된 경우 (원-핫 인코딩된 레이블)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
        else:
            # 일반적인 경우 (정수 인덱스 레이블)
            loss = criterion(outputs, labels)
            
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # 정확도 계산 - CutMix 적용 여부에 따라 다르게 처리
        if use_cutmix:
            # 원-핫 인코딩된 레이블에서 argmax를 사용해 가장 큰 값의 인덱스 추출
            _, label_idx = labels.max(1)
        else:
            # 정수 인덱스 레이블 그대로 사용
            label_idx = labels
            
        # top-1 정확도 계산
        _, predicted = outputs.max(1)
        total += inputs.size(0)
        correct_top1 += predicted.eq(label_idx).sum().item()
        
        # top-5 정확도 계산
        _, top5_idx = outputs.topk(5, 1, largest=True, sorted=True)
        correct_top5 += sum([1 for i in range(len(label_idx)) if label_idx[i] in top5_idx[i]])
        
        if (i + 1) % 50 == 0:  # 50 배치마다 출력
            print(f'Epoch [{epoch+1}], Batch [{i+1}/{len(trainloader)}], Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(trainloader)
    accuracy_top1 = 100.0 * correct_top1 / total
    accuracy_top5 = 100.0 * correct_top5 / total
    
    train_time = time.time() - start_time
    
    # 학습 세트에 대한 성능 출력
    print(f'Train set: Epoch: {epoch+1}, Average loss:{epoch_loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f} '
          f'Top-1 Accuracy: {accuracy_top1:.4f}%, Top-5 Accuracy: {accuracy_top5:.4f}%, Time consumed:{train_time:.2f}s')
    
    return epoch_loss, accuracy_top1, accuracy_top5

def evaluate(model, dataloader, criterion, device, epoch, phase="val"):
    """
    평가 함수
    """
    model.eval()  # 모델을 평가 모드로 설정
    start_time = time.time()  # 시간 측정 시작
    
    eval_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    # 그래디언트 계산 비활성화
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # 순전파
            outputs = model(inputs)
            
            # 손실 계산
            loss = criterion(outputs, labels)
            eval_loss += loss.item()
            
            # top-1 정확도 계산
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct_top1 += (predicted == labels).sum().item()
            
            # top-5 정확도 계산
            _, top5_idx = outputs.topk(5, 1, largest=True, sorted=True)
            correct_top5 += top5_idx.eq(labels.view(-1, 1).expand_as(top5_idx)).sum().item()
    
    # 평균 손실 및 정확도 계산
    eval_loss = eval_loss / len(dataloader)
    accuracy_top1 = 100.0 * correct_top1 / total
    accuracy_top5 = 100.0 * correct_top5 / total
    
    # 평가 시간 계산
    eval_time = time.time() - start_time
    
    # 테스트 세트에 대한 성능 출력
    print(f'{phase.capitalize()} set: Epoch: {epoch+1}, Average loss:{eval_loss:.4f}, '
          f'Top-1 Accuracy: {accuracy_top1:.4f}%, Top-5 Accuracy: {accuracy_top5:.4f}%, Time consumed:{eval_time:.2f}s')
    print()
    
    return eval_loss, accuracy_top1, accuracy_top5


# 메인 학습 루프
def main_training_loop(model, trainloader, valloader, testloader, criterion, optimizer, device, num_epochs, patience):
    """
    메인 학습 루프
    """
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    best_acc_top1 = 0.0
    best_acc_top5 = 0.0
    
    # tqdm을 사용한 진행 상황 표시
    for epoch in tqdm(range(num_epochs)):
        # 학습
        train_loss, train_acc_top1, train_acc_top5 = train(model, trainloader, criterion, optimizer, device, epoch)
        
        # 평가
        val_loss, val_acc_top1, val_acc_top5 = evaluate(model, valloader, criterion, device, epoch, phase="val")
        
        # WandB에 로깅
        wandb.log({
            "epoch": epoch + 1,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "train_loss": train_loss,
            "train_accuracy_top1": train_acc_top1,
            "train_accuracy_top5": train_acc_top5,
            "val_loss": val_loss,
            "val_accuracy_top1": val_acc_top1,
            "val_accuracy_top5": val_acc_top5
        })
            
        # 최고 정확도 모델 저장 (top-1 기준)
        if val_acc_top1 > best_acc_top1:
            best_acc_top1 = val_acc_top1
            best_acc_top5_at_best_top1 = val_acc_top5
            print(f'New best top-1 accuracy: {best_acc_top1:.2f}%, top-5 accuracy: {best_acc_top5_at_best_top1:.2f}%')
            # 모델 저장
            model_path = f'best_model_{wandb.run.name}.pth'
            torch.save(model.state_dict(), model_path)
            
            # WandB에 모델 아티팩트 저장
            wandb.save(model_path)
        
        # top-5 accuracy 기록 업데이트
        if val_acc_top5 > best_acc_top5:
            best_acc_top5 = val_acc_top5
            print(f'New best top-5 accuracy: {best_acc_top5:.2f}%')

        # Early stopping 체크 (validation loss 기준)
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered. Training stopped.")
            break
    
    # 훈련 완료 후 모델 평가 (best model 로드)
    if early_stopping.early_stop:
        print("Loading best model from early stopping checkpoint...")
    else:
        print("Loading best model based on validation accuracy...")
        model_path = f'best_model_{wandb.run.name}.pth'
        model.load_state_dict(torch.load(model_path))

    # 최종 테스트 세트 평가
    test_loss, test_acc_top1, test_acc_top5 = evaluate(model, testloader, criterion, device, num_epochs-1, phase="test")
    
    # 테스트 결과를 wandb 로그에 추가 - 이 부분이 누락되어 있어서 추가했습니다
    wandb.log({
        "epoch": epoch + 1,  # 마지막 에폭 또는 early stopping된 에폭
        "test_loss": test_loss,
        "test_accuracy_top1": test_acc_top1,
        "test_accuracy_top5": test_acc_top5
    })
    
    print(f'Finish! Best validation top-1 accuracy: {best_acc_top1:.2f}%, Best validation top-5 accuracy: {best_acc_top5:.2f}%')
    print(f'Final test top-1 accuracy: {test_acc_top1:.2f}%, Final test top-5 accuracy: {test_acc_top5:.2f}%')
    
    # WandB에 최종 결과 기록
    wandb.run.summary["best_val_accuracy_top1"] = best_acc_top1
    wandb.run.summary["best_val_accuracy_top5"] = best_acc_top5
    wandb.run.summary["test_accuracy_top1"] = test_acc_top1
    wandb.run.summary["test_accuracy_top5"] = test_acc_top5

    # Early stopping 정보 저장
    if early_stopping.early_stop:
        wandb.run.summary["early_stopped"] = True
        wandb.run.summary["early_stopped_epoch"] = epoch+1
    else:
        wandb.run.summary["early_stopped"] = False


# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 모델 초기화
model = resnet18().to(device)  
criterion = nn.CrossEntropyLoss()  # 손실 함수 정의
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])  # 옵티마이저 정의

# WandB에 모델 구조 기록
wandb.watch(model, log="all")

# GPU 가속
if torch.cuda.device_count() > 1:
    print(f"{torch.cuda.device_count()}개의 GPU를 사용합니다.")
    model = nn.DataParallel(model)

# 훈련 시작 시간 기록
start_time = time.time()

# 메인 학습 루프 호출
main_training_loop(
    model=model,
    trainloader=trainloader,
    valloader=valloader,
    testloader=testloader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=config["num_epochs"],
    patience=config["patience"]
)

# 훈련 종료 시간 기록 및 출력
end_time = time.time()
total_time = end_time - start_time
wandb.log({"total_training_time": total_time})

print(f"Total training time: {total_time:.2f} seconds")

# WandB 실행 종료
wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/guswls/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msokjh1310[0m ([33msokjh1310-hanyang-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Files already downloaded and verified
Files already downloaded and verified
Train set size: 40000
Validation set size: 5000
Test set size: 5000
Using device: cuda
2개의 GPU를 사용합니다.


  0%|                                                                                                       | 0/100 [00:00<?, ?it/s]

Epoch [1], Batch [50/313], Loss: 4.3788
Epoch [1], Batch [100/313], Loss: 4.3641
Epoch [1], Batch [150/313], Loss: 3.8305
Epoch [1], Batch [200/313], Loss: 4.2847
Epoch [1], Batch [250/313], Loss: 3.6260
Epoch [1], Batch [300/313], Loss: 3.6727
Train set: Epoch: 1, Average loss:4.0287, LR: 0.001000 Top-1 Accuracy: 9.8200%, Top-5 Accuracy: 29.0200%, Time consumed:39.15s
Val set: Epoch: 1, Average loss:3.4809, Top-1 Accuracy: 15.7000%, Top-5 Accuracy: 42.5800%, Time consumed:4.57s

New best top-1 accuracy: 15.70%, top-5 accuracy: 42.58%
New best top-5 accuracy: 42.58%
Validation loss decreased (inf --> 3.480859). Saving model ...


  1%|▉                                                                                            | 1/100 [00:44<1:12:36, 44.00s/it]

Epoch [2], Batch [50/313], Loss: 4.1062
Epoch [2], Batch [100/313], Loss: 3.7713
Epoch [2], Batch [150/313], Loss: 4.2300
Epoch [2], Batch [200/313], Loss: 3.0753
Epoch [2], Batch [250/313], Loss: 2.6882
Epoch [2], Batch [300/313], Loss: 2.8606
Train set: Epoch: 2, Average loss:3.4741, LR: 0.001000 Top-1 Accuracy: 19.4725%, Top-5 Accuracy: 46.7300%, Time consumed:39.74s
Val set: Epoch: 2, Average loss:2.8862, Top-1 Accuracy: 26.4400%, Top-5 Accuracy: 59.2200%, Time consumed:4.57s

New best top-1 accuracy: 26.44%, top-5 accuracy: 59.22%
New best top-5 accuracy: 59.22%
Validation loss decreased (3.480859 --> 2.886189). Saving model ...


  2%|█▊                                                                                           | 2/100 [01:28<1:12:24, 44.33s/it]

Epoch [3], Batch [50/313], Loss: 2.5937
Epoch [3], Batch [100/313], Loss: 2.9314
Epoch [3], Batch [150/313], Loss: 3.2919
Epoch [3], Batch [200/313], Loss: 2.5725
Epoch [3], Batch [250/313], Loss: 2.4483
Epoch [3], Batch [300/313], Loss: 2.8945
Train set: Epoch: 3, Average loss:3.0831, LR: 0.001000 Top-1 Accuracy: 28.1750%, Top-5 Accuracy: 58.3200%, Time consumed:39.28s
Val set: Epoch: 3, Average loss:2.5775, Top-1 Accuracy: 32.9400%, Top-5 Accuracy: 65.7200%, Time consumed:4.66s

New best top-1 accuracy: 32.94%, top-5 accuracy: 65.72%
New best top-5 accuracy: 65.72%
Validation loss decreased (2.886189 --> 2.577469). Saving model ...


  3%|██▊                                                                                          | 3/100 [02:12<1:11:35, 44.28s/it]

Epoch [4], Batch [50/313], Loss: 3.1261
Epoch [4], Batch [100/313], Loss: 2.2756
Epoch [4], Batch [150/313], Loss: 3.3604
Epoch [4], Batch [200/313], Loss: 1.6864
Epoch [4], Batch [250/313], Loss: 2.0429
Epoch [4], Batch [300/313], Loss: 2.0292
Train set: Epoch: 4, Average loss:2.7348, LR: 0.001000 Top-1 Accuracy: 35.3275%, Top-5 Accuracy: 66.2675%, Time consumed:39.00s
Val set: Epoch: 4, Average loss:2.2645, Top-1 Accuracy: 40.2000%, Top-5 Accuracy: 72.7400%, Time consumed:4.74s

New best top-1 accuracy: 40.20%, top-5 accuracy: 72.74%
New best top-5 accuracy: 72.74%
Validation loss decreased (2.577469 --> 2.264508). Saving model ...


  4%|███▋                                                                                         | 4/100 [02:56<1:10:41, 44.18s/it]

Epoch [5], Batch [50/313], Loss: 1.8189
Epoch [5], Batch [100/313], Loss: 3.6395
Epoch [5], Batch [150/313], Loss: 1.6487
Epoch [5], Batch [200/313], Loss: 1.9601
Epoch [5], Batch [250/313], Loss: 1.7333
Epoch [5], Batch [300/313], Loss: 3.7499
Train set: Epoch: 5, Average loss:2.4645, LR: 0.001000 Top-1 Accuracy: 42.5075%, Top-5 Accuracy: 72.7300%, Time consumed:38.37s
Val set: Epoch: 5, Average loss:2.0878, Top-1 Accuracy: 44.0600%, Top-5 Accuracy: 76.2800%, Time consumed:4.60s

New best top-1 accuracy: 44.06%, top-5 accuracy: 76.28%
New best top-5 accuracy: 76.28%
Validation loss decreased (2.264508 --> 2.087789). Saving model ...


  5%|████▋                                                                                        | 5/100 [03:40<1:09:24, 43.84s/it]

Epoch [6], Batch [50/313], Loss: 1.6278
Epoch [6], Batch [100/313], Loss: 3.5830
Epoch [6], Batch [150/313], Loss: 3.4693
Epoch [6], Batch [200/313], Loss: 1.4417
Epoch [6], Batch [250/313], Loss: 1.5466
Epoch [6], Batch [300/313], Loss: 3.0426
Train set: Epoch: 6, Average loss:2.2061, LR: 0.001000 Top-1 Accuracy: 48.2200%, Top-5 Accuracy: 77.5925%, Time consumed:40.61s
Val set: Epoch: 6, Average loss:1.9221, Top-1 Accuracy: 48.4000%, Top-5 Accuracy: 79.1000%, Time consumed:4.70s

New best top-1 accuracy: 48.40%, top-5 accuracy: 79.10%
New best top-5 accuracy: 79.10%
Validation loss decreased (2.087789 --> 1.922123). Saving model ...


  6%|█████▌                                                                                       | 6/100 [04:25<1:09:35, 44.42s/it]

Epoch [7], Batch [50/313], Loss: 3.4184
Epoch [7], Batch [100/313], Loss: 3.3505
Epoch [7], Batch [150/313], Loss: 1.2417
Epoch [7], Batch [200/313], Loss: 1.3599
Epoch [7], Batch [250/313], Loss: 2.8352
Epoch [7], Batch [300/313], Loss: 2.0834
Train set: Epoch: 7, Average loss:2.2507, LR: 0.001000 Top-1 Accuracy: 49.8150%, Top-5 Accuracy: 78.1175%, Time consumed:38.77s
Val set: Epoch: 7, Average loss:1.7858, Top-1 Accuracy: 51.8800%, Top-5 Accuracy: 82.0600%, Time consumed:5.30s

New best top-1 accuracy: 51.88%, top-5 accuracy: 82.06%
New best top-5 accuracy: 82.06%
Validation loss decreased (1.922123 --> 1.785751). Saving model ...


  7%|██████▌                                                                                      | 7/100 [05:09<1:08:48, 44.40s/it]

Epoch [8], Batch [50/313], Loss: 1.2304
Epoch [8], Batch [100/313], Loss: 3.2994
Epoch [8], Batch [150/313], Loss: 3.0182
Epoch [8], Batch [200/313], Loss: 3.6045
Epoch [8], Batch [250/313], Loss: 1.1144
Epoch [8], Batch [300/313], Loss: 1.1483
Train set: Epoch: 8, Average loss:1.9045, LR: 0.001000 Top-1 Accuracy: 58.6000%, Top-5 Accuracy: 84.7175%, Time consumed:41.17s


  8%|███████▍                                                                                     | 8/100 [05:55<1:08:48, 44.88s/it]

Val set: Epoch: 8, Average loss:1.8569, Top-1 Accuracy: 50.7600%, Top-5 Accuracy: 80.1400%, Time consumed:4.73s

EarlyStopping 카운터: 1 / 10
Epoch [9], Batch [50/313], Loss: 0.8296
Epoch [9], Batch [100/313], Loss: 0.9050
Epoch [9], Batch [150/313], Loss: 1.0436
Epoch [9], Batch [200/313], Loss: 1.0836
Epoch [9], Batch [250/313], Loss: 1.0492
Epoch [9], Batch [300/313], Loss: 1.0139
Train set: Epoch: 9, Average loss:1.7874, LR: 0.001000 Top-1 Accuracy: 61.5225%, Top-5 Accuracy: 85.6975%, Time consumed:39.23s


  9%|████████▎                                                                                    | 9/100 [06:40<1:07:44, 44.66s/it]

Val set: Epoch: 9, Average loss:1.7823, Top-1 Accuracy: 51.5200%, Top-5 Accuracy: 82.5000%, Time consumed:4.79s

New best top-5 accuracy: 82.50%
Validation loss decreased (1.785751 --> 1.782291). Saving model ...
Epoch [10], Batch [50/313], Loss: 1.4383
Epoch [10], Batch [100/313], Loss: 1.1799
