In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
import sys
import os
import time
import random
import numpy as np
import wandb
from tqdm import tqdm
from utility.utils import DualMAPEarlyStopping  # DualMAPEarlyStopping으로 변경
from torchmetrics.classification import MulticlassAveragePrecision
from Dataset.data import OfficeHomeDataset  # 기존 데이터셋 클래스 임포트
from models.resnet_dilated import ResnetDilated  # ResnetDilated 임포트
import timm
from models.efficientnet_teacher import EfficientNetTeacher, MultiTaskKDLoss, pretrain_teacher, train_with_kd


# 체크포인트 디렉토리 생성
os.makedirs('checkpoints', exist_ok=True)

# WandB 설정 
wandb.login(key="ef091b9abcea3186341ddf8995d62bde62d7469e")
wandb.init(
    project="office-home-classification", 
    name="Dilated_ResNet50_MultiTask_layer3 - Style, all_Category_KD",
    entity="hh0804352-hanyang-university"
)

# wandb run name을 체크포인트 경로에 사용
run_name = wandb.run.name
CHECKPOINT_PATH = os.path.join('checkpoints', f'{run_name}_checkpoint.pth')


# 설정
config = {
    # 기존 설정 유지
    "model": "resnet50",
    "batch_size": 128,
    "num_epochs": 300,
    "learning_rate": 0.001,
    "optimizer": "Adam",
    "seed": 2025,
    "deterministic": False,
    "patience": 20,  # 기존 30 → 20
    "max_epochs_wait": float('inf'),
    "num_domains": 4,
    "num_classes": 65,
    
    #  KD 최적화 설정
    "domain_weight": 0.2,  # 0.5 → 0.2 (KD가 오버피팅 해결)
    "class_weight": 0.8,   # 0.5 → 0.8 (클래스에 더 집중)
    
    # Teacher 설정
    "teacher_epochs": 20,   # Teacher 사전 훈련 에폭
    
    # 기존 설정들
    "num_workers": 20,
    "pin_memory": True,
    "save_every": 5,
    "scheduler": "ReduceLROnPlateau",
    "scheduler_mode": "max",
    "scheduler_factor": 0.1,
    "scheduler_patience": 5,
    "scheduler_verbose": True,
}

wandb.config.update(config)

class ImprovedMultiTaskModel(nn.Module):
    def __init__(self, num_domains=4, num_classes=65):
        super(ImprovedMultiTaskModel, self).__init__()
        
        # 공유 백본 - ResnetDilated 사용
        pretrained_resnet = models.resnet50(pretrained=True)
        self.backbone = ResnetDilated(pretrained_resnet, dilate_scale=8)
        
        # 백본 동결 (선택적)
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # 도메인 분류를 위한 분기점 (layer3 출력에서)
        self.domain_branch = nn.Sequential(
            nn.Conv2d(1024, 128, kernel_size=1),  # 채널 수 감소
            nn.GroupNorm(8, 128),
            nn.ReLU(inplace=True),
        )
        
        # 클래스 분류를 위한 분기점 (layer4 출력에서)
        self.class_branch = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=1),  # 채널 수 감소
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
        # 풀링 레이어
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 분류 헤드
        self.domain_classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.7),
            nn.Linear(64, num_domains)
        )
        
        self.class_classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        # 백본의 각 스테이지 출력 추출
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu1(x)  # relu1으로 수정 (ResnetDilated에서 정의된 이름과 일치)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        
        # Layer3 출력 저장 (도메인 분류용)
        layer3_output = self.backbone.layer3(x)
        
        # Layer4 출력 (클래스 분류용)
        layer4_output = self.backbone.layer4(layer3_output)
        
        # 도메인 분류 경로
        domain_features = self.domain_branch(layer3_output)
        domain_features = self.avgpool(domain_features)
        domain_features = torch.flatten(domain_features, 1)
        domain_out = self.domain_classifier(domain_features)
        
        # 클래스 분류 경로
        class_features = self.class_branch(layer4_output)
        class_features = self.avgpool(class_features)
        class_features = torch.flatten(class_features, 1)
        class_out = self.class_classifier(class_features)
        
        return domain_out, class_out


