# Store Dataset

In [None]:
import cv2
import os
from glob import glob
from tqdm.auto import tqdm
import shutil
import matplotlib.pyplot as plt
import numpy as np
import math
import os.path as osp
import pyclipper
from shapely.geometry import Polygon
#-------------------------------
min_text_size=8
shrink_ratio =0.4
IMAGE_SIZE   =512

ds_dir   ="/home/apsisdev/ansary/DATASETS/APSIS/Detection/processed/"
src_dir  ="/backup/RAW/DET/DBNet/"
#-------------------------------
def create_dir(base,ext):
    '''
        creates a directory extending base
        args:
            base    =   base path 
            ext     =   the folder to create
    '''
    _path=os.path.join(base,ext)
    if not os.path.exists(_path):
        os.mkdir(_path)
    return _path
#-------------------------------
ds_dir           = create_dir(ds_dir,"data")
img_dir          = create_dir(ds_dir,"image")
gt_dir           = create_dir(ds_dir,"gt")
mask_dir         = create_dir(ds_dir,"mask")
thresh_map_dir   = create_dir(ds_dir,"thresh_map")
thresh_mask_dir  = create_dir(ds_dir,"thresh_mask")

ds_idens = os.listdir(src_dir)
ds_idens

In [None]:
def get_anns(img_path,ds_iden):
    # gt_path
    if ds_iden in ['boise_state',"bw",'sorieTest','sorieTrain']:
        gt_path=img_path.replace("images","gts")
        gt_path=gt_path.split(".")[0]+".txt"
        
    elif ds_iden in ['mlt2017train','mlt2017eval','icdar2015test']:
        gt_path=img_path.replace("images","gts").replace("img","gt_img")
        gt_path=gt_path.split(".")[0]+".txt"
        
    else: #ctw,funsd,icdar2015train,"tr400","td500test","td500train",'totaltext',"wildreceipt"
        gt_path=img_path.replace("images","gts")+".txt"

    
    # 8 lenght xy
    len8s=["wildreceipt",
           "tr400",
           "td500test",
           "td500train",
           'mlt2017train',
           'mlt2017eval',
           'icdar2015train',
           'icdar2015test',
           'sorieTrain',
           'sorieTest',
           'funsd',
           'bw',
           'boise_state']
    # td500 datasets
    tds=["tr400","td500test","td500train"]
    
    lines = []
    # ann
    reader = open(gt_path, 'r').readlines()
    for line in reader:
        item = {}
        if ds_iden=="ctw":
            parts=line.strip().split("####")
            label=parts[-1].replace(",","*")
            line=parts[0]+label
    
        
        parts=line.strip().split(",")
        # mlt2017
        if "mlt2017" in ds_iden:
            lang=parts[8]
            label=line.split(f"{lang},")[-1]
        elif ds_iden in tds:
            label = parts[-1]
            if label == '1':
                label = '###'
        elif ds_iden in len8s:
            label="".join(parts[8:])
            
        else:
            label = parts[-1]
        
        # wildreceipt
        if ds_iden=='wildreceipt':
            if len(label)==0:
                label="###"
        
        # conversion
        if label!="###":
            label="text"
        else:
            label="ignore"
        
        
        #--> poly
        line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in parts]
        if  ds_iden in len8s:
            poly = np.array(list(map(float, line[:8]))).reshape((-1, 2)).tolist()
        else:
            num_points = math.floor((len(line) - 1) / 2) * 2
            poly = np.array(list(map(float, line[:num_points]))).reshape((-1, 2)).tolist()
        if len(poly) < 3:
            continue
        item['poly'] = poly
        item['text'] = label
        lines.append(item)
    return lines
    

In [None]:


def draw_thresh_map(polygon, canvas, mask, shrink_ratio=0.4):
    assert polygon.ndim == 2
    assert polygon.shape[1] == 2

    polygon_shape = Polygon(polygon)
    distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
    subject = [tuple(l) for l in polygon]
    padding = pyclipper.PyclipperOffset()
    padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
    padded_polygon = np.array(padding.Execute(distance)[0])
    cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)

    xmin = padded_polygon[:, 0].min()
    xmax = padded_polygon[:, 0].max()
    ymin = padded_polygon[:, 1].min()
    ymax = padded_polygon[:, 1].max()
    width = xmax - xmin + 1
    height = ymax - ymin + 1

    polygon[:, 0] = polygon[:, 0] - xmin
    polygon[:, 1] = polygon[:, 1] - ymin

    xs = np.broadcast_to(np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
    ys = np.broadcast_to(np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))

    distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
    for i in range(polygon.shape[0]):
        j = (i + 1) % polygon.shape[0]
        absolute_distance = compute_distance(xs, ys, polygon[i], polygon[j])
        distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
    distance_map = np.min(distance_map, axis=0)

    xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
    xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
    ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
    ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
    canvas[ymin_valid:ymax_valid, xmin_valid:xmax_valid] = np.fmax(
        1 - distance_map[
            ymin_valid - ymin:ymax_valid - ymin,
            xmin_valid - xmin:xmax_valid - xmin],
        canvas[ymin_valid:ymax_valid, xmin_valid:xmax_valid])
    
