In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import sys
import os
import torch
import time
import random
import numpy as np
import wandb
from tqdm import tqdm

from models.resnet import resnet18, resnet34, resnet50

wandb.login(key="ef091b9abcea3186341ddf8995d62bde62d7469e")
wandb.init(project="PBL-2", name="resnet18-cutmix")


"""
# 학습 재현성 고정
def fix_seed(seed, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True  # 성능 향상을 위해 True로 변경

deterministic=True와 benchmark=False는 확실히 학습 속도를 저하시킬 수 있습니다.
특히 torch.backends.cudnn.benchmark=False는 CUDA가 최적의 알고리즘을 찾기 위한 
벤치마킹을 수행하지 않게 만들어 성능이 떨어질 수 있습니다.
# 속도 우선 설정 -> 완벽한 재현성은 보장되지 않음 

fix_seed(2025, deterministic=False)
"""

# WandB 설정
config = {
    "model": "resnet18",
    "batch_size": 128,
    "num_epochs": 100,
    "learning_rate": 0.001,
    "optimizer": "Adam",
    "seed": 2025,
    "deterministic": False,
}
wandb.config.update(config)

# CIFAR-100 데이터셋 로드
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

trainset = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=config["batch_size"], shuffle=True, num_workers=16)

testset = torchvision.datasets.CIFAR100(
    root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=config["batch_size"], shuffle=False, num_workers=16)

def train(model, trainloader, criterion, optimizer, device, epoch):
    """
    학습 함수
    """
    model.train()   # 모델을 학습 모드로 설정
    start_time = time.time()  # 시간 측정 시작
    running_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # top-1 정확도 계산
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct_top1 += predicted.eq(labels).sum().item()
        
        # top-5 정확도 계산
        _, top5_idx = outputs.topk(5, 1, largest=True, sorted=True)
        correct_top5 += top5_idx.eq(labels.view(-1, 1).expand_as(top5_idx)).sum().item()
        
        if (i + 1) % 50 == 0:  # 50 배치마다 출력
            print(f'Epoch [{epoch+1}], Batch [{i+1}/{len(trainloader)}], Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(trainloader)
    accuracy_top1 = 100.0 * correct_top1 / total
    accuracy_top5 = 100.0 * correct_top5 / total
    
    train_time = time.time() - start_time
    
    # 학습 세트에 대한 성능 출력
    print(f'Train set: Epoch: {epoch+1}, Average loss:{epoch_loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f} '
          f'Top-1 Accuracy: {accuracy_top1:.4f}%, Top-5 Accuracy: {accuracy_top5:.4f}%, Time consumed:{train_time:.2f}s')
    
    return epoch_loss, accuracy_top1, accuracy_top5

def evaluate(model, testloader, criterion, device, epoch):
    """
    평가 함수
    """
    model.eval()  # 모델을 평가 모드로 설정
    start_time = time.time()  # 시간 측정 시작
    
    test_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    # 그래디언트 계산 비활성화
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # 순전파
            outputs = model(inputs)
            
            # 손실 계산
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            # top-1 정확도 계산
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct_top1 += (predicted == labels).sum().item()
            
            # top-5 정확도 계산
            _, top5_idx = outputs.topk(5, 1, largest=True, sorted=True)
            correct_top5 += top5_idx.eq(labels.view(-1, 1).expand_as(top5_idx)).sum().item()
    
    # 평균 손실 및 정확도 계산
    test_loss = test_loss / len(testloader)
    accuracy_top1 = 100.0 * correct_top1 / total
    accuracy_top5 = 100.0 * correct_top5 / total
    
    # 평가 시간 계산
    eval_time = time.time() - start_time
    
    # 테스트 세트에 대한 성능 출력
    print(f'Test set: Epoch: {epoch+1}, Average loss:{test_loss:.4f}, '
          f'Top-1 Accuracy: {accuracy_top1:.4f}%, Top-5 Accuracy: {accuracy_top5:.4f}%, Time consumed:{eval_time:.2f}s')
    print()
    
    return test_loss, accuracy_top1, accuracy_top5


# 메인 학습 루프
def main_training_loop(model, trainloader, testloader, criterion, optimizer, device, num_epochs=10):
    """
    메인 학습 루프
    """
    best_acc_top1 = 0.0
    best_acc_top5 = 0.0
    
    # tqdm을 사용한 진행 상황 표시
    for epoch in tqdm(range(num_epochs)):
        # 학습
        train_loss, train_acc_top1, train_acc_top5 = train(model, trainloader, criterion, optimizer, device, epoch)
        
        # 평가
        test_loss, test_acc_top1, test_acc_top5 = evaluate(model, testloader, criterion, device, epoch)
        
        # WandB에 로깅
        wandb.log({
            "epoch": epoch + 1,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "train_loss": train_loss,
            "train_accuracy_top1": train_acc_top1,
            "train_accuracy_top5": train_acc_top5,
            "test_loss": test_loss,
            "test_accuracy_top1": test_acc_top1,
            "test_accuracy_top5": test_acc_top5
        })
            
        # 최고 정확도 모델 저장 (top-1 기준)
        if test_acc_top1 > best_acc_top1:
            best_acc_top1 = test_acc_top1
            best_acc_top5_at_best_top1 = test_acc_top5
            print(f'New best top-1 accuracy: {best_acc_top1:.2f}%, top-5 accuracy: {best_acc_top5_at_best_top1:.2f}%')
            
            # 모델 저장
            model_path = f'best_model_{wandb.run.name}.pth'
            torch.save(model.state_dict(), model_path)
            
            # WandB에 모델 아티팩트 저장
            wandb.save(model_path)
        
        # top-5 accuracy 기록 업데이트
        if test_acc_top5 > best_acc_top5:
            best_acc_top5 = test_acc_top5
            print(f'New best top-5 accuracy: {best_acc_top5:.2f}%')
    
    print(f'Finish! Best top-1 accuracy: {best_acc_top1:.2f}%, Best top-5 accuracy: {best_acc_top5:.2f}%')
    wandb.run.summary["best_accuracy_top1"] = best_acc_top1
    wandb.run.summary["best_accuracy_top5"] = best_acc_top5

# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 모델 초기화
model = resnet18().to(device)  
criterion = nn.CrossEntropyLoss()  # 손실 함수 정의
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])  # 옵티마이저 정의

