In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
import torch 
import numpy as np
from torchvision import ops
from sklearn import metrics

from models import maskrcnn2d
from datasets import T4SegmentationDataset2DDepthAsClass
from predict_frame import simple_nms, spicy_nms, global_nms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def crowd_nms(image_tensor, masks, boxes, scores, depths, conf_threshold=0.6, iou_threshold=0.9):
    scores = scores.cpu().detach()
    boxes = boxes.cpu().detach()
    boxes_orig = torch.clone(boxes)
    boxes_orig = np.array(boxes_orig).tolist()

    if len(boxes) == 0:
        return []
    box_keep = []
    scores_keep = []
    keep = [0 for i in range(len(boxes))]

    while len(scores) > 0:
        filter1 = np.where(scores > 0.2)[0]
        boxes = boxes[filter1]
        scores = scores[filter1]
        if len(scores) < 1:
            break
        order = np.argsort(-scores)
        i = np.argwhere(order == 0)[0][0]
        a1 = boxes[i]
        a2 = scores[i]
        s = (a1[3] - a1[1]) * (a1[2] - a1[0])
        boxes = np.delete(boxes, i, axis=0)
        scores = np.delete(scores, i, axis=0)
        box_keep.append(a1)
        scores_keep.append(a2)

        if len(order) == 1:
            break
        for i in range(len(boxes)):
            xx1 = boxes[i][0]
            yy1 = boxes[i][1]
            xx2 = boxes[i][2]
            yy2 = boxes[i][3]
            x1 = max(a1[0], xx1)
            y1 = max(a1[1], yy1)
            x2 = min(a1[2], xx2)
            y2 = min(a1[3], yy2)
            inter = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
            area = (yy2 - yy1) * (xx2 - xx1)
            union = s + area - inter
            iou = inter / (union + 1e-6)
            ar = np.minimum(s, area) / np.maximum(s, area)
            if iou < iou_threshold:
                continue
            scores[i] *= 1 / np.pi * np.exp(-0.5 * (iou ** 2 + ar ** 2))
    if len(box_keep) > 1:
        box_keep = np.array(box_keep).tolist()
    box_keep = [box.numpy().tolist() for box in box_keep]
    for box in box_keep:
        keep[boxes_orig.index(box)] = 1
    return keep

In [4]:
def dice_score(pred, target):
    # Binarise vectors and put into contigious memory
    pred = ((pred > 0.5).float() * 1).squeeze().contiguous()
    target = ((target > 0.5).float() * 1).squeeze().contiguous()


    intersection = (pred * target).sum().sum()
    union = (pred + target).sum().sum()

    return ((2. * intersection) / union).mean()

In [5]:
def compute_confusion_and_dice(boxes, masks, depths, gt_boxes, gt_masks, gt_depths, iou_threshold=0.7):
    tp, fp, fn = 0, 0, 0
    matched_boxes = []
    dice_scores = []
    for gt_box, gt_mask, gt_depth in zip(gt_boxes, gt_masks, gt_depths):
        delta_depths = [abs(depth.item() - gt_depth.item()) for depth in depths]
        # Find all detections within 1 FP
        detections_at_similar_depth = [(box, mask) for box, mask, delta in zip(boxes, masks, delta_depths) if delta <= 1 and box.numpy().tolist() not in matched_boxes]
        if len(detections_at_similar_depth) == 0:
            # No detections on same plane
            fn += 1
        else:
            boxes_at_similar_depth = torch.cat([box.unsqueeze(0) for box, _ in detections_at_similar_depth])
            masks_at_similar_depth = torch.cat([mask.unsqueeze(0) for _, mask in detections_at_similar_depth])
            # Compute IOUs
            ious = ops.box_iou(boxes_at_similar_depth, gt_box.unsqueeze(0))
            if ious.squeeze().max() < iou_threshold:
                # No sufficiently overlapping detections
                fn += 1
            else:
                max_iou_idx = np.argmax(ious.squeeze())
                matched_boxes.append(boxes_at_similar_depth[max_iou_idx].numpy().tolist())
                tp += 1
                # Compute dice score
                dice = dice_score(gt_mask.cuda(), masks_at_similar_depth[max_iou_idx].cuda())
                dice_scores.append(dice.cpu().item())
    fp = len(boxes) - len(matched_boxes)
    return tp, fp, fn, dice_scores

