In [None]:
import cv2
import numpy as np

def parse_annot(src_path):
  """
    Input: path to the annotation file
    Output: list of dictionaries containing the bounding box coordinates and the label for each word in image
  """
  annot = [] # list of dictionaries {'poly': , 'text' : }
  reader = open(src_path, 'r').readlines()
  # read a line
  for line in reader:

    word = {} # one dict per bounding box
    parts = line.strip().split(',') # '641', '173', '656', '168', '657', '181', '643', '187', '###']
    label = parts[-1]

    # edge case
    if label == '1':
        label = '###'

    line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in parts]
    # extract the polygon coordinates: 2D array with x and y coordinates
    poly = np.array(list(map(float, line[:8]))).reshape((-1, 2)).tolist() # [[641.0, 173.0], [656.0, 168.0], [657.0, 181.0], [643.0, 187.0]]
    if len(poly) < 3:
        continue

    word['poly'] = poly
    word['text'] = label
    annot.append(word)

  return annot

def preprocess_img_annot(image, annots, training = True):

  if training:
    transform_aug = transform_aug.to_deterministic()
    image, annots = transform(transform_aug, image, annots)
    image, annots = crop(image, annots)

  image, annots = resize(self.img_size, image, annots)
  
  annots = [ann for ann in annots if Polygon(ann['poly']).is_valid]

  return image, annots

class DataGenerator(tf.keras.utils.Sequence):
  """
  Class that generates batches of data
  Returns a batch of images and corresponding YOLO output label matrices
  """
  
  # list_IDs = ['img1.jpg', 'img2.jpg', etc]
  # label_IDs = ['img1.txt', 'img2.txt', etc]
  
  def __init__(self, in_folder, label_folder, list_IDs, label_IDs, batch_size = 16,
               img_size = 640, min_text_size = 8, shrink_ratio = 0.4, thresh_min = 0.3, thresh_max = 0.7, training = True):
      self.in_folder = in_folder
      self.label_folder = label_folder
      self.list_IDs = list_IDs
      self.label_IDs = label_IDs
      self.batch_size = batch_size
      self.img_size = img_size
      self.min_text_size = min_text_size
      self.shrink_ratio = shrink_ratio
      self.thresh_min = thresh_min
      self.thresh_max = thresh_max
      self.training = training
      # self.on_epoch_end()

  def on_epoch_end(self):
    """
    Updates indexes after each ephoc
    """
    self.indexes = np.arange(len(self.list_IDs))

  def __len__(self):
    return int(np.floor(len(self.list_IDs) / self.batch_size))

  def generate_annotations(self, anns):
    # initialize things
    gt = np.zeros((self.img_size, self.img_size), dtype=np.float32)
    mask = np.ones((self.img_size, self.img_size), dtype=np.float32)
    thresh_map = np.zeros((self.img_size, self.img_size), dtype=np.float32)
    thresh_mask = np.zeros((self.img_size, self.img_size), dtype=np.float32)

    # for each annotation in image
    for ann in anns:
      # get annotation polygon
      poly = np.array(ann['poly']) # [[641.0, 173.0], [656.0, 168.0], [657.0, 181.0], [643.0, 187.0]]
      # height and width
      height = max(poly[:, 1]) - min(poly[:, 1])
      width = max(poly[:, 0]) - min(poly[:, 0])
      # polygon
      polygon = Polygon(poly)
      # generate gt and mask
      if polygon.area < 1 or min(height, width) < self.min_text_size or ann['text'] == '###':
        cv2.fillPoly(mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
        continue
      else:
        distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
        subject = [tuple(l) for l in ann['poly']]
        padding = pyclipper.PyclipperOffset()
        padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
        shrinked = padding.Execute(-distance)
        if len(shrinked) == 0:
          cv2.fillPoly(mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
          continue
        else:
          shrinked = np.array(shrinked[0]).reshape(-1, 2)
          if shrinked.shape[0] > 2 and Polygon(shrinked).is_valid:
            cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
          else:
            cv2.fillPoly(mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
            continue
      # generate thresh map and thresh mask
      ann['poly'], thresh_map, thresh_mask = draw_thresh_map(ann['poly'], thresh_map, thresh_mask, shrink_ratio = self.shrink_ratio)

    thresh_map = thresh_map * (self.thresh_max - self.thresh_min) + self.thresh_min

    return gt, mask, thresh_map, thresh_mask

  def __data_generation(self, batch_x, batch_y):
    """
    Generates data containing batch_size samples
    This code is multi-core friendly
    """

    images = []
    bboxes = []
    masks = []
    thresh_maps = []
    thresh_masks = []
    batch_loss = np.zeros([len(batch_x), ], dtype=np.float32)
    
    # all annotations from that paths
    # all_anns = load_all_anns(gt_paths)


    # batch_images = np.zeros([batch_size, image_size, image_size, 3], dtype=np.float32) -> images
    # batch_gts = np.zeros([batch_size, image_size, image_size], dtype=np.float32) -> ?
    # batch_masks = np.zeros([batch_size, image_size, image_size], dtype=np.float32) -> ?
    # batch_thresh_maps = np.zeros([batch_size, image_size, image_size], dtype=np.float32) -> ?
    # batch_thresh_masks = np.zeros([batch_size, image_size, image_size], dtype=np.float32) -> ?
    # batch_loss = np.zeros([batch_size, ], dtype=np.float32)

    for i in range(0, len(batch_x)):
      # image path and label path
      img_path = self.in_folder + '/' + batch_x[i] # '/content/drive/My Drive/Colab Notebooks/ICDAR2015/Challenge4/ch4_training_images/img_1.jpg'
      annot_path = self.label_folder + '/' + batch_y[i] # '/content/drive/My Drive/Colab Notebooks/ICDAR2015/Challenge4/ch4_training_localization_transcription_gt/gt_img_1.txt'

      # read image image
      image = cv2.imread(img_path)
      # # specific image annotations
      anns = parse_annot(annot_path) # get image annotations [{'poly', 'text'}, {'poly', 'text'}, etc.}]
      
      # data augmentation if training + image resizing
      image, anns = preprocess_img_annot(image, anns, self.training)

      # get different types of annotations
      gt, mask, thresh_map, thresh_mask = self.generate_annotations(anns)

      # image scaling
      image = image.astype(np.float32)
      image[..., 0] -= mean[0]
      image[..., 1] -= mean[1]
      image[..., 2] -= mean[2]

      images.append(image)
      bboxes.append(gt)
      masks.append(mask)
      thresh_maps.append(thresh_map)
      thresh_masks.append(thresh_mask)

    inputs = [np.array(images), np.array(bboxes), np.array(masks), np.array(thresh_maps), np.array(thresh_masks)]
    outputs = batch_loss
    
    return inputs, outputs

  def __getitem__(self, index):
    """
    Generates one batch of data
    """

    # Gets images indexes for this batch
    # indexes = self.indexes[index * self.batch_size : (index + 1)* self.batch_size]

    batch_x = self.list_IDs[index * self.batch_size : (index+1) * self.batch_size]
    batch_y = self.label_IDs[index * self.batch_size : (index+1) * self.batch_size]

    X, y = self.__data_generation(batch_x, batch_y)

    return X, y