In [1]:
import os
import cv2
import yaml
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from collections import Counter
import h5py
import io
from PIL import Image
from matplotlib import pyplot as plt
import torch
import torch.utils.data
import torch.nn as nn
from utils.general import non_max_suppression

In [2]:
class opt(object):
    def __init__(self):
        self.weight = 'runs/train/exp2/weights/epoch_299.pt'
        self.data = 'data/plate.yaml'
        self.input_size = 1024
        self.device = 'cuda'
        self.conf_thres = 0.25
        self.iou_thres = 0.65
        
opt = opt()

In [3]:
with open(opt.data) as f:
    data = yaml.load(f, Loader=yaml.FullLoader)
    classes = data['names']

In [4]:
checkpoint = torch.load(opt.weight)
print('Loaded {}, epoch {}'.format(opt.weight, checkpoint['epoch']))
model = checkpoint['model'].cuda().half()
model.eval()

Loaded runs/train/exp2/weights/epoch_299.pt, epoch 299


Model(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): SyncBatchNorm(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): SyncBatchNorm(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (2): Conv(
      (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): SyncBatchNorm(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (3): Conv(
      (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): SyncBatchNorm(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      

In [5]:
def resize(img, size):
    h, w, c = img.shape
    if not (h == size and w == size):
        img = img.copy()
        scale_x = float(size / w)
        scale_y = float(size / h)
        ratio = min(scale_x, scale_y)
        nw, nh = int(w*ratio), int(h*ratio)
        new_img = cv2.resize(img, (nw, nh))

        blank = np.zeros((size, size, c))
        dw, dh = (size-nw)//2, (size-nh)//2
        blank[dh: dh+nh, dw: dw+nw] = new_img
        meta = {'nw': nw, 'nh': nh, 'dw': dw, 'dh': dh, 'w': w, 'h': h}
        return blank, meta
    else:
        meta = {}
        return img, meta
    
def preproccess_img(img, device):
    img, meta = resize(img, opt.input_size)
    img = (img / 255.).astype(np.float32)
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x512x512
    img = np.ascontiguousarray(img)
    img = torch.from_numpy(img).to(device)
    img = torch.unsqueeze(img, 0)
    img = img.half()
    return img, meta

def warp_affine(pt, M):
    new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
    new_pt = np.dot(M, new_pt)
    return new_pt

def affine_transform(dets, meta):
    '''
    Transfer input-sized perdictions to original-sized coordinate.
    Input:
        dets = [1(batch), num_objs, 6(x1, y1, x2, y2, conf, cls)]
        meta = {'nw': resize_w, 
                'nh': resize_h, 
                'dw': offset_w, 
                'dh': offset_h, 
                'w': original_size_w, 
                'h': original_size_h}
    '''
    dets = np.array([x.cpu().numpy() for x in dets[0]])
    if len(meta)>0:
        p1 = np.float32([[0, 0], [0, meta['nh']], [meta['nw'], 0]])
        p2 = np.float32([[0, 0], [0, meta['h']], [meta['w'], 0]])
        M = cv2.getAffineTransform(p1, p2)

        for i in range(dets.shape[0]):
            dets[i, 0] -= meta['dw']
            dets[i, 1] -= meta['dh']
            dets[i, 2] -= meta['dw']
            dets[i, 3] -= meta['dh']

            dets[i, 0:2] = warp_affine(dets[i, 0:2], M)
            dets[i, 2:4] = warp_affine(dets[i, 2:4], M)
        return dets
    else:
        return dets

def transfer_det(dets, num_classes):
    '''
    Input:
        dets = [num_objs, 6(x1, y1, x2, y2, conf, cls)]
    '''
    top_preds = {}
    classes = dets[:, -1]
    for j in range(num_classes):
        inds = (classes == j)
        top_preds[j] = dets[inds, :5].astype(np.float32).tolist()
    return top_preds

def get_iou(bb1, bb2):
    '''
    Input:
        bb1(groundtruth) = [left_top_x, y, right_bottom_x, y]
        bb2(predict_point) = [left_top_x, y, right_bottom_x, y]
    '''
    x_left = max(bb1[0], bb2[0])
    y_top = max(bb1[1], bb2[1])
    x_right = min(bb1[2], bb2[2])
    y_bottom = min(bb1[3], bb2[3])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    bb1_area = (bb1[2] - bb1[0]) * (bb1[3] - bb1[1])
    bb2_area = (bb2[2] - bb2[0]) * (bb2[3] - bb2[1])

    iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
    assert iou >= 0.0
    assert iou <= 1.0
    return iou

In [6]:
def check_bbox(bb1, bb2):
    '''
    Check if bb2 is completely covered by bb1.
    If so, return True, otherwise return False.
    Input:
        bb1(large region) = [x_min, y_min, x_max, y_max] (all int)
        bb2(small region) = [x_min, y_min, x_max, y_max] (all int)
    '''
    # Assert bb1 is the larger one
    bb1_area = (bb1[2] - bb1[0]) * (bb1[3] - bb1[1])
    bb2_area = (bb2[2] - bb2[0]) * (bb2[3] - bb2[1])
    if bb1_area < bb2_area:
        temp = bb2
        bb2 = bb1
        bb1 = temp
    
    f1 = bb1[0] <= bb2[0]
    f2 = bb1[1] <= bb2[1]
    f3 = bb1[2] >= bb2[2]
    f4 = bb1[3] >= bb2[3]

    if f1 and f2 and f3 and f4:
        return True
    else:
        return False

def group_plate(outp):
    '''
    Group each character into seperate plates.
    Input:
        outp = float[[x1, y1, x2, y2, score, class], ...]
            i.e. [[     40.859      162.66      142.42      216.09      0.9668          34]
                  [     125.08      178.28      138.83      208.91     0.93213           6]
                  [         95      176.41      110.94      207.03      0.9248           8]
                  [     110.16      177.34      125.16      208.28     0.91211           8]
                  [     80.156      175.47      97.344      206.72     0.90674           4]
                  [     45.039      173.44      64.414         205     0.89111          13]
                  [     60.078      174.06      78.828      205.62     0.81445          17]]
    Return: 
        a sorted list of dict [{'plate': int[x1, y1, x2, y2], 'char':[int[x1, y1, x2, y2, idx], ...]}, ...]
             i.e. [{'plate': [40, 162, 142, 216], 
                    'char':  [[45, 173, 64, 205, 13],
                             [60, 174, 78, 205, 17],
                             [80, 175, 97, 206, 4],
                             [95, 176, 110, 207, 8],
                             [110, 177, 125, 208, 8],
                             [125, 178, 138, 208, 6]]}]
    '''
    plates = []
    chars = []
    groups = []
    for obj in outp:
        if int(obj[-1]) == 34:
            pla = [int(p) for p in obj[:4]]
            group = {'plate': pla, 'char':[]}
            groups.append(group)
        else:
            chars.append(obj)
    for obj in chars:
        cha = [int(obj[0]), int(obj[1]), int(obj[2]), int(obj[3]), int(obj[5])]
        for g in groups:
            pla = g['plate']
            if check_bbox(pla, cha[:4]):
                g['char'].append(cha)
    ### Sort list
    for g in groups:
        g['char'] = sorted(g['char'], key=lambda x: x[0])
    return groups

def get_str(chars):
    '''
    Create plate string from indivisual detected character.
    Input:
        chars = a sorted list [int[x1, y1, x2, y2, idx], ...]]
    Return:
        plate_str = plate string, i.e. 'DH4886'
    '''
    s = ''
    for o in chars:
        c = o[-1]
        s += classes[c]
    return s

def get_acc(inp, outp):
    """
    Compute two strings character by character.
    Input:
        inp: ground truth(str)
        outp: detected result(str)
    Return:
        m: number of groundtruth(int)
        count: number of detect correctly(int)
    """
    m = len(inp)
    count = sum((Counter(inp) & Counter(outp)).values())
    return m, count

def compute_acc(outp, labels):
    '''
    Compute accuracy between detected results and labels.
    Input:
        outp = a list of dic [{'plate': int[x1, y1, x2, y2], 'char':[int[x1, y1, x2, y2, idx], ...]}, ...]
        labels = a list of str, i.e. ['14,71,83,105,FS799', '215,188,324,240,DP4846']
    Return:
        total = the number of all characters in labels
        correct = the number of correct-detected characters
    '''
    total = 0
    correct = 0
    for label in labels:
        line = label.split(',')
        plate_gt = [int(x) for x in line[:4]]
        for g in outp:
            if get_iou(plate_gt, g['plate']) >= 0.5:
                detected_plate = get_str(g['char'])
                t, c = get_acc(line[-1], detected_plate)
                total += t
                correct += c
    return total, correct

def labels_len(labels):
    '''
    Compute the number of characters in ground truth.
    Input:
        labels = a list of str, i.e. ['14,71,83,105,FS799', '215,188,324,240,DP4846']
    Return:
        num = the number of characters of all plates
    '''
    num = 0
    for label in labels:
        lines = label.split(',')
        n = len(lines[-1])
        num += n
    return num

def get_wer(r, h):
    """
    Compute word_error_rate(WER) of two list of strings.
    Input:
        r = ground truth
        h = predicted results
    Return:
        result = WER (presented in percentage)
        sid = substitution + insertion + deletion
        total = the number of groundtruth
    """
    d = np.zeros((len(r) + 1) * (len(h) + 1), dtype=np.uint16)
    d = d.reshape((len(r) + 1, len(h) + 1))
    for i in range(len(r) + 1):
        for j in range(len(h) + 1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    for i in range(1, len(r) + 1):
        for j in range(1, len(h) + 1):
            if r[i - 1] == h[j - 1]:
                d[i][j] = d[i - 1][j - 1]
            else:
                substitution = d[i - 1][j - 1] + 1
                insertion = d[i][j - 1] + 1
                deletion = d[i - 1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)
    sid = d[len(r)][len(h)]
    total = len(r)
    result = float(sid) / total * 100

    return result, sid, total

def transfer_label(lab):
    '''
    Transfer label(string) into list(int).
    Input:
        lab: string list ['x_min,y_min,x_max,y_max,plate']
            i.e. ['14,71,83,105,FS799', '215,188,324,240,DP4846']
    Return:
        new_lab: int list [x_min, y_min, x_max, y_max]
            i.e. [[14,71,83,105], [215,188,324,240]]
    '''
    new_lab = []
    for l in lab:
        _l = l.split(',')
        x1 = int(_l[0])
        y1 = int(_l[1])
        x2 = int(_l[2])
        y2 = int(_l[3])
        new_lab.append([x1, y1, x2, y2])
    
    return new_lab

# Test AOLP

In [7]:
"""
Create label dictionary.
Format: dic = {key: file_name(str), value: [obj1(str), obj2(str), ...]}
        obj format = 'x_min, y_min, x_max, y_max, plate'
   i.e. dic['train_LE_3'] = ['266,199,350,242,2972KK']
        dic['train_LE_33'] = ['14,71,83,105,FS799', '215,188,324,240,DP4846']
"""
label_data = {}
label_txt = 'E:/MTL_FTP/ChengJungC/dataset/AOLP/label.txt'
label_file = open(label_txt, 'r')
lines = label_file.readlines()
for line in lines:
    l = line.strip().split(' ')
    name = l[0]
    plates = l[1:]
    label_data[name] = plates

In [8]:
img_dir = 'E:/MTL_FTP/ChengJungC/dataset/AOLP/original/'
img_paths = os.listdir(img_dir)
img_paths.sort()

results = {}
for f in tqdm(img_paths, ncols=80):
    if f.endswith('.jpg'):
        bname = os.path.splitext(f)[0]
        img_p = img_dir + f
        img = cv2.imread(img_p)
        img, meta = preproccess_img(img, opt.device)
        
        with torch.no_grad():
            ret = model(img)
        output = non_max_suppression(ret[0], conf_thres=opt.conf_thres, iou_thres=opt.iou_thres, labels=[], multi_label=True)
        
        if len(output[0]) > 0:
            output = affine_transform(output, meta)
            results[bname] = output

100%|███████████████████████████████████████| 2049/2049 [03:06<00:00, 11.02it/s]


### Detect plate only

In [9]:
correct_plate = {}
total_p, correct_p, pred_p = 0, 0, 0
for k, v in label_data.items():
    gt = transfer_label(v)
    total_p += len(gt)
    if k in results:
        r = group_plate(results[k]) # r = [{'plate': int[x1, y1, x2, y2], 'char':[int[x1, y1, x2, y2, label], ...]}, ...]
        pred_p += len(r)
        correct_r = []
        for _r in r:
            for idx, bbox in enumerate(gt):
                if get_iou(bbox, _r['plate']) >= 0.5:
                    n_r = {'plate': _r['plate'], 'char': _r['char'], 'idx': idx}
                    correct_r.append(n_r)
        num = len(correct_r)
        if num > 0:
            correct_plate[k] = correct_r
            correct_p += num
            
print("Number of Correctly Detected Plates =", correct_p)
print("Number of Detected Plates =", pred_p)
print("Number of All Plates =", total_p)
print("Recall = {:.4f}".format(correct_p/total_p))
print("Precision = {:.4f}".format(correct_p/pred_p))

Number of Correctly Detected Plates = 2143
Number of Detected Plates = 3547
Number of All Plates = 2164
Recall = 0.9903
Precision = 0.6042


### Detect character only

In [10]:
n_perfect = 0  ### number of perfectly recognized plates
n_sid = 0  ### number of failed recognized chars in detected
n_detected = 0  ### number of chars in detected plates

for k, v in label_data.items():
    gt_strs = [s.split(',')[-1] for s in v]
    if k in correct_plate:
        objs = correct_plate[k]
        for obj in objs: # obj = [{'plate': int[x1, y1, x2, y2], 'char':[int[x1, y1, x2, y2, label], ...], 'idx': int(i)}]
            pred_str = [classes[x[-1]] for x in obj['char']]
            wer, sid, t = get_wer(list(gt_strs[obj['idx']]), pred_str)
            n_sid += sid
            n_detected += t
            if wer == 0:
                n_perfect += 1


print("Characters in Detected Plates = ", n_detected)
print("Error Characters (Detected) =", n_sid)
print("World Error Rate (Detected) = {:.4f}".format(n_sid/n_detected))

print("\nNumber of Perfectly Recognized Plates = ", n_perfect)
print("Accuracy(Detected) = {:.4f}".format(n_perfect/correct_p))

Characters in Detected Plates =  12789
Error Characters (Detected) = 1572
World Error Rate (Detected) = 0.1229

Number of Perfectly Recognized Plates =  1175
Accuracy(Detected) = 0.5483


# Test weather

In [11]:
"""
Create label dictionary.
Format: dic = {key: file_name(str), value: [obj1(str), obj2(str), ...]}
        obj format = 'x_min, y_min, x_max, y_max, plate'
   i.e. dic['train_LE_3'] = ['266,199,350,242,2972KK']
        dic['train_LE_33'] = ['14,71,83,105,FS799', '215,188,324,240,DP4846']
"""
label_data = {}
label_txt = 'E:/MTL_FTP/ChengJungC/dataset/weather/label.txt'
label_file = open(label_txt, 'r')
lines = label_file.readlines()
for line in lines:
    l = line.strip().split(' ')
    name = l[0]
    plates = l[1:]
    label_data[name] = plates

In [12]:
img_dir = 'E:/MTL_FTP/ChengJungC/dataset/weather/original/'
img_paths = os.listdir(img_dir)
img_paths.sort()

results = {}
for f in tqdm(img_paths, ncols=80):
    if f.endswith('.jpg'):
        bname = os.path.splitext(f)[0]
        img_p = img_dir + f
        img = cv2.imread(img_p)
        img, meta = preproccess_img(img, opt.device)
        
        with torch.no_grad():
            ret = model(img)
        output = non_max_suppression(ret[0], conf_thres=opt.conf_thres, iou_thres=opt.iou_thres, labels=[], multi_label=True)
        
        if len(output[0]) > 0:
            output = affine_transform(output, meta)
            results[bname] = output

100%|███████████████████████████████████████████| 47/47 [00:04<00:00, 10.15it/s]


### Detect plate only

In [13]:
correct_plate = {}
total_p, correct_p, pred_p = 0, 0, 0
for k, v in label_data.items():
    gt = transfer_label(v)
    total_p += len(gt)
    if k in results:
        r = group_plate(results[k]) # r = [{'plate': int[x1, y1, x2, y2], 'char':[int[x1, y1, x2, y2, label], ...]}, ...]
        pred_p += len(r)
        correct_r = []
        for _r in r:
            for idx, bbox in enumerate(gt):
                if get_iou(bbox, _r['plate']) >= 0.5:
                    n_r = {'plate': _r['plate'], 'char': _r['char'], 'idx': idx}
                    correct_r.append(n_r)
        num = len(correct_r)
        if num > 0:
            correct_plate[k] = correct_r
            correct_p += num
            
print("Number of Correctly Detected Plates =", correct_p)
print("Number of Detected Plates =", pred_p)
print("Number of All Plates =", total_p)
print("Recall = {:.4f}".format(correct_p/total_p))
print("Precision = {:.4f}".format(correct_p/pred_p))

Number of Correctly Detected Plates = 35
Number of Detected Plates = 47
Number of All Plates = 54
Recall = 0.6481
Precision = 0.7447


### Detect character only

In [14]:
n_perfect = 0  ### number of perfectly recognized plates
n_sid = 0  ### number of failed recognized chars in detected
n_detected = 0  ### number of chars in detected plates

for k, v in label_data.items():
    gt_strs = [s.split(',')[-1] for s in v]
    if k in correct_plate:
        objs = correct_plate[k]
        for obj in objs: # obj = [{'plate': int[x1, y1, x2, y2], 'char':[int[x1, y1, x2, y2, label], ...], 'idx': int(i)}]
            pred_str = [classes[x[-1]] for x in obj['char']]
            wer, sid, t = get_wer(list(gt_strs[obj['idx']]), pred_str)
            n_sid += sid
            n_detected += t
            if wer == 0:
                n_perfect += 1


print("Characters in Detected Plates = ", n_detected)
print("Error Characters (Detected) =", n_sid)
print("World Error Rate (Detected) = {:.4f}".format(n_sid/n_detected))

print("\nNumber of Perfectly Recognized Plates = ", n_perfect)
print("Accuracy(Detected) = {:.4f}".format(n_perfect/correct_p))

Characters in Detected Plates =  220
Error Characters (Detected) = 106
World Error Rate (Detected) = 0.4818

Number of Perfectly Recognized Plates =  4
Accuracy(Detected) = 0.1143
