In [1]:
import re
from __future__ import print_function, division
import xml.etree.ElementTree as ET
import cv2
from skimage import io
from skimage.transform import resize
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image, ImageDraw
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import xml.etree.ElementTree as ET

### Utils

In [2]:
def IoU(anchor, bbox):
    (x1, y1, x2, y2) = anchor
    (x3, y3, x4, y4) = bbox

    intersect_width = max(0.0, min(x2, x4) - max(x1, x3))
    intersect_height = max(0.0, min(y2, y4) - max(y1, y3))
    intersect = intersect_width * intersect_height
    return intersect / ((y2 - y1) * (x2 - x1) + (y4 - y3) * (x4 - x3) - intersect)

def parse_pbtxt(file):
    lines = open(file, 'r+').readlines()
    text = ''.join(lines)
    items = re.findall("item {([^}]*)}", text)
    return [dict(re.findall("(\w*): '*([^\n']*)'*", item)) for item in items]

def get_label_map_from_pbtxt(file):
    items = parse_pbtxt(file)
    result = {}
    for item in items:
        result[int(item['id'])] = item['name']
    return result

def get_inverse_label_map_from_pbtxt(file):
    items = parse_pbtxt(file)
    result = {}
    for item in items:
        result[item['name']] = int(item['id'])
    return result

def nms(dets, cls, thresh):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = cls

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order.item(0)
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep

def parametrize(anchors, bboxes):
    reg = np.zeros(anchors.shape, dtype=np.float32)
    if not len(bboxes):
        return reg

    reg[:, 0] = 0.5 * (bboxes[:, 0] + bboxes[:, 2] - anchors[:, 0] - anchors[:, 2]) / (anchors[:, 2] - anchors[:, 0])
    reg[:, 1] = 0.5 * (bboxes[:, 1] + bboxes[:, 3] - anchors[:, 1] - anchors[:, 3]) / (anchors[:, 3] - anchors[:, 1])
    reg[:, 2] = np.log((bboxes[:, 2] - bboxes[:, 0]) / (anchors[:, 2] - anchors[:, 0]) )
    reg[:, 3] = np.log((bboxes[:, 3] - bboxes[:, 1]) / (anchors[:, 3] - anchors[:, 1]) )
    # print(reg)
    return reg

def unparametrize(anchors, reg):
    reg = reg.view(anchors.shape).float()
    bboxes = torch.zeros(anchors.shape, dtype=torch.float64)

    bboxes[:, 0] = (anchors[:, 2] - anchors[:, 0]) * reg[:, 0] + (anchors[:, 0] + anchors[:, 2]) / 2.0
    bboxes[:, 1] = (anchors[:, 3] - anchors[:, 1]) * reg[:, 1] + (anchors[:, 1] + anchors[:, 3]) / 2.0
    bboxes[:, 2] = (anchors[:, 2] - anchors[:, 0]) * torch.exp(reg[:, 2])
    bboxes[:, 3] = (anchors[:, 3] - anchors[:, 1]) * torch.exp(reg[:, 3])

    bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2.0
    bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2.0
    bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
    bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]

    return bboxes.float()

def count_positive_anchors_on_image(i, dataset):
    bboxes = dataset.get_truth_bboxes(i)
    anchors, _ = dataset.get_image_anchors()
    truth_bbox, positives, negatives = dataset.get_positive_negative_anchors(anchors, bboxes)
    print(anchors[np.where(positives)])
    return len(np.where(positives))

### DataLoader

