In [1]:
#테스트 코드
import torch
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
from torchvision import models

# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [10]:
# 모델 정의 및 조정
model_m = models.resnet50(pretrained=False)
num_ftrs = model_m.fc.in_features
model_m.fc = torch.nn.Linear(num_ftrs, 2)  
model_m = model_m.to(device)

# 모델 로드
checkpoint_path = 'D:/minkwan/무신사 크롤링/coordikitty-ML-DL/중분류 모델링/model_resnet50_중분류.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)
model_m.load_state_dict(checkpoint['model_state_dict'])
model_m.eval()

# 데이터 전처리
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Lambda(lambda x: x.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 라벨 인덱스를 medium_category로 매핑
dfm = pd.read_csv('D:/minkwan/무신사 크롤링/coordikitty-ML-DL/중분류 모델링/중분류(데님,스웨트).csv')
class_names_m = dfm['medium_category'].unique()
class_names_m = sorted(class_names_m, key=lambda x: list(dfm['medium_category']).index(x))
idx_to_class_m = {i: class_name for i, class_name in enumerate(class_names_m)}

def predict_image_category_m(image_path, model_m, transform, device, idx_to_class):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model_m(image)
        _, predicted = torch.max(outputs, 1)
        predicted_idx = predicted.item()
        category = idx_to_class[predicted_idx]
    
    return category






In [11]:
# 모델 정의 및 조정
model_s = models.resnet50(pretrained=False)
num_ftrs = model_s.fc.in_features
model_s.fc = torch.nn.Linear(num_ftrs, 3)  # 체크포인트와 일치하도록 클래스 수를 3으로 변경
model_s = model_s.to(device)

# 모델 로드
checkpoint_path = 'D:/minkwan/무신사 크롤링/coordikitty-ML-DL/소분류 모델링/model_resnet50_소분류.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)
model_s.load_state_dict(checkpoint['model_state_dict'])
model_s.eval()

# 데이터 전처리
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Lambda(lambda x: x.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 라벨 인덱스를 small_category로 매핑
dfs = pd.read_csv('D:/minkwan/무신사 크롤링/coordikitty-ML-DL/소분류 모델링/소분류(데님,스웨트).csv')
class_names_s = dfs['small_category'].unique()
class_names_s = sorted(class_names_s, key=lambda x: list(dfs['small_category']).index(x))
idx_to_class_s = {i: class_name for i, class_name in enumerate(class_names_s)}

def predict_image_category_s(image_path, model_s, transform, device, idx_to_class):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model_s(image)
        _, predicted = torch.max(outputs, 1)
        predicted_idx = predicted.item()
        category = idx_to_class[predicted_idx]
    
    return category




In [12]:
# 테스트할 이미지 경로
test_image_path = 'D:/minkwan/무신사 크롤링/coordikitty-ML-DL/압축/롱팬츠_스웨트/long_pants_sweat_test_data_90장/KakaoTalk_20240601_233416332_01.jpg'

# 예측 수행
predicted_category_m = predict_image_category_m(test_image_path, model_m, transform, device, idx_to_class_m)
if predicted_category_m == 0:
    print(f'The predicted medium_category is: 롱팬츠')
elif predicted_category_m == 1:
    print(f'The predicted medium_category is: 숏팬츠')
    
# 예측 수행
predicted_category_s = predict_image_category_s(test_image_path, model_s, transform, device, idx_to_class_s)
if predicted_category_s == 0:
    print(f'The predicted small_category is: 데님')
elif predicted_category_s == 1:
    print(f'The predicted small_category is: 스웨트')

The predicted medium_category is: 숏팬츠
The predicted small_category is: 스웨트