# 데이터 변환 (RandomResizedCrop 제거)
transform_train = transforms.Compose([
    transforms.Resize((224,224)),  # 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 데이터셋 로드
trainset = OfficeHomeDataset(root_dir='./Dataset/train', transform=transform_train)
testset = OfficeHomeDataset(root_dir='./Dataset/test', transform=transform_test)

# DataLoader 생성
trainloader = DataLoader(
    trainset, 
    batch_size=config["batch_size"], 
    shuffle=True, 
    pin_memory=config["pin_memory"], 
    num_workers=config["num_workers"]
)

testloader = DataLoader(
    testset, 
    batch_size=config["batch_size"], 
    shuffle=False, 
    pin_memory=config["pin_memory"], 
    num_workers=config["num_workers"]
)

print(f"Train set size: {len(trainset)}")
print(f"Test set size: {len(testset)}")

def train(model, trainloader, domain_criterion, class_criterion, optimizer, device, epoch):
    """
    멀티태스크 학습 함수 (도메인 + 클래스 분류)
    """
    model.train()
    start_time = time.time()
    running_loss = 0.0
    running_domain_loss = 0.0  # 도메인 손실 별도 추적
    running_class_loss = 0.0   # 클래스 손실 별도 추적
    domain_correct = 0
    class_correct = 0
    total = 0
    
    # mAP 계산기 초기화
    domain_map = MulticlassAveragePrecision(num_classes=config["num_domains"], average='macro')
    class_map = MulticlassAveragePrecision(num_classes=config["num_classes"], average='macro')
    
    domain_map = domain_map.to(device)
    class_map = class_map.to(device)
    
    for i, (inputs, domain_labels, class_labels) in enumerate(trainloader):
        inputs = inputs.to(device)
        domain_labels = domain_labels.to(device)
        class_labels = class_labels.to(device)
        
        # 그래디언트 초기화
        optimizer.zero_grad()
        
        # 모델 전방 전파
        domain_outputs, class_outputs = model(inputs)
        
        # 손실 계산
        domain_loss = domain_criterion(domain_outputs, domain_labels)
        class_loss = class_criterion(class_outputs, class_labels)
        loss = config["domain_weight"] * domain_loss + config["class_weight"] * class_loss
        
        # 역전파 및 최적화
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_domain_loss += domain_loss.item()  # 도메인 손실 누적
        running_class_loss += class_loss.item()    # 클래스 손실 누적
        
        # 정확도 계산
        _, domain_preds = domain_outputs.max(1)
        domain_correct += domain_preds.eq(domain_labels).sum().item()
        
        _, class_preds = class_outputs.max(1)
        class_correct += class_preds.eq(class_labels).sum().item()
        
        total += inputs.size(0)
        
        # mAP 업데이트
        domain_map.update(domain_outputs, domain_labels)
        class_map.update(class_outputs, class_labels)
        
        if (i + 1) % 20 == 0:
            print(f'Epoch [{epoch+1}], Batch [{i+1}/{len(trainloader)}], Loss: {loss.item():.4f}, '
                  f'Domain Loss: {domain_loss.item():.4f}, Class Loss: {class_loss.item():.4f}, '
                  f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
    # 에폭 통계
    epoch_loss = running_loss / len(trainloader)
    epoch_domain_loss = running_domain_loss / len(trainloader)  # 평균 도메인 손실
    epoch_class_loss = running_class_loss / len(trainloader)    # 평균 클래스 손실
    domain_accuracy = 100.0 * domain_correct / total
    class_accuracy = 100.0 * class_correct / total
    
    # mAP 계산
    domain_map_value = domain_map.compute().item()
    class_map_value = class_map.compute().item()
    
    train_time = time.time() - start_time
    
    # 학습 세트에 대한 성능 출력
    print(f'Train set: Epoch: {epoch+1}, Avg loss: {epoch_loss:.4f}, '
          f'Domain Loss: {epoch_domain_loss:.4f}, Class Loss: {epoch_class_loss:.4f}, '
          f'Domain Acc: {domain_accuracy:.2f}%, Class Acc: {class_accuracy:.2f}%, '
          f'Domain mAP: {domain_map_value:.4f}, Class mAP: {class_map_value:.4f}, '
          f'Time: {train_time:.2f}s')
    
    return epoch_loss, epoch_domain_loss, epoch_class_loss, domain_accuracy, class_accuracy, domain_map_value, class_map_value

def evaluate(model, dataloader, domain_criterion, class_criterion, device, epoch):
    """
    멀티태스크 평가 함수
    """
    model.eval()
    start_time = time.time()
    running_loss = 0.0
    running_domain_loss = 0.0  # 도메인 손실 별도 추적
    running_class_loss = 0.0   # 클래스 손실 별도 추적
    domain_correct = 0
    class_correct = 0
    total = 0
    
    # mAP 계산기 초기화
    domain_map = MulticlassAveragePrecision(num_classes=config["num_domains"], average='macro')
    class_map = MulticlassAveragePrecision(num_classes=config["num_classes"], average='macro')
    
    domain_map = domain_map.to(device)
    class_map = class_map.to(device)
    
    with torch.no_grad():
        for inputs, domain_labels, class_labels in dataloader:
            inputs = inputs.to(device)
            domain_labels = domain_labels.to(device)
            class_labels = class_labels.to(device)
            
            # 순전파
            domain_outputs, class_outputs = model(inputs)
            
            # 손실 계산
            domain_loss = domain_criterion(domain_outputs, domain_labels)
            class_loss = class_criterion(class_outputs, class_labels)
            loss = config["domain_weight"] * domain_loss + config["class_weight"] * class_loss
            
            running_loss += loss.item()
            running_domain_loss += domain_loss.item()  # 도메인 손실 누적
            running_class_loss += class_loss.item()    # 클래스 손실 누적
            
            # 정확도 계산
            _, domain_preds = domain_outputs.max(1)
            domain_correct += domain_preds.eq(domain_labels).sum().item()
            
            _, class_preds = class_outputs.max(1)
            class_correct += class_preds.eq(class_labels).sum().item()
            
            total += inputs.size(0)
            
            # mAP 업데이트
            domain_map.update(domain_outputs, domain_labels)
            class_map.update(class_outputs, class_labels)
    
    # 평균 손실 및 정확도 계산
    eval_loss = running_loss / len(dataloader)
    eval_domain_loss = running_domain_loss / len(dataloader)  # 평균 도메인 손실
    eval_class_loss = running_class_loss / len(dataloader)    # 평균 클래스 손실
    domain_accuracy = 100.0 * domain_correct / total
    class_accuracy = 100.0 * class_correct / total
    
    # mAP 계산
    domain_map_value = domain_map.compute().item()
    class_map_value = class_map.compute().item()
    
    # 평가 시간 계산
    eval_time = time.time() - start_time
    
    # 테스트 세트에 대한 성능 출력
    print(f'Test set: Epoch: {epoch+1}, Avg loss: {eval_loss:.4f}, '
          f'Domain Loss: {eval_domain_loss:.4f}, Class Loss: {eval_class_loss:.4f}, '
          f'Domain Acc: {domain_accuracy:.2f}%, Class Acc: {class_accuracy:.2f}%, '
          f'Domain mAP: {domain_map_value:.4f}, Class mAP: {class_map_value:.4f}, '
          f'Time: {eval_time:.2f}s')
    print()
    
    return eval_loss, eval_domain_loss, eval_class_loss, domain_accuracy, class_accuracy, domain_map_value, class_map_value

# 체크포인트 저장 함수 추가
def save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, filename=CHECKPOINT_PATH):
    """
    학습 상태 저장 함수
    """
    # 모델이 DataParallel로 감싸져 있는 경우 처리
    model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state_dict,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),  # 스케줄러 상태 저장 추가
        'best_class_map': best_class_map,
        'best_domain_map': best_domain_map,
        'early_stopping_counter': early_stopping.counter,
        'early_stopping_domain_best_score': early_stopping.domain_best_score,  # DualMAP 맞게 수정
        'early_stopping_class_best_score': early_stopping.class_best_score,    # DualMAP 맞게 수정
        'early_stopping_best_epoch': early_stopping.best_epoch,
        'early_stopping_domain_best_epoch': early_stopping.domain_best_epoch,  # DualMAP 맞게 추가
        'early_stopping_class_best_epoch': early_stopping.class_best_epoch,    # DualMAP 맞게 추가
        'early_stopping_early_stop': early_stopping.early_stop,
        'early_stopping_domain_map_max': early_stopping.domain_map_max,  # DualMAP 맞게 추가
        'early_stopping_class_map_max': early_stopping.class_map_max,    # DualMAP 맞게 추가
        'config': config,  # 설정값도 저장
    }
    torch.save(checkpoint, filename)
    print(f"체크포인트가 {filename}에 저장되었습니다.")