In [10]:
class VOCDataset(Dataset):
    INPUT_SIZE = (1600, 800)

    def __init__(self, root_dir):
        """
        Args:
            root_dir (string): Directory with all the images under VOC format.
        """
        self.root_dir = root_dir
        self.label_map_path = os.path.join(root_dir, 'pascal_label_map.pbtxt')
        self.tooth_images_paths = os.listdir(os.path.join(root_dir, 'Annotations'))
        self.label_map = self.get_label_map(self.label_map_path)
        self.inverse_label_map = self.get_inverse_label_map(self.label_map_path)

    def __len__(self):
        return len(self.tooth_images_paths)

    def __getitem__(self, i):
        image = self.get_image(self.tooth_images_paths[i].split(".")[0])
        bboxes, classes = self.get_truth_bboxes(self.tooth_images_paths[i].split(".")[0])
        # image input is grayscale, convert to rgb
        im = np.expand_dims(np.stack((resize(image, self.INPUT_SIZE),) * 3), axis=0)
        return im, bboxes, classes

    def get_classes(self):
        return list(self.inverse_label_map.values())

    def get_image(self, i):
        path = os.path.join(self.root_dir, 'JPEGImages', str(i) + '.jpg')
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        return self.preprocess_image(img)

    def preprocess_image(self, img):
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
        cl = clahe.apply(img)
        return cl

    def get_truth_bboxes(self, i):
        path = os.path.join(self.root_dir, 'Annotations', str(i) + '.xml')
        tree = ET.parse(path)
        root = tree.getroot()

        # we need to resize the bboxes to the INPUT_SIZE
        size = root.find('size')
        height = int(size.find('height').text)
        width = int(size.find('width').text)
        width_ratio = float(width) / float(self.INPUT_SIZE[0])
        height_ratio = float(height) / float(self.INPUT_SIZE[1])

        raw_boxes = [child for child in root if child.tag == 'object']
        bboxes = np.array([[[int(d.text) for d in c] for c in object if c.tag == 'bndbox'] for object in raw_boxes])
        classes = np.array(
            [int(self.inverse_label_map[c.text]) for object in raw_boxes for c in object if c.tag == 'name'])
        if not len(bboxes):
            return np.array([]), np.array([])

        bboxes = bboxes.reshape(-1, bboxes.shape[-1])
        for i in [0, 2]:
            bboxes[:, i] = bboxes[:, i] / width_ratio
        for i in [1, 3]:
            bboxes[:, i] = bboxes[:, i] / height_ratio
        return bboxes, classes

    def get_label_map(self, label_map_path):
        return get_label_map_from_pbtxt(label_map_path)

    def get_inverse_label_map(self, label_map_path):
        return get_inverse_label_map_from_pbtxt(label_map_path)

    def get_resized_image(self, i):
        image = self.get_image(i)
        temp_im = Image.fromarray(image).resize(self.INPUT_SIZE)
        im = Image.new('RGB', temp_im.size)
        im.paste(temp_im)
        return im

    def visualise_proposals_on_image(self, bboxes, i):
        im = self.get_resized_image(i)
        draw = ImageDraw.Draw(im)

        for bbox in bboxes:
            draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], outline='blue')

        im.show()

### RPN

