In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
import sys
import os
import time
import random
import numpy as np
import wandb
from tqdm import tqdm
from utility.utils import DualMAPEarlyStopping
from torchmetrics.classification import MulticlassAveragePrecision
from Dataset.data import OfficeHomeDataset
from models.resnet_dilated import ResnetDilated
import timm
from models.efficientnet_teacher import EfficientNetTeacher, MultiTaskKDLoss, pretrain_teacher, train_with_kd

# 체크포인트 디렉토리 생성
os.makedirs('checkpoints', exist_ok=True)

# WandB 설정 
wandb.login(key="ef091b9abcea3186341ddf8995d62bde62d7469e")
wandb.init(
    project="office-home-classification", 
    name="Dilated_ResNet50_MultiTask_layer3 - Style, all_Category_KD",
    entity="hh0804352-hanyang-university"
)

# wandb run name을 체크포인트 경로에 사용
run_name = wandb.run.name
CHECKPOINT_PATH = os.path.join('checkpoints', f'{run_name}_checkpoint.pth')

# 설정
config = {
    # 기존 설정 유지
    "model": "resnet50",
    "batch_size": 128,
    "num_epochs": 300,
    "learning_rate": 0.0001,
    "optimizer": "Adam",
    "seed": 2025,
    "deterministic": False,
    "patience": 20,
    "max_epochs_wait": float('inf'),
    "num_domains": 4,
    "num_classes": 65,
    
    # KD 설정
    "domain_weight": 0.2,
    "class_weight": 0.8,
    
    # Teacher 설정
    "teacher_epochs": 20,
    
    # 기존 설정들
    "num_workers": 20,
    "pin_memory": True,
    "save_every": 5,
    "scheduler": "ReduceLROnPlateau",
    "scheduler_mode": "max",
    "scheduler_factor": 0.1,
    "scheduler_patience": 5,
    "scheduler_verbose": True,
}

wandb.config.update(config)

class ImprovedMultiTaskModel(nn.Module):
    def __init__(self, num_domains=4, num_classes=65):
        super(ImprovedMultiTaskModel, self).__init__()
        
        # 공유 백본 - ResnetDilated 사용
        pretrained_resnet = models.resnet50(pretrained=True)
        self.backbone = ResnetDilated(pretrained_resnet, dilate_scale=8)
        
        # 백본 동결 (선택적)
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # 도메인 분류를 위한 분기점 (layer3 출력에서)
        self.domain_branch = nn.Sequential(
            nn.Conv2d(1024, 128, kernel_size=1),
            nn.GroupNorm(8, 128),
            nn.ReLU(inplace=True),
        )
        
        # 클래스 분류를 위한 분기점 (layer4 출력에서)
        self.class_branch = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=1),
            nn.GroupNorm(8, 256),
            nn.ReLU(inplace=True)
        )
        
        # 풀링 레이어
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 분류 헤드
        self.domain_classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.7),
            nn.Linear(64, num_domains)
        )
        
        self.class_classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.7),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        # 백본의 각 스테이지 출력 추출
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu1(x)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        
        # Layer3 출력 저장 (도메인 분류용)
        layer3_output = self.backbone.layer3(x)
        
        # Layer4 출력 (클래스 분류용)
        layer4_output = self.backbone.layer4(layer3_output)
        
        # 도메인 분류 경로
        domain_features = self.domain_branch(layer3_output)
        domain_features = self.avgpool(domain_features)
        domain_features = torch.flatten(domain_features, 1)
        domain_out = self.domain_classifier(domain_features)
        
        # 클래스 분류 경로
        class_features = self.class_branch(layer4_output)
        class_features = self.avgpool(class_features)
        class_features = torch.flatten(class_features, 1)
        class_out = self.class_classifier(class_features)
        
        return domain_out, class_out

