In [9]:
from dataset import SmallDataset
from torch.utils.data import DataLoader
import matplotlib as mpl
import matplotlib.pyplot as plt

from model import efficientnet_b0
import os
import torch

from tqdm import tqdm

data_dir = 'new_data/valid/herbs'
data_set = SmallDataset(data_dir, 'valid')
data_loader = DataLoader(
    data_set,
    batch_size=1,
    shuffle=True,
    num_workers=4
)

In [2]:
lst = sorted(os.listdir(data_dir))
len(lst)

36

In [12]:
@torch.no_grad()
def inference(model):
    result = {
        "inference": {},
        "time": {
            "profile": {"cuda": float("inf"), "cpu": float("inf")},
            "runtime": {"all": 0, "inference_only": 0},
            "inference": {},
        },
        "macs": float("inf"),
    }
    time_measure_inference = 0
    for img, label in tqdm(data_loader, "Running inference ..."):
        t_start = torch.cuda.Event(enable_timing=True)
        t_end = torch.cuda.Event(enable_timing=True)

        t_start.record()
        img = img.to(device)
        pred = model(img)
        pred = torch.argmax(pred, dim=-1)

        t_end.record()
        torch.cuda.synchronize()
        t_inference = t_start.elapsed_time(t_end) / 1000
        time_measure_inference += t_inference

        # print(pred)
        result["inference"][label] = lst[int(pred.detach())]
        result["time"]["inference"][label] = t_inference

    result["time"]["runtime"]["inference_only"] = time_measure_inference

    return result

In [4]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
m = torch.nn.Sigmoid()

In [13]:
model = efficientnet_b0(num_classes=len(lst))

model_path = 'model2/herbs_e20_step/best.pt'
model_checkpoint = torch.load(model_path)
model.load_state_dict(model_checkpoint, strict=False)  #모델 불러오기

model = model.to(device)
model.eval()

result = inference(model)
result["time"]["runtime"]["inference_only"]

Running inference ...: 100%|██████████| 7181/7181 [01:40<00:00, 71.76it/s]


86.94799757957418

In [14]:

model_path = 'model2/herbs_e20_step/best.ts'
model = torch.jit.load(model_path)

model = model.to(device)
model.eval()

result = inference(model)
result["time"]["runtime"]["inference_only"]

Running inference ...: 100%|██████████| 7181/7181 [01:06<00:00, 108.27it/s]


55.82749732923508