In [2]:
import multiprocessing
import numpy as np
import os
from abc import abstractmethod
import cv2
import tensorflow as tf
import tqdm

In [3]:
import tensorpack

In [4]:
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost
from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from tensorpack.utils.stats import RatioCounter

### 1. A minimum example

In [4]:

# def my_data_loader():
#   # load data from somewhere with Python, and yield them
#   for k in range(100):
#     yield [my_array, my_label]

# df = DataFromGenerator(my_data_loader)

In [11]:
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
from tensorpack import DataFlow, DataFromGenerator
from tensorpack.dataflow.parallel import PlasmaGetData, PlasmaPutData  # noqa
class MyDataFlow(DataFlow):
  def __iter__(self):
    # load data from somewhere with Python, and yield them
#     ds = PrintData(ds, num=2, max_list=2)
 
    for k in range(10):
      digit = np.random.rand(2, 2)
      label = np.random.randint(10)
      yield [digit, label]
      
df = MyDataFlow()
df = BatchData(df, 3)
# df = PlasmaGetData(df)
df.reset_state()
for datapoint in df:
    
    print("")
    print(datapoint[1][0])
# print(df[0])


6

5

8


### 2. A higher level demo

In [None]:
import cv2
from time import time
from random import randint
from icdar_smart import load_annotations_solo, check_and_validate_polys
from tensorpack import DataFlow, DataFromGenerator
from tensorpack.dataflow.parallel import PlasmaGetData, PlasmaPutData  # noqa
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
ICDAR2013 = '/work/cascades/lxiaol9/ARC/EAST/data/ICDAR2013/'
video_dir = ICDAR2013+'/train/'
# we may not need to process 
data_dir = '/work/cascades/lxiaol9/ARC/EAST/data/pre-processed/'

def crop_all_random_seq(num_steps, data, score_maps, geo_maps, training_masks, crop_background=False, max_tries=20):
    '''
    make random crop from the input image
    :param im:
    :param polys:
    :param tags:
    :param crop_background:
    :param max_tries:
    :return:
    '''
    # crop and assemble
    score_map = score_maps[0, :, :]
    flag = False # indication of success
    input_size = 512
    px_size = 512 # real patch size
    py_size = 512 # real patch size
    frame_height, frame_width = score_map.shape
    if frame_height < 512:
        # for i in range(max_tries*15):
        images_new = np.zeros((num_steps, input_size, input_size, 3), dtype=np.uint8)
        scores_new = np.zeros((num_steps, input_size, input_size), dtype=np.uint8)
        geos_new = np.zeros((num_steps, input_size, input_size, 5), dtype=np.float32)
        tmasks_new = np.ones((num_steps, input_size, input_size), dtype=np.uint8)
        x = 0
        y = randint(2, frame_width-py_size-5)
            # print("Iterating over max_tries*15, {}".format(i))
            # if (sum(sum(score_map[:, y-2:y+3])) + sum(sum(score_map[:, y+py_size-2:y+py_size+3]))) == 0:
        flag = True
        images_new[:, :frame_height, :py_size, :] = data[:, :, y:y+py_size, :]
        scores_new[:, :frame_height, :py_size] = score_maps[:, :, y:y+py_size]
        geos_new[:, :frame_height, :py_size, :]= geo_maps[:, :, y:y+py_size, :]
        tmasks_new[:, :frame_height, :py_size] = training_masks[:, :, y:y+py_size]
        return images_new, scores_new, geos_new, tmasks_new
    else:
        new_h, new_w = frame_height, frame_width
        attempt_cnt = 0
        while attempt_cnt<max_tries:
            py_size = 512
            x = randint(2, new_h-px_size-5)
            y = randint(2, new_w-py_size-5)
            if score_map[x, y] > 0:
                continue
            attempt_cnt +=1
            return data[:, x:x+px_size, y:y+py_size, :], score_maps[:, x:x+px_size, y:y+py_size], geo_maps[:, x:x+px_size, y:y+py_size],  training_masks[:, x:x+px_size, y:y+py_size]
    # print('Cropping failed, change the strategy!!!')
    return None, None, None, None

