In [1]:
from shapely.geometry import Polygon
import numpy as np
import cv2
from PIL import Image
import math
import os
import torch
import torchvision.transforms as transforms
from torch.utils import data


def cal_distance(x1, y1, x2, y2):
	'''calculate the Euclidean distance'''
	return math.sqrt((x1 - x2)**2 + (y1 - y2)**2)


def move_points(vertices, index1, index2, r, coef):
	'''move the two points to shrink edge
	Input:
		vertices: vertices of text region <numpy.ndarray, (8,)>
		index1  : offset of point1
		index2  : offset of point2
		r       : [r1, r2, r3, r4] in paper
		coef    : shrink ratio in paper
	Output:
		vertices: vertices where one edge has been shinked
	'''
	index1 = index1 % 4
	index2 = index2 % 4
	x1_index = index1 * 2 + 0
	y1_index = index1 * 2 + 1
	x2_index = index2 * 2 + 0
	y2_index = index2 * 2 + 1
	
	r1 = r[index1]
	r2 = r[index2]
	length_x = vertices[x1_index] - vertices[x2_index]
	length_y = vertices[y1_index] - vertices[y2_index]
	length = cal_distance(vertices[x1_index], vertices[y1_index], vertices[x2_index], vertices[y2_index])
	if length > 1:	
		ratio = (r1 * coef) / length
		vertices[x1_index] += ratio * (-length_x) 
		vertices[y1_index] += ratio * (-length_y) 
		ratio = (r2 * coef) / length
		vertices[x2_index] += ratio * length_x 
		vertices[y2_index] += ratio * length_y
	return vertices	


def shrink_poly(vertices, coef=0.3):
	'''shrink the text region
	Input:
		vertices: vertices of text region <numpy.ndarray, (8,)>
		coef    : shrink ratio in paper
	Output:
		v       : vertices of shrinked text region <numpy.ndarray, (8,)>
	'''
	x1, y1, x2, y2, x3, y3, x4, y4 = vertices
	r1 = min(cal_distance(x1,y1,x2,y2), cal_distance(x1,y1,x4,y4))
	r2 = min(cal_distance(x2,y2,x1,y1), cal_distance(x2,y2,x3,y3))
	r3 = min(cal_distance(x3,y3,x2,y2), cal_distance(x3,y3,x4,y4))
	r4 = min(cal_distance(x4,y4,x1,y1), cal_distance(x4,y4,x3,y3))
	r = [r1, r2, r3, r4]

	# obtain offset to perform move_points() automatically
	if cal_distance(x1,y1,x2,y2) + cal_distance(x3,y3,x4,y4) > \
       cal_distance(x2,y2,x3,y3) + cal_distance(x1,y1,x4,y4):
		offset = 0 # two longer edges are (x1y1-x2y2) & (x3y3-x4y4)
	else:
		offset = 1 # two longer edges are (x2y2-x3y3) & (x4y4-x1y1)

	v = vertices.copy()
	v = move_points(v, 0 + offset, 1 + offset, r, coef)
	v = move_points(v, 2 + offset, 3 + offset, r, coef)
	v = move_points(v, 1 + offset, 2 + offset, r, coef)
	v = move_points(v, 3 + offset, 4 + offset, r, coef)
	return v


def get_rotate_mat(theta):
	'''positive theta value means rotate clockwise'''
	return np.array([[math.cos(theta), -math.sin(theta)], [math.sin(theta), math.cos(theta)]])


def rotate_vertices(vertices, theta, anchor=None):
	'''rotate vertices around anchor
	Input:	
		vertices: vertices of text region <numpy.ndarray, (8,)>
		theta   : angle in radian measure
		anchor  : fixed position during rotation
	Output:
		rotated vertices <numpy.ndarray, (8,)>
	'''
	v = vertices.reshape((4,2)).T
	if anchor is None:
		anchor = v[:,:1]
	rotate_mat = get_rotate_mat(theta)
	res = np.dot(rotate_mat, v - anchor)
	return (res + anchor).T.reshape(-1)


def get_boundary(vertices):
	'''get the tight boundary around given vertices
	Input:
		vertices: vertices of text region <numpy.ndarray, (8,)>
	Output:
		the boundary
	'''
	x1, y1, x2, y2, x3, y3, x4, y4 = vertices
	x_min = min(x1, x2, x3, x4)
	x_max = max(x1, x2, x3, x4)
	y_min = min(y1, y2, y3, y4)
	y_max = max(y1, y2, y3, y4)
	return x_min, x_max, y_min, y_max


def cal_error(vertices):
	'''default orientation is x1y1 : left-top, x2y2 : right-top, x3y3 : right-bot, x4y4 : left-bot
	calculate the difference between the vertices orientation and default orientation
	Input:
		vertices: vertices of text region <numpy.ndarray, (8,)>
	Output:
		err     : difference measure
	'''
	x_min, x_max, y_min, y_max = get_boundary(vertices)
	x1, y1, x2, y2, x3, y3, x4, y4 = vertices
	err = cal_distance(x1, y1, x_min, y_min) + cal_distance(x2, y2, x_max, y_min) + \
          cal_distance(x3, y3, x_max, y_max) + cal_distance(x4, y4, x_min, y_max)
	return err	