def save_kd_checkpoint(student, teacher, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, filename="checkpoint_kd.pt"):
    """KD 훈련 체크포인트 저장 함수"""
    # 모델 상태 추출
    student_state_dict = student.module.state_dict() if isinstance(student, nn.DataParallel) else student.state_dict()
    teacher_state_dict = teacher.module.state_dict() if isinstance(teacher, nn.DataParallel) else teacher.state_dict()
    
    checkpoint = {
        'epoch': epoch,
        'student_state_dict': student_state_dict,
        'teacher_state_dict': teacher_state_dict,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_class_map': best_class_map,
        'best_domain_map': best_domain_map,
        'early_stopping_counter': early_stopping.counter,
        'early_stopping_domain_best_score': early_stopping.domain_best_score,
        'early_stopping_class_best_score': early_stopping.class_best_score,
        'early_stopping_best_epoch': early_stopping.best_epoch,
        'early_stopping_domain_best_epoch': early_stopping.domain_best_epoch,
        'early_stopping_class_best_epoch': early_stopping.class_best_epoch,
        'early_stopping_early_stop': early_stopping.early_stop,
        'early_stopping_domain_map_max': early_stopping.domain_map_max,
        'early_stopping_class_map_max': early_stopping.class_map_max,
        'config': config,
    }
    torch.save(checkpoint, filename)
    print(f"KD 체크포인트가 {filename}에 저장되었습니다.")

def evaluate(model, dataloader, domain_criterion, class_criterion, device, epoch):
    """멀티태스크 평가 함수"""
    model.eval()
    start_time = time.time()
    running_loss = 0.0
    running_domain_loss = 0.0
    running_class_loss = 0.0
    domain_correct = 0
    class_correct = 0
    total = 0
    
    # mAP 계산기 초기화
    domain_map = MulticlassAveragePrecision(num_classes=config["num_domains"], average='macro')
    class_map = MulticlassAveragePrecision(num_classes=config["num_classes"], average='macro')
    
    domain_map = domain_map.to(device)
    class_map = class_map.to(device)
    
    with torch.no_grad():
        for inputs, domain_labels, class_labels in dataloader:
            inputs = inputs.to(device)
            domain_labels = domain_labels.to(device)
            class_labels = class_labels.to(device)
            
            # 순전파
            domain_outputs, class_outputs = model(inputs)
            
            # 손실 계산
            domain_loss = domain_criterion(domain_outputs, domain_labels)
            class_loss = class_criterion(class_outputs, class_labels)
            loss = config["domain_weight"] * domain_loss + config["class_weight"] * class_loss
            
            running_loss += loss.item()
            running_domain_loss += domain_loss.item()
            running_class_loss += class_loss.item()
            
            # 정확도 계산
            _, domain_preds = domain_outputs.max(1)
            domain_correct += domain_preds.eq(domain_labels).sum().item()
            
            _, class_preds = class_outputs.max(1)
            class_correct += class_preds.eq(class_labels).sum().item()
            
            total += inputs.size(0)
            
            # mAP 업데이트
            domain_map.update(domain_outputs, domain_labels)
            class_map.update(class_outputs, class_labels)
    
    # 평균 손실 및 정확도 계산
    eval_loss = running_loss / len(dataloader)
    eval_domain_loss = running_domain_loss / len(dataloader)
    eval_class_loss = running_class_loss / len(dataloader)
    domain_accuracy = 100.0 * domain_correct / total
    class_accuracy = 100.0 * class_correct / total
    
    # mAP 계산
    domain_map_value = domain_map.compute().item()
    class_map_value = class_map.compute().item()
    
    eval_time = time.time() - start_time
    
    print(f'Test set: Epoch: {epoch+1}, Avg loss: {eval_loss:.4f}, '
          f'Domain Loss: {eval_domain_loss:.4f}, Class Loss: {eval_class_loss:.4f}, '
          f'Domain Acc: {domain_accuracy:.2f}%, Class Acc: {class_accuracy:.2f}%, '
          f'Domain mAP: {domain_map_value:.4f}, Class mAP: {class_map_value:.4f}, '
          f'Time: {eval_time:.2f}s')
    print()
    
    return eval_loss, eval_domain_loss, eval_class_loss, domain_accuracy, class_accuracy, domain_map_value, class_map_value