# preprocessing function
class data_raw():
    def __init__(self, video_dir, datapre_path, is_print=False):
        self.video_set = []
        self.video_dir = video_dir
        self.datapre_path = datapre_path
        for root, dirs, files in os.walk(video_dir):
            for file in files:
                if file.endswith('.mp4'):
                    cap = cv2.VideoCapture(video_dir+file)
                    b = {}
                    b["video_name"]   = os.path.splitext(file)[0]
                    b["frame_width"]  = int(cap.get(3))
                    b["frame_height"] = int(cap.get(4))
                    b["frame_num"]    = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                    xml_solo_path = video_dir + b["video_name"]
                    polys_array_list, tags_array_list, _, _ = load_annotations_solo(xml_solo_path, 1,\
                                                                                  b["frame_num"])
                    b["polys_list"]   = polys_array_list
                    b["tags_list"]    = tags_array_list
                    self.video_set.append(b)
                    cap.release()
        self.total_num = len(self.video_set)
        if is_print:
            for i in range(self.total_num):
                print("{} with {} frames of w:{}, h:{}, polys of the first frame: {}, and tags: {}\n".format(
                                                self.video_set[i]["video_name"], 
                                                self.video_set[i]["frame_num"], 
                                                self.video_set[i]["frame_width"],
                                                self.video_set[i]["frame_height"],
                                                self.video_set[i]["polys_list"][0],
                                                self.video_set[i]["tags_list"][0]
                )
                     )
          
class MyDataFlow(DataFlow):
    def __init__(self, raw_data, num_steps, is_training):
        super(MyDataFlow, self).__init__()
        self.raw_data = raw_data
        self.num_steps = num_steps
        self.is_training = is_training
    def __iter__(self):
        raw_data =  self.raw_data 
        datapre_path = raw_data.datapre_path
        num_steps   = self.num_steps 
        is_training = self.is_training  
        if is_training:
            for k in range(100):
                i = randint(0, raw_data.total_num-3-1)
                j = randint(0, raw_data.video_set[i]["frame_num"] - num_steps -1)  
                new_h, new_w = raw_data.video_set[i]["frame_height"], raw_data.video_set[i]["frame_width"]
                # pre_data
                images = np.zeros([num_steps, new_h, new_w, 3], dtype=np.uint8)
                score_maps = np.zeros([num_steps, new_h, new_w], dtype=np.uint8)
                geo_maps = np.zeros([num_steps, new_h, new_w, 5], dtype=np.float32)
                training_masks = np.ones([num_steps, new_h, new_w], dtype=np.uint8)
                cap = cv2.VideoCapture(raw_data.video_dir+ raw_data.video_set[i]["video_name"] + '.mp4')
                for m in range(num_steps):
                    cap.set(1, j+m)
                    ret, image = cap.read()
                    text_polys, text_tags = raw_data.video_set[i]["polys_list"][j+m], \
                                            raw_data.video_set[i]["tags_list"][j+m]
                    text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (new_h, new_w))
                    if text_polys.shape[0] == 0:# means no boxes here
                        images[m, :, :, :] = image
                        continue
                    # load pre-processed data
                    score_name = datapre_path + 'score_maps/' + '{}'.format(raw_data.video_set[i]["video_name"])+ '/frame'+'{:04d}'.format(j+m)+'.npy'
                    geo_name = datapre_path + 'geo_maps/' + '{}'.format(raw_data.video_set[i]["video_name"])+'/frame'+'{:04d}'.format(j+m)+'.npy'
                    mask_name = datapre_path + 'training_masks/' + '{}'.format(raw_data.video_set[i]["video_name"])+'/frame'+'{:04d}'.format(j+m)+'.npy'
                    images[m, :, :, :]    = image
                    score_maps[m, :, :] = np.load(score_name)
                    geo_maps[m, :, :, :] = np.load(geo_name)
                    training_masks[m, :, :] = np.load(mask_name) 
                cap.release()
                imgs_c, score_maps_c, geo_maps_c, training_masks_c = crop_all_random_seq(num_steps, images, score_maps, geo_maps, training_masks)