# 체크포인트 로드 함수 추가
def load_checkpoint(model, optimizer, scheduler, early_stopping, filename=CHECKPOINT_PATH):
    """
    학습 상태 로드 함수
    """
    if not os.path.exists(filename):
        print(f"체크포인트 파일 {filename}이 존재하지 않습니다. 처음부터 학습을 시작합니다.")
        return model, optimizer, scheduler, early_stopping, 0, 0.0, 0.0
    
    print(f"체크포인트 {filename}을 로드합니다.")
    checkpoint = torch.load(filename)
    
    # 모델이 DataParallel로 감싸져 있는 경우 처리
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # 스케줄러 상태 로드 추가
    if 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # 조기 중단 상태 복원 (DualMAP에 맞게 수정)
    early_stopping.counter = checkpoint['early_stopping_counter']
    early_stopping.domain_best_score = checkpoint.get('early_stopping_domain_best_score')
    early_stopping.class_best_score = checkpoint.get('early_stopping_class_best_score')
    early_stopping.best_epoch = checkpoint['early_stopping_best_epoch']
    early_stopping.domain_best_epoch = checkpoint.get('early_stopping_domain_best_epoch', early_stopping.best_epoch)
    early_stopping.class_best_epoch = checkpoint.get('early_stopping_class_best_epoch', early_stopping.best_epoch)
    early_stopping.early_stop = checkpoint['early_stopping_early_stop']
    early_stopping.domain_map_max = checkpoint.get('early_stopping_domain_map_max', -np.Inf)
    early_stopping.class_map_max = checkpoint.get('early_stopping_class_map_max', -np.Inf)
    
    # 기타 학습 상태
    start_epoch = checkpoint['epoch'] + 1  # 다음 에폭부터 시작
    best_class_map = checkpoint['best_class_map']
    best_domain_map = checkpoint['best_domain_map']
    
    print(f"체크포인트에서 로드 완료: 에폭 {start_epoch}부터 시작합니다.")
    print(f"이전 최고 성능: Class mAP: {best_class_map:.4f}, Domain mAP: {best_domain_map:.4f}")
    
    return model, optimizer, scheduler, early_stopping, start_epoch, best_class_map, best_domain_map

