In [None]:
from pytorch_model import load_wpod
import cv2
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms, utils, models
from torch import nn
import torch
from src.label import Label, Shape
from src.utils import getWH, nms
from src.projection_utils import getRectPts, find_T_matrix
import time

In [None]:
class DLabel (Label):

    def __init__(self,cl,pts,prob):
        self.pts = pts
        tl = np.amin(pts,1)
        br = np.amax(pts,1)
        Label.__init__(self,cl,tl,br,prob)


def reconstruct(Iorig,I,Y,out_size,threshold=.9):

    net_stride 	= 2**4
    side = ((208. + 40.)/2.)/net_stride # 7.75

    Probs = Y[...,0]
    Affines = Y[...,2:]
    rx,ry = Y.shape[:2]
    ywh = Y.shape[1::-1]
    iwh = np.array(I.shape[1::-1],dtype=float).reshape((2,1))

    xx,yy = np.where(Probs>threshold)

    WH = getWH(I.shape)
    MN = WH/net_stride

    vxx = vyy = 0.5 #alpha

    base = lambda vx,vy: np.matrix([[-vx,-vy,1.],[vx,-vy,1.],[vx,vy,1.],[-vx,vy,1.]]).T
    labels = []

    for i in range(len(xx)):
        y,x = xx[i],yy[i]
        affine = Affines[y,x]
        prob = Probs[y,x]

        mn = np.array([float(x) + .5,float(y) + .5])

        A = np.reshape(affine,(2,3))
        A[0,0] = max(A[0,0],0.)
        A[1,1] = max(A[1,1],0.)

        pts = np.array(A*base(vxx,vyy)) #*alpha
        pts_MN_center_mn = pts*side
        pts_MN = pts_MN_center_mn + mn.reshape((2,1))

        pts_prop = pts_MN/MN.reshape((2,1))

        labels.append(DLabel(0,pts_prop,prob))

    final_labels = nms(labels,.1)
    TLps = []

    if len(final_labels):
        final_labels.sort(key=lambda x: x.prob(), reverse=True)
        for i,label in enumerate(final_labels):

            t_ptsh 	= getRectPts(0,0,out_size[0],out_size[1])
            ptsh 	= np.concatenate((label.pts*getWH(Iorig.shape).reshape((2,1)),np.ones((1,4))))
            H 		= find_T_matrix(ptsh,t_ptsh)
            Ilp 	= cv2.warpPerspective(Iorig,H,out_size,borderValue=.0)

            TLps.append(Ilp)

    return final_labels,TLps
    

def detect_lp(model,I,max_dim,net_step,out_size,threshold):

    min_dim_img = min(I.shape[:2])
    factor 		= float(max_dim)/min_dim_img

    w,h = (np.array(I.shape[1::-1],dtype=float)*factor).astype(int).tolist()
    w += (w%net_step!=0)*(net_step - w%net_step)
    h += (h%net_step!=0)*(net_step - h%net_step)
    Iresized = cv2.resize(I,(w,h))

    T = Iresized.copy()
    T = T.reshape((1,T.shape[0],T.shape[1],T.shape[2]))
    T = torch.tensor(T).permute(0,3,1,2)
    start = time.time()
    model.eval()
    Yr = model(T).permute(0,2,3,1).detach().numpy()
    Yr = np.squeeze(Yr)
    elapsed = time.time() - start

    L,TLps = reconstruct(I,Iresized,Yr,out_size,threshold)

    return L,TLps,elapsed


In [None]:
def adjust_pts(pts,lroi):
	return pts*lroi.wh().reshape((2,1)) + lroi.tl().reshape((2,1))


output_dir = "./temp"

lp_threshold = .5

wpod_net = load_wpod()

print('Searching for license plates using WPOD-NET')

Ivehicle = cv2.imread('Plate_examples/germany_car_plate.jpg')

ratio = float(max(Ivehicle.shape[:2]))/min(Ivehicle.shape[:2])
side  = int(ratio*288.)
bound_dim = min(side + (side%(2**4)),608)
print("\t\tBound dim: %d, ratio: %f" % (bound_dim,ratio)) 
#plt.imshow(im2single(Ivehicle))
Llp,LlpImgs,_ = detect_lp(wpod_net,im2single(Ivehicle),bound_dim,2**4,(240,80),lp_threshold)

if len(LlpImgs):
    Ilp = LlpImgs[0]
    Ilp = cv2.cvtColor(Ilp, cv2.COLOR_BGR2GRAY)
    Ilp = cv2.cvtColor(Ilp, cv2.COLOR_GRAY2BGR)

    s = Shape(Llp[0].pts)

    plt.imshow(Ilp)
