In [3]:
import os
import time
import json
os.chdir('d:\\qk_maskrcnn_trs\\')

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision

import import_ipynb
from hy_tools.nets_option import create_model
from draw_box_utils import draw_objs

In [1]:
def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()


def qk_predict(model, img_path, weights_path):
    weights_path = weights_path
    img_path = img_path
    label_json_path = r'D:\qk_maskrcnn_trs\melon_qk_indices.json'

    # get devices
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))


    # load train weights
    assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
    weights_dict = torch.load(weights_path, map_location='cpu')
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
    model.load_state_dict(weights_dict)
    model.to(device)

    # read class_indict
    assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
    with open(label_json_path, 'r') as json_file:
        category_index = json.load(json_file)

    # load image
    assert os.path.exists(img_path), f"{img_path} does not exits."
    original_img = Image.open(img_path).convert('RGB')

    # from pil image to tensor, do not normalize image
    data_transform = transforms.Compose([transforms.ToTensor()])
    img = data_transform(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    model.eval()  # 进入验证模式
    with torch.no_grad():
        # init
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        predictions = model(img.to(device))[0]
        t_end = time_synchronized()
        print("inference+NMS time: {}".format(t_end - t_start))

        predict_boxes = predictions["boxes"].to("cpu").numpy()
        predict_classes = predictions["labels"].to("cpu").numpy()
        predict_scores = predictions["scores"].to("cpu").numpy()
        predict_mask = predictions["masks"].to("cpu").numpy()
        predict_mask = np.squeeze(predict_mask, axis=1)  # [batch, 1, h, w] -> [batch, h, w]
    return predictions

In [4]:
def non_max_suppression(boxes, scores, threshold):
    # 初始化一个空列表来存储最终的非极大抑制结果
    keep = torchvision.ops.nms(boxes, scores, iou_threshold=threshold)
    return keep

In [5]:
def filter_draw_bbox(model,img_path, weights_path, iou_threshold, score_threshold, is_filter_shard = True):
    predictions = qk_predict(model, img_path, weights_path)
    bbox = predictions['boxes']
    label = predictions['labels']
    socres = predictions['scores']
    mask = predictions['masks']
    keep = non_max_suppression(bbox, socres, iou_threshold) # 采用nms去过滤一些不好的bbox
    # 根据idx获取相应的值
    if len(keep) == 0:
        pass
    b, l, s, m = [], [], [], []
    for i in keep:
        if socres[i] >= score_threshold:
            if is_filter_shard == True:
                if bbox[i][0] < 5 or bbox[i][1] < 5 or bbox[i][2] > mask[i].shape[2]-5 or bbox[i][3] > mask[i].shape[1] -5: 
                    pass
                else:
                    b.append(bbox[i])
                    l.append(label[i])
                    s.append(socres[i])
                    m.append(mask[i])
            else:
                b.append(bbox[i])
                l.append(label[i])
                s.append(socres[i])
                m.append(mask[i])
    return b, l, s, m

In [6]:
import cv2
def draw_bbox_on_image(img_path, model, weights_path, iou_threshold, score_threshold, save_path=None):
    bbox, labels, scores, masks = filter_draw_bbox(model, img_path, weights_path, iou_threshold, score_threshold, is_filter_shard=False)
    # print(bbox)
    image = cv2.imread(img_path)
    model_name = weights_path.split('\\')[-2].split('_')[-1]
    split_name = img_path.split('\\')[-2] + '_' + img_path.split('\\')[-1].split('.')[0]
    for i, box in enumerate(bbox):
        x1, y1, x2, y2 = map(int, box)
        if model_name == 'fpn':
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
        else:
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 100, 0), thickness=2)
        text = 'labels:{0},scores{1:.2f}'.format(labels[i], scores[i])
        cv2.putText(image, text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1)
    if save_path:
        if os.path.exists(save_path) is False:
            os.makedirs(save_path, exist_ok=True)
        image_name = os.path.join(save_path, split_name + '_' +  model_name + '.png')
        cv2.imwrite(image_name, image)
        print('image({0}) has been saved successed'.format(image_name))
    else:
        plt.imshow(image)

In [13]:
def save_pred(img_path, model, weights_path, iou_threshold,
              score_threshold,save_csv_path = None, save_mask_png_path = None, is_return=False):
    i_name = img_path.split('\\')[-1].split('.')[0]
    dir_namm = img_path.split('\\')[-2]
    m_name = weights_path.split('\\')[-2].split('_')[-1]
    file = dir_namm + '_' + i_name + '_' + m_name 
    # print(file_name)
    bbox, labels, scores, masks = filter_draw_bbox(model, img_path, weights_path, iou_threshold, score_threshold)
    b = [box.to('cpu').numpy() for box in bbox]
    l = [label.to('cpu').numpy() for label in labels]
    s = [score.to('cpu').numpy() for score in scores]
    m = [mask.to('cpu').numpy() for mask in masks]
    df_bbox = pd.DataFrame(b)
    if len(b) == 0:
        pass
    else:
        df_bbox.columns = ['x1','y1','x2','y2']
        df_label = pd.DataFrame(l)
        df_label.columns = ['label']
        df_socre = pd.DataFrame(s)
        df_socre.columns = ['socre']
        df = pd.concat([df_label,df_socre,df_bbox], axis=1)
        if save_csv_path:
            if os.path.exists(save_csv_path) is False:
                os.makedirs(save_csv_path)
            file_name = file + '.csv'
            save_path_file = os.path.join(save_csv_path,file_name)
            df.to_csv(save_path_file, index=False)
    if save_mask_png_path:
        if os.path.exists(save_mask_png_path) is False:
                os.makedirs(save_mask_png_path)
        for i, mask in enumerate(m):
            mask_T = np.transpose(mask,(1,2,0)) * 255
            png_name = file + '_' + str(i) + '.jpg'
            png_name_path = os.path.join(save_mask_png_path, png_name)
            cv2.imwrite(png_name_path,mask_T)
    if is_return == True:
        return df, m