def find_min_rect_angle(vertices):
	'''find the best angle to rotate poly and obtain min rectangle
	Input:
		vertices: vertices of text region <numpy.ndarray, (8,)>
	Output:
		the best angle <radian measure>
	'''
	angle_interval = 1
	angle_list = list(range(-90, 90, angle_interval))
	area_list = []
	for theta in angle_list: 
		rotated = rotate_vertices(vertices, theta / 180 * math.pi)
		x1, y1, x2, y2, x3, y3, x4, y4 = rotated
		temp_area = (max(x1, x2, x3, x4) - min(x1, x2, x3, x4)) * \
                    (max(y1, y2, y3, y4) - min(y1, y2, y3, y4))
		area_list.append(temp_area)
	
	sorted_area_index = sorted(list(range(len(area_list))), key=lambda k : area_list[k])
	min_error = float('inf')
	best_index = -1
	rank_num = 10
	# find the best angle with correct orientation
	for index in sorted_area_index[:rank_num]:
		rotated = rotate_vertices(vertices, angle_list[index] / 180 * math.pi)
		temp_error = cal_error(rotated)
		if temp_error < min_error:
			min_error = temp_error
			best_index = index
	return angle_list[best_index] / 180 * math.pi


def is_cross_text(start_loc, length, vertices):
	'''check if the crop image crosses text regions
	Input:
		start_loc: left-top position
		length   : length of crop image
		vertices : vertices of text regions <numpy.ndarray, (n,8)>
	Output:
		True if crop image crosses text region
	'''
	if vertices.size == 0:
		return False
	start_w, start_h = start_loc
	a = np.array([start_w, start_h, start_w + length, start_h, \
          start_w + length, start_h + length, start_w, start_h + length]).reshape((4,2))
	p1 = Polygon(a).convex_hull
	for vertice in vertices:
		p2 = Polygon(vertice.reshape((4,2))).convex_hull
		inter = p1.intersection(p2).area
		if 0.01 <= inter / p2.area <= 0.99: 
			return True
	return False
		

def crop_img(img, vertices, labels, length):
	'''crop img patches to obtain batch and augment
	Input:
		img         : PIL Image
		vertices    : vertices of text regions <numpy.ndarray, (n,8)>
		labels      : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
		length      : length of cropped image region
	Output:
		region      : cropped image region
		new_vertices: new vertices in cropped region
	'''
	h, w = img.height, img.width
	# confirm the shortest side of image >= length
	if h >= w and w < length:
		img = img.resize((length, int(h * length / w)), Image.Resampling.BILINEAR)
	elif h < w and h < length:
		img = img.resize((int(w * length / h), length), Image.Resampling.BILINEAR)
	ratio_w = img.width / w
	ratio_h = img.height / h
	assert(ratio_w >= 1 and ratio_h >= 1)

	new_vertices = np.zeros(vertices.shape)
	if vertices.size > 0:
		new_vertices[:,[0,2,4,6]] = vertices[:,[0,2,4,6]] * ratio_w
		new_vertices[:,[1,3,5,7]] = vertices[:,[1,3,5,7]] * ratio_h

	# find random position
	remain_h = img.height - length
	remain_w = img.width - length
	flag = True
	cnt = 0
	while flag and cnt < 1000:
		cnt += 1
		start_w = int(np.random.rand() * remain_w)
		start_h = int(np.random.rand() * remain_h)
		flag = is_cross_text([start_w, start_h], length, new_vertices[labels==1,:])
	box = (start_w, start_h, start_w + length, start_h + length)
	region = img.crop(box)
	if new_vertices.size == 0:
		return region, new_vertices	
	
	new_vertices[:,[0,2,4,6]] -= start_w
	new_vertices[:,[1,3,5,7]] -= start_h
	return region, new_vertices


def rotate_all_pixels(rotate_mat, anchor_x, anchor_y, length):
	'''get rotated locations of all pixels for next stages
	Input:
		rotate_mat: rotatation matrix
		anchor_x  : fixed x position
		anchor_y  : fixed y position
		length    : length of image
	Output:
		rotated_x : rotated x positions <numpy.ndarray, (length,length)>
		rotated_y : rotated y positions <numpy.ndarray, (length,length)>
	'''
	x = np.arange(length)
	y = np.arange(length)
	x, y = np.meshgrid(x, y)
	x_lin = x.reshape((1, x.size))
	y_lin = y.reshape((1, x.size))
	coord_mat = np.concatenate((x_lin, y_lin), 0)
	rotated_coord = np.dot(rotate_mat, coord_mat - np.array([[anchor_x], [anchor_y]])) + \
                                                   np.array([[anchor_x], [anchor_y]])
	rotated_x = rotated_coord[0, :].reshape(x.shape)
	rotated_y = rotated_coord[1, :].reshape(y.shape)
	return rotated_x, rotated_y