# WandB에 모델 구조 기록
wandb.watch(model, log="all")

# GPU 가속
if torch.cuda.device_count() > 1:
    print(f"{torch.cuda.device_count()}개의 GPU를 사용합니다.")
    model = nn.DataParallel(model)

# 훈련 시작 시간 기록
start_time = time.time()

# 메인 학습 루프 호출
main_training_loop(
    model=model,
    trainloader=trainloader,
    testloader=testloader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=config["num_epochs"]
)

# 훈련 종료 시간 기록 및 출력
end_time = time.time()
total_time = end_time - start_time
wandb.log({"total_training_time": total_time})

print(f"Total training time: {total_time:.2f} seconds")

# WandB 실행 종료
wandb.finish()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/guswls/.netrc


Files already downloaded and verified
Files already downloaded and verified
Using device: cuda
2개의 GPU를 사용합니다.


  0%|                                                                                                       | 0/100 [00:00<?, ?it/s]

Epoch [1], Batch [50/391], Loss: 4.0091
Epoch [1], Batch [100/391], Loss: 3.9891
Epoch [1], Batch [150/391], Loss: 3.7799
Epoch [1], Batch [200/391], Loss: 3.6455
Epoch [1], Batch [250/391], Loss: 3.7078
Epoch [1], Batch [300/391], Loss: 3.6388
Epoch [1], Batch [350/391], Loss: 3.1598
Train set: Epoch: 1, Average loss:3.6887, LR: 0.001000 Top-1 Accuracy: 13.2660%, Top-5 Accuracy: 36.3280%, Time consumed:138.04s


  1%|▉                                                                                           | 1/100 [02:42<4:27:52, 162.35s/it]