In [6]:
def eval_seg(fold, iou_threshold=0.7, conf_threshold=0.7, nms=spicy_nms):
    # Setup
    dataset = T4SegmentationDataset2DDepthAsClass(
        data_dir='/datasets/test/stacks/t4', 
        label_dir='/datasets/test/seg/t4'
    )
    print(f'N = {len(dataset)}')
    model = maskrcnn2d(
        12).cuda() if torch.cuda.is_available() else maskrcnn2d(12)
    model.load_state_dict(torch.load(f'../fold_{fold}_model_2000_new_data.ckpt'))
    model.eval()

    tp = 0
    fp = 0
    fn = 0
    dice_scores = []

    # Forward through seg model
    for img_idx in range(len(dataset)):
        image, target = dataset[img_idx]
        image_tensor = torch.from_numpy(image).cuda().float().permute(2, 0, 1).unsqueeze(0)
        pred = model.forward(image_tensor)
        # Extract result
        pred = pred[0]
        masks = pred['masks']
        boxes = pred['boxes'].int()
        scores = pred['scores']
        depths = pred['labels']
        # Do NMS
        keep = np.array(nms(None, masks, boxes, scores, depths, conf_threshold=conf_threshold, iou_threshold=iou_threshold))
        if len([k for k in keep if k == 1]) == 0:
            fn += len(keep)
            continue
        final_masks = [masks[i] for i, k in enumerate(keep) if k == 1]
        final_boxes = torch.cat([boxes[i].unsqueeze(0) for i, k in enumerate(keep) if k == 1], axis=0).int().cpu()
        final_depths = [depths[i] for i, k in enumerate(keep) if k == 1]
        # Compute metrics
        sample_tp, sample_fp, sample_fn, sample_dice_scores = compute_confusion_and_dice(
            final_boxes, 
            final_masks, 
            final_depths, 
            target['boxes'].int(),
            target['masks'], 
            target['labels'], 
            iou_threshold=iou_threshold
        )
        tp += sample_tp
        fp += sample_fp
        fn += sample_fn
        dice_scores.extend(sample_dice_scores)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    mean_dice, std_dice = np.mean(dice_scores), np.std(dice_scores)
    N = tp + fn
    f1 = 2 * (precision * recall) / (precision + recall)
    return precision, recall, f1, mean_dice, std_dice, N

In [7]:
def eval_with_folds(conf_threshold, iou_threshold, nms):
    recalls = []
    precisions = []
    f1s = []
    dice_scores = []
    for fold in range(5):
        print(f'Evaluating fold {fold}')
        precision, recall, f1, mean_dice, std_dice, N = eval_seg(fold, conf_threshold=conf_threshold, iou_threshold=iou_threshold, nms=nms)
        recalls.append(recall)
        precisions.append(precision)
        f1s.append(f1)
        dice_scores.append(mean_dice)

    print(f'Precision {np.mean(precisions)} ({np.std(precisions)}); Recall {np.mean(recalls)} ({np.std(recalls)}); F1 {np.mean(f1s)} ({np.std(f1s)}); Mean Dice {np.mean(dice_scores)} ({np.std(dice_scores)})')

In [8]:
algorithms = [simple_nms, crowd_nms] # spicy_nms, global_nms
conf_thresholds = [0.6, 0.7, 0.8, 0.9]
iou_thresholds = [0.6, 0.7, 0.8, 0.9]

for algorithm in algorithms:
    best_params = None
    best_f1 = 0
    for c in conf_thresholds:
        for i in iou_thresholds:
            _, _, f1, _, _, _ = eval_seg(0, iou_threshold=i, conf_threshold=c, nms=algorithm)
            if f1 > best_f1:
                best_f1 = f1
                best_params = (c, i)
    print(f'Best params for {algorithm} are {best_params}')



N = 62


KeyboardInterrupt: 

In [9]:
eval_with_folds(
    conf_threshold=0.8,
    iou_threshold=0.6,
    nms=spicy_nms
)

eval_with_folds(
    conf_threshold=0.8,
    iou_threshold=0.6,
    nms=global_nms
)

eval_with_folds(
    conf_threshold=0.6,
    iou_threshold=0.6,
    nms=crowd_nms
)

Evaluating fold 0
N = 62
Evaluating fold 1
N = 62
Evaluating fold 2
N = 62
Evaluating fold 3
N = 62
Evaluating fold 4
N = 62
Precision 0.8816839923230733 (0.017096472518584607); Recall 0.8888274243567178 (0.011482971568964842); F1 0.8852077350234266 (0.013556726217415651); Mean Dice 0.9444936133946277 (0.0011286591456861924)
Evaluating fold 0
N = 62
Evaluating fold 1
N = 62
Evaluating fold 2
N = 62
Evaluating fold 3
N = 62
Evaluating fold 4
N = 62
Precision 0.925548247645515 (0.01026635117917693); Recall 0.812402265407043 (0.014121644760438588); F1 0.8652811412711866 (0.012404216408863735); Mean Dice 0.9437340234908944 (0.0010560684080046247)
Evaluating fold 0
N = 62


  box_keep = np.array(box_keep).tolist()
  box_keep = np.array(box_keep).tolist()


Evaluating fold 1
N = 62
Evaluating fold 2
N = 62
Evaluating fold 3
N = 62
Evaluating fold 4
N = 62
Precision 0.8670909937549409 (0.023328274347099227); Recall 0.8445993031358885 (0.00566135080462436); F1 0.8554628936846559 (0.00852486722857628); Mean Dice 0.9432191722881766 (0.0013864537431772781)