In [4]:
class RPN(nn.Module):
    INPUT_SIZE = (1600, 800)
    OUTPUT_SIZE = (100, 50)
    OUTPUT_CELL_SIZE = float(INPUT_SIZE[0]) / float(OUTPUT_SIZE[0])

    # anchors constants
    ANCHORS_RATIOS = [0.25, 0.5, 0.9]
    ANCHORS_SCALES = [4, 6, 8]

    NUMBER_ANCHORS_WIDE = OUTPUT_SIZE[0]
    NUMBER_ANCHORS_HEIGHT = OUTPUT_SIZE[1]

    NEGATIVE_THRESHOLD = 0.3
    POSITIVE_THRESHOLD = 0.6

    ANCHOR_SAMPLING_SIZE = 256

    NMS_THRESHOLD = 0.5
    PRE_NMS_MAX_PROPOSALS = 6000
    POST_NMS_MAX_PROPOSALS = 100

    def __init__(self, in_dim):
        super(RPN, self).__init__()

        self.in_dim = in_dim
        self.anchor_dimensions = self.get_anchor_dimensions()
        self.anchor_number = len(self.anchor_dimensions)
        mid_layers = 1024
        self.RPN_conv = nn.Conv2d(self.in_dim, mid_layers, 3, 1, 1)
        # cls layer
        self.cls_layer = nn.Conv2d(mid_layers, 2  * self.anchor_number, 1, 1, 0)
        # reg_layer
        self.reg_layer = nn.Conv2d(mid_layers, 4 * self.anchor_number, 1, 1, 0)

        #initialize layers
        torch.nn.init.normal_(self.RPN_conv.weight, std=0.01)
        torch.nn.init.normal_(self.cls_layer.weight, std=0.01)
        torch.nn.init.normal_(self.reg_layer.weight, std=0.01)

    def forward(self, x):
        ''' Takes feature map as input'''
        rpn_conv = F.relu(self.RPN_conv(x), inplace=True)
        # permute dimensions
        cls_output = self.cls_layer(rpn_conv).permute(0, 2, 3, 1).contiguous().view(1, -1, 2)
        reg_output = self.reg_layer(rpn_conv).permute(0, 2, 3, 1).contiguous().view(1, -1, 4)

        cls_output = F.softmax(cls_output.view(-1, 2), dim=1)
        reg_output = reg_output.view(-1, 4)
        return cls_output, reg_output

    def get_target(self, bboxes):
        anchors, filter_out = self.get_image_anchors()
        truth_bbox, positives, negatives = self.get_positive_negative_anchors(anchors, bboxes)
        reg_target = parametrize(anchors, truth_bbox)

        n = len(anchors)
        indices = np.array([i for i in range(n)])
        selected_indices, positive_indices = self.get_selected_indices_sample(indices, positives, negatives)

        cls_truth = np.zeros((n, 2))
        cls_truth[np.arange(n), positives.astype(int)] = 1.0
        return torch.from_numpy(reg_target), torch.from_numpy(cls_truth), selected_indices, positive_indices

    def get_anchor_dimensions(self):
        dimensions = []
        for r in self.ANCHORS_RATIOS:
            for s in self.ANCHORS_SCALES:
                width = s * np.sqrt(r)
                height = s * np.sqrt(1.0 / r)
                dimensions.append((width, height))
        return dimensions

    def get_anchors_at_position(self, pos):
        # dimensions of anchors: (self.anchor_number, 4)
        # each anchor is [xa, ya, xb, yb]
        x, y = pos
        anchors = np.zeros((self.anchor_number, 4))
        for i in range(self.anchor_number):
            center_x = self.OUTPUT_CELL_SIZE * (float(x) + 0.5)
            center_y = self.OUTPUT_CELL_SIZE * (float(y) + 0.5)

            width = self.anchor_dimensions[i][0] * self.OUTPUT_CELL_SIZE
            height = self.anchor_dimensions[i][1] * self.OUTPUT_CELL_SIZE

            top_x = center_x - width / 2.0
            top_y = center_y - height / 2.0
            anchors[i, :] = [top_x, top_y, top_x + width, top_y + height]
        return anchors

    def get_proposals(self, reg, cls):
        a, filter_out = self.get_image_anchors()
        anchors = torch.from_numpy(a).float()
        bboxes = unparametrize(anchors, reg).reshape((-1, 4))
        bboxes = bboxes[filter_out]
        objects = torch.argmax(cls[filter_out], dim=1)

        cls = cls.detach().numpy()
        cls = cls[np.where(objects == 1)][:self.PRE_NMS_MAX_PROPOSALS]
        bboxes = bboxes[np.where(objects == 1)][:self.PRE_NMS_MAX_PROPOSALS]
        keep = nms(bboxes.detach().numpy(), cls[:, 1].ravel(), self.NMS_THRESHOLD)[:self.POST_NMS_MAX_PROPOSALS]
        return bboxes[keep]

    def get_training_proposals(self, reg, cls):
        a, filter_out = self.get_image_anchors()
        anchors = torch.from_numpy(a).float()
        bboxes = unparametrize(anchors, reg).reshape((-1, 4))
        bboxes = bboxes[filter_out]
        objects = torch.argmax(cls[filter_out], dim=1)

        cls = cls.detach().numpy()
        cls = cls[np.where(objects == 1)][:self.PRE_NMS_MAX_PROPOSALS]
        bboxes = bboxes[np.where(objects == 1)][:self.PRE_NMS_MAX_PROPOSALS]
        keep = nms(bboxes.detach().numpy(), cls[:, 1].ravel(), self.NMS_THRESHOLD)[:self.POST_NMS_MAX_PROPOSALS]
        return bboxes[keep]

    def get_image_anchors(self):
        print('get_image_anchors')
        anchors = np.zeros((self.NUMBER_ANCHORS_WIDE, self.NUMBER_ANCHORS_HEIGHT, self.anchor_number, 4))

        for i in range(self.NUMBER_ANCHORS_WIDE):
            for j in range(self.NUMBER_ANCHORS_HEIGHT):
                anchors_pos = self.get_anchors_at_position((i, j))
                anchors[i, j, :] = anchors_pos
        anchors = anchors.reshape((-1, 4))
        filter_out = (anchors[:, 0] < 0) | (anchors[:, 1] < 0) | (anchors[:, 2] > self.INPUT_SIZE[0]) | (anchors[:, 3] > self.INPUT_SIZE[1])
        return anchors, np.where(~filter_out)

    def get_positive_negative_anchors(self, anchors, bboxes):
        if not len(bboxes):
            ious = np.zeros(anchors.shape[:3])
            positives = ious > self.POSITIVE_THRESHOLD
            negatives = ious < self.NEGATIVE_THRESHOLD
            return np.array([]), positives, negatives

        ious = np.zeros((anchors.shape[0], len(bboxes)))

        # TODO improve speed with a real numpy formula
        for i in range(ious.shape[0]):
            for j in range(ious.shape[1]):
                ious[i, j] = IoU(anchors[i], bboxes[j])
        best_bbox_for_anchor = np.argmax(ious, axis=1)
        best_anchor_for_bbox = np.argmax(ious, axis=0)
        max_iou_per_anchor = np.amax(ious, axis=1)

        # truth box for each anchor
        truth_bbox = bboxes[best_bbox_for_anchor, :]

        # Selecting all ious > POSITIVE_THRESHOLD
        positives = max_iou_per_anchor > self.POSITIVE_THRESHOLD
        # Adding max iou for each ground truth box
        positives[best_anchor_for_bbox] = True
        negatives = max_iou_per_anchor < self.NEGATIVE_THRESHOLD
        return truth_bbox, positives, negatives

    def get_selected_indices_sample(self, indices, positives, negatives):
        positive_indices = indices[positives]
        negative_indices = indices[negatives]
        random_positives = np.random.permutation(positive_indices)[:self.ANCHOR_SAMPLING_SIZE // 2]
        random_negatives = np.random.permutation(negative_indices)[:self.ANCHOR_SAMPLING_SIZE - len(random_positives)]
        selected_indices = np.concatenate((random_positives, random_negatives))
        return selected_indices, positive_indices

    def get_positive_anchors(self, bboxes):
        anchors, _ = self.get_image_anchors()
        truth_bbox, positives, negatives = self.get_positive_negative_anchors(anchors, bboxes)

        n = len(anchors)
        indices = np.array([i for i in range(n)])
        selected_indices, positive_indices = self.get_selected_indices_sample(indices, positives, negatives)
        return anchors[positive_indices]

### Faster RCNN

In [5]:
class FasterRCNN(nn.Module):
    INPUT_SIZE = (1600, 800)
    OUTPUT_SIZE = (100, 50)
    OUTPUT_CELL_SIZE = float(INPUT_SIZE[0]) / float(OUTPUT_SIZE[0])

    NEGATIVE_THRESHOLD = 0.3
    POSITIVE_THRESHOLD = 0.5

    def __init__(self, n_classes, model='resnet50', path='fasterrcnn_resnet50.pt', training=False):
        super(FasterRCNN, self).__init__()

        self.n_roi_sample = 128
        self.pos_ratio = 0.25
        self.pos_iou_thresh = 0.5
        self.neg_iou_thresh_hi = 0.5
        self.neg_iou_thresh_lo = 0.0

        if model == 'resnet50':
            self.in_dim = 1024
            resnet = models.resnet50(pretrained=True)
            self.feature_map = nn.Sequential(*list(resnet.children())[:-3])
        if model == 'vgg16':
            self.in_dim = 512
            vgg = models.vgg16(pretrained=True)
            self.feature_map = nn.Sequential(*list(vgg.children())[:-1])

        self.n_classes = n_classes + 1
        self.in_fc_dim = 7 * 7 * self.in_dim
        self.out_fc_dim = 1024

        rpn_path = path.replace('fasterrcnn_', '')
        self.rpn = RPN(self.in_dim)
        self.fc = nn.Linear(self.in_fc_dim, self.out_fc_dim)
        self.cls_layer = nn.Linear(self.out_fc_dim, self.n_classes)
        self.reg_layer = nn.Linear(self.out_fc_dim, self.n_classes * 4)

        self.training = training

        #initialize layers
        torch.nn.init.normal_(self.fc.weight, std=0.01)
        torch.nn.init.normal_(self.cls_layer.weight, std=0.1)
        torch.nn.init.normal_(self.reg_layer.weight, std=0.01)

        if os.path.isfile(path):
            self.load_state_dict(torch.load(path))

    def forward(self, x):
        feature_map = self.feature_map(x)
        cls, reg = self.rpn(feature_map)
        feature_map = feature_map.view((-1, self.OUTPUT_SIZE[0], self.OUTPUT_SIZE[1]))
        if self.training:
            proposals = self.rpn.get_proposals(reg, cls)
        else:
            proposals = self.rpn.get_proposals(reg, cls)

        all_cls = []
        all_reg = []
        for roi in proposals.int():
            roi[np.where(roi < 0)] = 0
            roi = roi / self.OUTPUT_CELL_SIZE
            roi_feature_map = feature_map[:, roi[0]:roi[2]+1, roi[1]:roi[3]+1]
            pooled_roi = F.adaptive_max_pool2d(roi_feature_map, (7, 7)).view((-1, 50176))
            r = F.relu(self.fc(pooled_roi))
            r_cls = self.cls_layer(r)
            r_reg = self.reg_layer(r).view((self.n_classes, 4))
            all_cls.append(r_cls)
            all_reg.append(r_reg[torch.argmax(r_cls)])
        # print(all_cls.shape, all_reg.shape)
        return torch.stack(all_cls).view((-1, self.n_classes)), torch.stack(all_reg), proposals, cls, reg

    def get_target(self, proposals, bboxes, classes):
        ious = np.zeros((proposals.shape[0], len(bboxes)))
        for i in range(proposals.shape[0]):
            for j in range(len(bboxes)):
                ious[i, j] = IoU(proposals[i], bboxes[j])
        best_bbox_for_proposal = np.argmax(ious, axis=1)
        best_proposal_for_bbox = np.argmax(ious, axis=0)
        max_iou_per_proposal = np.amax(ious, axis=1)

        labels = classes[best_bbox_for_proposal]

        # truth box for each proposal
        truth_bbox_for_roi = bboxes[best_bbox_for_proposal, :]
        truth_bbox = parametrize(proposals.detach().numpy(), truth_bbox_for_roi)

        # Selecting all ious > POSITIVE_THRESHOLD
        positives = max_iou_per_proposal > self.POSITIVE_THRESHOLD
        # TODO: improve the negatives selection
        negatives = max_iou_per_proposal < self.POSITIVE_THRESHOLD
        # Assign 'other' label to negatives
        labels[negatives] = 0

        # Keep positives and negatives
        selected = np.where(positives | negatives)

        return torch.from_numpy(labels[selected]), torch.from_numpy(truth_bbox[selected])

    def get_proposals(self, reg, cls, rpn_proposals):
        # print(cls)
        # print(F.softmax(cls, dim=1))
        # print(cls.shape)
        objects = torch.argmax(F.softmax(cls, dim=1), dim=1)
        bboxes = unparametrize(rpn_proposals, reg)

        return bboxes[np.where(objects != 0)]

### Train

In [15]:
model = 'resnet50'
MODEL_PATH = os.path.join('models', f'fasterrcnn_{model}.pt')

def train(dataset):
    save_range = 40
    lamb = 10.0
    n_classes = len(dataset.get_classes())

    fasterrcnn = FasterRCNN(n_classes, model=model, path=MODEL_PATH, training=True)
    optimizer = optim.Adam(fasterrcnn.parameters(), lr = 0.0001)

    for i in range(1, len(dataset)):
        im, bboxes, classes = dataset[i]
        print(im.shape)
        if not len(classes):
            continue
        print(i)
        optimizer.zero_grad()
        all_cls, all_reg, proposals, rpn_cls, rpn_reg = fasterrcnn(torch.from_numpy(im).float())

        rpn_reg_target, rpn_cls_target, rpn_selected_indices, rpn_positives = fasterrcnn.rpn.get_target(bboxes)
        cls_target, reg_target = fasterrcnn.get_target(proposals, bboxes, classes)
        print(cls_target)

        rpn_reg_loss = F.smooth_l1_loss(rpn_reg[rpn_positives], rpn_reg_target[rpn_positives])
        # look at a sample of positive + negative boxes for classification
        rpn_cls_loss = F.binary_cross_entropy(rpn_cls[rpn_selected_indices], rpn_cls_target[rpn_selected_indices].float())

        fastrcnn_reg_loss = F.smooth_l1_loss(all_reg, reg_target)
        fastrcnn_cls_loss = F.cross_entropy(all_cls, cls_target)
        rpn_loss = rpn_cls_loss + lamb * rpn_reg_loss

        fastrcnn_loss = fastrcnn_cls_loss + fastrcnn_reg_loss
        print(rpn_reg_loss, rpn_cls_loss, fastrcnn_reg_loss, fastrcnn_cls_loss)
        loss = rpn_loss + fastrcnn_loss
        print(loss)
        loss.backward()
        optimizer.step()

        print('[%d] loss: %.5f'.format(i, loss.item()))

        if i % save_range == 0:
            torch.save(fasterrcnn.state_dict(), MODEL_PATH)
    print('Finished Training')


In [16]:
dataset = VOCDataset('VOC2007')

In [17]:
dataset

<__main__.VOCDataset at 0x7fb3af303490>

In [None]:
train(dataset)