In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
import sys
import os
import time
import random
import numpy as np
import wandb
from tqdm import tqdm
from utility.utils import DualMAPEarlyStopping  # DualMAPEarlyStopping으로 변경
from torchmetrics.classification import MulticlassAveragePrecision
from Dataset.data import OfficeHomeDataset  # 기존 데이터셋 클래스 임포트
from models.resnet_dilated import ResnetDilated  # ResnetDilated 임포트

# 체크포인트 디렉토리 생성
os.makedirs('checkpoints', exist_ok=True)

# WandB 설정 
wandb.login(key="ef091b9abcea3186341ddf8995d62bde62d7469e")
wandb.init(
    project="office-home-classification", 
    name="Dilated_ResNet50_MultiTask_layer3 - Style, all_Category",
    entity="hh0804352-hanyang-university"
)

# wandb run name을 체크포인트 경로에 사용
run_name = wandb.run.name
CHECKPOINT_PATH = os.path.join('checkpoints', f'{run_name}_checkpoint.pth')


# 설정
config = {
    # 모델 설정
    "model": "resnet50",
    "batch_size": 256,
    "num_epochs": 300,
    
    "learning_rate": 0.001,  # Adam 기본 학습률
    "optimizer": "Adam",
    
    # 학습 과정 설정
    "seed": 2025,
    "deterministic": False,
    "patience": 30,  # early stopping patience
    "max_epochs_wait": float('inf'),
    
    # 멀티태스크 설정
    "num_domains": 4,
    "num_classes": 65,
    "domain_weight": 0.5,  # 도메인 분류 손실 가중치
    "class_weight": 0.5,   # 클래스 분류 손실 가중치
    
    # 시스템 설정
    "num_workers": 24,
    "pin_memory": True,
    
    # 체크포인트 설정
    "save_every": 5,  # 몇 epoch마다 저장할지
    
    # 스케줄러 설정
    "scheduler": "ReduceLROnPlateau",
    "scheduler_mode": "max",
    "scheduler_factor": 0.1,
    "scheduler_patience": 5,
    "scheduler_verbose": True,
}

wandb.config.update(config)

class ImprovedMultiTaskModel(nn.Module):
    def __init__(self, num_domains=4, num_classes=65):
        super(ImprovedMultiTaskModel, self).__init__()
        
        # 공유 백본 - ResnetDilated 사용
        pretrained_resnet = models.resnet50(pretrained=True)
        self.backbone = ResnetDilated(pretrained_resnet, dilate_scale=8)
        
        # 백본 동결 (선택적)
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # 도메인 분류를 위한 분기점 (layer3 출력에서)
        self.domain_branch = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=1),  # 채널 수 감소
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
        # 클래스 분류를 위한 분기점 (layer4 출력에서)
        self.class_branch = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=1),  # 채널 수 감소
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
        # 풀링 레이어
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 분류 헤드
        self.domain_classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_domains)
        )
        
        self.class_classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        # 백본의 각 스테이지 출력 추출
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu1(x)  # relu1으로 수정 (ResnetDilated에서 정의된 이름과 일치)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        
        # Layer3 출력 저장 (도메인 분류용)
        layer3_output = self.backbone.layer3(x)
        
        # Layer4 출력 (클래스 분류용)
        layer4_output = self.backbone.layer4(layer3_output)
        
        # 도메인 분류 경로
        domain_features = self.domain_branch(layer3_output)
        domain_features = self.avgpool(domain_features)
        domain_features = torch.flatten(domain_features, 1)
        domain_out = self.domain_classifier(domain_features)
        
        # 클래스 분류 경로
        class_features = self.class_branch(layer4_output)
        class_features = self.avgpool(class_features)
        class_features = torch.flatten(class_features, 1)
        class_out = self.class_classifier(class_features)
        
        return domain_out, class_out