# 메인 학습 루프
def main_training_loop(model, trainloader, testloader, domain_criterion, class_criterion, optimizer, scheduler, device, num_epochs=None, patience=None, max_epochs_wait=None):
    """
    메인 학습 루프 (mAP 기준 early stopping)
    """
    # config에서 값 가져오기
    if num_epochs is None:
        num_epochs = config["num_epochs"]
    if patience is None:
        patience = config["patience"]
    if max_epochs_wait is None:
        max_epochs_wait = config["max_epochs_wait"]
        
    # DualMAP 얼리 스토핑 초기화 (AccuracyEarlyStopping에서 변경)
    early_stopping = DualMAPEarlyStopping(
        patience=patience, 
        verbose=True, 
        path='checkpoint.pt', 
        max_epochs=max_epochs_wait
    )
    
    # 체크포인트 로드 시도
    model, optimizer, scheduler, early_stopping, start_epoch, best_class_map, best_domain_map = load_checkpoint(
        model, optimizer, scheduler, early_stopping
    )

    # 이미 로드된 값이 없으면 초기화
    if start_epoch == 0:
        best_class_map = 0.0
        best_domain_map = 0.0
    
    # tqdm을 사용한 진행 상황 표시
    for epoch in tqdm(range(start_epoch, num_epochs)):
        # 학습
        train_loss, train_domain_loss, train_class_loss, train_domain_acc, train_class_acc, train_domain_map, train_class_map = train(
            model, 
            trainloader, 
            domain_criterion, 
            class_criterion, 
            optimizer, 
            device, 
            epoch
        )
        
        # 테스트 데이터로 평가
        test_loss, test_domain_loss, test_class_loss, test_domain_acc, test_class_acc, test_domain_map, test_class_map = evaluate(
            model, 
            testloader, 
            domain_criterion, 
            class_criterion, 
            device, 
            epoch
        )
        
        # 학습률 조정 - 검증 성능에 따라 스케줄러 업데이트
        avg_map = (test_domain_map + test_class_map) / 2
        scheduler.step(avg_map)  # 스케줄러 호출 추가
        
        # WandB에 로깅 (도메인 손실과 클래스 손실 별도 로깅)
        wandb.log({
            "epoch": epoch + 1,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "train_loss": train_loss,
            "train_domain_loss": train_domain_loss,
            "train_class_loss": train_class_loss,
            "train_domain_accuracy": train_domain_acc,
            "train_class_accuracy": train_class_acc,
            "train_domain_map": train_domain_map,
            "train_class_map": train_class_map,
            "test_loss": test_loss,
            "test_domain_loss": test_domain_loss,
            "test_class_loss": test_class_loss,
            "test_domain_accuracy": test_domain_acc,
            "test_class_accuracy": test_class_acc,
            "test_domain_map": test_domain_map,
            "test_class_map": test_class_map
        })
            
        # 최고 클래스 mAP 모델 저장
        if test_class_map > best_class_map:
            best_class_map = test_class_map
            best_domain_map_at_best_class = test_domain_map
            print(f'새로운 최고 Class mAP: {best_class_map:.4f}, Domain mAP: {best_domain_map_at_best_class:.4f}')
            # 모델 저장
            model_path = f'best_model_class_{wandb.run.name}.pth'
            # 모델이 DataParallel로 감싸져 있는 경우 처리
            model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(model_state_dict, model_path)
            wandb.save(model_path)
        
        # 최고 도메인 mAP 모델 저장
        if test_domain_map > best_domain_map:
            best_domain_map = test_domain_map
            print(f'새로운 최고 Domain mAP: {best_domain_map:.4f}')
            # 모델 저장
            model_path = f'best_model_domain_{wandb.run.name}.pth'
            # 모델이 DataParallel로 감싸져 있는 경우 처리
            model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(model_state_dict, model_path)
            wandb.save(model_path)

        # 주기적으로 체크포인트 저장 (설정한 간격마다)
        if (epoch + 1) % config["save_every"] == 0:
            save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping)

        # DualMAP Early stopping 체크 (두 mAP 모두 사용)
        early_stopping(test_domain_map, test_class_map, model, epoch)
        
        # 매 에폭 후에 체크포인트 저장 (가장 최신 상태)
        save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, 
                       filename=os.path.join('checkpoints', 'latest_checkpoint.pth'))
        
        if early_stopping.early_stop:
            print(f"에폭 {epoch+1}에서 학습 조기 종료. 최고 성능 에폭: {early_stopping.best_epoch+1}")
            break
    
    # 훈련 완료 후 최고 모델 로드
    print("최고 클래스 mAP 모델 로드 중...")
    model_path = f'best_model_class_{wandb.run.name}.pth'
    if os.path.exists(model_path):
        # 모델이 DataParallel로 감싸져 있는 경우 처리
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(torch.load(model_path))
    else:
        print(f"경고: {model_path} 파일이 존재하지 않습니다. 최종 모델을 사용합니다.")

    # 최종 테스트 평가
    final_test_loss, final_test_domain_loss, final_test_class_loss, final_test_domain_acc, final_test_class_acc, final_test_domain_map, final_test_class_map = evaluate(
        model, testloader, domain_criterion, class_criterion, device, num_epochs-1
    )
    
    print(f'완료! 최고 Class mAP: {best_class_map:.4f}, 최고 Domain mAP: {best_domain_map:.4f}')
    
    # WandB에 최종 결과 기록
    wandb.run.summary["best_class_map"] = best_class_map
    wandb.run.summary["best_domain_map"] = best_domain_map
    wandb.run.summary["final_test_class_map"] = final_test_class_map
    wandb.run.summary["final_test_domain_map"] = final_test_domain_map

    # Early stopping 정보 저장
    if early_stopping.early_stop:
        wandb.run.summary["early_stopped"] = True
        wandb.run.summary["early_stopped_epoch"] = epoch+1
        wandb.run.summary["best_epoch"] = early_stopping.best_epoch+1
        wandb.run.summary["domain_best_epoch"] = early_stopping.domain_best_epoch+1
        wandb.run.summary["class_best_epoch"] = early_stopping.class_best_epoch+1
    else:
        wandb.run.summary["early_stopped"] = False
        
    # 최종 체크포인트 저장
    save_checkpoint(model, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, 
                   filename=os.path.join('checkpoints', 'final_checkpoint.pth'))