def adjust_height(img, vertices, ratio=0.2):
	'''adjust height of image to aug data
	Input:
		img         : PIL Image
		vertices    : vertices of text regions <numpy.ndarray, (n,8)>
		ratio       : height changes in [0.8, 1.2]
	Output:
		img         : adjusted PIL Image
		new_vertices: adjusted vertices
	'''
	ratio_h = 1 + ratio * (np.random.rand() * 2 - 1)
	old_h = img.height
	new_h = int(np.around(old_h * ratio_h))
	img = img.resize((img.width, new_h), Image.Resampling.BILINEAR)
	
	new_vertices = vertices.copy()
	if vertices.size > 0:
		new_vertices[:,[1,3,5,7]] = vertices[:,[1,3,5,7]] * (new_h / old_h)
	return img, new_vertices


def rotate_img(img, vertices, angle_range=10):
	'''rotate image [-10, 10] degree to aug data
	Input:
		img         : PIL Image
		vertices    : vertices of text regions <numpy.ndarray, (n,8)>
		angle_range : rotate range
	Output:
		img         : rotated PIL Image
		new_vertices: rotated vertices
	'''
	center_x = (img.width - 1) / 2
	center_y = (img.height - 1) / 2
	angle = angle_range * (np.random.rand() * 2 - 1)
	img = img.rotate(angle, Image.Resampling.BILINEAR)
	new_vertices = np.zeros(vertices.shape)
	for i, vertice in enumerate(vertices):
		new_vertices[i,:] = rotate_vertices(vertice, -angle / 180 * math.pi, np.array([[center_x],[center_y]]))
	return img, new_vertices


def get_score_geo(img, vertices, labels, scale, length):
	'''generate score gt and geometry gt
	Input:
		img     : PIL Image
		vertices: vertices of text regions <numpy.ndarray, (n,8)>
		labels  : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
		scale   : feature map / image
		length  : image length
	Output:
		score gt, geo gt, ignored
	'''
	score_map   = np.zeros((int(img.height * scale), int(img.width * scale), 1), np.float32)
	geo_map     = np.zeros((int(img.height * scale), int(img.width * scale), 5), np.float32)
	ignored_map = np.zeros((int(img.height * scale), int(img.width * scale), 1), np.float32)
	
	index = np.arange(0, length, int(1/scale))
	index_x, index_y = np.meshgrid(index, index)
	ignored_polys = []
	polys = []
	
	for i, vertice in enumerate(vertices):
		if labels[i] == 0:
			ignored_polys.append(np.around(scale * vertice.reshape((4,2))).astype(np.int32))
			continue		
		
		poly = np.around(scale * shrink_poly(vertice).reshape((4,2))).astype(np.int32) # scaled & shrinked
		polys.append(poly)
		temp_mask = np.zeros(score_map.shape[:-1], np.float32)
		cv2.fillPoly(temp_mask, [poly], 1)
		
		theta = find_min_rect_angle(vertice)
		rotate_mat = get_rotate_mat(theta)
		
		rotated_vertices = rotate_vertices(vertice, theta)
		x_min, x_max, y_min, y_max = get_boundary(rotated_vertices)
		rotated_x, rotated_y = rotate_all_pixels(rotate_mat, vertice[0], vertice[1], length)
	
		d1 = rotated_y - y_min
		d1[d1<0] = 0
		d2 = y_max - rotated_y
		d2[d2<0] = 0
		d3 = rotated_x - x_min
		d3[d3<0] = 0
		d4 = x_max - rotated_x
		d4[d4<0] = 0
		geo_map[:,:,0] += d1[index_y, index_x] * temp_mask
		geo_map[:,:,1] += d2[index_y, index_x] * temp_mask
		geo_map[:,:,2] += d3[index_y, index_x] * temp_mask
		geo_map[:,:,3] += d4[index_y, index_x] * temp_mask
		geo_map[:,:,4] += theta * temp_mask
	
	cv2.fillPoly(ignored_map, ignored_polys, 1)
	cv2.fillPoly(score_map, polys, 1)
	return torch.Tensor(score_map).permute(2,0,1), torch.Tensor(geo_map).permute(2,0,1), torch.Tensor(ignored_map).permute(2,0,1)


def extract_vertices(lines):
	'''extract vertices info from txt lines
	Input:
		lines   : list of string info
	Output:
		vertices: vertices of text regions <numpy.ndarray, (n,8)>
		labels  : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
	'''
	labels = []
	vertices = []
	for line in lines:
		vertices.append(list(map(int,line.rstrip('\n').lstrip('\ufeff').split(',')[:8])))
		label = 0 if '###' in line else 1
		labels.append(label)
	return np.array(vertices), np.array(labels)

def resize(image, anns,size):
    image = np.asarray(image)
    h, w, c = image.shape
    scale_w = size / w
    scale_h = size / h
    scale = min(scale_w, scale_h)
    h = int(h * scale)
    w = int(w * scale)
    mean=np.mean(image, axis=(0,1))
    
    padimg = np.zeros((size, size, c), image.dtype)
    padimg[:,:,0]=np.full((size, size), mean[0])
    padimg[:,:,1]=np.full((size,size), mean[1])
    padimg[:,:,2]=np.full((size,size), mean[2])
    padimg[:h, :w] = cv2.resize(image, (w, h))
    new_anns = []
    for ann in anns:
        # random_pad=np.random.randint(-10,10, size=(4, 2))
        # poly =np.clip(poly*scale+random_pad, 1, size-1)
        new_ann=ann*scale
        new_ann=[int(i) for i in new_ann]
        new_anns.append(new_ann)
    padimg_pil = Image.fromarray(padimg)
    return padimg_pil, np.array(new_anns)
