In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
import sys
import os
import time
import random
import numpy as np
import wandb
from tqdm import tqdm
from utility.utils import AccuracyEarlyStopping
from torchmetrics.classification import MulticlassAveragePrecision
from Dataset.data import OfficeHomeDataset  # 기존 데이터셋 클래스 임포트

# 체크포인트 디렉토리 생성
os.makedirs('checkpoints', exist_ok=True)

# WandB 설정 
wandb.login(key="ef091b9abcea3186341ddf8995d62bde62d7469e")
wandb.init(
    project="office-home-classification", 
    name="ResNet50_MultiTask_layer3 Extraction - Style, all - Category",
    entity="hh0804352-hanyang-university"
)

# wandb run name을 체크포인트 경로에 사용
run_name = wandb.run.name
CHECKPOINT_PATH = os.path.join('checkpoints', f'{run_name}_checkpoint.pth')


# 설정
config = {
    # 모델 설정
    "model": "resnet50",
    "batch_size": 256,
    "num_epochs": 300,
    
    "learning_rate": 0.001,  # Adam 기본 학습률
    "optimizer": "Adam",
    
    # 학습 과정 설정
    "seed": 2025,
    "deterministic": False,
    "patience": 30,  # early stopping patience
    "max_epochs_wait": float('inf'),
    
    # 멀티태스크 설정
    "num_domains": 4,
    "num_classes": 65,
    "domain_weight": 0.5,  # 도메인 분류 손실 가중치
    "class_weight": 0.5,   # 클래스 분류 손실 가중치
    
    # 시스템 설정
    "num_workers": 32,
    "pin_memory": True,
    
    # 체크포인트 설정
    "save_every": 5,  # 몇 epoch마다 저장할지
    
    # 스케줄러 설정
    "scheduler": "ReduceLROnPlateau",
    "scheduler_mode": "max",
    "scheduler_factor": 0.1,
    "scheduler_patience": 5,
    "scheduler_verbose": True,
}

wandb.config.update(config)

