In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms
from resnetRS import ResNetRS18  # 직접 작성한 ResNetRS 모델 불러오기
from early_stopping import EarlyStopping  # Early Stopping 클래스 임포트

ModuleNotFoundError: No module named 'timm'

In [None]:
# CIFAR-100 Superclass와 클래스 매핑
CIFAR100_SUPERCLASS_MAPPING = {
    'aquatic mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'],
    'fish': ['aquarium fish', 'flatfish', 'ray', 'shark', 'trout'],
    'flowers': ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
    'food containers': ['bottle', 'bowl', 'can', 'cup', 'plate'],
    'fruit and vegetables': ['apple', 'mushroom', 'orange', 'pear', 'sweet pepper'],
    'household electrical devices': ['clock', 'keyboard', 'lamp', 'telephone', 'television'],
    'household furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'],
    'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
    'large carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
    'large man-made outdoor things': ['bridge', 'castle', 'house', 'road', 'skyscraper'],
    'large natural outdoor scenes': ['cloud', 'forest', 'mountain', 'plain', 'sea'],
    'large omnivores and herbivores': ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
    'medium-sized mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
    'non-insect invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'],
    'people': ['baby', 'boy', 'girl', 'man', 'woman'],
    'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
    'small mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
    'trees': ['maple tree', 'oak tree', 'palm tree', 'pine tree', 'willow tree'],
    'vehicles 1': ['bicycle', 'bus', 'motorcycle', 'pickup truck', 'train'],
    'vehicles 2': ['lawn mower', 'rocket', 'streetcar', 'tank', 'tractor']
}

# 클래스와 인덱스를 연결
class_to_idx = {cls: idx for idx, cls in enumerate(sum(CIFAR100_SUPERCLASS_MAPPING.values(), []))}
superclass_to_classes = {sc: [class_to_idx[c] for c in classes] for sc, classes in CIFAR100_SUPERCLASS_MAPPING.items()}

In [None]:
# 데이터셋 로드
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

In [None]:
# 모델 초기화 및 학습 준비
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetRS18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
early_stopping = EarlyStopping(patience=5, path='best_model.pth')


In [None]:
# 모델 학습 함수
def train_model(model, trainloader, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(trainloader.dataset)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

        # Early Stopping 체크
        early_stopping(epoch_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break


In [None]:
# 모델 평가 및 superclass 예측 함수
def predict_superclass(model, input_data):
    model.eval()
    with torch.no_grad():
        outputs = model(input_data)
        outputs = outputs.cpu().numpy()

        # 각 superclass별 확률 합산
        superclass_probs = {}
        for sc, class_indices in superclass_to_classes.items():
            superclass_probs[sc] = sum(outputs[:, idx] for idx in class_indices)

        # 확률이 가장 높은 superclass 선택
        top_superclass = max(superclass_probs, key=superclass_probs.get)

        # 해당 superclass 안에서 가장 확률이 높은 클래스 선택
        class_probs = {class_idx: outputs[:, class_idx] for class_idx in superclass_to_classes[top_superclass]}
        top_class = max(class_probs, key=class_probs.get)

        return top_superclass, top_class

In [None]:
# 모델 학습
train_model(model, trainloader, criterion, optimizer)

In [None]:
# 테스트 데이터에 대한 예측 수행
for inputs, labels in testloader:
    inputs = inputs.to(device)
    top_superclass, top_class = predict_superclass(model, inputs)
    print(f"Predicted Superclass: {top_superclass}, Predicted Class: {top_class}")