In [None]:
import os
import random
import torch
import skimage.io as io
import numpy as np
import matplotlib.pyplot as plt
from torchvision import ops
from pycocotools.coco import COCO

from config import Config
from model import MaskRCNN
from acne_data import AcneSegDataset, transforms, expand_mask, seg_to_mask
import visualize
import utils

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

MODEL_PATH = 'run/resnet101_all_epoch_160.pth'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# 背景 丘疹 痣 节结
# 开口粉刺 闭口粉刺
# 萎缩性瘢痕 肥厚性瘢痕
# 黄褐斑 脓疱 其它
categories = ['BG', 'papule', 'nevus', 'nodule',
              'open_comedo', 'closed_comedo',
              'atrophic_scar', 'hypertrophic_scar',
              'melasma', 'pustule', 'other']
category_to_id = {c: i for i, c in enumerate(categories)}

cfg = Config()
dataset = AcneSegDataset(os.path.join(cfg.DATA_BASE_DIR, 'test_patch'),
                         os.path.join(cfg.DATA_BASE_DIR, 'annotations', 'acne_test.json'),
                         'test', cfg, transforms(cfg.RGB_MEAN, cfg.RGB_STD, cfg.IMAGE_SHAPE[:2], 'test'))

mrcnn = MaskRCNN(cfg)
mrcnn = mrcnn.to(DEVICE)
utils.load_weights(mrcnn, MODEL_PATH)

In [None]:
@torch.no_grad()
def infer(model, images, config):
    model.eval()
    images = images.to(DEVICE)

    bs = images.size(0)
    detections, masks = model([images])
    bboxes, class_ids, scores = detections[:, :, :4], detections[:, :, 4].long(), detections[:, :, 5]

    pred_class_ids = []
    pred_scores = []
    pred_bboxes = []
    pred_masks = []
    for i in range(bs):
        # Filter out background
        idx = torch.nonzero(class_ids[i])[:, 0]
        b_class_ids = class_ids[i, idx]
        b_scores = scores[i, idx]
        b_bboxes = bboxes[i, idx]
        b_masks = masks[i, idx, b_class_ids]

        b_bboxes[:, [0, 2]] *= config.IMAGE_SHAPE[0]
        b_bboxes[:, [1, 3]] *= config.IMAGE_SHAPE[1]
        b_bboxes = utils.clip_boxes(b_bboxes, [0, 0, config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1]])
        b_bboxes = b_bboxes.round()

        # Filter out detections with zero area. Often only happens in early
        # stages of training when the network weights are still a bit random.
        areas = (b_bboxes[:, 2] - b_bboxes[:, 0]) * (b_bboxes[:, 3] - b_bboxes[:, 1])
        idx = torch.nonzero(areas > 0)[:, 0]
        b_class_ids = b_class_ids[idx]
        b_scores = b_scores[idx]
        b_bboxes = b_bboxes[idx]
        b_masks = b_masks[idx]

        pred_class_ids.append(b_class_ids.int().cpu().numpy())
        pred_scores.append(b_scores.cpu().numpy())
        pred_bboxes.append(b_bboxes.cpu().numpy())
        pred_masks.append(b_masks.cpu().numpy().transpose((1, 2, 0)))

    return pred_class_ids, pred_scores, pred_bboxes, pred_masks

In [None]:
idx = random.randint(0, len(dataset) - 1)
image, img_id = dataset[idx]
pred_class_ids, pred_scores, pred_bboxes, pred_masks = infer(mrcnn, image.unsqueeze(0), cfg)
if cfg.USE_MINI_MASK:
    refind_masks = np.zeros((cfg.IMAGE_SHAPE[0], cfg.IMAGE_SHAPE[1], len(pred_class_ids[0])))
    for i in range(len(pred_class_ids[0])):
        refind_masks[:, :, i] = expand_mask(pred_bboxes[0][i], pred_masks[0][:, :, i], cfg.IMAGE_SHAPE[:2])
    pred_masks = refind_masks

In [None]:
image = image.numpy().transpose((1, 2, 0))
image = np.clip(((image * cfg.RGB_STD) + cfg.RGB_MEAN) * 255, 0, 255).astype(np.int32)

img_obj = dataset.coco.imgs[img_id]
anns = dataset.coco.imgToAnns[img_id]