# 데이터 변환 (RandomResizedCrop 제거)
transform_train = transforms.Compose([
    transforms.Resize((224,224)),  # 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 데이터셋 로드
trainset = OfficeHomeDataset(root_dir='./Dataset/train', transform=transform_train)
testset = OfficeHomeDataset(root_dir='./Dataset/test', transform=transform_test)

# DataLoader 생성
trainloader = DataLoader(
    trainset, 
    batch_size=config["batch_size"], 
    shuffle=True, 
    pin_memory=config["pin_memory"], 
    num_workers=config["num_workers"]
)

testloader = DataLoader(
    testset, 
    batch_size=config["batch_size"], 
    shuffle=False, 
    pin_memory=config["pin_memory"], 
    num_workers=config["num_workers"]
)

print(f"Train set size: {len(trainset)}")
print(f"Test set size: {len(testset)}")

def train(model, trainloader, domain_criterion, class_criterion, optimizer, device, epoch):
    """
    멀티태스크 학습 함수 (도메인 + 클래스 분류)
    """
    model.train()
    start_time = time.time()
    running_loss = 0.0
    running_domain_loss = 0.0  # 도메인 손실 별도 추적
    running_class_loss = 0.0   # 클래스 손실 별도 추적
    domain_correct = 0
    class_correct = 0
    total = 0
    
    # mAP 계산기 초기화
    domain_map = MulticlassAveragePrecision(num_classes=config["num_domains"], average='macro')
    class_map = MulticlassAveragePrecision(num_classes=config["num_classes"], average='macro')
    
    domain_map = domain_map.to(device)
    class_map = class_map.to(device)
    
    for i, (inputs, domain_labels, class_labels) in enumerate(trainloader):
        inputs = inputs.to(device)
        domain_labels = domain_labels.to(device)
        class_labels = class_labels.to(device)
        
        # 그래디언트 초기화
        optimizer.zero_grad()
        
        # 모델 전방 전파
        domain_outputs, class_outputs = model(inputs)
        
        # 손실 계산
        domain_loss = domain_criterion(domain_outputs, domain_labels)
        class_loss = class_criterion(class_outputs, class_labels)
        loss = config["domain_weight"] * domain_loss + config["class_weight"] * class_loss
        
        # 역전파 및 최적화
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_domain_loss += domain_loss.item()  # 도메인 손실 누적
        running_class_loss += class_loss.item()    # 클래스 손실 누적
        
        # 정확도 계산
        _, domain_preds = domain_outputs.max(1)
        domain_correct += domain_preds.eq(domain_labels).sum().item()
        
        _, class_preds = class_outputs.max(1)
        class_correct += class_preds.eq(class_labels).sum().item()
        
        total += inputs.size(0)
        
        # mAP 업데이트
        domain_map.update(domain_outputs, domain_labels)
        class_map.update(class_outputs, class_labels)
        
        if (i + 1) % 20 == 0:
            print(f'Epoch [{epoch+1}], Batch [{i+1}/{len(trainloader)}], Loss: {loss.item():.4f}, '
                  f'Domain Loss: {domain_loss.item():.4f}, Class Loss: {class_loss.item():.4f}, '
                  f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
    # 에폭 통계
    epoch_loss = running_loss / len(trainloader)
    epoch_domain_loss = running_domain_loss / len(trainloader)  # 평균 도메인 손실
    epoch_class_loss = running_class_loss / len(trainloader)    # 평균 클래스 손실
    domain_accuracy = 100.0 * domain_correct / total
    class_accuracy = 100.0 * class_correct / total
    
    # mAP 계산
    domain_map_value = domain_map.compute().item()
    class_map_value = class_map.compute().item()
    
    train_time = time.time() - start_time
    
    # 학습 세트에 대한 성능 출력
    print(f'Train set: Epoch: {epoch+1}, Avg loss: {epoch_loss:.4f}, '
          f'Domain Loss: {epoch_domain_loss:.4f}, Class Loss: {epoch_class_loss:.4f}, '
          f'Domain Acc: {domain_accuracy:.2f}%, Class Acc: {class_accuracy:.2f}%, '
          f'Domain mAP: {domain_map_value:.4f}, Class mAP: {class_map_value:.4f}, '
          f'Time: {train_time:.2f}s')
    
    return epoch_loss, epoch_domain_loss, epoch_class_loss, domain_accuracy, class_accuracy, domain_map_value, class_map_value

def evaluate(model, dataloader, domain_criterion, class_criterion, device, epoch):
    """
    멀티태스크 평가 함수
    """
    model.eval()
    start_time = time.time()
    running_loss = 0.0
    running_domain_loss = 0.0  # 도메인 손실 별도 추적
    running_class_loss = 0.0   # 클래스 손실 별도 추적
    domain_correct = 0
    class_correct = 0
    total = 0
    
    # mAP 계산기 초기화
    domain_map = MulticlassAveragePrecision(num_classes=config["num_domains"], average='macro')
    class_map = MulticlassAveragePrecision(num_classes=config["num_classes"], average='macro')
    
    domain_map = domain_map.to(device)
    class_map = class_map.to(device)
    
    with torch.no_grad():
        for inputs, domain_labels, class_labels in dataloader:
            inputs = inputs.to(device)
            domain_labels = domain_labels.to(device)
            class_labels = class_labels.to(device)
            
            # 순전파
            domain_outputs, class_outputs = model(inputs)
            
            # 손실 계산
            domain_loss = domain_criterion(domain_outputs, domain_labels)
            class_loss = class_criterion(class_outputs, class_labels)
            loss = config["domain_weight"] * domain_loss + config["class_weight"] * class_loss
            
            running_loss += loss.item()
            running_domain_loss += domain_loss.item()  # 도메인 손실 누적
            running_class_loss += class_loss.item()    # 클래스 손실 누적
            
            # 정확도 계산
            _, domain_preds = domain_outputs.max(1)
            domain_correct += domain_preds.eq(domain_labels).sum().item()
            
            _, class_preds = class_outputs.max(1)
            class_correct += class_preds.eq(class_labels).sum().item()
            
            total += inputs.size(0)
            
            # mAP 업데이트
            domain_map.update(domain_outputs, domain_labels)
            class_map.update(class_outputs, class_labels)
    
    # 평균 손실 및 정확도 계산
    eval_loss = running_loss / len(dataloader)
    eval_domain_loss = running_domain_loss / len(dataloader)  # 평균 도메인 손실
    eval_class_loss = running_class_loss / len(dataloader)    # 평균 클래스 손실
    domain_accuracy = 100.0 * domain_correct / total
    class_accuracy = 100.0 * class_correct / total
    
    # mAP 계산
    domain_map_value = domain_map.compute().item()
    class_map_value = class_map.compute().item()
    
    # 평가 시간 계산
    eval_time = time.time() - start_time
    
    # 테스트 세트에 대한 성능 출력
    print(f'Test set: Epoch: {epoch+1}, Avg loss: {eval_loss:.4f}, '
          f'Domain Loss: {eval_domain_loss:.4f}, Class Loss: {eval_class_loss:.4f}, '
          f'Domain Acc: {domain_accuracy:.2f}%, Class Acc: {class_accuracy:.2f}%, '
          f'Domain mAP: {domain_map_value:.4f}, Class mAP: {class_map_value:.4f}, '
          f'Time: {eval_time:.2f}s')
    print()
    
    return eval_loss, eval_domain_loss, eval_class_loss, domain_accuracy, class_accuracy, domain_map_value, class_map_value

# 체크포인트 저장 함수 추가
def save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, filename=CHECKPOINT_PATH):
    """
    학습 상태 저장 함수
    """
    # 모델이 DataParallel로 감싸져 있는 경우 처리
    model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state_dict,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),  # 스케줄러 상태 저장 추가
        'best_class_map': best_class_map,
        'best_domain_map': best_domain_map,
        'early_stopping_counter': early_stopping.counter,
        'early_stopping_domain_best_score': early_stopping.domain_best_score,  # DualMAP 맞게 수정
        'early_stopping_class_best_score': early_stopping.class_best_score,    # DualMAP 맞게 수정
        'early_stopping_best_epoch': early_stopping.best_epoch,
        'early_stopping_domain_best_epoch': early_stopping.domain_best_epoch,  # DualMAP 맞게 추가
        'early_stopping_class_best_epoch': early_stopping.class_best_epoch,    # DualMAP 맞게 추가
        'early_stopping_early_stop': early_stopping.early_stop,
        'early_stopping_domain_map_max': early_stopping.domain_map_max,  # DualMAP 맞게 추가
        'early_stopping_class_map_max': early_stopping.class_map_max,    # DualMAP 맞게 추가
        'config': config,  # 설정값도 저장
    }
    torch.save(checkpoint, filename)
    print(f"체크포인트가 {filename}에 저장되었습니다.")