class custom_dataset(data.Dataset):
	def __init__(self, img_path, gt_path, scale=0.25, length=512,is_training=True):
		super(custom_dataset, self).__init__()
		self.img_files = [os.path.join(img_path, img_file) for img_file in sorted(os.listdir(img_path))]
		self.gt_files  = [os.path.join(gt_path, gt_file) for gt_file in sorted(os.listdir(gt_path))]
		self.scale = scale
		self.length = length
		self.is_training= is_training
	def __len__(self):
		return len(self.img_files)

	def __getitem__(self, index):
		with open(self.gt_files[index], 'r') as f:
			lines = f.readlines()
		vertices, labels = extract_vertices(lines)
		
		img = Image.open(self.img_files[index]).convert("RGB")
		# img, vertices = adjust_height(img, vertices) 
		# img, vertices = rotate_img(img, vertices)

		img, vertices=resize(img, vertices, self.length)
		# img, vertices = crop_img(img, vertices, labels, self.length) 
		transform = transforms.Compose([transforms.ToTensor(), \
                                        transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
		ignore_tags=[]
		
		for i in labels:
			if i ==1:
				ignore_tags.append(False)
			else:
				ignore_tags.append(True)
		score_map, geo_map, ignored_map = get_score_geo(img, vertices, labels, self.scale, self.length)
		if not self.is_training:
				new_vertices=[]
				for j in vertices:
					j=[float(m) for m in j]
					temp=[(j[0],j[1]),(j[2], j[3]),(j[4], j[5]),(j[6], j[7])]
					new_vertices.append(temp)
				return transform(img), score_map, geo_map, ignore_tags, self.gt_files[index], new_vertices

		return transform(img), score_map, geo_map, ignored_map, self.gt_files[index]

  from .autonotebook import tqdm as notebook_tqdm


# Loss

In [2]:
import torch
import torch.nn as nn


def get_dice_loss(gt_score, pred_score):
	inter = torch.sum(gt_score * pred_score)
	union = torch.sum(gt_score) + torch.sum(pred_score) + 1e-5
	return 1. - (2 * inter / union)
	 

def get_geo_loss(gt_geo, pred_geo):
	d1_gt, d2_gt, d3_gt, d4_gt, angle_gt = torch.split(gt_geo, 1, 1)
	d1_pred, d2_pred, d3_pred, d4_pred, angle_pred = torch.split(pred_geo, 1, 1)
	area_gt = (d1_gt + d2_gt) * (d3_gt + d4_gt)
	area_pred = (d1_pred + d2_pred) * (d3_pred + d4_pred)
	w_union = torch.min(d3_gt, d3_pred) + torch.min(d4_gt, d4_pred)
	h_union = torch.min(d1_gt, d1_pred) + torch.min(d2_gt, d2_pred)
	area_intersect = w_union * h_union
	area_union = area_gt + area_pred - area_intersect
	iou_loss_map = -torch.log((area_intersect + 1.0)/(area_union + 1.0))
	angle_loss_map = 1 - torch.cos(angle_pred - angle_gt)
	return iou_loss_map, angle_loss_map


class Loss(nn.Module):
	def __init__(self, weight_angle=10):
		super(Loss, self).__init__()
		self.weight_angle = weight_angle

	def forward(self, gt_score, pred_score, gt_geo, pred_geo, ignored_map):
		if torch.sum(gt_score) < 1:
			return torch.sum(pred_score + pred_geo) * 0
		
		classify_loss = get_dice_loss(gt_score, pred_score*(1-ignored_map))
		iou_loss_map, angle_loss_map = get_geo_loss(gt_geo, pred_geo)

		angle_loss = torch.sum(angle_loss_map*gt_score) / torch.sum(gt_score)
		iou_loss = torch.sum(iou_loss_map*gt_score) / torch.sum(gt_score)
		geo_loss = self.weight_angle * angle_loss + iou_loss
		print('classify loss is {:.8f}, angle loss is {:.8f}, iou loss is {:.8f}'.format(classify_loss, angle_loss, iou_loss))
		return geo_loss + classify_loss

# Model

In [3]:
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import math
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']


def make_layers(cfg, batch_norm=False):
	layers = []
	in_channels = 3
	for v in cfg:
		if v == 'M':
			layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
		else:
			conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
			if batch_norm:
				layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
			else:
				layers += [conv2d, nn.ReLU(inplace=True)]
			in_channels = v
	return nn.Sequential(*layers)


class VGG(nn.Module):
	def __init__(self, features):
		super(VGG, self).__init__()
		self.features = features
		self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
		self.classifier = nn.Sequential(
			nn.Linear(512 * 7 * 7, 4096),
			nn.ReLU(True),
			nn.Dropout(),
			nn.Linear(4096, 4096),
			nn.ReLU(True),
			nn.Dropout(),
			nn.Linear(4096, 1000),
		)

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)
			elif isinstance(m, nn.Linear):
				nn.init.normal_(m.weight, 0, 0.01)
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		x = self.features(x)
		x = self.avgpool(x)
		x = x.view(x.size(0), -1)
		x = self.classifier(x)
		return x


class extractor(nn.Module):
	def __init__(self, pretrained):
		super(extractor, self).__init__()
		vgg16_bn = VGG(make_layers(cfg, batch_norm=True))
		if pretrained:
			vgg16_bn.load_state_dict(torch.load('vgg16_bn-6c64b313.pth'))
		self.features = vgg16_bn.features
	
	def forward(self, x):
		out = []
		for m in self.features:
			x = m(x)
			if isinstance(m, nn.MaxPool2d):
				out.append(x)
		return out[1:]


class merge(nn.Module):
	def __init__(self):
		super(merge, self).__init__()

		self.conv1 = nn.Conv2d(1024, 128, 1)
		self.bn1 = nn.BatchNorm2d(128)
		self.relu1 = nn.ReLU()
		self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
		self.bn2 = nn.BatchNorm2d(128)
		self.relu2 = nn.ReLU()

		self.conv3 = nn.Conv2d(384, 64, 1)
		self.bn3 = nn.BatchNorm2d(64)
		self.relu3 = nn.ReLU()
		self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
		self.bn4 = nn.BatchNorm2d(64)
		self.relu4 = nn.ReLU()

		self.conv5 = nn.Conv2d(192, 32, 1)
		self.bn5 = nn.BatchNorm2d(32)
		self.relu5 = nn.ReLU()
		self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
		self.bn6 = nn.BatchNorm2d(32)
		self.relu6 = nn.ReLU()

		self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
		self.bn7 = nn.BatchNorm2d(32)
		self.relu7 = nn.ReLU()
		
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		y = F.interpolate(x[3], scale_factor=2, mode='bilinear', align_corners=True)
		y = torch.cat((y, x[2]), 1)
		y = self.relu1(self.bn1(self.conv1(y)))		
		y = self.relu2(self.bn2(self.conv2(y)))
		
		y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
		y = torch.cat((y, x[1]), 1)
		y = self.relu3(self.bn3(self.conv3(y)))		
		y = self.relu4(self.bn4(self.conv4(y)))
		
		y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
		y = torch.cat((y, x[0]), 1)
		y = self.relu5(self.bn5(self.conv5(y)))		
		y = self.relu6(self.bn6(self.conv6(y)))
		
		y = self.relu7(self.bn7(self.conv7(y)))
		return y

class output(nn.Module):
	def __init__(self, scope=512):
		super(output, self).__init__()
		self.conv1 = nn.Conv2d(32, 1, 1)
		self.sigmoid1 = nn.Sigmoid()
		self.conv2 = nn.Conv2d(32, 4, 1)
		self.sigmoid2 = nn.Sigmoid()
		self.conv3 = nn.Conv2d(32, 1, 1)
		self.sigmoid3 = nn.Sigmoid()
		self.scope = 512
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)

	def forward(self, x):
		score = self.sigmoid1(self.conv1(x))
		loc   = self.sigmoid2(self.conv2(x)) * self.scope
		angle = (self.sigmoid3(self.conv3(x)) - 0.5) * math.pi
		geo   = torch.cat((loc, angle), 1) 
		return score, geo
		
	
