In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import easydict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import torch
import torch.utils.data as data

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop

import sklearn

from tqdm import notebook
import gc
import random

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    

seed_everything(42)

In [None]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)

In [None]:
class swinBaseModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x
    
class swinTinyModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('swin_tiny_patch4_window7_224',pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x

class swinLargeModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x

class EfficientNet7(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b7',class_n)
    def forward(self,x):
        x = self.model(x)
        return x
    
class EfficientNet5(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b5',class_n)
    def forward(self,x):
        x = self.model(x)
        return x

class caitBaseModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('cait_s24_224',pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x

In [None]:
# meta 데이터와 이미지 경로를 불러옵니다.
test_dir = './eval'
submission = pd.read_csv(os.path.join(test_dir, 'info.csv'))
image_dir = os.path.join(test_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]
transform = transforms.Compose([
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
dataset = TestDataset(image_paths, transform)

loader = DataLoader(
    dataset,
    shuffle=False
)

# 모델을 정의합니다.
device = torch.device('cuda')
model = caitBaseModel()
model = torch.load('{model_path}')
model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in loader:
    with torch.no_grad():
        images = images.to(device)
        pred = model(images)
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(test_dir, 'submission.csv'), index=False)
print('test inference is done!')

test inference is done!