# 메인 실행 코드
if __name__ == "__main__":
    # 시드 설정
    seed = config["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # 결정적 알고리즘 사용 여부
    if config["deterministic"]:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True
    
    # 디바이스 설정
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # teacher 모델 생성
    teacher = EfficientNetTeacher(
        num_domains=config["num_domains"],
        num_classes=config["num_classes"]
    ).to(device)

    print(f"Teacher 파라미터: {sum(p.numel() for p in teacher.parameters()):,}")

    # 모델 초기화 - ImprovedMultiTaskModel 사용 (수정됨)
    student = ImprovedMultiTaskModel(
        num_domains=config["num_domains"], 
        num_classes=config["num_classes"]
    ).to(device)

    print(f"Student 파라미터: {sum(p.numel() for p in student.parameters()):,}")

    teacher_checkpoint_path = "efficientnet_teacher_best.pth"
    
    if os.path.exists(teacher_checkpoint_path):
        print(f"✅ 기존 Teacher 모델 발견: {teacher_checkpoint_path}")
        print("📁 저장된 Teacher 모델을 로드합니다...")
        teacher.load_state_dict(torch.load(teacher_checkpoint_path, map_location=device))
        print("🎓 Teacher 모델 로드 완료! KD 학습을 시작합니다.")
    else:
        print("❌ Teacher 체크포인트가 없습니다. Teacher를 훈련합니다...")
        teacher = pretrain_teacher(teacher, trainloader, testloader, device, epochs=config["teacher_epochs"])
    
    #  KD Loss 
    kd_loss_fn = MultiTaskKDLoss(
        domain_alpha=0.8,
        class_alpha=0.6,
        domain_temp=5.0,
        class_temp=3.0
    )
    
    # 학습률 차별화 - domain_branch, class_branch도 포함
    optimizer = optim.Adam([
        {'params': student.domain_branch.parameters(), 'lr': config["learning_rate"]},
        {'params': student.class_branch.parameters(), 'lr': config["learning_rate"]},
        {'params': student.domain_classifier.parameters(), 'lr': config["learning_rate"]},
        {'params': student.class_classifier.parameters(), 'lr': config["learning_rate"]}
    ])

    # GPU 병렬 처리
    if torch.cuda.device_count() > 1:
        print(f"{torch.cuda.device_count()}개의 GPU를 사용합니다.")
        student = nn.DataParallel(student)
        teacher = nn.DataParallel(teacher)
    
    # 2. WandB watch 수정
    wandb.watch(student, log="all")
    
    # 스케줄러 초기화
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode=config["scheduler_mode"], 
        factor=config["scheduler_factor"], 
        patience=config["scheduler_patience"], 
        verbose=config["scheduler_verbose"]
    )
    
     #  KD 훈련 루프 (main_training_loop 대신!)
    early_stopping = DualMAPEarlyStopping(
        patience=config["patience"], 
        verbose=True, 
        path='checkpoint_kd.pt', 
        max_epochs=config["max_epochs_wait"]
    )

    best_class_map = 0.0
    best_domain_map = 0.0
    start_time = time.time()

    print("Knowledge Distillation 훈련 시작!")
    
    
    try:
        for epoch in range(config["num_epochs"]):
            # 🔥 KD 훈련 (train 대신 train_with_kd 사용!)
            train_results = train_with_kd(
                student, teacher, trainloader, kd_loss_fn, optimizer, device, epoch, config
            )
            
            train_loss, train_domain_loss, train_class_loss, train_domain_acc, train_class_acc, train_domain_map, train_class_map, detailed_losses = train_results
            
            # 🔍 평가 (기존 evaluate 함수 사용)
            test_loss, test_domain_loss, test_class_loss, test_domain_acc, test_class_acc, test_domain_map, test_class_map = evaluate(
                student, testloader, nn.CrossEntropyLoss(), nn.CrossEntropyLoss(), device, epoch
            )
            
            # 스케줄러 업데이트
            avg_map = (test_domain_map + test_class_map) / 2
            scheduler.step(avg_map)
            
            # WandB 로깅 (KD 정보 추가)
            wandb.log({
                "epoch": epoch + 1,
                "learning_rate": optimizer.param_groups[0]['lr'],
                "train_loss": train_loss,
                "train_domain_accuracy": train_domain_acc,
                "train_class_accuracy": train_class_acc,
                "train_domain_map": train_domain_map,
                "train_class_map": train_class_map,
                "test_loss": test_loss,
                "test_domain_accuracy": test_domain_acc,
                "test_class_accuracy": test_class_acc,
                "test_domain_map": test_domain_map,
                "test_class_map": test_class_map,
                #  KD 전용 메트릭
                "kd_domain_kd_loss": detailed_losses['domain_kd'],
                "kd_domain_hard_loss": detailed_losses['domain_hard'],
                "kd_class_kd_loss": detailed_losses['class_kd'],
                "kd_class_hard_loss": detailed_losses['class_hard'],
            })
            
            # 최고 성능 모델 저장
            if test_class_map > best_class_map:
                best_class_map = test_class_map
                model_to_save = student.module if isinstance(student, nn.DataParallel) else student
                torch.save(model_to_save.state_dict(), f'best_student_class_kd_{wandb.run.name}.pth')
                print(f'🏆 새로운 최고 Class mAP: {best_class_map:.4f}')
            
            if test_domain_map > best_domain_map:
                best_domain_map = test_domain_map
                model_to_save = student.module if isinstance(student, nn.DataParallel) else student
                torch.save(model_to_save.state_dict(), f'best_student_domain_kd_{wandb.run.name}.pth')
                print(f'🏆 새로운 최고 Domain mAP: {best_domain_map:.4f}')
            
            # Early stopping
            early_stopping(test_domain_map, test_class_map, student, epoch)
            if early_stopping.early_stop:
                print(f"🛑 KD 훈련 조기 종료 - 에폭 {epoch+1}")
                break
        
        print(f'🎓 KD 훈련 완료!')
        print(f'최고 Class mAP: {best_class_map:.4f}')
        print(f'최고 Domain mAP: {best_domain_map:.4f}')
        
        # WandB 최종 결과 기록
        wandb.run.summary["best_class_map_kd"] = best_class_map
        wandb.run.summary["best_domain_map_kd"] = best_domain_map
        
    except KeyboardInterrupt:
        print("KD 훈련이 중단되었습니다.")
    
    # 훈련 종료 시간
    end_time = time.time()
    total_time = end_time - start_time
    wandb.log({"total_training_time": total_time})
    print(f"전체 학습 시간: {total_time:.2f} 초")
    
    wandb.finish()



