In [None]:
import os
import random
import torch
import skimage
import numpy as np
import matplotlib.pyplot as plt
from torchvision import ops

from config import Config
from model import MaskRCNN
from acne_data import AcneSegDataset, transforms, parse_image_metas, expand_mask, seg_to_mask
import visualize
import utils

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

MODEL_PATH = '../autodl-tmp/mask-rcnn/resnet101_all_epoch_160.pth'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# 背景 丘疹 痣 节结
# 开口粉刺 闭口粉刺
# 萎缩性瘢痕 肥厚性瘢痕
# 黄褐斑 脓疱 其它
categories = ['BG', 'papule', 'nevus', 'nodule',
              'open_comedo', 'closed_comedo',
              'atrophic_scar', 'hypertrophic_scar',
              'melasma', 'pustule', 'other']
category_to_id = {c: i for i, c in enumerate(categories)}

cfg = Config()
dataset = AcneSegDataset(cfg.DATA_BASE_DIR, 'test', cfg, transforms(cfg.RGB_MEAN, cfg.RGB_STD, 'test'))

mrcnn = MaskRCNN(cfg)
mrcnn = mrcnn.to(DEVICE)
utils.load_weights(mrcnn, MODEL_PATH)

In [None]:
@torch.no_grad()
def infer(model, image_patches, image_metas, config):
    model.eval()
    image_patches = image_patches.to(DEVICE)

    patches = image_patches.size(0)
    detections, masks = model([image_patches])
    bboxes, class_ids, scores = detections[:, :, :4], detections[:, :, 4].long(), detections[:, :, 5]
    _, _, windows = parse_image_metas(image_metas)

    pred_class_ids = []
    pred_scores = []
    pred_bboxes = []
    pred_masks = []
    for i in range(patches):
        idx = torch.nonzero(class_ids[i])[:, 0]
        p_class_ids = class_ids[i, idx]
        p_scores = scores[i, idx]
        p_bboxes = bboxes[i, idx]
        p_masks = masks[i, idx, p_class_ids]
        window = windows[i]

        p_bboxes[:, [0, 2]] *= window[2] - window[0]
        p_bboxes[:, [1, 3]] *= window[3] - window[1]
        p_bboxes[:, [0, 2]] += window[0]
        p_bboxes[:, [1, 3]] += window[1]
        p_bboxes = p_bboxes.round()

        # Filter out detections with zero area. Often only happens in early
        # stages of training when the network weights are still a bit random.
        areas = (p_bboxes[:, 2] - p_bboxes[:, 0]) * (p_bboxes[:, 3] - p_bboxes[:, 1])
        idx = torch.nonzero(areas > 0)[:, 0]
        p_class_ids = p_class_ids[idx]
        p_scores = p_scores[idx]
        p_bboxes = p_bboxes[idx]
        p_masks = p_masks[idx]

        pred_class_ids.append(p_class_ids.int().cpu().numpy())
        pred_scores.append(p_scores)
        pred_bboxes.append(p_bboxes)
        pred_masks.append(p_masks.cpu().numpy().transpose((1, 2, 0)))
    pred_bboxes = torch.cat(pred_bboxes, dim=0)
    pred_scores = torch.cat(pred_scores, dim=0)
    keep = ops.nms(pred_bboxes, pred_scores, config.DETECTION_NMS_THRESHOLD)
    keep = keep.cpu().numpy()

    pred_bboxes = pred_bboxes.cpu().numpy()[keep]
    pred_scores = pred_scores.cpu().numpy()[keep]
    pred_class_ids = np.concatenate(pred_class_ids, axis=0)[keep]
    pred_masks = np.concatenate(pred_masks, axis=2)[:, :, keep]

    return pred_class_ids, pred_scores, pred_bboxes, pred_masks

In [None]:
idx = random.randint(0, len(dataset) - 1)
image_patches, image_metas = dataset[idx]

image_meta = image_metas[0].numpy()
img_id = image_meta[0]
shape = image_meta[1:4]
img_obj = dataset.coco.imgs[img_id]
anns = dataset.coco.imgToAnns[img_id]
image = skimage.io.imread(os.path.join(cfg.DATA_BASE_DIR, 'images', img_obj['file_name']))
gt_class_ids = np.zeros((len(anns),), dtype=int)
gt_bboxes = np.zeros((len(anns), 4))
gt_masks = np.zeros((shape[0], shape[1], len(anns)))
for i, ann in enumerate(anns):
    gt_class_ids[i] = ann['category_id']
    bbox = ann['bbox']
    gt_bboxes[i, :] = np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
    gt_masks[:, :, i] = seg_to_mask(ann['segmentation'], shape[0], shape[1])

fig = plt.figure(figsize=(12, 16), dpi=150)
ax = fig.add_subplot(1, 1, 1)
visualize.display_instances(image, gt_bboxes, gt_masks, gt_class_ids, categories, ax=ax)

In [None]:
pred_class_ids, pred_scores, pred_bboxes, pred_masks = infer(mrcnn, image_patches, image_metas, cfg)
if cfg.USE_MINI_MASK:
    refind_masks = np.zeros((shape[0], shape[1], len(pred_class_ids)))
    for i in range(len(pred_class_ids)):
        refind_masks[:, :, i] = expand_mask(pred_bboxes[i], pred_masks[:, :, i], shape[:2])
    pred_masks = refind_masks

fig = plt.figure(figsize=(12, 16), dpi=150)
ax = fig.add_subplot(1, 1, 1)
visualize.display_instances(image, pred_bboxes[:len(anns)], pred_masks[:,:,:len(anns)], pred_class_ids[:len(anns)], categories, pred_scores[:len(anns)], ax=ax)