def compute_distance(xs, ys, point_1, point_2):
    square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
    square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
    square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])

    cosin = (square_distance - square_distance_1 - square_distance_2) / \
            (2 * np.sqrt(square_distance_1 * square_distance_2))
    square_sin = 1 - np.square(cosin)
    square_sin = np.nan_to_num(square_sin)
    result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)

    result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
    return result




def resize(size, image,pad):
    h, w = image.shape[0],image.shape[1]
    scale_w = size / w
    scale_h = size / h
    scale = min(scale_w, scale_h)
    h = int(h * scale)
    w = int(w * scale)
    if len(image.shape)==3:
        padimg = np.ones((size, size, 3), image.dtype)*pad
    else:
        padimg = np.ones((size, size), image.dtype)*pad
    padimg[:h, :w] = cv2.resize(image, (w, h),fx=0,fy=0, interpolation = cv2.INTER_NEAREST)
    return np.squeeze(padimg)

def process_single_data(img_path,ds_iden):
    # img
    image=cv2.imread(img_path)
    h,w,c=image.shape
    # masks
    dim=(h,w)
    gt = np.zeros(dim, dtype=np.float32)
    mask = np.ones(dim, dtype=np.float32)
    thresh_map = np.zeros(dim, dtype=np.float32)
    thresh_mask = np.zeros(dim, dtype=np.float32)
    
    anns=get_anns(img_path,ds_iden)
    # process annotations
    for ann in anns:
        poly = np.array(ann['poly'])
        height = max(poly[:, 1]) - min(poly[:, 1])
        width = max(poly[:, 0]) - min(poly[:, 0])
        polygon = Polygon(poly)
        # generate gt and mask
        if polygon.area < 1 or min(height, width) < min_text_size or ann['text'] == 'ignore': 
            cv2.fillPoly(mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
            continue
        else:
            distance = polygon.area * (1 - np.power(shrink_ratio, 2)) / polygon.length
            subject = [tuple(l) for l in ann['poly']]
            padding = pyclipper.PyclipperOffset()
            padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
            shrinked = padding.Execute(-distance)
            if len(shrinked) == 0:
                cv2.fillPoly(mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
                continue
            else:
                shrinked = np.array(shrinked[0]).reshape(-1, 2)
                if shrinked.shape[0] > 2 and Polygon(shrinked).is_valid:
                    cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
                else:
                    cv2.fillPoly(mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
                    continue
        # generate thresh map and thresh mask
        draw_thresh_map(poly, thresh_map, thresh_mask, shrink_ratio=shrink_ratio)

    # resize
    image= resize(IMAGE_SIZE, image,0)
    gt= resize(IMAGE_SIZE, gt,0)
    thresh_map= resize(IMAGE_SIZE,thresh_map,0)
    thresh_mask= resize(IMAGE_SIZE,thresh_mask,0)
    mask= resize(IMAGE_SIZE,mask,1)
    
    
    
    thresh_map*=255
    gt*=255
    mask*=255
    thresh_mask*=255
    return image,gt,mask,thresh_map,thresh_mask

In [None]:
fiden=0
def process_dataset(ds_path,ds_iden):
    global fiden
    dataset_path=os.path.join(ds_path,ds_iden)
    img_paths=[img_path for img_path in tqdm(glob(os.path.join(dataset_path,"images","*.*")))]
    # extract anns 
    for img_path in tqdm(img_paths):
        try:
            img,gt,mask,thresh_map,thresh_mask=process_single_data(img_path,ds_iden)
            # save
            cv2.imwrite(os.path.join(img_dir,f"{fiden}.png"),img)
            cv2.imwrite(os.path.join(gt_dir,f"{fiden}.png"),gt)
            cv2.imwrite(os.path.join(mask_dir,f"{fiden}.png"),mask)
            cv2.imwrite(os.path.join(thresh_map_dir,f"{fiden}.png"),thresh_map)
            cv2.imwrite(os.path.join(thresh_mask_dir,f"{fiden}.png"),thresh_mask)
            fiden+=1
        except Exception as e:
            print("-------------------------------")
            print(ds_iden,":",img_path)
            print(e)
            print("-------------------------------")
   



In [None]:
def debug_process(ds_path,ds_iden):
    global fiden
    dataset_path=os.path.join(ds_path,ds_iden)
    img_paths=[img_path for img_path in tqdm(glob(os.path.join(dataset_path,"images","*.*")))]
    # extract anns 
    for img_path in tqdm(img_paths):
        img,gt,mask,thresh_map,thresh_mask=process_single_data(img_path,ds_iden)
        plt.imshow(img)
        plt.show()
        plt.imshow(gt)
        plt.show()
        plt.imshow(mask)
        plt.show()
        plt.imshow(thresh_map)
        plt.show()
        plt.imshow(thresh_mask)
        plt.show()
        break



In [None]:
for ds_iden in ds_idens:
    print(ds_iden)
    debug_process(src_dir,ds_iden)

In [None]:
for ds_iden in ds_idens:
    print(ds_iden)
    process_dataset(src_dir,ds_iden)