# 데이터 변환
transform_train = transforms.Compose([
    transforms.Resize((255,255)),
    transforms.RandomCrop((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 메인 실행 코드
if __name__ == "__main__":
    # 시드 설정
    seed = config["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    # 결정적 알고리즘 사용 여부
    if config["deterministic"]:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True
    
    # 디바이스 설정
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 데이터셋 로드
    trainset = OfficeHomeDataset(root_dir='./Dataset/train', transform=transform_train)
    testset = OfficeHomeDataset(root_dir='./Dataset/test', transform=transform_test)

    # DataLoader 생성
    trainloader = DataLoader(
        trainset, 
        batch_size=config["batch_size"], 
        shuffle=True, 
        pin_memory=config["pin_memory"], 
        num_workers=config["num_workers"]
    )

    testloader = DataLoader(
        testset, 
        batch_size=config["batch_size"], 
        shuffle=False, 
        pin_memory=config["pin_memory"], 
        num_workers=config["num_workers"]
    )

    print(f"Train set size: {len(trainset)}")
    print(f"Test set size: {len(testset)}")

    # Teacher 모델 로드
    teacher = EfficientNetTeacher(
        num_domains=config["num_domains"],
        num_classes=config["num_classes"]
    ).to(device)

    teacher_checkpoint_path = "efficientnet_teacher_best.pth"
    if os.path.exists(teacher_checkpoint_path):
        print(f" Teacher 모델 로드: {teacher_checkpoint_path}")
        teacher.load_state_dict(torch.load(teacher_checkpoint_path, map_location=device))
    else:
        print(" Teacher 모델이 없습니다. Teacher를 훈련합니다...")
        teacher = pretrain_teacher(teacher, trainloader, testloader, device, epochs=config["teacher_epochs"])

    # Student 모델 생성
    student = ImprovedMultiTaskModel(
        num_domains=config["num_domains"], 
        num_classes=config["num_classes"]
    ).to(device)

    # KD Loss
    kd_loss_fn = MultiTaskKDLoss(
        domain_alpha=0.8,
        class_alpha=0.6,
        domain_temp=5.0,
        class_temp=3.0
    )
    
    # 옵티마이저 및 스케줄러
    optimizer = optim.Adam([
        {'params': student.domain_branch.parameters(), 'lr': config["learning_rate"]},
        {'params': student.class_branch.parameters(), 'lr': config["learning_rate"]},
        {'params': student.domain_classifier.parameters(), 'lr': config["learning_rate"]},
        {'params': student.class_classifier.parameters(), 'lr': config["learning_rate"]}
    ])

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode=config["scheduler_mode"], 
        factor=config["scheduler_factor"], 
        patience=config["scheduler_patience"], 
        verbose=config["scheduler_verbose"]
    )

    # GPU 병렬 처리
    if torch.cuda.device_count() > 1:
        print(f"{torch.cuda.device_count()}개의 GPU를 사용합니다.")
        student = nn.DataParallel(student)
        teacher = nn.DataParallel(teacher)

    # Early stopping 초기화
    early_stopping = DualMAPEarlyStopping(
        patience=config["patience"], 
        verbose=True, 
        path='checkpoint_kd.pt', 
        max_epochs=config["max_epochs_wait"]
    )

    # 학습 시작 (체크포인트 로드 없이 처음부터)
    start_epoch = 0
    best_class_map = 0.0
    best_domain_map = 0.0

    # WandB watch
    wandb.watch(student, log="all")
    
    print(f"Knowledge Distillation !")
    start_time = time.time()
    
    try:
        for epoch in range(start_epoch, config["num_epochs"]):
            print(f"\n 에폭 {epoch+1}/{config['num_epochs']} 시작")
            
            # KD 훈련
            train_results = train_with_kd(
                student, teacher, trainloader, kd_loss_fn, optimizer, device, epoch, config
            )
            
            train_loss, train_domain_loss, train_class_loss, train_domain_acc, train_class_acc, train_domain_map, train_class_map, detailed_losses = train_results
            
            # 평가
            test_loss, test_domain_loss, test_class_loss, test_domain_acc, test_class_acc, test_domain_map, test_class_map = evaluate(
                student, testloader, nn.CrossEntropyLoss(), nn.CrossEntropyLoss(), device, epoch
            )
            
            # 스케줄러 업데이트
            avg_map = (test_domain_map + test_class_map) / 2
            scheduler.step(avg_map)
            
            # WandB 로깅
            wandb.log({
                "epoch": epoch + 1,
                "learning_rate": optimizer.param_groups[0]['lr'],
                "train_loss": train_loss,
                "train_domain_accuracy": train_domain_acc,
                "train_class_accuracy": train_class_acc,
                "train_domain_map": train_domain_map,
                "train_class_map": train_class_map,
                "test_loss": test_loss,
                "test_domain_accuracy": test_domain_acc,
                "test_class_accuracy": test_class_acc,
                "test_domain_map": test_domain_map,
                "test_class_map": test_class_map,
                "kd_domain_kd_loss": detailed_losses['domain_kd'],
                "kd_domain_hard_loss": detailed_losses['domain_hard'],
                "kd_class_kd_loss": detailed_losses['class_kd'],
                "kd_class_hard_loss": detailed_losses['class_hard'],
            })
            
            # 최고 성능 모델 저장
            if test_class_map > best_class_map:
                best_class_map = test_class_map
                model_to_save = student.module if isinstance(student, nn.DataParallel) else student
                torch.save(model_to_save.state_dict(), f'best_student_class_kd_{wandb.run.name}.pth')
                print(f' 새로운 최고 Class mAP: {best_class_map:.4f}')
            
            if test_domain_map > best_domain_map:
                best_domain_map = test_domain_map
                model_to_save = student.module if isinstance(student, nn.DataParallel) else student
                torch.save(model_to_save.state_dict(), f'best_student_domain_kd_{wandb.run.name}.pth')
                print(f' 새로운 최고 Domain mAP: {best_domain_map:.4f}')
        
            # 체크포인트 주기적 저장
            if (epoch + 1) % config["save_every"] == 0:
                save_kd_checkpoint(student, teacher, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping)
            
            # Early stopping
            early_stopping(test_domain_map, test_class_map, student, epoch)
            if early_stopping.early_stop:
                print(f" KD 훈련 조기 종료 - 에폭 {epoch+1}")
                break
        
        print(f' KD 훈련 완료!')
        print(f'최고 Class mAP: {best_class_map:.4f}')
        print(f'최고 Domain mAP: {best_domain_map:.4f}')
        
        # WandB 최종 결과 기록
        wandb.run.summary["best_class_map_kd"] = best_class_map
        wandb.run.summary["best_domain_map_kd"] = best_domain_map
        
    except KeyboardInterrupt:
        print("KD 훈련이 중단되었습니다. 체크포인트를 저장합니다.")
        save_kd_checkpoint(student, teacher, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping)
    except Exception as e:
        print(f"훈련 중 오류 발생: {e}")
        print("체크포인트를 저장하고 종료합니다.")
        save_kd_checkpoint(student, teacher, optimizer, scheduler, epoch if 'epoch' in locals() else start_epoch, 
                          best_class_map, best_domain_map, early_stopping)
    
    # 최종 체크포인트 저장
    save_kd_checkpoint(student, teacher, optimizer, scheduler, epoch, best_class_map, best_domain_map, early_stopping, 
                      filename="final_checkpoint_kd.pt")
    
    # 훈련 종료 시간
    end_time = time.time()
    total_time = end_time - start_time
    wandb.log({"total_training_time": total_time})
    print(f"전체 학습 시간: {total_time:.2f} 초")
    
    wandb.finish()



KeyboardInterrupt: 