class EAST(nn.Module):
	def __init__(self, pretrained=True):
		super(EAST, self).__init__()
		self.extractor = extractor(pretrained)
		self.merge     = merge()
		self.output    = output()
	
	def forward(self, x):
		return self.output(self.merge(self.extractor(x)))
		

if __name__ == '__main__':
	m = EAST(pretrained=False)
	x = torch.randn(1, 3, 256, 256)
	score, geo = m(x)
	print(score.shape)
	print(geo.shape)

torch.Size([1, 1, 64, 64])
torch.Size([1, 5, 64, 64])


In [4]:
import time
from torch.optim import lr_scheduler

In [5]:
def train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
	file_num = len(os.listdir(train_img_path))
	trainset = custom_dataset(train_img_path, train_gt_path)
	train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                   shuffle=True, num_workers=num_workers, drop_last=True)
	
	criterion = Loss()
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	model = EAST(pretrained=True)
# 	data_parallel = False
# 	if torch.cuda.device_count() > 1:
# 		model = nn.DataParallel(model)
# 		data_parallel = True
	model.to(device)
	optimizer = torch.optim.Adam(model.parameters(), lr=lr)
	scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)
	checkpoint = torch.load('/home/lab/khanhnd/EAST/saved_model/model_epoch_20.pth')
	model.load_state_dict(checkpoint)

	for epoch in range(20,epoch_iter):	
		model.train()
		
		epoch_loss = 0
		epoch_time = time.time()
		for i, (img, gt_score, gt_geo, ignored_map,path) in enumerate(train_loader):
			start_time = time.time()

			img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(device)
			pred_score, pred_geo = model(img)
			loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)
			
			epoch_loss += loss.item()
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
              epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))
		scheduler.step()
		print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss/int(file_num/batch_size), time.time()-epoch_time))
		print(time.asctime(time.localtime(time.time())))
		print('='*50)
		if (epoch + 1) % interval == 0:
