In [1]:
from glob import glob
from tqdm import tqdm
from time import time
import argparse
import logging
import os
import cv2
import numpy as np
from PIL import Image
import base64
from matplotlib import pyplot as plt
import torch

from model import Unet
from dataset import load_image
from tensorrt_model import TrtModel

logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('true', '1'):
        return True
    elif v.lower() in ('false', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
        
def dice_loss(inputs, targets, smooth=1):
    inputs = inputs.reshape(-1)
    targets = targets.reshape(-1)

    intersection = (inputs * targets).sum()                            
    dice = (2.*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)  

    return 1-dice 

def display_image(img, mask, local = False):
    img = img[0].cpu().detach().numpy() 
    mask = mask[0].cpu().detach().numpy() 
    
    img = np.transpose(img, (1,2,0))
    mask = np.transpose(mask, (1,2,0))
    
    img = img * 255
    img = np.minimum(np.maximum(img, 0), 255)
    mask[mask > 0.5] = 255
    mask[mask <= 0.5] = 0
    
    # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img = img.astype(np.int16)
    green = np.zeros_like(mask)
    green[:,:,1] = mask[:,:,1]
    img[green >= 255] = img[green >= 255] * 3
    img[img >= 255] = 255

    other = np.zeros_like(mask)
    other[:,:,[0,2]] = mask[:,:,[0,2]] 
    img[other >= 255] = img[other >= 255] * 0.3
    
    plt.imshow(img)
    # cv2.imshow('img', img)
    # cv2.waitKey(1)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def inference(model_path, data_path, display = False):
    logger.info('model loading.. {}'.format(model_path))
    batch_size = 1
     # os.path.join("..","models","main.trt")
    model = TrtModel(model_path)
    shape = model.engine.get_binding_shape(0)
    
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model = Unet().to(device)
    # model.load_state_dict(torch.load('/home/workspace/iot_ai_model/check_points/unet/model_state_dict_latest.pt'))
    # model = model.eval()
    
    # data_paths = glob(dataset_path)
    
    
    logger.info('dataset loading..')
   
    with open(data_path, 'r') as f:
        line = f.readlines()

    total = len(line)
    logger.info('number of test dataset : {}'.format(total))
    
    logger.info('start inferencing')
    preds = []
    targets = []
    cnt = 0
    
    base_dir = os.path.dirname(data_path)
    imgs = []
    masks = []
    filepaths = []
    
#     for row in tqdm(line):
#         img_path, mask_path = row.rstrip().split(',')
        
#         img = load_image(os.path.join(base_dir, img_path))
#         mask = load_image(os.path.join(base_dir, mask_path))
#         img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2])
#         mask = mask.reshape(1, mask.shape[0], mask.shape[1], mask.shape[2])
        
#         imgs.append(img)
#         masks.append(mask)
#         filepaths.append(os.path.basename(img_path))
    
    start_time = time()
    pre_elap = 0.0
    fps = 0.0
    cost = .0
    loss = .0
    for idx, row in enumerate(line):
        filename, _ = row.rstrip().split(',')
        filename = os.path.basename(filename)
        
        img = load_image(os.path.join(base_dir, 'images', filename))
        mask = load_image(os.path.join(base_dir, 'masks', filename))
        img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2])
        mask = mask.reshape(1, mask.shape[0], mask.shape[1], mask.shape[2])
        
        # img = load_image(os.path.join(base_dir,'images',filename))
        # img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2])
        
        # img = torch.tensor(img).to(device).type(torch.float32)
        # mask = torch.tensor(mask).to(device).type(torch.float32)
        
        # print(img)
        
        output = model(img)
        output = output[0].reshape(mask.shape)
        
        loss = dice_loss(output, mask)
        
        cost += loss
        
        print('{}/{} - {},  fps: {:.1f}, dice loss: {:.1f}'.format(idx+1, total, filename, fps, (loss)))

        # if(display):
        # display_image(img, output)
        
        elap = time() - start_time
        fps = max(0.0, 1.0 / (elap - pre_elap))
        pre_elap = elap
        
        
    if(display):
        cv2.destroyAllWindows()

    # preds = torch.tensor(preds)
    # targets = torch.tensor(targets)
    # # acc = (correct/len(dataset))
    # f1_score = f1(preds, targets) 
    
    elap = time() - start_time
    fps = total / elap
    logger.info('dice coefficient: {:.4f}, fps: {:.4f}'.format(cost/total, fps))

In [None]:
inference('/home/workspace/iot_ai_model/check_points/unet/model.engine', '/home/workspace/iot_ai_model/dataset/supervisely_person/test_data_list.txt')

2022-10-04 08:14:34,328 - model loading.. /home/workspace/iot_ai_model/check_points/unet/model.engine
2022-10-04 08:14:34,362 - dataset loading..
2022-10-04 08:14:34,364 - number of test dataset : 534
2022-10-04 08:14:34,365 - start inferencing


load /home/workspace/iot_ai_model/check_points/unet/model.engine
1/534 - ds5_pexels-photo-245241.png,  fps: 0.0, dice loss: 0.9
2/534 - ds8_pexels-photo-66152_GX5RTeXShS.png,  fps: 20.1, dice loss: 0.7
3/534 - ds6_light-red-white-home.png,  fps: 21.7, dice loss: 0.7
4/534 - ds6_pexels-photo-756439.png,  fps: 29.0, dice loss: 0.8
5/534 - ds6_pexels-photo-819398.png,  fps: 33.4, dice loss: 0.8
6/534 - ds8_pexels-photo-303473.png,  fps: 60.6, dice loss: 0.7
7/534 - ds10_thailand-costume-girl-woman-157857.png,  fps: 22.5, dice loss: 0.6
8/534 - ds8_pexels-photo-207772_aEgMrxjXyq.png,  fps: 20.9, dice loss: 0.8
9/534 - ds8_studio-portrait-blond-blondie-girl-47736.png,  fps: 21.1, dice loss: 0.8
10/534 - ds1_pexels-photo-373984.png,  fps: 41.4, dice loss: 0.7
11/534 - ds2_pexels-photo-194087.png,  fps: 22.2, dice loss: 0.7
12/534 - ds8_pexels-photo-590479.png,  fps: 23.4, dice loss: 0.6
13/534 - ds8_pexels-photo-413805.png,  fps: 35.4, dice loss: 0.7
14/534 - ds1_bow-tie-fashion-man-person.p