# ResNet-50 기반 멀티태스크 모델 - 중간 활성화 사용
class FeatureExtractingMultiTaskModel(nn.Module):
    def __init__(self, num_domains=4, num_classes=65):
        super(FeatureExtractingMultiTaskModel, self).__init__()
        # 사전학습된 ResNet-50 로드
        self.backbone = models.resnet50(pretrained=True)
        
        # 전체 백본 동결 (선택적)
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # 2. FeatureExtractingMultiTaskModel 클래스에서 도메인 분류기 수정 (약 91-96 라인)
        # 도메인 분류를 위한 추가 레이어 (Layer3 출력에 적용)
        # Layer3 출력 크기: [B, 1024, H/16, W/16]
        self.domain_avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.domain_classifier = nn.Sequential(
            nn.Linear(1024, 256),  # Layer3의 채널 수는 1024
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_domains)
        )
        
        # 클래스 분류를 위한 헤드 (최종 특징에 적용)
        self.class_classifier = nn.Sequential(
            nn.Linear(2048, 512),  # 최종 출력의 채널 수는 2048
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # 중간 결과를 저장할 리스트
        activations = []
        
        # ResNet 블록을 하나씩 통과시키며 중간 결과 저장
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        activations.append(x)  # 저수준 특징
        
        x = self.backbone.layer2(x)
        activations.append(x)  # 중간 수준 특징 - 도메인 분류에 사용
        
        x = self.backbone.layer3(x)
        activations.append(x)  # 중간-고수준 특징
        
        x = self.backbone.layer4(x)
        activations.append(x)  # 고수준 특징
        
        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        
        # 클래스 분류에는 전체 특징 사용
        class_out = self.class_classifier(x)
        
        # 도메인 분류에는 중간 수준 특징(layer3 출력) 사용
        domain_mid_feat = activations[2]  # Layer3 출력 
        domain_mid_feat = self.domain_avgpool(domain_mid_feat)  # [B, 512, 1, 1]
        domain_mid_feat = torch.flatten(domain_mid_feat, 1)  # [B, 512]
        domain_out = self.domain_classifier(domain_mid_feat)
        
        return domain_out, class_out

# 데이터 변환 (RandomResizedCrop 제거)
transform_train = transforms.Compose([
    transforms.Resize((224,224)),  # 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 데이터셋 로드
trainset = OfficeHomeDataset(root_dir='./Dataset/train', transform=transform_train)
testset = OfficeHomeDataset(root_dir='./Dataset/test', transform=transform_test)

# DataLoader 생성
trainloader = DataLoader(
    trainset, 
    batch_size=config["batch_size"], 
    shuffle=True, 
    pin_memory=config["pin_memory"], 
    num_workers=config["num_workers"]
)

testloader = DataLoader(
    testset, 
    batch_size=config["batch_size"], 
    shuffle=False, 
    pin_memory=config["pin_memory"], 
    num_workers=config["num_workers"]
)

print(f"Train set size: {len(trainset)}")
print(f"Test set size: {len(testset)}")

def train(model, trainloader, domain_criterion, class_criterion, optimizer, device, epoch):
    """
    멀티태스크 학습 함수 (도메인 + 클래스 분류)
    """
    model.train()
    start_time = time.time()
    running_loss = 0.0
    running_domain_loss = 0.0  # 도메인 손실 별도 추적
    running_class_loss = 0.0   # 클래스 손실 별도 추적
    domain_correct = 0
    class_correct = 0
    total = 0
    
    # mAP 계산기 초기화
    domain_map = MulticlassAveragePrecision(num_classes=config["num_domains"], average='macro')
    class_map = MulticlassAveragePrecision(num_classes=config["num_classes"], average='macro')
    
    domain_map = domain_map.to(device)
    class_map = class_map.to(device)
    
    for i, (inputs, domain_labels, class_labels) in enumerate(trainloader):
        inputs = inputs.to(device)
        domain_labels = domain_labels.to(device)
        class_labels = class_labels.to(device)
        
        # 그래디언트 초기화
        optimizer.zero_grad()
        
        # 모델 전방 전파
        domain_outputs, class_outputs = model(inputs)
        
        # 손실 계산
        domain_loss = domain_criterion(domain_outputs, domain_labels)
        class_loss = class_criterion(class_outputs, class_labels)
        loss = config["domain_weight"] * domain_loss + config["class_weight"] * class_loss
        
        # 역전파 및 최적화
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_domain_loss += domain_loss.item()  # 도메인 손실 누적
        running_class_loss += class_loss.item()    # 클래스 손실 누적
        
        # 정확도 계산
        _, domain_preds = domain_outputs.max(1)
        domain_correct += domain_preds.eq(domain_labels).sum().item()
        
        _, class_preds = class_outputs.max(1)
        class_correct += class_preds.eq(class_labels).sum().item()
        
        total += inputs.size(0)
        
        # mAP 업데이트
        domain_map.update(domain_outputs, domain_labels)
        class_map.update(class_outputs, class_labels)
        
        if (i + 1) % 20 == 0:
            print(f'Epoch [{epoch+1}], Batch [{i+1}/{len(trainloader)}], Loss: {loss.item():.4f}, '
                  f'Domain Loss: {domain_loss.item():.4f}, Class Loss: {class_loss.item():.4f}, '
                  f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
    # 에폭 통계
    epoch_loss = running_loss / len(trainloader)
    epoch_domain_loss = running_domain_loss / len(trainloader)  # 평균 도메인 손실
    epoch_class_loss = running_class_loss / len(trainloader)    # 평균 클래스 손실
    domain_accuracy = 100.0 * domain_correct / total
    class_accuracy = 100.0 * class_correct / total
    
    # mAP 계산
    domain_map_value = domain_map.compute().item()
    class_map_value = class_map.compute().item()
    
    train_time = time.time() - start_time
    
    # 학습 세트에 대한 성능 출력
    print(f'Train set: Epoch: {epoch+1}, Avg loss: {epoch_loss:.4f}, '
          f'Domain Loss: {epoch_domain_loss:.4f}, Class Loss: {epoch_class_loss:.4f}, '
          f'Domain Acc: {domain_accuracy:.2f}%, Class Acc: {class_accuracy:.2f}%, '
          f'Domain mAP: {domain_map_value:.4f}, Class mAP: {class_map_value:.4f}, '
          f'Time: {train_time:.2f}s')
    
    return epoch_loss, epoch_domain_loss, epoch_class_loss, domain_accuracy, class_accuracy, domain_map_value, class_map_value

def evaluate(model, dataloader, domain_criterion, class_criterion, device, epoch):
    """
    멀티태스크 평가 함수
    """
    model.eval()
    start_time = time.time()
    running_loss = 0.0
    running_domain_loss = 0.0  # 도메인 손실 별도 추적
    running_class_loss = 0.0   # 클래스 손실 별도 추적
    domain_correct = 0
    class_correct = 0
    total = 0
    
    # mAP 계산기 초기화
    domain_map = MulticlassAveragePrecision(num_classes=config["num_domains"], average='macro')
    class_map = MulticlassAveragePrecision(num_classes=config["num_classes"], average='macro')
    
    domain_map = domain_map.to(device)
    class_map = class_map.to(device)
    
    with torch.no_grad():
        for inputs, domain_labels, class_labels in dataloader:
            inputs = inputs.to(device)
            domain_labels = domain_labels.to(device)
            class_labels = class_labels.to(device)
            
            # 순전파
            domain_outputs, class_outputs = model(inputs)
            
            # 손실 계산
            domain_loss = domain_criterion(domain_outputs, domain_labels)
            class_loss = class_criterion(class_outputs, class_labels)
            loss = config["domain_weight"] * domain_loss + config["class_weight"] * class_loss
            
            running_loss += loss.item()
            running_domain_loss += domain_loss.item()  # 도메인 손실 누적
            running_class_loss += class_loss.item()    # 클래스 손실 누적
            
            # 정확도 계산
            _, domain_preds = domain_outputs.max(1)
            domain_correct += domain_preds.eq(domain_labels).sum().item()
            
            _, class_preds = class_outputs.max(1)
            class_correct += class_preds.eq(class_labels).sum().item()
            
            total += inputs.size(0)
            
            # mAP 업데이트
            domain_map.update(domain_outputs, domain_labels)
            class_map.update(class_outputs, class_labels)
    
    # 평균 손실 및 정확도 계산
    eval_loss = running_loss / len(dataloader)
    eval_domain_loss = running_domain_loss / len(dataloader)  # 평균 도메인 손실
    eval_class_loss = running_class_loss / len(dataloader)    # 평균 클래스 손실
    domain_accuracy = 100.0 * domain_correct / total
    class_accuracy = 100.0 * class_correct / total
    
    # mAP 계산
    domain_map_value = domain_map.compute().item()
    class_map_value = class_map.compute().item()
    
    # 평가 시간 계산
    eval_time = time.time() - start_time
    
    # 테스트 세트에 대한 성능 출력
    print(f'Test set: Epoch: {epoch+1}, Avg loss: {eval_loss:.4f}, '
          f'Domain Loss: {eval_domain_loss:.4f}, Class Loss: {eval_class_loss:.4f}, '
          f'Domain Acc: {domain_accuracy:.2f}%, Class Acc: {class_accuracy:.2f}%, '
          f'Domain mAP: {domain_map_value:.4f}, Class mAP: {class_map_value:.4f}, '
          f'Time: {eval_time:.2f}s')
    print()
    
    return eval_loss, eval_domain_loss, eval_class_loss, domain_accuracy, class_accuracy, domain_map_value, class_map_value

# 체크포인트 저장 함수 추가
def save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, filename=CHECKPOINT_PATH):
    """
    학습 상태 저장 함수
    """
    # 모델이 DataParallel로 감싸져 있는 경우 처리
    model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state_dict,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),  # 스케줄러 상태 저장 추가
        'best_class_map': best_class_map,
        'best_domain_map': best_domain_map,
        'early_stopping_counter': early_stopping.counter,
        'early_stopping_best_score': early_stopping.best_score,
        'early_stopping_best_epoch': early_stopping.best_epoch,
        'early_stopping_early_stop': early_stopping.early_stop,
        'config': config,  # 설정값도 저장
    }
    torch.save(checkpoint, filename)
    print(f"체크포인트가 {filename}에 저장되었습니다.")

# 체크포인트 로드 함수 추가
def load_checkpoint(model, optimizer, scheduler, early_stopping, filename=CHECKPOINT_PATH):
    """
    학습 상태 로드 함수
    """
    if not os.path.exists(filename):
        print(f"체크포인트 파일 {filename}이 존재하지 않습니다. 처음부터 학습을 시작합니다.")
        return model, optimizer, scheduler, early_stopping, 0, 0.0, 0.0
    
    print(f"체크포인트 {filename}을 로드합니다.")
    checkpoint = torch.load(filename)
    
    # 모델이 DataParallel로 감싸져 있는 경우 처리
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # 스케줄러 상태 로드 추가
    if 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # 조기 중단 상태 복원
    early_stopping.counter = checkpoint['early_stopping_counter']
    early_stopping.best_score = checkpoint['early_stopping_best_score']
    early_stopping.best_epoch = checkpoint['early_stopping_best_epoch']
    early_stopping.early_stop = checkpoint['early_stopping_early_stop']
    
    # 기타 학습 상태
    start_epoch = checkpoint['epoch'] + 1  # 다음 에폭부터 시작
    best_class_map = checkpoint['best_class_map']
    best_domain_map = checkpoint['best_domain_map']
    
    print(f"체크포인트에서 로드 완료: 에폭 {start_epoch}부터 시작합니다.")
    print(f"이전 최고 성능: Class mAP: {best_class_map:.4f}, Domain mAP: {best_domain_map:.4f}")
    
    return model, optimizer, scheduler, early_stopping, start_epoch, best_class_map, best_domain_map

# 메인 학습 루프
def main_training_loop(model, trainloader, testloader, domain_criterion, class_criterion, optimizer, scheduler, device, num_epochs=None, patience=None, max_epochs_wait=None):
    """
    메인 학습 루프 (mAP 기준 early stopping)
    """
    # config에서 값 가져오기
    if num_epochs is None:
        num_epochs = config["num_epochs"]
    if patience is None:
        patience = config["patience"]
    if max_epochs_wait is None:
        max_epochs_wait = config["max_epochs_wait"]
        
    # mAP 기반 얼리 스토핑 초기화
    early_stopping = AccuracyEarlyStopping(patience=patience, verbose=True, path='checkpoint.pt', max_epochs=max_epochs_wait)
    

    start_epoch = 0
    best_class_map = 0.0
    best_domain_map = 0.0
    
    # tqdm을 사용한 진행 상황 표시
    for epoch in tqdm(range(start_epoch, num_epochs)):
        # 학습
        train_loss, train_domain_loss, train_class_loss, train_domain_acc, train_class_acc, train_domain_map, train_class_map = train(
            model, 
            trainloader, 
            domain_criterion, 
            class_criterion, 
            optimizer, 
            device, 
            epoch
        )
        
        # 테스트 데이터로 평가
        test_loss, test_domain_loss, test_class_loss, test_domain_acc, test_class_acc, test_domain_map, test_class_map = evaluate(
            model, 
            testloader, 
            domain_criterion, 
            class_criterion, 
            device, 
            epoch
        )
        
        # 학습률 조정 - 검증 성능에 따라 스케줄러 업데이트
        avg_map = (test_domain_map + test_class_map) / 2
        scheduler.step(avg_map)  # 스케줄러 호출 추가
        
        # WandB에 로깅 (도메인 손실과 클래스 손실 별도 로깅)
        wandb.log({
            "epoch": epoch + 1,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "train_loss": train_loss,
            "train_domain_loss": train_domain_loss,
            "train_class_loss": train_class_loss,
            "train_domain_accuracy": train_domain_acc,
            "train_class_accuracy": train_class_acc,
            "train_domain_map": train_domain_map,
            "train_class_map": train_class_map,
            "test_loss": test_loss,
            "test_domain_loss": test_domain_loss,
            "test_class_loss": test_class_loss,
            "test_domain_accuracy": test_domain_acc,
            "test_class_accuracy": test_class_acc,
            "test_domain_map": test_domain_map,
            "test_class_map": test_class_map
        })
            
        # 최고 클래스 mAP 모델 저장
        if test_class_map > best_class_map:
            best_class_map = test_class_map
            best_domain_map_at_best_class = test_domain_map
            print(f'새로운 최고 Class mAP: {best_class_map:.4f}, Domain mAP: {best_domain_map_at_best_class:.4f}')
            # 모델 저장
            model_path = f'best_model_class_{wandb.run.name}.pth'
            # 모델이 DataParallel로 감싸져 있는 경우 처리
            model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(model_state_dict, model_path)
            wandb.save(model_path)
        
        # 최고 도메인 mAP 모델 저장
        if test_domain_map > best_domain_map:
            best_domain_map = test_domain_map
            print(f'새로운 최고 Domain mAP: {best_domain_map:.4f}')
            # 모델 저장
            model_path = f'best_model_domain_{wandb.run.name}.pth'
            # 모델이 DataParallel로 감싸져 있는 경우 처리
            model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(model_state_dict, model_path)
            wandb.save(model_path)

        # 주기적으로 체크포인트 저장 (설정한 간격마다)
        if (epoch + 1) % config["save_every"] == 0:
            save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping)

        # Early stopping 체크 (클래스 mAP 기준)
        early_stopping(test_class_map, model, epoch)
        
        # 매 에폭 후에 체크포인트 저장 (가장 최신 상태)
        save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, 
                       filename=os.path.join('checkpoints', 'latest_checkpoint.pth'))
        
        if early_stopping.early_stop:
            print(f"에폭 {epoch+1}에서 학습 조기 종료. 최고 성능 에폭: {early_stopping.best_epoch+1}")
            break
    
    # 훈련 완료 후 최고 모델 로드
    print("최고 클래스 mAP 모델 로드 중...")
    model_path = f'best_model_class_{wandb.run.name}.pth'
    if os.path.exists(model_path):
        # 모델이 DataParallel로 감싸져 있는 경우 처리
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(torch.load(model_path))
    else:
        print(f"경고: {model_path} 파일이 존재하지 않습니다. 최종 모델을 사용합니다.")

    # 최종 테스트 평가
    final_test_loss, final_test_domain_loss, final_test_class_loss, final_test_domain_acc, final_test_class_acc, final_test_domain_map, final_test_class_map = evaluate(
        model, testloader, domain_criterion, class_criterion, device, num_epochs-1
    )
    
    print(f'완료! 최고 Class mAP: {best_class_map:.4f}, 최고 Domain mAP: {best_domain_map:.4f}')
    
    # WandB에 최종 결과 기록
    wandb.run.summary["best_class_map"] = best_class_map
    wandb.run.summary["best_domain_map"] = best_domain_map
    wandb.run.summary["final_test_class_map"] = final_test_class_map
    wandb.run.summary["final_test_domain_map"] = final_test_domain_map

    # Early stopping 정보 저장
    if early_stopping.early_stop:
        wandb.run.summary["early_stopped"] = True
        wandb.run.summary["early_stopped_epoch"] = epoch+1
        wandb.run.summary["best_epoch"] = early_stopping.best_epoch+1
    else:
        wandb.run.summary["early_stopped"] = False
        
    # 최종 체크포인트 저장
    save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, 
                   filename=os.path.join('checkpoints', 'final_checkpoint.pth'))