Test set: Epoch: 1, Average loss:3.3946, Top-1 Accuracy: 19.6200%, Top-5 Accuracy: 46.6900%, Time consumed:24.21s

New best top-1 accuracy: 19.62%, top-5 accuracy: 46.69%
New best top-5 accuracy: 46.69%
Epoch [2], Batch [50/391], Loss: 3.0589
Epoch [2], Batch [100/391], Loss: 3.1191
Epoch [2], Batch [150/391], Loss: 2.7881
Epoch [2], Batch [200/391], Loss: 2.8302
Epoch [2], Batch [250/391], Loss: 2.6513
Epoch [2], Batch [300/391], Loss: 2.7857
Epoch [2], Batch [350/391], Loss: 2.3993
Train set: Epoch: 2, Average loss:2.7871, LR: 0.001000 Top-1 Accuracy: 28.6640%, Top-5 Accuracy: 60.4860%, Time consumed:133.60s


  2%|█▊                                                                                          | 2/100 [05:21<4:21:57, 160.38s/it]

Test set: Epoch: 2, Average loss:2.5855, Top-1 Accuracy: 33.0800%, Top-5 Accuracy: 65.4000%, Time consumed:25.28s

New best top-1 accuracy: 33.08%, top-5 accuracy: 65.40%
New best top-5 accuracy: 65.40%
Epoch [3], Batch [50/391], Loss: 2.4285
Epoch [3], Batch [100/391], Loss: 2.2038
Epoch [3], Batch [150/391], Loss: 2.2922
Epoch [3], Batch [200/391], Loss: 2.1802
Epoch [3], Batch [250/391], Loss: 2.0384
Epoch [3], Batch [300/391], Loss: 2.2779
Epoch [3], Batch [350/391], Loss: 2.3236
Train set: Epoch: 3, Average loss:2.1784, LR: 0.001000 Top-1 Accuracy: 41.3240%, Top-5 Accuracy: 74.3320%, Time consumed:142.59s


  3%|██▊                                                                                         | 3/100 [08:07<4:23:46, 163.16s/it]

Test set: Epoch: 3, Average loss:2.3836, Top-1 Accuracy: 38.2300%, Top-5 Accuracy: 70.1200%, Time consumed:23.74s

New best top-1 accuracy: 38.23%, top-5 accuracy: 70.12%
New best top-5 accuracy: 70.12%
Epoch [4], Batch [50/391], Loss: 1.7502
Epoch [4], Batch [100/391], Loss: 1.8550
Epoch [4], Batch [150/391], Loss: 1.6262
Epoch [4], Batch [200/391], Loss: 1.8721
Epoch [4], Batch [250/391], Loss: 1.6512
Epoch [4], Batch [300/391], Loss: 1.7207
Epoch [4], Batch [350/391], Loss: 1.6392
Train set: Epoch: 4, Average loss:1.7704, LR: 0.001000 Top-1 Accuracy: 50.8460%, Top-5 Accuracy: 81.9060%, Time consumed:144.95s


  4%|███▋                                                                                        | 4/100 [10:58<4:25:37, 166.01s/it]

Test set: Epoch: 4, Average loss:1.8603, Top-1 Accuracy: 48.7100%, Top-5 Accuracy: 79.7200%, Time consumed:25.27s