# 			state_dict = model.module.state_dict() if data_parallel else model.state_dict()
			state_dict = model.state_dict()
			print("Saving--------------------------------")
			torch.save(state_dict, os.path.join(pths_path, 'model_epoch_{}.pth'.format(epoch+1)))
			print("Saved---------------------------------")

In [6]:
train_img_path="/home/lab/khanhnd/STD_DBNet/dataset/vietnamese/train_images"
train_gt_path="/home/lab/khanhnd/STD_DBNet/dataset/vietnamese/train_gts"

In [7]:

# train(train_img_path, train_gt_path, pths_path='/home/lab/khanhnd/EAST/saved_model', batch_size=5, lr=0.0003, num_workers=2, epoch_iter=40, interval=10)



Utils


In [8]:
import lanms

def is_valid_poly(res, score_shape, scale):
	'''check if the poly in image scope
	Input:
		res        : restored poly in original image
		score_shape: score map shape
		scale      : feature map -> image
	Output:
		True if valid
	'''
	cnt = 0
	for i in range(res.shape[1]):
		if res[0,i] < 0 or res[0,i] >= score_shape[1] * scale or \
           res[1,i] < 0 or res[1,i] >= score_shape[0] * scale:
			cnt += 1
	return True if cnt <= 1 else False
def restore_polys(valid_pos, valid_geo, score_shape, scale=4):
	'''restore polys from feature maps in given positions
	Input:
		valid_pos  : potential text positions <numpy.ndarray, (n,2)>
		valid_geo  : geometry in valid_pos <numpy.ndarray, (5,n)>
		score_shape: shape of score map
		scale      : image / feature map
	Output:
		restored polys <numpy.ndarray, (n,8)>, index
	'''
	polys = []
	index = []
	valid_pos *= scale
	d = valid_geo[:4, :] # 4 x N
	angle = valid_geo[4, :] # N,

	for i in range(valid_pos.shape[0]):
		x = valid_pos[i, 0]
		y = valid_pos[i, 1]
		y_min = y - d[0, i]
		y_max = y + d[1, i]
		x_min = x - d[2, i]
		x_max = x + d[3, i]
		rotate_mat = get_rotate_mat(-angle[i])
		
		temp_x = np.array([[x_min, x_max, x_max, x_min]]) - x
		temp_y = np.array([[y_min, y_min, y_max, y_max]]) - y
		coordidates = np.concatenate((temp_x, temp_y), axis=0)
		res = np.dot(rotate_mat, coordidates)
		res[0,:] += x
		res[1,:] += y
		
		if is_valid_poly(res, score_shape, scale):
			index.append(i)
			polys.append([res[0,0], res[1,0], res[0,1], res[1,1], res[0,2], res[1,2],res[0,3], res[1,3]])
	return np.array(polys),  index
def get_boxes(score, geo, score_thresh=0.9, nms_thresh=0.2):
	'''get boxes from feature map
	Input:
		score       : score map from model <numpy.ndarray, (1,row,col)>
		geo         : geo map from model <numpy.ndarray, (5,row,col)>
		score_thresh: threshold to segment score map
		nms_thresh  : threshold in nms
	Output:
		boxes       : final polys <numpy.ndarray, (n,9)>
	'''
	score = score[0,:,:]
	xy_text = np.argwhere(score > score_thresh) # n x 2, format is [r, c]
	if xy_text.size == 0:
		return None

	xy_text = xy_text[np.argsort(xy_text[:, 0])]
	valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y]
	valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n
	polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape) 
	if polys_restored.size == 0:
		return None

	boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32)
	boxes[:, :8] = polys_restored
	boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]]
	boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thresh)
	return boxes
def adjust_ratio(boxes, ratio_w, ratio_h):
	'''refine boxes
	Input:
		boxes  : detected polys <numpy.ndarray, (n,9)>
		ratio_w: ratio of width
		ratio_h: ratio of height
	Output:
		refined boxes
	'''
	if boxes is None or boxes.size == 0:
		return None
	boxes[:,[0,2,4,6]] /= ratio_w
	boxes[:,[1,3,5,7]] /= ratio_h
	return np.around(boxes)

Metric

In [9]:

from collections import namedtuple
import numpy as np
from shapely.geometry import Polygon