# 메인 실행 코드
if __name__ == "__main__":
    # 시드 설정
    seed = config["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # 결정적 알고리즘 사용 여부
    if config["deterministic"]:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True
    
    # 디바이스 설정
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 모델 초기화 - 새로운 중간 특징 추출 모델 사용
    model = FeatureExtractingMultiTaskModel(
        num_domains=config["num_domains"], 
        num_classes=config["num_classes"]
    ).to(device)
    
    # 학습 가능한 파라미터 확인
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"학습 가능한 파라미터: {trainable_params:,} / {total_params:,} ({trainable_params / total_params:.2%})")
    
    # 손실 함수
    domain_criterion = nn.CrossEntropyLoss()
    class_criterion = nn.CrossEntropyLoss()
    
    # 학습률 차별화 (선택적)
    optimizer = optim.Adam([
        {'params': model.domain_classifier.parameters(), 'lr': config["learning_rate"]},  # 도메인 분류기
        {'params': model.class_classifier.parameters(), 'lr': config["learning_rate"]}    # 클래스 분류기
    ])
    
    # 스케줄러 초기화
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode=config["scheduler_mode"], 
        factor=config["scheduler_factor"], 
        patience=config["scheduler_patience"], 
        verbose=config["scheduler_verbose"]
    )
    
    # WandB에 모델 구조 기록
    wandb.watch(model, log="all")
    
    # GPU 가속 
    if torch.cuda.device_count() > 1:
        print(f"{torch.cuda.device_count()}개의 GPU를 사용합니다.")
        model = nn.DataParallel(model)
    
    # 훈련 시작 시간 기록
    start_time = time.time()
    
    # 메인 학습 루프 호출 
    try:
        main_training_loop(
            model=model,
            trainloader=trainloader,
            testloader=testloader,
            domain_criterion=domain_criterion,
            class_criterion=class_criterion,
            optimizer=optimizer,
            scheduler=scheduler,  # 스케줄러 전달 추가
            device=device
        )
    except KeyboardInterrupt:
        # 사용자가 Ctrl+C로 중단한 경우, 현재 상태 저장
        print("학습이 사용자에 의해 중단되었습니다. 현재 상태를 저장합니다.")
        early_stopping = AccuracyEarlyStopping(patience=config["patience"], verbose=True)
        save_checkpoint(model, optimizer, scheduler, 0, 0.0, 0.0, early_stopping, 
                       filename=os.path.join('checkpoints', 'interrupted_checkpoint.pth'))
    
    # 훈련 종료 시간 및 출력
    end_time = time.time()
    total_time = end_time - start_time
    wandb.log({"total_training_time": total_time})
    
    print(f"전체 학습 시간: {total_time:.2f} 초")
    
    # WandB 실행 종료
    wandb.finish()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/guswls/.netrc


