In [1]:
import multiprocessing
import numpy as np
import os
from abc import abstractmethod
import cv2
import tensorflow as tf
import tqdm

In [2]:
import tensorpack

In [5]:
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost
from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from tensorpack.utils.stats import RatioCounter

### 1. A minimum example

In [31]:
from tensorpack import DataFlow, DataFromGenerator
from tensorpack.dataflow.parallel import PlasmaGetData, PlasmaPutData  # noqa
def my_data_loader():
  # load data from somewhere with Python, and yield them
  for k in range(100):
    yield [my_array, my_label]

df = DataFromGenerator(my_data_loader)

In [33]:
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
class MyDataFlow(DataFlow):
  def __iter__(self):
    # load data from somewhere with Python, and yield them
#     ds = PrintData(ds, num=2, max_list=2)
 
    for k in range(10):
      digit = np.random.rand(2, 2)
      label = np.random.randint(10)
      yield [digit, label]
      
df = MyDataFlow()
df = BatchData(df, 3)
# df = PlasmaGetData(df)
df.reset_state()
for datapoint in df:
    
    print("")
    print(datapoint[0][0], datapoint[1][0])
# print(df[0])


[[0.56543431 0.46338831]
 [0.28324179 0.32366504]] 7

[[0.4780968  0.03412478]
 [0.96600137 0.96960545]] 9

[[0.35308466 0.12336586]
 [0.4294453  0.78506172]] 0


### 2. A higher level demo

In [10]:
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
class MyDataFlow(DataFlow):
  def __iter__(self):
    # load data from somewhere with Python, and yield them
#     ds = PrintData(ds, num=2, max_list=2)
    for k in range(10):
      digit = np.random.rand(2, 2)
      label = np.random.randint(10)
      yield [digit, label]
      
df = MyDataFlow()
df = BatchData(df, 3)
# df = PlasmaGetData(df)
df.reset_state()
for datapoint in df:
    
    print("")
    print(datapoint[0][0], datapoint[1][0])

In [11]:
def test_orig(dir, name, augs, batch):
    ds = dataset.ILSVRC12(dir, name, shuffle=True)
    ds = AugmentImageComponent(ds, augs)

    ds = BatchData(ds, batch)
    # ds = PlasmaPutData(ds)
    ds = PrefetchDataZMQ(ds, 50, hwm=80)
    return ds

def test_lmdb_train(db, augs, batch):
    ds = LMDBData(db, shuffle=False)
    ds = LocallyShuffleData(ds, 50000)
    ds = PrefetchData(ds, 5000, 1)
    return ds

    ds = LMDBDataPoint(ds)

    def f(x):
        return cv2.imdecode(x, cv2.IMREAD_COLOR)
    ds = MapDataComponent(ds, f, 0)
    ds = AugmentImageComponent(ds, augs)

    ds = BatchData(ds, batch, use_list=True)
    # ds = PlasmaPutData(ds)
    ds = PrefetchDataZMQ(ds, 40, hwm=80)
    # ds = PlasmaGetData(ds)
    return ds

def test_inference(dir, name, augs, batch=128):
    ds = dataset.ILSVRC12Files(dir, name, shuffle=False, dir_structure='train')

    aug = imgaug.AugmentorList(augs)

    def mapf(dp):
        fname, cls = dp
        im = cv2.imread(fname, cv2.IMREAD_COLOR)
        im = aug.augment(im)
        return im, cls
    ds = ThreadedMapData(ds, 30, mapf, buffer_size=2000, strict=True)
    ds = BatchData(ds, batch)
    ds = PrefetchDataZMQ(ds, 1)
    return ds

In [None]:
class MyDataFlow(DataFlow):
  def __iter__(self, dir):
    # load data from somewhere with Python, and yield them
    video_set = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if file.endswith('.mp4'):
                video_set.append(os.path.splitext(file)[0])
                
      yield [digit, label]
    
class ProcessingDataFlow(DataFlow):
  def __init__(self, ds):
    self.ds = ds
    
  def reset_state(self):
    self.ds.reset_state()

  def __iter__(self):
    for datapoint in self.ds:
      # do something
      yield new_datapoint
# df = MyDataFlow()
# df.reset_state()
# for datapoint in df:
#     print(datapoint[0], datapoint[1]
df = MyDataFlow(dir='/my/data', shuffle=True)
# resize the image component of each datapoint
df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128
df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3)