In [11]:
from dataset import ClsDataset
from torch.utils.data import DataLoader
import matplotlib as mpl
import matplotlib.pyplot as plt

from model import efficientnet_b0
import os
import torch

from tqdm import tqdm

data_dir = '../data/valid/10'
data_set = ClsDataset(data_dir)
data_loader = DataLoader(
    data_set,
    batch_size=1,
    shuffle=True,
    num_workers=8
)

In [12]:
lst = sorted(os.listdir(data_dir))
lst

['10011001',
 '10011002',
 '10011003',
 '10011004',
 '10011005',
 '10012001',
 '10012002',
 '10012003',
 '10012004',
 '10012005',
 '10012006',
 '10012007',
 '10012008',
 '10012009',
 '10014001',
 '10014002',
 '10014003',
 '10014004',
 '10014005',
 '10014006']

In [31]:
@torch.no_grad()
def inference(model):
    result = {
        "inference": {},
        "time": {
            "profile": {"cuda": float("inf"), "cpu": float("inf")},
            "runtime": {"all": 0, "inference_only": 0},
            "inference": {},
        },
        "macs": float("inf"),
    }
    time_measure_inference = 0
    for img, label in tqdm(data_loader, "Running inference ..."):
        t_start = torch.cuda.Event(enable_timing=True)
        t_end = torch.cuda.Event(enable_timing=True)

        t_start.record()
        img = img.to(device)
        pred = model(img)
        pred = torch.argmax(pred)

        t_end.record()
        torch.cuda.synchronize()
        t_inference = t_start.elapsed_time(t_end) / 1000
        time_measure_inference += t_inference

        result["inference"][label] = lst[int(pred.detach())]
        result["time"]["inference"][label] = t_inference

    result["time"]["runtime"]["inference_only"] = time_measure_inference

    return result

In [32]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
m = torch.nn.Sigmoid()

In [33]:
model = efficientnet_b0(num_classes=len(lst))

model_path = '../models/classification_10_ts/best.pt'
model_checkpoint = torch.load(model_path)
model.load_state_dict(model_checkpoint, strict=False)  #모델 불러오기

model = model.to(device)
model.eval()

result = inference(model)
result["time"]["runtime"]["inference_only"]

Running inference ...: 100%|██████████| 1600/1600 [00:22<00:00, 71.14it/s]


19.768724025726303

In [34]:
model_path = '../models/classification_10_ts/best.ts'
model = torch.jit.load(model_path)

model = model.to(device)
model.eval()

result = inference(model)
result["time"]["runtime"]["inference_only"]

Running inference ...: 100%|██████████| 1600/1600 [00:15<00:00, 100.16it/s]


13.273399139881146