## 0. Libarary 불러오기 및 경로설정

In [53]:
import os
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize

In [54]:
# 테스트 데이터셋 폴더 경로를 지정해주세요.
args={
    'load_model':'vit_base_patch32_224_clip_laion2b'
}
test_dir = '/opt/ml/input/data/eval'
check_point = '/opt/ml/level1_imageclassification-cv-04/checkpoint/kfold4_0_ViTclip_multiclass59_bs64_ep100_adamw_lr1e-06_vit_base_patch32_224_clip_laion2b/epoch(59)_acc(0.910)_loss(0.276)_f1(0.878)_state_dict.pt'

## 1. Model 정의

In [55]:
import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary
from timm import create_model, list_models
from metric import pred_to_label
class MyModel(nn.Module):
    def __init__(self, args):
        super(MyModel, self).__init__()

        self.num_classes = 8
        self.load_model = args['load_model']
        if self.load_model:
            # list_models('resnet*', pretrained=True)
            self.backbone = create_model(self.load_model, pretrained=True, num_classes=self.num_classes)
            

    def forward(self, x):
        if self.load_model:
            x = self.backbone(x)
        return x


## 2. Test Dataset 정의

In [56]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform=None):
        self.img_paths = img_paths
        self.transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize( (0.485, 0.456, 0.406),
                                (0.229, 0.224, 0.225))
            ])

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)


## 3. Inference

In [57]:
# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(test_dir, 'info.csv'))
image_dir = os.path.join(test_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]

dataset = TestDataset(image_paths)

loader = DataLoader(
    dataset,
    shuffle=False
)

# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
device = torch.device('cuda')
model = MyModel(args).to(device)
model.eval()
state_dict = torch.load(check_point)
model.load_state_dict(state_dict['model_state_dict'])
# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in iter(tqdm(loader)):
    with torch.no_grad():
        images = images.to(device)
        pred = model(images)
        pred = pred_to_label(pred)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(test_dir, 'submission.csv'), index=False)
print('test inference is done!')

 78%|███████▊  | 9889/12600 [02:49<00:48, 55.42it/s]