# 체크포인트 로드 함수 추가
def load_checkpoint(model, optimizer, scheduler, early_stopping, filename=CHECKPOINT_PATH):
    """
    학습 상태 로드 함수
    """
    if not os.path.exists(filename):
        print(f"체크포인트 파일 {filename}이 존재하지 않습니다. 처음부터 학습을 시작합니다.")
        return model, optimizer, scheduler, early_stopping, 0, 0.0, 0.0
    
    print(f"체크포인트 {filename}을 로드합니다.")
    checkpoint = torch.load(filename)
    
    # 모델이 DataParallel로 감싸져 있는 경우 처리
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # 스케줄러 상태 로드 추가
    if 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # 조기 중단 상태 복원 (DualMAP에 맞게 수정)
    early_stopping.counter = checkpoint['early_stopping_counter']
    early_stopping.domain_best_score = checkpoint.get('early_stopping_domain_best_score')
    early_stopping.class_best_score = checkpoint.get('early_stopping_class_best_score')
    early_stopping.best_epoch = checkpoint['early_stopping_best_epoch']
    early_stopping.domain_best_epoch = checkpoint.get('early_stopping_domain_best_epoch', early_stopping.best_epoch)
    early_stopping.class_best_epoch = checkpoint.get('early_stopping_class_best_epoch', early_stopping.best_epoch)
    early_stopping.early_stop = checkpoint['early_stopping_early_stop']
    early_stopping.domain_map_max = checkpoint.get('early_stopping_domain_map_max', -np.Inf)
    early_stopping.class_map_max = checkpoint.get('early_stopping_class_map_max', -np.Inf)
    
    # 기타 학습 상태
    start_epoch = checkpoint['epoch'] + 1  # 다음 에폭부터 시작
    best_class_map = checkpoint['best_class_map']
    best_domain_map = checkpoint['best_domain_map']
    
    print(f"체크포인트에서 로드 완료: 에폭 {start_epoch}부터 시작합니다.")
    print(f"이전 최고 성능: Class mAP: {best_class_map:.4f}, Domain mAP: {best_domain_map:.4f}")
    
    return model, optimizer, scheduler, early_stopping, start_epoch, best_class_map, best_domain_map

