# 고양이 그림 점수 생성을 위한 모델 학습

### 라이브러리 호출

In [None]:
import os
import glob
import matplotlib.pyplot as plt
from PIL import Image
import random
import numpy as np
from tqdm.notebook import tqdm
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torchvision.models as models

import timm

In [None]:
# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

device = 'mps'

### 이미지 확인

In [None]:
train_img_list = os.listdir('/Users/kimhongseok/squid_game_heaven/data/train/cat')

for i in range(20):
    plt.subplot(4, 5, i+1)
    root = os.path.join('/Users/kimhongseok/squid_game_heaven/data/train/cat', train_img_list[i])
    img = Image.open(root)
    plt.imshow(img)

# Custom Dataset Class

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dir, classes, transform):
        super().__init__()
        self.data = list()
        self.transform = transform

        for i in range(len(classes)):
            root_dir = os.path.join(dir, classes[i])
            img_list = os.listdir(root_dir)
            for img in img_list:
                if '.DS_Store' not in img:
                    self.data.append((os.path.join(root_dir, img), i))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img, label = self.data[idx]
        img = Image.open(img).convert("RGB")
        img = self.transform(img)

        return img, label

In [None]:
transforms = T.Compose([
    T.Resize((480, 480)),
    T.ToTensor(),
])

# dataset, dataloader 생성
train_dataset = CustomDataset('data/train', ['cat', 'non_cat'], transforms)
valid_dataset = CustomDataset('data/valid', ['cat', 'non_cat'], transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False)

# Train, Validation code

In [None]:
def training(model, train_dataloader, train_dataset, criterion, optimizer, epoch, num_epochs):
    model.train()
    train_loss = 0.0
    train_accuracy = 0

    tbar = tqdm(train_dataloader)
    for images, labels in tbar:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        max_prob, preds = torch.max(probs, 1)
        train_accuracy += (preds == labels).sum().item()

        tbar.set_description(f'Epoch/Epochs [{epoch+1}/{num_epochs}] Loss: {loss.item():.4f}')

    train_loss /= len(train_dataset)
    train_accuracy /= len(train_dataset)

    return model, train_loss, train_accuracy

def evaluation(model, valid_dataloader, valid_dataset, criterion, epoch, num_epochs):
    model.eval()
    valid_loss = 0.0
    valid_accuracy = 0

    with torch.no_grad():
        tbar = tqdm(valid_dataloader)
        for images, labels in tbar:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            loss = criterion(outputs, labels)

            valid_loss += loss.item()
            max_prob, preds = torch.max(probs, 1)
            valid_accuracy += (preds == labels).sum().item()

            tbar.set_description(f'Epoch/Epochs [{epoch+1}/{num_epochs}] Loss: {loss.item():.4f}')

    valid_loss /= len(valid_dataset)
    valid_accuracy /= len(valid_dataset)

    return valid_loss, valid_accuracy

def training_loop(model, train_dataloader, valid_dataloader, train_dataset, valid_dataset, criterion, optimizer, num_epochs):
    model.to(device)
    best_valid_loss = float('inf')

    for epoch in range(num_epochs):
        model, train_loss, train_accuracy = training(model, train_dataloader, train_dataset, criterion, optimizer, epoch, num_epochs)
        valid_loss, valid_accuracy = evaluation(model, valid_dataloader, valid_dataset, criterion, epoch, num_epochs)

        print(f'Train Loss: {train_loss}, Train Accuracy: {train_accuracy}, Valid Loss: {valid_loss}, Valid Accuracy: {valid_accuracy}')

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model, 'best_model.pth')
            print(f'Best model updated at epoch {epoch + 1} (Valid Loss: {valid_loss:.4f})')
    
    torch.save(model, 'last_model.pth')

    return model

# 모델 생성 및 학습

In [None]:
model = timm.create_model(
    'resnet18',
    pretrained=True,
    num_classes=2
)
model

In [None]:
# 파라미터 학습 유무 조정

'''
# 마지막 레이어를 제외한 모든 파라미터는 학습되지 않게 설정

for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True
'''

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

model = training_loop(model, train_dataloader, valid_dataloader, train_dataset, valid_dataset, criterion, optimizer, num_epochs=100)

# Test

In [None]:
test_dataset = CustomDataset('/Users/kimhongseok/squid_game_heaven/data/valid', ['cat', 'non_cat'], transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
model.eval()
total_preds = []

with torch.no_grad():
    tbar = tqdm(test_dataloader)
    for images, _ in tbar:
        images = images.to(device)
        outputs = model(images)
        probs = torch.nn.functional.softmax(outputs, dim=1)

        total_preds.extend(probs)

In [None]:
plt.figure(figsize=(10, 10))

for i in range(20):
    plt.subplot(4, 5, i+1)
    plt.imshow(test_dataset[i][0].permute(1, 2, 0))
    plt.title(f'Score: {total_preds[i][0].cpu().item()*100:.2f}')

In [None]:
import torch
from PIL import Image
from torchvision import transforms

# 1. 모델 평가 모드 전환
model.eval()

# 2. 단일 이미지 불러오기
image_path = "/Users/kimhongseok/squid_game_heaven/data/valid/cat/25_17.jpg"  # 판별할 이미지 경로
image = Image.open(image_path).convert("RGB")

# 3. 전처리 (모델 학습 시 사용한 전처리와 동일해야 함)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 모델 입력 크기에 맞춤
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

image = transform(image).unsqueeze(0)  # 배치 차원 추가 (1, C, H, W)
image = image.to(device)

# 4. 예측
with torch.no_grad():
    outputs = model(image)
    probs = torch.nn.functional.softmax(outputs, dim=1)
    predicted_class = torch.argmax(probs, dim=1).item()

print("예측 클래스:", predicted_class)
print("클래스별 확률:", probs.cpu().numpy()[0][0])