총 11113 이미지, 65 클래스, 4 도메인을 로드했습니다.
총 3213 이미지, 65 클래스, 4 도메인을 로드했습니다.
Train set size: 11113
Test set size: 3213
Using device: cuda


  model = create_fn(


Teacher 파라미터: 10,489,543




Student 파라미터: 24,846,149
✅ 기존 Teacher 모델 발견: efficientnet_teacher_best.pth
📁 저장된 Teacher 모델을 로드합니다...
🎓 Teacher 모델 로드 완료! KD 학습을 시작합니다.
2개의 GPU를 사용합니다.
Knowledge Distillation 훈련 시작!




KD Epoch [1], Batch [20/87]
Total: 15.0850 | Domain KD: 17.6160 | Class KD: 23.3997
KD Epoch [1], Batch [40/87]
Total: 12.2148 | Domain KD: 16.2182 | Class KD: 18.1731
KD Epoch [1], Batch [60/87]
Total: 10.2639 | Domain KD: 12.7069 | Class KD: 15.6878
KD Epoch [1], Batch [80/87]
Total: 8.2902 | Domain KD: 10.6157 | Class KD: 12.5795
KD Train Epoch 1:
Total Loss: 12.1558 | Domain Acc: 54.44% | Class Acc: 39.13%
Domain mAP: 0.5311 | Class mAP: 0.3531
KD Losses - Domain: 15.1380 | Class: 18.5203
Time: 307.98s

Test set: Epoch: 1, Avg loss: 1.3692, Domain Loss: 1.0692, Class Loss: 1.4442, Domain Acc: 63.80%, Class Acc: 62.31%, Domain mAP: 0.7115, Class mAP: 0.7023, Time: 50.94s

🏆 새로운 최고 Class mAP: 0.7023
🏆 새로운 최고 Domain mAP: 0.7115
Performance improved. Saving model ...
Domain mAP: -inf --> 0.7115, Class mAP: -inf --> 0.7023
KD Epoch [2], Batch [20/87]
Total: 6.7091 | Domain KD: 8.0048 | Class KD: 10.1600
KD Epoch [2], Batch [40/87]
Total: 6.9191 | Domain KD: 10.4619 | Class KD: 9.8846
KD