# 메인 학습 루프
def main_training_loop(model, trainloader, testloader, domain_criterion, class_criterion, optimizer, scheduler, device, num_epochs=None, patience=None, max_epochs_wait=None):
    """
    메인 학습 루프 (mAP 기준 early stopping)
    """
    # config에서 값 가져오기
    if num_epochs is None:
        num_epochs = config["num_epochs"]
    if patience is None:
        patience = config["patience"]
    if max_epochs_wait is None:
        max_epochs_wait = config["max_epochs_wait"]
        
    # DualMAP 얼리 스토핑 초기화 (AccuracyEarlyStopping에서 변경)
    early_stopping = DualMAPEarlyStopping(
        patience=patience, 
        verbose=True, 
        path='checkpoint.pt', 
        max_epochs=max_epochs_wait
    )
    
    # 체크포인트 로드 시도
    model, optimizer, scheduler, early_stopping, start_epoch, best_class_map, best_domain_map = load_checkpoint(
        model, optimizer, scheduler, early_stopping
    )

    # 이미 로드된 값이 없으면 초기화
    if start_epoch == 0:
        best_class_map = 0.0
        best_domain_map = 0.0
    
    # tqdm을 사용한 진행 상황 표시
    for epoch in tqdm(range(start_epoch, num_epochs)):
        # 학습
        train_loss, train_domain_loss, train_class_loss, train_domain_acc, train_class_acc, train_domain_map, train_class_map = train(
            model, 
            trainloader, 
            domain_criterion, 
            class_criterion, 
            optimizer, 
            device, 
            epoch
        )
        
        # 테스트 데이터로 평가
        test_loss, test_domain_loss, test_class_loss, test_domain_acc, test_class_acc, test_domain_map, test_class_map = evaluate(
            model, 
            testloader, 
            domain_criterion, 
            class_criterion, 
            device, 
            epoch
        )
        
        # 학습률 조정 - 검증 성능에 따라 스케줄러 업데이트
        avg_map = (test_domain_map + test_class_map) / 2
        scheduler.step(avg_map)  # 스케줄러 호출 추가
        
        # WandB에 로깅 (도메인 손실과 클래스 손실 별도 로깅)
        wandb.log({
            "epoch": epoch + 1,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "train_loss": train_loss,
            "train_domain_loss": train_domain_loss,
            "train_class_loss": train_class_loss,
            "train_domain_accuracy": train_domain_acc,
            "train_class_accuracy": train_class_acc,
            "train_domain_map": train_domain_map,
            "train_class_map": train_class_map,
            "test_loss": test_loss,
            "test_domain_loss": test_domain_loss,
            "test_class_loss": test_class_loss,
            "test_domain_accuracy": test_domain_acc,
            "test_class_accuracy": test_class_acc,
            "test_domain_map": test_domain_map,
            "test_class_map": test_class_map
        })
            
        # 최고 클래스 mAP 모델 저장
        if test_class_map > best_class_map:
            best_class_map = test_class_map
            best_domain_map_at_best_class = test_domain_map
            print(f'새로운 최고 Class mAP: {best_class_map:.4f}, Domain mAP: {best_domain_map_at_best_class:.4f}')
            # 모델 저장
            model_path = f'best_model_class_{wandb.run.name}.pth'
            # 모델이 DataParallel로 감싸져 있는 경우 처리
            model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(model_state_dict, model_path)
            wandb.save(model_path)
        
        # 최고 도메인 mAP 모델 저장
        if test_domain_map > best_domain_map:
            best_domain_map = test_domain_map
            print(f'새로운 최고 Domain mAP: {best_domain_map:.4f}')
            # 모델 저장
            model_path = f'best_model_domain_{wandb.run.name}.pth'
            # 모델이 DataParallel로 감싸져 있는 경우 처리
            model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(model_state_dict, model_path)
            wandb.save(model_path)

        # 주기적으로 체크포인트 저장 (설정한 간격마다)
        if (epoch + 1) % config["save_every"] == 0:
            save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping)

        # DualMAP Early stopping 체크 (두 mAP 모두 사용)
        early_stopping(test_domain_map, test_class_map, model, epoch)
        
        # 매 에폭 후에 체크포인트 저장 (가장 최신 상태)
        save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, 
                       filename=os.path.join('checkpoints', 'latest_checkpoint.pth'))
        
        if early_stopping.early_stop:
            print(f"에폭 {epoch+1}에서 학습 조기 종료. 최고 성능 에폭: {early_stopping.best_epoch+1}")
            break
    
    # 훈련 완료 후 최고 모델 로드
    print("최고 클래스 mAP 모델 로드 중...")
    model_path = f'best_model_class_{wandb.run.name}.pth'
    if os.path.exists(model_path):
        # 모델이 DataParallel로 감싸져 있는 경우 처리
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(torch.load(model_path))
    else:
        print(f"경고: {model_path} 파일이 존재하지 않습니다. 최종 모델을 사용합니다.")

    # 최종 테스트 평가
    final_test_loss, final_test_domain_loss, final_test_class_loss, final_test_domain_acc, final_test_class_acc, final_test_domain_map, final_test_class_map = evaluate(
        model, testloader, domain_criterion, class_criterion, device, num_epochs-1
    )
    
    print(f'완료! 최고 Class mAP: {best_class_map:.4f}, 최고 Domain mAP: {best_domain_map:.4f}')
    
    # WandB에 최종 결과 기록
    wandb.run.summary["best_class_map"] = best_class_map
    wandb.run.summary["best_domain_map"] = best_domain_map
    wandb.run.summary["final_test_class_map"] = final_test_class_map
    wandb.run.summary["final_test_domain_map"] = final_test_domain_map

    # Early stopping 정보 저장
    if early_stopping.early_stop:
        wandb.run.summary["early_stopped"] = True
        wandb.run.summary["early_stopped_epoch"] = epoch+1
        wandb.run.summary["best_epoch"] = early_stopping.best_epoch+1
        wandb.run.summary["domain_best_epoch"] = early_stopping.domain_best_epoch+1
        wandb.run.summary["class_best_epoch"] = early_stopping.class_best_epoch+1
    else:
        wandb.run.summary["early_stopped"] = False
        
    # 최종 체크포인트 저장
    save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, 
                   filename=os.path.join('checkpoints', 'final_checkpoint.pth'))