New best top-1 accuracy: 48.71%, top-5 accuracy: 79.72%
New best top-5 accuracy: 79.72%
Epoch [5], Batch [50/391], Loss: 1.2719
Epoch [5], Batch [100/391], Loss: 1.3769
Epoch [5], Batch [150/391], Loss: 1.5925
Epoch [5], Batch [200/391], Loss: 1.4389
Epoch [5], Batch [250/391], Loss: 1.4264
Epoch [5], Batch [300/391], Loss: 1.4523
Epoch [5], Batch [350/391], Loss: 1.3304
Train set: Epoch: 5, Average loss:1.4339, LR: 0.001000 Top-1 Accuracy: 58.9460%, Top-5 Accuracy: 87.4760%, Time consumed:145.63s


  5%|████▌                                                                                       | 5/100 [13:48<4:25:32, 167.71s/it]

Test set: Epoch: 5, Average loss:1.9187, Top-1 Accuracy: 48.6400%, Top-5 Accuracy: 79.3500%, Time consumed:25.08s

Epoch [6], Batch [50/391], Loss: 1.1000
Epoch [6], Batch [100/391], Loss: 0.9617
Epoch [6], Batch [150/391], Loss: 1.3693
Epoch [6], Batch [200/391], Loss: 1.1276
Epoch [6], Batch [250/391], Loss: 0.8863
Epoch [6], Batch [300/391], Loss: 1.3932
Epoch [6], Batch [350/391], Loss: 1.0701
Train set: Epoch: 6, Average loss:1.1362, LR: 0.001000 Top-1 Accuracy: 66.4320%, Top-5 Accuracy: 91.8860%, Time consumed:138.36s


  6%|█████▌                                                                                      | 6/100 [16:31<4:20:06, 166.03s/it]

Test set: Epoch: 6, Average loss:1.8652, Top-1 Accuracy: 51.3600%, Top-5 Accuracy: 81.3000%, Time consumed:24.30s

New best top-1 accuracy: 51.36%, top-5 accuracy: 81.30%
New best top-5 accuracy: 81.30%
Epoch [7], Batch [50/391], Loss: 0.7025
Epoch [7], Batch [100/391], Loss: 0.7820
Epoch [7], Batch [150/391], Loss: 0.6128
Epoch [7], Batch [200/391], Loss: 0.8296
Epoch [7], Batch [250/391], Loss: 0.6264
Epoch [7], Batch [300/391], Loss: 1.0520
Epoch [7], Batch [350/391], Loss: 0.7147
Train set: Epoch: 7, Average loss:0.8296, LR: 0.001000 Top-1 Accuracy: 74.7740%, Top-5 Accuracy: 95.7240%, Time consumed:133.71s


  7%|██████▍                                                                                     | 7/100 [19:09<4:13:10, 163.34s/it]

Test set: Epoch: 7, Average loss:1.8084, Top-1 Accuracy: 52.6700%, Top-5 Accuracy: 82.4500%, Time consumed:23.97s

New best top-1 accuracy: 52.67%, top-5 accuracy: 82.45%
New best top-5 accuracy: 82.45%
Epoch [8], Batch [50/391], Loss: 0.3586
Epoch [8], Batch [100/391], Loss: 0.4930
Epoch [8], Batch [150/391], Loss: 0.5492
Epoch [8], Batch [200/391], Loss: 0.5007
Epoch [8], Batch [250/391], Loss: 0.4338
Epoch [8], Batch [300/391], Loss: 0.4716
Epoch [8], Batch [350/391], Loss: 0.6598
Train set: Epoch: 8, Average loss:0.5549, LR: 0.001000 Top-1 Accuracy: 82.9240%, Top-5 Accuracy: 98.1340%, Time consumed:132.51s


  8%|███████▎                                                                                    | 8/100 [21:46<4:07:16, 161.26s/it]

Test set: Epoch: 8, Average loss:1.9458, Top-1 Accuracy: 53.2900%, Top-5 Accuracy: 82.3400%, Time consumed:24.16s

New best top-1 accuracy: 53.29%, top-5 accuracy: 82.34%
Epoch [9], Batch [50/391], Loss: 0.1705
Epoch [9], Batch [100/391], Loss: 0.3146
Epoch [9], Batch [150/391], Loss: 0.2883
Epoch [9], Batch [200/391], Loss: 0.3016
