In [1]:
from torchvision.transforms import Normalize, ToTensor, Resize, Compose
from torchvision.datasets import ImageFolder
from efficientnet_pytorch import EfficientNet
from PIL import Image
from glob import glob
import numpy as np
import cv2
import pandas as pd
from torch import nn
import torch
import os
import torch.functional as F
from glob import glob
from tqdm import tqdm

from utils import t2np, np_softmax
from data import get_preprocessing
from model import SignModel


def draw_rpn_prediction(img, bboxes):
    img = np.array(img)
    import cv2
    for b in bboxes:
        cv2.rectangle(img, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (255,0,0), 3)
    return Image.fromarray(img)


def draw_prediction(img, bboxes, scores, classes):
    img = np.array(img)
    import cv2
    for b, c, s in zip(bboxes, classes, scores):
        cv2.rectangle(img, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (0,0,255), 3)
        cv2.putText(img, '{} ({:d})'.format(c, int(100*s)), (int(b[0]), int(b[1])-10), 1, 2, (0, 0, 0), 6)
        cv2.putText(img, '{} ({:d})'.format(c, int(100*s)), (int(b[0]), int(b[1])-10), 1, 2, (255, 255, 255), 3)
    return Image.fromarray(img)
    

def gather_crops(img, t):
    crops = []
    boxes = []
    for _, row in t.iterrows():
        xtl, ytl, xbr, ybr = map(lambda x: int(0.5+float(x)), (row.xtl, row.ytl, row.xbr, row.ybr))
        boxes.append((xtl, ytl, xbr, ybr))
        crops.append(val_transform(img.crop(boxes[-1]))[None, ...])
        
    if len(crops) > 0:
        return torch.cat(crops).half().cuda(), np.array(boxes)
    return None, None

Loaded pretrained weights for efficientnet-b4


IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [None]:
impaths = sorted(glob('/media/grisha/hdd/icevision/ice/images/2018-03-23_1352_right/*.jpg'))
df = pd.read_csv('/media/grisha/hdd/icevision/models/Simultaneous-Traffic-Sign-Detection-and-Classification-with-RetinaNet/2018-03-23_1352_right.csv')
id2class = ImageFolder('classification/classification_data/train/').classes
val_transform = get_preprocessing(train=False)

model = SignModel().half().cuda()
checkpoint = torch.load(os.path.join('classification/class_ckpts/init/2_ckpt.pth'))
model.load_state_dict(checkpoint[0])

In [2]:
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('output_cls.avi',fourcc, 30.0, (1224,1024))

for impath in tqdm(impaths):
    imname = os.path.basename(impath)
    img = Image.open(impath)

    t = df[df.imname==imname]
    crops, boxes = gather_crops(img, t)
    if boxes is not None:
        img = draw_rpn_prediction(img, boxes)

        y_hat = np_softmax(t2np(model(crops)))

        amax = y_hat.argmax(-1).flatten()
        scores = y_hat.max(-1).flatten()
        preds = np.array([id2class[x] for x in amax])
        area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])

        idx = (scores > 0.7) & (preds != 'other') & (area > 100)
        scores = scores[idx]
        preds = preds[idx]
        boxes = boxes[idx]

        img = draw_prediction(img, boxes, scores, preds)

    img = np.array(img.resize((1224, 1024)))[..., ::-1]
    out.write(img)

out.release()

100%|██████████| 8539/8539 [15:57<00:00,  9.05it/s]