#                 if if imgs_c is not None
                yield [imgs_c, score_maps_c, geo_maps_c, training_masks_c]
t1_start = time()
dr = data_raw(video_dir, data_dir, is_print=False)      
df = MyDataFlow(raw_data=dr, num_steps=5, is_training=True)
df = BatchData(df, 16)
df = PrefetchDataZMQ(df, 16)
# df = PlasmaGetData(df)
df.reset_state()
t1_end = time()
print("data loader preparation costs {} seconds".format(t1_end - t1_start))
for datapoint in df:
    print("now passed {} seconds".format(time() - t1_end))
    

[32m[0106 15:21:52 @parallel.py:311][0m [PrefetchDataZMQ] Will fork a dataflow more than one times. This assumes the datapoints are i.i.d.
[32m[0106 15:21:52 @argtools.py:146][0m [5m[31mWRN[0m "import prctl" failed! Install python-prctl so that processes can be cleaned with guarantee.
data loader preparation costs 2.1918246746063232 seconds
now passed 34.47462201118469 seconds
now passed 42.04209542274475 seconds
now passed 42.729400396347046 seconds
now passed 43.465091943740845 seconds
now passed 44.5822970867157 seconds
now passed 44.81189441680908 seconds
now passed 47.75291585922241 seconds
now passed 48.540337800979614 seconds
now passed 48.86101245880127 seconds
now passed 49.04390048980713 seconds
now passed 49.817816495895386 seconds
now passed 49.99043846130371 seconds
now passed 50.406837463378906 seconds
now passed 50.50053524971008 seconds
now passed 50.67884349822998 seconds
now passed 51.746458292007446 seconds
now passed 84.2175965309143 seconds
now passed 84.381

In [11]:
def test_orig(dir, name, augs, batch):
    ds = dataset.ILSVRC12(dir, name, shuffle=True)
    ds = AugmentImageComponent(ds, augs)

    ds = BatchData(ds, batch)
    # ds = PlasmaPutData(ds)
    ds = PrefetchDataZMQ(ds, 50, hwm=80)
    return ds

def test_lmdb_train(db, augs, batch):
    ds = LMDBData(db, shuffle=False)
    ds = LocallyShuffleData(ds, 50000)
    ds = PrefetchData(ds, 5000, 1)
    return ds

    ds = LMDBDataPoint(ds)

    def f(x):
        return cv2.imdecode(x, cv2.IMREAD_COLOR)
    ds = MapDataComponent(ds, f, 0)
    ds = AugmentImageComponent(ds, augs)

    ds = BatchData(ds, batch, use_list=True)
    # ds = PlasmaPutData(ds)
    ds = PrefetchDataZMQ(ds, 40, hwm=80)
    # ds = PlasmaGetData(ds)
    return ds

def test_inference(dir, name, augs, batch=128):
    ds = dataset.ILSVRC12Files(dir, name, shuffle=False, dir_structure='train')

    aug = imgaug.AugmentorList(augs)

    def mapf(dp):
        fname, cls = dp
        im = cv2.imread(fname, cv2.IMREAD_COLOR)
        im = aug.augment(im)
        return im, cls
    ds = ThreadedMapData(ds, 30, mapf, buffer_size=2000, strict=True)
    ds = BatchData(ds, batch)
    ds = PrefetchDataZMQ(ds, 1)
    return ds

In [None]:
class MyDataFlow(DataFlow):
  def __iter__(self, dir):
    # load data from somewhere with Python, and yield them
    video_set = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if file.endswith('.mp4'):
                video_set.append(os.path.splitext(file)[0])
                
      yield [digit, label]
    
class ProcessingDataFlow(DataFlow):
  def __init__(self, ds):
    self.ds = ds
    
  def reset_state(self):
    self.ds.reset_state()

  def __iter__(self):
    for datapoint in self.ds:
      # do something
      yield new_datapoint
# df = MyDataFlow()
# df.reset_state()
# for datapoint in df:
#     print(datapoint[0], datapoint[1]
df = MyDataFlow(dir='/my/data', shuffle=True)
# resize the image component of each datapoint
df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128
df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3)