In [1]:
import multiprocessing
import os
from importlib import import_module

import pandas as pd
import torch
from torch.utils.data import DataLoader

from dataset import TestDataset, MaskBaseDataset
from collections import Counter

In [2]:
model_list = ['Swin_b_Deep', 'Swin_T_Deep']

In [3]:
def load_model(model_name, num_classes, device):
    model_cls = getattr(import_module("model"), model_name)
    model = model_cls(
        num_classes=num_classes
    )
    # 모델 경로
    model_path = os.path.join('./ensemble', f'{model_name}.pth')
    model.load_state_dict(torch.load(model_path, map_location=device))

    return model

In [6]:
@torch.no_grad()
def inference(data_dir, output_dir, resize, batch_size):
    """
    """
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    num_classes = MaskBaseDataset.num_classes  # 18
    img_root = os.path.join(data_dir, 'images')
    info_path = os.path.join(data_dir, 'info.csv')
    info = pd.read_csv(info_path)
    img_paths = [os.path.join(img_root, img_id) for img_id in info.ImageID]
    dataset = TestDataset(img_paths, resize)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=multiprocessing.cpu_count() // 2,
        shuffle=False,
        pin_memory=use_cuda,
        drop_last=False,
    )
    

    for index, model_name in enumerate(model_list):
        model = load_model(model_name, num_classes, device).to(device)
        model.eval()
        print("Calculating inference results..")
        preds = []
        with torch.no_grad():
            for idx, images in enumerate(loader):
                images = images.to(device)
                pred = model(images)
                pred = pred.argmax(dim=-1)
                preds.extend(pred.cpu().numpy())

        info[model_name] = preds
    ans_list = []
    for arr in info.iloc[:, 2:].values:
        ans_list.append(Counter(list(arr)).most_common(1)[0][0])
    info['ans'] = ans_list
    save_path = os.path.join(output_dir, f'output2.csv')
    info.to_csv(save_path, index=False)
    print(f"Inference Done! Inference result saved at {save_path}")

In [7]:
inference('/opt/ml/input/data/eval', './output/ensemble', resize=(128, 96), batch_size=1000)

Calculating inference results..
Calculating inference results..
Inference Done! Inference result saved at ./output/ensemble/output2.csv