총 11113 이미지, 65 클래스, 4 도메인을 로드했습니다.
총 3213 이미지, 65 클래스, 4 도메인을 로드했습니다.
Train set size: 11113
Test set size: 3213
Using device: cuda




학습 가능한 파라미터: 1,345,861 / 26,902,893 (5.00%)
2개의 GPU를 사용합니다.


  0%|                                                                                                       | 0/300 [00:00<?, ?it/s]

Epoch [1], Batch [20/44], Loss: 1.9314, Domain Loss: 1.2300, Class Loss: 2.6328, LR: 0.001000
Epoch [1], Batch [40/44], Loss: 1.3973, Domain Loss: 0.9719, Class Loss: 1.8228, LR: 0.001000
Train set: Epoch: 1, Avg loss: 1.9138, Domain Loss: 1.1782, Class Loss: 2.6493, Domain Acc: 52.68%, Class Acc: 38.80%, Domain mAP: 0.5320, Class mAP: 0.3419, Time: 74.77s
Test set: Epoch: 1, Avg loss: 1.2668, Domain Loss: 1.0102, Class Loss: 1.5233, Domain Acc: 60.97%, Class Acc: 63.90%, Domain mAP: 0.6613, Class mAP: 0.7282, Time: 23.86s

새로운 최고 Class mAP: 0.7282, Domain mAP: 0.6613
새로운 최고 Domain mAP: 0.6613
Accuracy improved (-inf% --> 0.73%). Saving model ...


  0%|▎                                                                                            | 1/300 [01:39<8:16:25, 99.62s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [2], Batch [20/44], Loss: 1.1512, Domain Loss: 0.8358, Class Loss: 1.4667, LR: 0.001000
Epoch [2], Batch [40/44], Loss: 1.0923, Domain Loss: 0.7778, Class Loss: 1.4069, LR: 0.001000
Train set: Epoch: 2, Avg loss: 1.1306, Domain Loss: 0.8470, Class Loss: 1.4141, Domain Acc: 68.18%, Class Acc: 62.96%, Domain mAP: 0.6671, Class mAP: 0.6378, Time: 69.23s
Test set: Epoch: 2, Avg loss: 0.9854, Domain Loss: 0.8214, Class Loss: 1.1493, Domain Acc: 67.29%, Class Acc: 70.37%, Domain mAP: 0.7116, Class mAP: 0.7778, Time: 23.39s

새로운 최고 Class mAP: 0.7778, Domain mAP: 0.7116
새로운 최고 Domain mAP: 0.7116
Accuracy improved (0.73% --> 0.78%). Saving model ...


  1%|▌                                                                                            | 2/300 [03:13<7:57:54, 96.22s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [3], Batch [20/44], Loss: 0.9164, Domain Loss: 0.7076, Class Loss: 1.1252, LR: 0.001000
Epoch [3], Batch [40/44], Loss: 0.9141, Domain Loss: 0.7086, Class Loss: 1.1196, LR: 0.001000
Train set: Epoch: 3, Avg loss: 0.9569, Domain Loss: 0.7372, Class Loss: 1.1767, Domain Acc: 71.25%, Class Acc: 68.01%, Domain mAP: 0.7076, Class mAP: 0.7099, Time: 69.24s
Test set: Epoch: 3, Avg loss: 0.8989, Domain Loss: 0.7581, Class Loss: 1.0397, Domain Acc: 68.72%, Class Acc: 71.68%, Domain mAP: 0.7450, Class mAP: 0.8022, Time: 22.96s

새로운 최고 Class mAP: 0.8022, Domain mAP: 0.7450
새로운 최고 Domain mAP: 0.7450
Accuracy improved (0.78% --> 0.80%). Saving model ...


  1%|▉                                                                                            | 3/300 [04:46<7:49:42, 94.89s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [4], Batch [20/44], Loss: 0.8314, Domain Loss: 0.6730, Class Loss: 0.9898, LR: 0.001000
Epoch [4], Batch [40/44], Loss: 0.8215, Domain Loss: 0.6415, Class Loss: 1.0016, LR: 0.001000
Train set: Epoch: 4, Avg loss: 0.8659, Domain Loss: 0.6839, Class Loss: 1.0479, Domain Acc: 73.68%, Class Acc: 71.18%, Domain mAP: 0.7352, Class mAP: 0.7526, Time: 73.12s
Test set: Epoch: 4, Avg loss: 0.8324, Domain Loss: 0.7239, Class Loss: 0.9408, Domain Acc: 71.12%, Class Acc: 74.14%, Domain mAP: 0.7656, Class mAP: 0.8172, Time: 22.55s

새로운 최고 Class mAP: 0.8172, Domain mAP: 0.7656
새로운 최고 Domain mAP: 0.7656
Accuracy improved (0.80% --> 0.82%). Saving model ...


  1%|█▏                                                                                           | 4/300 [06:23<7:51:56, 95.66s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [5], Batch [20/44], Loss: 0.8112, Domain Loss: 0.7446, Class Loss: 0.8778, LR: 0.001000
Epoch [5], Batch [40/44], Loss: 0.8253, Domain Loss: 0.7110, Class Loss: 0.9397, LR: 0.001000
Train set: Epoch: 5, Avg loss: 0.7907, Domain Loss: 0.6423, Class Loss: 0.9391, Domain Acc: 75.49%, Class Acc: 74.01%, Domain mAP: 0.7606, Class mAP: 0.7822, Time: 69.34s
Test set: Epoch: 5, Avg loss: 0.8127, Domain Loss: 0.7036, Class Loss: 0.9218, Domain Acc: 71.37%, Class Acc: 74.14%, Domain mAP: 0.7807, Class mAP: 0.8231, Time: 23.13s

새로운 최고 Class mAP: 0.8231, Domain mAP: 0.7807
새로운 최고 Domain mAP: 0.7807
체크포인트가 checkpoints/ResNet50_MultiTask_layer3 Extraction - Style, all - Category_checkpoint.pth에 저장되었습니다.
Accuracy improved (0.82% --> 0.82%). Saving model ...


  2%|█▌                                                                                           | 5/300 [07:57<7:46:59, 94.98s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [6], Batch [20/44], Loss: 0.7708, Domain Loss: 0.6634, Class Loss: 0.8782, LR: 0.001000
Epoch [6], Batch [40/44], Loss: 0.6892, Domain Loss: 0.5349, Class Loss: 0.8436, LR: 0.001000
Train set: Epoch: 6, Avg loss: 0.7424, Domain Loss: 0.6133, Class Loss: 0.8714, Domain Acc: 76.59%, Class Acc: 75.35%, Domain mAP: 0.7767, Class mAP: 0.8054, Time: 70.42s
Test set: Epoch: 6, Avg loss: 0.7848, Domain Loss: 0.6681, Class Loss: 0.9015, Domain Acc: 73.61%, Class Acc: 74.76%, Domain mAP: 0.7915, Class mAP: 0.8287, Time: 23.38s

새로운 최고 Class mAP: 0.8287, Domain mAP: 0.7915
새로운 최고 Domain mAP: 0.7915
Accuracy improved (0.82% --> 0.83%). Saving model ...


  2%|█▊                                                                                           | 6/300 [09:32<7:45:26, 94.99s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [22], Batch [40/44], Loss: 0.4889, Domain Loss: 0.4825, Class Loss: 0.4953, LR: 0.001000
Train set: Epoch: 22, Avg loss: 0.4365, Domain Loss: 0.4592, Class Loss: 0.4138, Domain Acc: 82.68%, Class Acc: 87.47%, Domain mAP: 0.8591, Class mAP: 0.9331, Time: 71.43s
Test set: Epoch: 22, Avg loss: 0.7125, Domain Loss: 0.5985, Class Loss: 0.8265, Domain Acc: 77.81%, Class Acc: 78.24%, Domain mAP: 0.8353, Class mAP: 0.8494, Time: 24.13s

EarlyStopping 카운터: 1 / 30


  7%|██████▋                                                                                     | 22/300 [34:46<7:19:46, 94.92s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [23], Batch [20/44], Loss: 0.3949, Domain Loss: 0.4086, Class Loss: 0.3811, LR: 0.001000
