In [1]:
from dataset import ClsDataset
from torch.utils.data import DataLoader
import matplotlib as mpl
import matplotlib.pyplot as plt

from model import efficientnet_b0
import os
import torch

from tqdm import tqdm

data_dir = '../data/valid/10'
data_set = ClsDataset(data_dir)
data_loader = DataLoader(
    data_set,
    batch_size=1,
    shuffle=True,
    num_workers=8
)

In [2]:
lst = sorted(os.listdir(data_dir))
lst

['10011001',
 '10011002',
 '10011003',
 '10011004',
 '10011005',
 '10012001',
 '10012002',
 '10012003',
 '10012004',
 '10012005',
 '10012006',
 '10012007',
 '10012008',
 '10012009',
 '10014001',
 '10014002',
 '10014003',
 '10014004',
 '10014005',
 '10014006']

In [13]:
def inference(model):
    result = {
        "inference": {},
        "time": {
            "profile": {"cuda": float("inf"), "cpu": float("inf")},
            "runtime": {"all": 0, "inference_only": 0},
            "inference": {},
        },
        "macs": float("inf"),
    }
    time_measure_inference = 0
    for img, label in tqdm(data_loader, "Running inference ..."):
        t_start = torch.cuda.Event(enable_timing=True)
        t_end = torch.cuda.Event(enable_timing=True)

        t_start.record()
        img = img.to(device)
        pred = model(img)
        pred = torch.argmax(pred)

        t_end.record()
        torch.cuda.synchronize()
        t_inference = t_start.elapsed_time(t_end) / 1000
        time_measure_inference += t_inference

        result["inference"][label] = lst[int(pred.detach())]
        result["time"]["inference"][label] = t_inference

    result["time"]["runtime"]["inference_only"] = time_measure_inference

    return result

In [4]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
m = torch.nn.Sigmoid()

In [7]:
model = efficientnet_b0(num_classes=len(lst))

model_path = '../models/classification_10_ts/best.pt'
model_checkpoint = torch.load(model_path)
model.load_state_dict(model_checkpoint, strict=False)  #모델 불러오기

model = model.to(device)
model.eval()

inference(model)

Running inference ...: 100%|██████████| 800/800 [00:12<00:00, 63.45it/s]


{'inference': {tensor([12]): '10012003',
  tensor([2]): '10012009',
  tensor([12]): '10012006',
  tensor([17]): '10012003',
  tensor([5]): '10011002',
  tensor([17]): '10012009',
  tensor([4]): '10012001',
  tensor([11]): '10014001',
  tensor([12]): '10012009',
  tensor([9]): '10012002',
  tensor([7]): '10011002',
  tensor([5]): '10011005',
  tensor([17]): '10012009',
  tensor([5]): '10012002',
  tensor([2]): '10012003',
  tensor([9]): '10012006',
  tensor([13]): '10012002',
  tensor([9]): '10014004',
  tensor([7]): '10012002',
  tensor([13]): '10012003',
  tensor([15]): '10012009',
  tensor([19]): '10012002',
  tensor([1]): '10012009',
  tensor([11]): '10011002',
  tensor([14]): '10014001',
  tensor([7]): '10012002',
  tensor([17]): '10012006',
  tensor([11]): '10012009',
  tensor([3]): '10012009',
  tensor([16]): '10011002',
  tensor([16]): '10012006',
  tensor([8]): '10011002',
  tensor([16]): '10012009',
  tensor([5]): '10011002',
  tensor([7]): '10011002',
  tensor([16]): '1001200

In [14]:
# model = efficientnet_b0(num_classes=len(lst))

model_path = '../models/classification_10_ts/best.ts'
model = torch.jit.load(model_path)

model = model.to(device)
model.eval()

inference(model)

Running inference ...: 100%|██████████| 800/800 [00:55<00:00, 14.45it/s]


{'inference': {tensor([17]): '10014004',
  tensor([0]): '10011001',
  tensor([15]): '10014002',
  tensor([10]): '10012006',
  tensor([0]): '10011001',
  tensor([8]): '10012004',
  tensor([19]): '10014006',
  tensor([7]): '10012003',
  tensor([2]): '10011003',
  tensor([6]): '10012002',
  tensor([11]): '10012007',
  tensor([19]): '10014006',
  tensor([15]): '10014002',
  tensor([11]): '10012007',
  tensor([12]): '10012008',
  tensor([6]): '10012002',
  tensor([16]): '10014003',
  tensor([17]): '10014004',
  tensor([5]): '10012001',
  tensor([10]): '10012006',
  tensor([2]): '10011003',
  tensor([0]): '10011001',
  tensor([14]): '10014001',
  tensor([18]): '10014005',
  tensor([6]): '10012002',
  tensor([14]): '10014001',
  tensor([3]): '10011004',
  tensor([18]): '10014005',
  tensor([18]): '10014005',
  tensor([0]): '10011001',
  tensor([9]): '10012005',
  tensor([13]): '10012009',
  tensor([11]): '10012007',
  tensor([17]): '10014004',
  tensor([11]): '10012007',
  tensor([9]): '10012