gt_class_ids = np.zeros((len(anns),), dtype=int)
gt_bboxes = np.zeros((len(anns), 4))
gt_masks = np.zeros((cfg.IMAGE_SHAPE[0], cfg.IMAGE_SHAPE[1], len(anns)))
for i, ann in enumerate(anns):
    gt_class_ids[i] = ann['category_id']
    bbox = ann['bbox']
    gt_bboxes[i, :] = np.round(np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]))
    gt_masks[:, :, i] = seg_to_mask(ann['segmentation'], cfg.IMAGE_SHAPE[0], cfg.IMAGE_SHAPE[1])

fig = plt.figure(figsize=(10, 10), dpi=150)
ax = fig.add_subplot(1, 1, 1)
visualize.display_instances(image, gt_bboxes, gt_masks, gt_class_ids, categories, ax=ax)

In [None]:
fig = plt.figure(figsize=(10, 10), dpi=150)
ax = fig.add_subplot(1, 1, 1)
visualize.display_instances(image, pred_bboxes[0], pred_masks, pred_class_ids[0], categories, pred_scores[0], ax=ax)

In [None]:
coco = COCO(os.path.join(cfg.DATA_BASE_DIR, 'annotations', 'acne_train.json'))
img_id = random.choice(coco.getImgIds())
img_obj = coco.imgs[img_id]
anns = coco.imgToAnns[img_id]

image = io.imread(os.path.join(cfg.DATA_BASE_DIR, 'images', img_obj['file_name']))
win_gen = utils.WindowGenerator(img_obj['height'], img_obj['width'], cfg.INFER_WINDOW_SIZE[0], cfg.INFER_WINDOW_SIZE[1], cfg.INFER_WINDOW_STRIDES[0], cfg.INFER_WINDOW_STRIDES[1])
img_patches = []
windows = []
for slice_h, slice_w in win_gen:
    img_patch = image[slice_h, slice_w, :]
    img_patch = img_patch.astype(np.float32) / 255.0
    img_patch = (img_patch - cfg.RGB_MEAN) / cfg.RGB_STD
    img_patches.append(torch.tensor(img_patch.transpose((2, 0, 1)).copy()).float())
    windows.append([slice_w.start, slice_h.start, slice_w.stop, slice_h.stop])
img_patches = torch.stack(img_patches, dim=0)

gt_class_ids = np.zeros((len(anns),), dtype=int)
gt_bboxes = np.zeros((len(anns), 4))
gt_masks = np.zeros((img_obj['height'], img_obj['width'], len(anns)))
for i, ann in enumerate(anns):
    gt_class_ids[i] = ann['category_id']
    bbox = ann['bbox']
    gt_bboxes[i, :] = np.round(np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]))
    gt_masks[:, :, i] = seg_to_mask(ann['segmentation'], img_obj['height'], img_obj['width'])

fig = plt.figure(figsize=(12, 16), dpi=150)
ax = fig.add_subplot(1, 1, 1)
visualize.display_instances(image, gt_bboxes, gt_masks, gt_class_ids, categories, ax=ax)

In [None]:
pred_class_ids, pred_scores, pred_bboxes, pred_masks = infer(mrcnn, img_patches, cfg)

for i, bboxes in enumerate(pred_bboxes):
    win = windows[i]
    bboxes[:, [0, 2]] += win[0]
    bboxes[:, [1, 3]] += win[1]

pred_class_ids = np.concatenate(pred_class_ids, axis=0)
pred_scores = np.concatenate(pred_scores, axis=0)
pred_bboxes = np.concatenate(pred_bboxes, axis=0)
pred_masks = np.concatenate(pred_masks, axis=2)

keep = ops.nms(torch.tensor(pred_bboxes), torch.tensor(pred_scores), 0.3)
keep = keep.numpy()
pred_class_ids = pred_class_ids[keep]
pred_scores = pred_scores[keep]
pred_bboxes = pred_bboxes[keep]
pred_masks = pred_masks[keep]

if cfg.USE_MINI_MASK:
    refind_masks = np.zeros((img_obj['height'], img_obj['width'], len(pred_class_ids)))
    for i in range(len(pred_class_ids)):
        refind_masks[:, :, i] = expand_mask(pred_bboxes[i], pred_masks[:, :, i], [img_obj['height'], img_obj['width']])
    pred_masks = refind_masks

fig = plt.figure(figsize=(12, 16), dpi=150)
ax = fig.add_subplot(1, 1, 1)
visualize.display_instances(image, pred_bboxes, pred_masks, pred_class_ids, categories, pred_scores, ax=ax)