class DetectionIoUEvaluator(object):
    def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
        self.iou_constraint = iou_constraint
        self.area_precision_constraint = area_precision_constraint

    def evaluate_image(self, gt, pred):
        def get_union(pD, pG):
            pD = Polygon(pD).buffer(0)
            pG = Polygon(pG).buffer(0)
            return pD.union(pG).area

        def get_intersection_over_union(pD, pG):
            iou = get_intersection(pD, pG) / get_union(pD, pG)
            return iou

        def get_intersection(pD, pG):
            pD = Polygon(pD).buffer(0)
            pG = Polygon(pG).buffer(0)
            return pD.intersection(pG).area

        def compute_ap(confList, matchList, numGtCare):
            correct = 0
            AP = 0
            if len(confList) > 0:
                confList = np.array(confList)
                matchList = np.array(matchList)
                sorted_ind = np.argsort(-confList)
                confList = confList[sorted_ind]
                matchList = matchList[sorted_ind]
                for n in range(len(confList)):
                    match = matchList[n]
                    if match:
                        correct += 1
                        AP += float(correct) / (n + 1)

                if numGtCare > 0:
                    AP /= numGtCare

            return AP

        perSampleMetrics = {}

        matchedSum = 0

        Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')

        numGlobalCareGt = 0
        numGlobalCareDet = 0

        arrGlobalConfidences = []
        arrGlobalMatches = []

        recall = 0
        precision = 0
        hmean = 0

        detMatched = 0

        iouMat = np.empty([1, 1])

        gtPols = []
        detPols = []

        gtPolPoints = []
        detPolPoints = []

        gtDontCarePolsNum = []

        detDontCarePolsNum = []

        pairs = []
        detMatchedNums = []

        arrSampleConfidences = []
        arrSampleMatch = []

        evaluationLog = ""

        for n in range(len(gt)):
            points = gt[n]['points']
            dontCare = gt[n]['ignore']

            if not Polygon(points).buffer(0).is_valid or \
                    not Polygon(points).buffer(0).is_simple:
                continue

            gtPol = points
            gtPols.append(gtPol)
            gtPolPoints.append(points)
            if dontCare:
                gtDontCarePolsNum.append(len(gtPols) - 1)

        evaluationLog += "GT polygons: " + str(len(gtPols)) + (
            " (" + str(len(gtDontCarePolsNum)) +
            " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")

        for n in range(len(pred)):
            points = pred[n]['points']

            # if not Polygon(points).is_valid or not Polygon(points).is_simple:
            if not Polygon(points).buffer(0).is_valid or \
                    not Polygon(points).buffer(0).is_simple:
                continue

            detPol = points
            detPols.append(detPol)
            detPolPoints.append(points)
            if len(gtDontCarePolsNum) > 0:
                for dontCarePol in gtDontCarePolsNum:
                    dontCarePol = gtPols[dontCarePol]
                    intersected_area = get_intersection(dontCarePol, detPol)
                    pdDimensions = Polygon(detPol).area
                    precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
                    if (precision > self.area_precision_constraint):
                        detDontCarePolsNum.append(len(detPols) - 1)
                        break

        evaluationLog += "DET polygons: " + str(len(detPols)) + (
            " (" + str(len(detDontCarePolsNum)) +
            " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")

        if len(gtPols) > 0 and len(detPols) > 0:
            # Calculate IoU and precision matrixs
            outputShape = [len(gtPols), len(detPols)]
            iouMat = np.empty(outputShape)
            gtRectMat = np.zeros(len(gtPols), np.int8)
            detRectMat = np.zeros(len(detPols), np.int8)
            for gtNum in range(len(gtPols)):
                for detNum in range(len(detPols)):
                    pG = gtPols[gtNum]
                    pD = detPols[detNum]
                    iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)

            for gtNum in range(len(gtPols)):
                for detNum in range(len(detPols)):
                    if gtRectMat[gtNum] == 0 and detRectMat[
                            detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
                        if iouMat[gtNum, detNum] > self.iou_constraint:
                            gtRectMat[gtNum] = 1
                            detRectMat[detNum] = 1
                            detMatched += 1
                            pairs.append({'gt': gtNum, 'det': detNum})
                            detMatchedNums.append(detNum)
                            evaluationLog += "Match GT #" + \
                                str(gtNum) + " with Det #" + str(detNum) + "\n"

        numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
        numDetCare = (len(detPols) - len(detDontCarePolsNum))
        if numGtCare == 0:
            recall = float(1)
            precision = float(0) if numDetCare > 0 else float(1)
        else:
            recall = float(detMatched) / numGtCare
            precision = 0 if numDetCare == 0 else float(
                detMatched) / numDetCare

        hmean = 0 if (precision + recall) == 0 else 2.0 * \
            precision * recall / (precision + recall)

        matchedSum += detMatched
        numGlobalCareGt += numGtCare
        numGlobalCareDet += numDetCare

        perSampleMetrics = {
            'precision': precision,
            'recall': recall,
            'hmean': hmean,
            'pairs': pairs,
            'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
            'gtPolPoints': gtPolPoints,
            'detPolPoints': detPolPoints,
            'gtCare': numGtCare,
            'detCare': numDetCare,
            'gtDontCare': gtDontCarePolsNum,
            'detDontCare': detDontCarePolsNum,
            'detMatched': detMatched,
            'evaluationLog': evaluationLog
        }

        return perSampleMetrics

    def combine_results(self, results):
        numGlobalCareGt = 0
        numGlobalCareDet = 0
        matchedSum = 0
        for result in results:
            numGlobalCareGt += result['gtCare']
            numGlobalCareDet += result['detCare']
            matchedSum += result['detMatched']

        methodRecall = 0 if numGlobalCareGt == 0 else float(
            matchedSum) / numGlobalCareGt
        methodPrecision = 0 if numGlobalCareDet == 0 else float(
            matchedSum) / numGlobalCareDet
        methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
            methodRecall * methodPrecision / (methodRecall + methodPrecision)

        methodMetrics = {
            'precision': methodPrecision,
            'recall': methodRecall,
            'hmean': methodHmean
        }

        return methodMetrics

def to_list_tuples_coords(anns):
    new_anns = []
    for ann in anns:
        points = []
        for x, y in ann:
            points.append((x[0].tolist(), y[0].tolist()))
        new_anns.append(points)
    return new_anns
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        return self
class QuadMetric:
    def __init__(self):
        self.evaluator = DetectionIoUEvaluator()

    def measure(self, batch, output, is_output_polygon=False, box_thresh=0.6):
        '''
        batch: (image, polygons, ignore_tags
        batch: a dict produced by dataloaders.
            image: tensor of shape (N, C, H, W).
            polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
            ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
            shape: the original shape of images.
            filename: the original filenames of images.
        output: (polygons, ...)
        '''
        results = []


        pred_polygons_batch = np.array(output[0])
        pred_scores_batch = np.array(output[1])
    
        gt_polygons_batch = to_list_tuples_coords(batch['anns'])

        ignore_tags_batch = [i[0].tolist() for i in batch['ignore_tags']]
        gt = []
        for gt_polygon, ignore_tag in zip(gt_polygons_batch,
                                          ignore_tags_batch):
            gt.append({'points': gt_polygon, 'ignore': ignore_tag})
        
        
        pred = []  # for 1 image
        for pred_polygon, pred_score in zip(pred_polygons_batch[0],
                                            pred_scores_batch[0]):
            pred.append({'points': pred_polygon, 'ignore': False})

        results.append(self.evaluator.evaluate_image(gt, pred))

        return results

    def validate_measure(self,
                         batch,
                         output,
                         is_output_polygon=False,
                         box_thresh=0.6):
        return self.measure(batch, output, is_output_polygon, box_thresh)

    # def evaluate_measure(self, batch, output):
    #     return self.measure(batch, output), np.linspace(
    #         0, batch['image'].shape[0]).tolist()

    def gather_measure(self, raw_metrics):
        raw_metrics = [
            image_metrics for batch_metrics in raw_metrics
            for image_metrics in batch_metrics
        ]

        result = self.evaluator.combine_results(raw_metrics)

        precision = AverageMeter()
        recall = AverageMeter()
        fmeasure = AverageMeter()

        precision.update(result['precision'], n=len(raw_metrics))
        recall.update(result['recall'], n=len(raw_metrics))
        fmeasure_score = 2 * precision.val * recall.val / (precision.val +
                                                           recall.val + 1e-8)
        fmeasure.update(fmeasure_score)

        return {'precision': precision, 'recall': recall, 'fmeasure': fmeasure}


test_util

EVAL


In [19]:
import tqdm
def eval(valid_img_path, valid_gt_path, batch_size):
	file_num = len(os.listdir(valid_img_path))
	validset = custom_dataset(valid_img_path, valid_gt_path, is_training=False)
	valid_loader = data.DataLoader(validset, batch_size=batch_size, \
                                   shuffle=False, drop_last=False)
	
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	model = EAST(pretrained=True)
# 	data_parallel = False
# 	if torch.cuda.device_count() > 1:
# 		model = nn.DataParallel(model)
# 		data_parallel = True
	model.to(device)
	checkpoint = torch.load('/home/lab/khanhnd/EAST/saved_model/model_epoch_40.pth')
	model.load_state_dict(checkpoint)

	metric_cls = QuadMetric()
	model.eval()
	
	epoch_loss = 0
	epoch_time = time.time()
	raw_metrics = []
	for  (img, gt_score, gt_geo,ignore_tags, path, anns) in tqdm.tqdm(valid_loader):
		start_time = time.time()

		with torch.no_grad():
			img, gt_score, gt_geo= img.to(device), gt_score.to(device), gt_geo.to(device)
			pred_score, pred_geo = model(img)
		boxes=get_boxes(pred_score.squeeze(0).cpu().numpy(), pred_geo.squeeze(0).cpu().numpy())
		box_list = [[[[j[0],j[1]],[j[2], j[3]],[j[4], j[5]],[j[6], j[7]]] for j in boxes[:,:8]]]
		# box_list=[[(v[k],v[k+1]) for k in range(0,7,2)] for v in box_list]
		score_list= [boxes[:,8]]

		valid_batch={"anns": anns, "ignore_tags": ignore_tags}


		raw_metric = metric_cls.validate_measure(
                    valid_batch, (box_list, score_list))
		raw_metrics.append(raw_metric)
	metrics = metric_cls.gather_measure(raw_metrics)
	recall = metrics['recall'].avg
	precision = metrics['precision'].avg
	hmean = metrics['fmeasure'].avg
	print("hmean", hmean)
	print("precision", precision)
	print("recall", recall)

		
	

In [20]:
eval(valid_img_path="/home/lab/khanhnd/STD_DBNet/dataset/vietnamese/valid_images", valid_gt_path="/home/lab/khanhnd/STD_DBNet/dataset/vietnamese/valid_gts", batch_size=1)

  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **kwargs)
  return lib.intersection(a, b, **

hmean 0.8399397842675745
precision 0.8138065143412737
recall 0.8678071539657853



