# Pytorch starter - FasterRCNN Inference

- You can find the [train notebook here](https://www.kaggle.com/pestipeti/pytorch-starter-fasterrcnn-train)
- The weights are [available here](https://www.kaggle.com/dataset/7d5f1ed9454c848ecb909c109c6fa8e573ea4de299e249c79edc6f47660bf4c5)

In [1]:
import pandas as pd
import numpy as np
import cv2
import os
import re

from PIL import Image

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler

from matplotlib import pyplot as plt
from tqdm import tqdm
from numba import jit
import numba

DIR_INPUT = '/home/hy/dataset/gwd'
DIR_TRAIN = f'{DIR_INPUT}/train'
DIR_TEST = f'{DIR_INPUT}/test'

DIR_WEIGHTS = '/home/hy/kaggle/gwd'

WEIGHTS_FILE = f'{DIR_WEIGHTS}/0513_fasterrcnn_resnet50_fpn.pth'

In [2]:
train_df = pd.read_csv(f'{DIR_INPUT}/train.csv')
train_df.shape

(147793, 5)

In [3]:
train_df['x'] = -1
train_df['y'] = -1
train_df['w'] = -1
train_df['h'] = -1

def expand_bbox(x):
    r = np.array(re.findall("([0-9]+[.]?[0-9]*)", x))
    if len(r) == 0:
        r = [-1, -1, -1, -1]
    return r

train_df[['x', 'y', 'w', 'h']] = np.stack(train_df['bbox'].apply(lambda x: expand_bbox(x)))
train_df.drop(columns=['bbox'], inplace=True)
train_df['x'] = train_df['x'].astype(np.float)
train_df['y'] = train_df['y'].astype(np.float)
train_df['w'] = train_df['w'].astype(np.float)
train_df['h'] = train_df['h'].astype(np.float)

In [4]:
image_ids = train_df['image_id'].unique()
valid_ids = image_ids[-665:]
train_ids = image_ids[:-665]

In [5]:
valid_df = train_df[train_df['image_id'].isin(valid_ids)]
train_df = train_df[train_df['image_id'].isin(train_ids)]

In [6]:
valid_df.shape, train_df.shape

((25006, 8), (122787, 8))

In [7]:
class WheatDataset(Dataset):

    def __init__(self, dataframe, image_dir, transforms=None):
        super().__init__()

        self.image_ids = dataframe['image_id'].unique()
        self.df = dataframe
        self.image_dir = image_dir
        self.transforms = transforms

    def __getitem__(self, index: int):

        image_id = self.image_ids[index]
        records = self.df[self.df['image_id'] == image_id]

        image = cv2.imread(f'{self.image_dir}/{image_id}.jpg', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0

        boxes = records[['x', 'y', 'w', 'h']].values
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        area = torch.as_tensor(area, dtype=torch.float32)

        # there is only one class
        labels = torch.ones((records.shape[0],), dtype=torch.int64)
        
        # suppose all instances are not crowd
        iscrowd = torch.zeros((records.shape[0],), dtype=torch.int64)
        
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        # target['masks'] = None
        target['image_id'] = torch.tensor([index])
        target['area'] = area
        target['iscrowd'] = iscrowd

        if self.transforms:
            sample = {
                'image': image,
                'bboxes': target['boxes'],
                'labels': labels
            }
            sample = self.transforms(**sample)
            image = sample['image']
            
            target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)

        return image, target, image_id

    def __len__(self) -> int:
        return self.image_ids.shape[0]

In [8]:
# Albumentations
def get_train_transform():
    return A.Compose([
        A.Flip(0.5),
        ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

def get_valid_transform():
    return A.Compose([
        ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

In [9]:
# load a model; pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)

In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_classes = 2  # 1 class (wheat) + background

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Load the trained weights
model.load_state_dict(torch.load(WEIGHTS_FILE))
model.to(device)
model.eval()


FasterRCNN(
  (transform): GeneralizedRCNNTransform()
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(original_name=FrozenBatchNorm2d)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(original_name=FrozenBatchNorm2d)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(original_name=FrozenBatchNorm2d)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(original_name=FrozenBatchNorm2d)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
       

In [11]:
def collate_fn(batch):
    return tuple(zip(*batch))

train_dataset = WheatDataset(train_df[:100], DIR_TRAIN, get_train_transform())
valid_dataset = WheatDataset(valid_df[:50], DIR_TRAIN, get_valid_transform())


# split the dataset in train and test set
indices = torch.randperm(len(train_dataset)).tolist()

train_data_loader = DataLoader(
    train_dataset,
    batch_size=28,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)

valid_data_loader = DataLoader(
    valid_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)

In [12]:
validation_image_precisions = []
iou_thresholds = [x for x in np.arange(0.5, 0.76, 0.05)]

### MAP calculation

In [13]:
@jit(nopython=True)
def calculate_iou(gt, pr, form='pascal_voc') -> float:
    """Calculates the Intersection over Union.

    Args:
        gt: (np.ndarray[Union[int, float]]) coordinates of the ground-truth box
        pr: (np.ndarray[Union[int, float]]) coordinates of the prdected box
        form: (str) gt/pred coordinates format
            - pascal_voc: [xmin, ymin, xmax, ymax]
            - coco: [xmin, ymin, w, h]
    Returns:
        (float) Intersection over union (0.0 <= iou <= 1.0)
    """
    if form == 'coco':
        gt = gt.copy()
        pr = pr.copy()

        gt[2] = gt[0] + gt[2]
        gt[3] = gt[1] + gt[3]
        pr[2] = pr[0] + pr[2]
        pr[3] = pr[1] + pr[3]

    # Calculate overlap area
    dx = min(gt[2], pr[2]) - max(gt[0], pr[0]) + 1
    
    if dx < 0:
        return 0.0
    
    dy = min(gt[3], pr[3]) - max(gt[1], pr[1]) + 1

    if dy < 0:
        return 0.0

    overlap_area = dx * dy

    # Calculate union area
    union_area = (
            (gt[2] - gt[0] + 1) * (gt[3] - gt[1] + 1) +
            (pr[2] - pr[0] + 1) * (pr[3] - pr[1] + 1) -
            overlap_area
    )

    return overlap_area / union_area

In [14]:
@jit(nopython=True)
def find_best_match(gts, pred, pred_idx, threshold = 0.5, form = 'pascal_voc', ious=None) -> int:
    """Returns the index of the 'best match' between the
    ground-truth boxes and the prediction. The 'best match'
    is the highest IoU. (0.0 IoUs are ignored).

    Args:
        gts: (List[List[Union[int, float]]]) Coordinates of the available ground-truth boxes
        pred: (List[Union[int, float]]) Coordinates of the predicted box
        pred_idx: (int) Index of the current predicted box
        threshold: (float) Threshold
        form: (str) Format of the coordinates
        ious: (np.ndarray) len(gts) x len(preds) matrix for storing calculated ious.

    Return:
        (int) Index of the best match GT box (-1 if no match above threshold)
    """
    best_match_iou = -np.inf
    best_match_idx = -1

    for gt_idx in range(len(gts)):
        
        if gts[gt_idx][0] < 0:
            # Already matched GT-box
            continue
        
        iou = -1 if ious is None else ious[gt_idx][pred_idx]

        if iou < 0:
            iou = calculate_iou(gts[gt_idx], pred, form=form)
            
            if ious is not None:
                ious[gt_idx][pred_idx] = iou

        if iou < threshold:
            continue

        if iou > best_match_iou:
            best_match_iou = iou
            best_match_idx = gt_idx

    return best_match_idx

@jit(nopython=True)
def calculate_precision(gts, preds, threshold = 0.5, form = 'coco', ious=None) -> float:
    """Calculates precision for GT - prediction pairs at one threshold.

    Args:
        gts: (List[List[Union[int, float]]]) Coordinates of the available ground-truth boxes
        preds: (List[List[Union[int, float]]]) Coordinates of the predicted boxes,
               sorted by confidence value (descending)
        threshold: (float) Threshold
        form: (str) Format of the coordinates
        ious: (np.ndarray) len(gts) x len(preds) matrix for storing calculated ious.

    Return:
        (float) Precision
    """
    n = len(preds)
    tp = 0
    fp = 0
    
    # for pred_idx, pred in enumerate(preds_sorted):
    for pred_idx in range(n):

        best_match_gt_idx = find_best_match(gts, preds[pred_idx], pred_idx,
                                            threshold=threshold, form=form, ious=ious)

        if best_match_gt_idx >= 0:
            # True positive: The predicted box matches a gt box with an IoU above the threshold.
            tp += 1
            # Remove the matched GT box
            gts[best_match_gt_idx] = -1

        else:
            # No match
            # False positive: indicates a predicted box had no associated gt box.
            fp += 1

    # False negative: indicates a gt box had no associated predicted box.
    fn = (gts.sum(axis=1) > 0).sum()

    return tp / (tp + fp + fn)


@jit(nopython=True)
def calculate_image_precision(gts, preds, thresholds = (0.5, ), form = 'coco') -> float:
    """Calculates image precision.

    Args:
        gts: (List[List[Union[int, float]]]) Coordinates of the available ground-truth boxes
        preds: (List[List[Union[int, float]]]) Coordinates of the predicted boxes,
               sorted by confidence value (descending)
        thresholds: (float) Different thresholds
        form: (str) Format of the coordinates

    Return:
        (float) Precision
    """
    n_threshold = len(thresholds)
    image_precision = 0.0
    
    ious = np.ones((len(gts), len(preds))) * -1
    # ious = None

    for threshold in thresholds:
        precision_at_threshold = calculate_precision(gts.copy(), preds, threshold=threshold,
                                                     form=form, ious=ious)
        image_precision += precision_at_threshold / n_threshold

    return image_precision

In [19]:
for images, targets, image_ids in tqdm(valid_data_loader):
    
        images = list(image.to(device) for image in images)
        outputs = model(images)
        for i, image in enumerate(images):
    
            preds = outputs[i]['boxes'].data.cpu().numpy()
            scores = outputs[i]['scores'].data.cpu().numpy()
            image_id = image_ids[i]
            print('image_id:',image_id)
            print('scores:',scores)
            preds[:, 2] = preds[:, 2] - preds[:, 0]
            preds[:, 3] = preds[:, 3] - preds[:, 1]
            
            gt_boxes = valid_df[valid_df['image_id'] == image_id][['x', 'y', 'w', 'h']].values
            gt_boxes = gt_boxes.astype(np.int)
            #print('preds:',preds.astype(np.int))
            preds_sorted_idx = np.argsort(scores)[::-1]
            #print('preds_sorted_idx:',preds_sorted_idx)
            preds_sorted = preds[preds_sorted_idx].astype(np.int)
            #print('preds_sorted:',preds_sorted)
            image_precision = calculate_image_precision(preds_sorted,
                                                        gt_boxes,
                                                        thresholds=iou_thresholds,
                                                        form='coco')
            validation_image_precisions.append(image_precision)
            print('validation_image_precisions:',validation_image_precisions)
print("Validation IOU: {0:.4f}".format(np.mean(validation_image_precisions)))

100%|██████████| 1/1 [00:00<00:00,  1.87it/s]

image_id: bbce58f71
scores: [0.99686426 0.9966024  0.9957345  0.9953365  0.9952716  0.99508375
 0.9947802  0.99468046 0.9933749  0.992993   0.9929323  0.9924238
 0.9923258  0.99190485 0.99023527 0.99004984 0.9899067  0.98919463
 0.9884278  0.9867121  0.98631895 0.984937   0.9844738  0.9841298
 0.98174804 0.9814953  0.97958666 0.97730404 0.9770215  0.97695816
 0.9679233  0.9672326  0.9626706  0.96222174 0.9561973  0.95138264
 0.9457221  0.9418765  0.9223838  0.8865785  0.8547953  0.7260833
 0.54316944 0.51430106 0.29412213 0.28823933 0.15232734 0.13994728
 0.09286306 0.09246738 0.06437422]
preds: [[672  86 147 100]
 [931 910  92 109]
 [312 584 130  99]
 [789 646 148  72]
 [581  54 118  76]
 [ 60   3  88  68]
 [343 831 103 115]
 [320 921  88 101]
 [522 458 140  77]
 [479 314 144  73]
 [891 173 105  65]
 [944 802  79 103]
 [158 302  94  97]
 [212 699  96  76]
 [718 415 130  96]
 [257 379  94  68]
 [770 319 151  70]
 [427   7 106  72]
 [192 473  88  72]
 [555 721 151  77]
 [948 703  75  75




In [16]:
#sample_id = '1ef16dab1'
sample_id = 'bbce58f71'
gt_boxes = valid_df[valid_df['image_id'] == sample_id][['x', 'y', 'w', 'h']].values
gt_boxes = gt_boxes.astype(np.int)

# Ground-truth boxes of our sample
gt_boxes

array([[430,   0, 142,  75],
       [255, 377, 100,  69],
       [788, 634, 146,  77],
       [344, 836, 102, 110],
       [217, 694, 107,  85],
       [885, 174, 115,  64],
       [752, 263, 160,  64],
       [  2, 685, 122, 100],
       [709, 901, 117,  96],
       [587,  49, 120,  89],
       [519, 455, 141,  80],
       [547, 716, 162,  80],
       [195, 482,  87,  66],
       [932, 473,  92,  77],
       [480, 911,  95, 105],
       [  0, 801,   5,  79],
       [309, 573, 128, 107],
       [671,  83, 161, 110],
       [ 57,   0, 100,  66],
       [314, 127, 144,  77],
       [157, 312, 103,  87],
       [217, 539,  94,  64],
       [910, 389, 113,  73],
       [465, 314, 163,  63],
       [808, 408, 128,  75],
       [626, 493, 174, 121],
       [394, 558,  84,  71],
       [  0, 326,  94,  81],
       [721, 412, 130, 104],
       [583, 555, 154,  67],
       [  0, 742, 102,  80],
       [310, 922, 100, 102],
       [795, 880, 114, 142],
       [770, 325, 151,  61],
       [376, 1

In [22]:
preds:[[672,  86, 147, 100],
 [931, 910,  92, 109],
 [312, 584, 130,  99],
 [789, 646, 148,  72],
 [581,  54, 118,  76],
 [ 60,   3,  88,  68],
 [343, 831, 103, 115],
 [320, 921,  88, 101],
 [522, 458, 140,  77],
 [479, 314, 144,  73],
 [891, 173, 105,  65],
 [944, 802,  79, 103],
 [158, 302,  94,  97],
 [212, 699,  96,  76],
 [718, 415, 130,  96],
 [257, 379,  94,  68],
 [770, 319, 151,  70],
 [427,   7, 106,  72],
 [192, 473,  88,  72],
 [555, 721, 151,  77],
 [948, 703,  75,  75],
 [698, 889, 129, 104],
 [919, 478, 104,  71],
 [471, 901,  97, 106],
 [788, 403, 133,  79],
 [399, 561,  82,  65],
 [218, 543,  98,  64],
 [587, 558, 134,  70],
 [  0, 683, 133,  93],
 [  0, 315,  89,  78],
 [765, 267, 110,  61],
 [535, 262, 178,  72],
 [320, 131, 114,  68],
 [910, 388, 112,  66],
 [  0, 259,  79,  71],
 [368, 163, 100,  65],
 [  3, 747,  85,  77],
 [146, 925,  80,  72],
 [791, 878, 102, 123],
 [657, 533, 136,  77],
 [338, 141, 121,  80],
 [979, 223,  44,  54],
 [812, 939,  92,  82],
 [633, 504, 149,  85],
 [762, 279, 146,  97],
 [460, 993, 101,  30],
 [710, 898, 180, 116],
 [494, 585,  54,  56],
 [ 62, 665,  64,  47],
 [ 71, 634,  71,  38],
 [808, 519,  33,  85]]

In [23]:
len(preds)

62

In [24]:
scores = np.array([0.99686426, 0.9966024, 0.9957345, 0.9953365, 0.9952716, 0.99508375,
 0.9947802,  0.99468046, 0.9933749,  0.992993,   0.9929323,  0.9924238,
 0.9923258,  0.99190485, 0.99023527, 0.99004984, 0.9899067,  0.98919463,
 0.9884278,  0.9867121,  0.98631895, 0.984937,   0.9844738,  0.9841298,
 0.98174804, 0.9814953,  0.97958666, 0.97730404, 0.9770215,  0.97695816,
 0.9679233,  0.9672326,  0.9626706,  0.96222174, 0.9561973,  0.95138264,
 0.9457221,  0.9418765,  0.9223838,  0.8865785,  0.8547953,  0.7260833,
 0.54316944, 0.51430106, 0.29412213, 0.28823933, 0.15232734, 0.13994728,
 0.09286306, 0.09246738, 0.06437422])

In [25]:
len(scores)

51

In [26]:
# Sort highest confidence -> lowest confidence
preds_sorted_idx = np.argsort(scores)[::-1]
preds_sorted = preds[preds_sorted_idx]

In [30]:
len(preds_sorted)

51

In [32]:
gt_boxes.copy()

array([[ 38, 547, 100,  87],
       [266, 181,  97,  87],
       [958,  80,  66,  52],
       [732, 348, 192,  84],
       [ 20,   0, 115,  53],
       [512, 479, 190, 195],
       [422, 294, 108, 135],
       [488, 261,  98,  64],
       [501, 143,  90,  79],
       [ 32, 472,  71,  70],
       [614, 266, 154,  64],
       [910, 117,  83,  73],
       [570,  46, 154,  84],
       [302,  10,  82,  71],
       [880, 775, 108, 182],
       [ 42, 843, 121, 149],
       [ 15, 688,  72, 118],
       [483, 808,  82, 149],
       [  0, 862,  56, 121],
       [770, 747,  90,  90],
       [563, 174, 115,  71],
       [714, 256, 228,  84],
       [401,  61,  75,  64],
       [419,   0, 121,  37],
       [371, 204,  94,  82],
       [ 81, 425,  84,  95],
       [  0, 292,  39,  98],
       [670, 547, 110,  67],
       [283, 852,  78, 120],
       [391, 151,  77,  97],
       [588, 888, 103,  95],
       [632, 780,  77,  87],
       [  0, 227,  76,  69],
       [586,  81,  69,  72],
       [499,  

In [34]:
precision = calculate_precision(gt_boxes.copy(), preds_sorted, threshold=0.5, form='coco')
print("Precision at threshold 0.5: {0:.4f}".format(precision))

Precision at threshold 0.5: 0.7885