# 메인 실행 코드
if __name__ == "__main__":
    # 시드 설정
    seed = config["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # 결정적 알고리즘 사용 여부
    if config["deterministic"]:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True
    
    # 디바이스 설정
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 모델 초기화 - ImprovedMultiTaskModel 사용 (수정됨)
    model = ImprovedMultiTaskModel(
        num_domains=config["num_domains"], 
        num_classes=config["num_classes"]
    ).to(device)
    
    # 학습 가능한 파라미터 확인
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"학습 가능한 파라미터: {trainable_params:,} / {total_params:,} ({trainable_params / total_params:.2%})")
    
    # 손실 함수
    domain_criterion = nn.CrossEntropyLoss()
    class_criterion = nn.CrossEntropyLoss()
    
    # 학습률 차별화 - domain_branch, class_branch도 포함
    optimizer = optim.Adam([
        {'params': model.domain_branch.parameters(), 'lr': config["learning_rate"]},  # 도메인 분기점
        {'params': model.class_branch.parameters(), 'lr': config["learning_rate"]},   # 클래스 분기점
        {'params': model.domain_classifier.parameters(), 'lr': config["learning_rate"]},  # 도메인 분류기
        {'params': model.class_classifier.parameters(), 'lr': config["learning_rate"]}    # 클래스 분류기
    ])
    
    # 스케줄러 초기화
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode=config["scheduler_mode"], 
        factor=config["scheduler_factor"], 
        patience=config["scheduler_patience"], 
        verbose=config["scheduler_verbose"]
    )
    
    # WandB에 모델 구조 기록
    wandb.watch(model, log="all")
    
    # GPU 가속 
    if torch.cuda.device_count() > 1:
        print(f"{torch.cuda.device_count()}개의 GPU를 사용합니다.")
        model = nn.DataParallel(model)
    
    # 훈련 시작 시간 기록
    start_time = time.time()
    
    # 메인 학습 루프 호출 
    try:
        main_training_loop(
            model=model,
            trainloader=trainloader,
            testloader=testloader,
            domain_criterion=domain_criterion,
            class_criterion=class_criterion,
            optimizer=optimizer,
            scheduler=scheduler,  # 스케줄러 전달 추가
            device=device
        )
    except KeyboardInterrupt:
        # 사용자가 Ctrl+C로 중단한 경우, 현재 상태 저장
        print("학습이 사용자에 의해 중단되었습니다. 현재 상태를 저장합니다.")
        early_stopping = DualMAPEarlyStopping(patience=config["patience"], verbose=True)
        save_checkpoint(model, optimizer, scheduler, 0, 0.0, 0.0, early_stopping, 
                       filename=os.path.join('checkpoints', 'interrupted_checkpoint.pth'))
    
    # 훈련 종료 시간 및 출력
    end_time = time.time()
    total_time = end_time - start_time
    wandb.log({"total_training_time": total_time})
    
    print(f"전체 학습 시간: {total_time:.2f} 초")
    
    # WandB 실행 종료
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/guswls/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msokjh1310[0m ([33msokjh1310-hanyang-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


총 11113 이미지, 65 클래스, 4 도메인을 로드했습니다.
총 3213 이미지, 65 클래스, 4 도메인을 로드했습니다.
Train set size: 11113
Test set size: 3213
Using device: cuda




학습 가능한 파라미터: 1,856,325 / 25,364,357 (7.32%)
2개의 GPU를 사용합니다.
체크포인트 파일 checkpoints/Dilated_ResNet50_MultiTask_layer3 - Style, all_Category_checkpoint.pth이 존재하지 않습니다. 처음부터 학습을 시작합니다.


  0%|                                                                                                       | 0/300 [00:00<?, ?it/s]

Epoch [1], Batch [20/44], Loss: 1.6300, Domain Loss: 0.6882, Class Loss: 2.5718, LR: 0.001000
Epoch [1], Batch [40/44], Loss: 1.1044, Domain Loss: 0.5634, Class Loss: 1.6454, LR: 0.001000
Train set: Epoch: 1, Avg loss: 1.6961, Domain Loss: 0.7673, Class Loss: 2.6250, Domain Acc: 68.68%, Class Acc: 41.71%, Domain mAP: 0.6984, Class mAP: 0.3613, Time: 148.49s
Test set: Epoch: 1, Avg loss: 1.0920, Domain Loss: 0.7076, Class Loss: 1.4763, Domain Acc: 73.20%, Class Acc: 63.96%, Domain mAP: 0.8015, Class mAP: 0.7280, Time: 42.72s

새로운 최고 Class mAP: 0.7280, Domain mAP: 0.8015
새로운 최고 Domain mAP: 0.8015
Performance improved. Saving model ...
Domain mAP: -inf --> 0.8015, Class mAP: -inf --> 0.7280


  0%|▎                                                                                          | 1/300 [03:12<15:57:50, 192.21s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [2], Batch [20/44], Loss: 0.9723, Domain Loss: 0.5098, Class Loss: 1.4347, LR: 0.001000
Epoch [2], Batch [40/44], Loss: 0.7825, Domain Loss: 0.4669, Class Loss: 1.0982, LR: 0.001000
Train set: Epoch: 2, Avg loss: 0.9198, Domain Loss: 0.5195, Class Loss: 1.3201, Domain Acc: 80.01%, Class Acc: 65.85%, Domain mAP: 0.8213, Class mAP: 0.6623, Time: 141.35s
Test set: Epoch: 2, Avg loss: 0.8471, Domain Loss: 0.6234, Class Loss: 1.0707, Domain Acc: 75.47%, Class Acc: 72.46%, Domain mAP: 0.8350, Class mAP: 0.7974, Time: 43.95s

새로운 최고 Class mAP: 0.7974, Domain mAP: 0.8350
새로운 최고 Domain mAP: 0.8350
Domain mAP improved (0.8015 --> 0.8350).
Class mAP improved (0.7280 --> 0.7974).
Performance improved. Saving model ...
Domain mAP: 0.8015 --> 0.8350, Class mAP: 0.7280 --> 0.7974


  1%|▌                                                                                          | 2/300 [06:18<15:38:06, 188.88s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [3], Batch [20/44], Loss: 0.7282, Domain Loss: 0.4228, Class Loss: 1.0336, LR: 0.001000
Epoch [3], Batch [40/44], Loss: 0.6515, Domain Loss: 0.4987, Class Loss: 0.8043, LR: 0.001000
Train set: Epoch: 3, Avg loss: 0.7328, Domain Loss: 0.4465, Class Loss: 1.0192, Domain Acc: 83.23%, Class Acc: 72.98%, Domain mAP: 0.8612, Class mAP: 0.7586, Time: 140.63s
Test set: Epoch: 3, Avg loss: 0.7685, Domain Loss: 0.5254, Class Loss: 1.0116, Domain Acc: 79.83%, Class Acc: 74.07%, Domain mAP: 0.8472, Class mAP: 0.8133, Time: 44.18s

새로운 최고 Class mAP: 0.8133, Domain mAP: 0.8472
새로운 최고 Domain mAP: 0.8472
Domain mAP improved (0.8350 --> 0.8472).
Class mAP improved (0.7974 --> 0.8133).
Performance improved. Saving model ...
Domain mAP: 0.8350 --> 0.8472, Class mAP: 0.7974 --> 0.8133


  1%|▉                                                                                          | 3/300 [09:24<15:28:22, 187.55s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [4], Batch [20/44], Loss: 0.6863, Domain Loss: 0.4717, Class Loss: 0.9010, LR: 0.001000
Epoch [4], Batch [40/44], Loss: 0.6347, Domain Loss: 0.4617, Class Loss: 0.8077, LR: 0.001000
Train set: Epoch: 4, Avg loss: 0.6254, Domain Loss: 0.4085, Class Loss: 0.8424, Domain Acc: 84.53%, Class Acc: 77.05%, Domain mAP: 0.8764, Class mAP: 0.8130, Time: 142.60s
Test set: Epoch: 4, Avg loss: 0.7133, Domain Loss: 0.5155, Class Loss: 0.9110, Domain Acc: 81.11%, Class Acc: 75.75%, Domain mAP: 0.8553, Class mAP: 0.8318, Time: 43.63s

새로운 최고 Class mAP: 0.8318, Domain mAP: 0.8553
새로운 최고 Domain mAP: 0.8553
Domain mAP improved (0.8472 --> 0.8553).
Class mAP improved (0.8133 --> 0.8318).
Performance improved. Saving model ...
Domain mAP: 0.8472 --> 0.8553, Class mAP: 0.8133 --> 0.8318


  1%|█▏                                                                                         | 4/300 [12:32<15:25:05, 187.52s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [5], Batch [20/44], Loss: 0.5856, Domain Loss: 0.4569, Class Loss: 0.7144, LR: 0.001000
Epoch [5], Batch [40/44], Loss: 0.5978, Domain Loss: 0.3796, Class Loss: 0.8159, LR: 0.001000
Train set: Epoch: 5, Avg loss: 0.5510, Domain Loss: 0.3703, Class Loss: 0.7317, Domain Acc: 86.20%, Class Acc: 79.99%, Domain mAP: 0.8938, Class mAP: 0.8472, Time: 143.06s
Test set: Epoch: 5, Avg loss: 0.7020, Domain Loss: 0.5267, Class Loss: 0.8772, Domain Acc: 80.11%, Class Acc: 76.91%, Domain mAP: 0.8517, Class mAP: 0.8352, Time: 43.68s

새로운 최고 Class mAP: 0.8352, Domain mAP: 0.8517
체크포인트가 checkpoints/Dilated_ResNet50_MultiTask_layer3 - Style, all_Category_checkpoint.pth에 저장되었습니다.
Class mAP improved (0.8318 --> 0.8352).
Performance improved. Saving model ...
Domain mAP: 0.8553 --> 0.8517, Class mAP: 0.8318 --> 0.8352


  2%|█▌                                                                                         | 5/300 [15:40<15:22:37, 187.65s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [6], Batch [20/44], Loss: 0.5414, Domain Loss: 0.4024, Class Loss: 0.6803, LR: 0.001000
Epoch [6], Batch [40/44], Loss: 0.4868, Domain Loss: 0.3426, Class Loss: 0.6310, LR: 0.001000
Train set: Epoch: 6, Avg loss: 0.4899, Domain Loss: 0.3391, Class Loss: 0.6407, Domain Acc: 87.36%, Class Acc: 81.91%, Domain mAP: 0.9087, Class mAP: 0.8730, Time: 139.15s
Test set: Epoch: 6, Avg loss: 0.7060, Domain Loss: 0.5169, Class Loss: 0.8951, Domain Acc: 81.36%, Class Acc: 75.85%, Domain mAP: 0.8572, Class mAP: 0.8361, Time: 43.99s

새로운 최고 Class mAP: 0.8361, Domain mAP: 0.8572
새로운 최고 Domain mAP: 0.8572
Domain mAP improved (0.8553 --> 0.8572).
Class mAP improved (0.8352 --> 0.8361).
Performance improved. Saving model ...
Domain mAP: 0.8553 --> 0.8572, Class mAP: 0.8352 --> 0.8361


  2%|█▊                                                                                         | 6/300 [18:44<15:13:44, 186.48s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [7], Batch [20/44], Loss: 0.4189, Domain Loss: 0.2968, Class Loss: 0.5410, LR: 0.001000
Epoch [7], Batch [40/44], Loss: 0.4381, Domain Loss: 0.3427, Class Loss: 0.5335, LR: 0.001000
Train set: Epoch: 7, Avg loss: 0.4444, Domain Loss: 0.3207, Class Loss: 0.5681, Domain Acc: 87.91%, Class Acc: 83.79%, Domain mAP: 0.9162, Class mAP: 0.8935, Time: 141.40s
Test set: Epoch: 7, Avg loss: 0.7676, Domain Loss: 0.6415, Class Loss: 0.8937, Domain Acc: 75.79%, Class Acc: 76.35%, Domain mAP: 0.8313, Class mAP: 0.8406, Time: 45.37s

새로운 최고 Class mAP: 0.8406, Domain mAP: 0.8313
Class mAP improved (0.8361 --> 0.8406).
Performance improved. Saving model ...
Domain mAP: 0.8572 --> 0.8313, Class mAP: 0.8361 --> 0.8406


  2%|██                                                                                         | 7/300 [21:51<15:12:35, 186.88s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [8], Batch [20/44], Loss: 0.4417, Domain Loss: 0.3780, Class Loss: 0.5055, LR: 0.001000
Epoch [8], Batch [40/44], Loss: 0.4179, Domain Loss: 0.3079, Class Loss: 0.5279, LR: 0.001000
Train set: Epoch: 8, Avg loss: 0.4020, Domain Loss: 0.3001, Class Loss: 0.5039, Domain Acc: 89.14%, Class Acc: 85.74%, Domain mAP: 0.9264, Class mAP: 0.9118, Time: 142.88s
Test set: Epoch: 8, Avg loss: 0.7054, Domain Loss: 0.5452, Class Loss: 0.8656, Domain Acc: 80.42%, Class Acc: 77.56%, Domain mAP: 0.8593, Class mAP: 0.8443, Time: 43.68s

새로운 최고 Class mAP: 0.8443, Domain mAP: 0.8593
새로운 최고 Domain mAP: 0.8593
Domain mAP improved (0.8572 --> 0.8593).
Class mAP improved (0.8406 --> 0.8443).
Performance improved. Saving model ...
Domain mAP: 0.8572 --> 0.8593, Class mAP: 0.8406 --> 0.8443


  3%|██▍                                                                                        | 8/300 [24:59<15:10:55, 187.18s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [9], Batch [20/44], Loss: 0.3723, Domain Loss: 0.2567, Class Loss: 0.4879, LR: 0.001000
Epoch [9], Batch [40/44], Loss: 0.3733, Domain Loss: 0.2415, Class Loss: 0.5051, LR: 0.001000
Train set: Epoch: 9, Avg loss: 0.3631, Domain Loss: 0.2779, Class Loss: 0.4484, Domain Acc: 89.88%, Class Acc: 86.90%, Domain mAP: 0.9363, Class mAP: 0.9257, Time: 144.43s
Test set: Epoch: 9, Avg loss: 0.7941, Domain Loss: 0.7422, Class Loss: 0.8460, Domain Acc: 75.51%, Class Acc: 77.00%, Domain mAP: 0.8498, Class mAP: 0.8472, Time: 44.25s

새로운 최고 Class mAP: 0.8472, Domain mAP: 0.8498
Class mAP improved (0.8443 --> 0.8472).
Performance improved. Saving model ...
Domain mAP: 0.8593 --> 0.8498, Class mAP: 0.8443 --> 0.8472


  3%|██▋                                                                                        | 9/300 [28:09<15:11:31, 187.94s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [10], Batch [20/44], Loss: 0.2784, Domain Loss: 0.2120, Class Loss: 0.3449, LR: 0.001000
Epoch [10], Batch [40/44], Loss: 0.3834, Domain Loss: 0.2646, Class Loss: 0.5022, LR: 0.001000
Train set: Epoch: 10, Avg loss: 0.3397, Domain Loss: 0.2743, Class Loss: 0.4050, Domain Acc: 90.07%, Class Acc: 88.36%, Domain mAP: 0.9369, Class mAP: 0.9382, Time: 144.91s
Test set: Epoch: 10, Avg loss: 0.6924, Domain Loss: 0.5639, Class Loss: 0.8210, Domain Acc: 80.73%, Class Acc: 77.56%, Domain mAP: 0.8604, Class mAP: 0.8559, Time: 44.34s

새로운 최고 Class mAP: 0.8559, Domain mAP: 0.8604
새로운 최고 Domain mAP: 0.8604
체크포인트가 checkpoints/Dilated_ResNet50_MultiTask_layer3 - Style, all_Category_checkpoint.pth에 저장되었습니다.
Domain mAP improved (0.8593 --> 0.8604).
Class mAP improved (0.8472 --> 0.8559).
Performance improved. Saving model ...
Domain mAP: 0.8593 --> 0.8604, Class mAP: 0.8472 --> 0.8559


  3%|███                                                                                       | 10/300 [31:20<15:12:23, 188.77s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [11], Batch [20/44], Loss: 0.2741, Domain Loss: 0.1963, Class Loss: 0.3518, LR: 0.001000
Epoch [11], Batch [40/44], Loss: 0.3364, Domain Loss: 0.2600, Class Loss: 0.4127, LR: 0.001000
Train set: Epoch: 11, Avg loss: 0.3061, Domain Loss: 0.2423, Class Loss: 0.3699, Domain Acc: 91.06%, Class Acc: 89.43%, Domain mAP: 0.9485, Class mAP: 0.9452, Time: 143.62s
Test set: Epoch: 11, Avg loss: 0.7289, Domain Loss: 0.5767, Class Loss: 0.8811, Domain Acc: 80.36%, Class Acc: 77.40%, Domain mAP: 0.8591, Class mAP: 0.8491, Time: 44.08s

EarlyStopping 카운터: 1 / 30 (Domain 최고: 0.8604, Class 최고: 0.8559)


  4%|███▎                                                                                      | 11/300 [34:28<15:08:12, 188.56s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [12], Batch [20/44], Loss: 0.2514, Domain Loss: 0.1944, Class Loss: 0.3085, LR: 0.001000
Epoch [12], Batch [40/44], Loss: 0.2762, Domain Loss: 0.2008, Class Loss: 0.3517, LR: 0.001000
Train set: Epoch: 12, Avg loss: 0.2806, Domain Loss: 0.2260, Class Loss: 0.3352, Domain Acc: 91.70%, Class Acc: 90.47%, Domain mAP: 0.9556, Class mAP: 0.9519, Time: 143.48s
Test set: Epoch: 12, Avg loss: 0.7908, Domain Loss: 0.7489, Class Loss: 0.8326, Domain Acc: 77.56%, Class Acc: 78.18%, Domain mAP: 0.8507, Class mAP: 0.8491, Time: 43.60s

EarlyStopping 카운터: 2 / 30 (Domain 최고: 0.8604, Class 최고: 0.8559)


  4%|███▌                                                                                      | 12/300 [37:35<15:03:24, 188.21s/it]

체크포인트가 checkpoints/latest_checkpoint.pth에 저장되었습니다.
Epoch [13], Batch [20/44], Loss: 0.2283, Domain Loss: 0.1683, Class Loss: 0.2883, LR: 0.001000
