In [1]:
import multiprocessing
import numpy as np
import os
from abc import abstractmethod
import cv2
import tensorflow as tf
import tqdm

In [2]:
import tensorpack

In [5]:
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost
from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from tensorpack.utils.stats import RatioCounter

In [8]:
from tensorpack import DataFlow, DataFromGenerator
def my_data_loader():
  # load data from somewhere with Python, and yield them
  for k in range(100):
    yield [my_array, my_label]

df = DataFromGenerator(my_data_loader)

In [9]:
class MyDataFlow(DataFlow):
  def __iter__(self):
    # load data from somewhere with Python, and yield them
    for k in range(100):
      digit = np.random.rand(28, 28)
      label = np.random.randint(10)
      yield [digit, label]
      
df = MyDataFlow()
df.reset_state()
for datapoint in df:
    print(datapoint[0], datapoint[1])

[[0.25598935 0.6957873  0.50170093 0.50373152 0.28187876 0.57597462
  0.64973292 0.99044841 0.25694408 0.24305311 0.21153445 0.35572204
  0.4397992  0.15008425 0.61202061 0.46300748 0.87588557 0.94524365
  0.46456376 0.9322684  0.55532813 0.49999921 0.5437869  0.55556233
  0.83654931 0.53840303 0.03508824 0.84452918]
 [0.99070756 0.52694362 0.73845929 0.4743786  0.03490928 0.52936005
  0.62534422 0.80830235 0.0784521  0.4686222  0.77583546 0.87180126
  0.31989658 0.93975098 0.59906338 0.72763588 0.15165074 0.51004738
  0.38441431 0.56077041 0.96396749 0.40177079 0.57099798 0.86629605
  0.31892959 0.19044568 0.82471358 0.84826746]
 [0.55495838 0.52308657 0.8497032  0.77200587 0.38275811 0.95118674
  0.37162102 0.06025261 0.86303291 0.87173111 0.90270528 0.42671068
  0.36745239 0.22259063 0.95190169 0.54362313 0.86139349 0.19563167
  0.90271021 0.2400999  0.32233287 0.71170456 0.04908137 0.82416976
  0.38870785 0.73860865 0.74648331 0.41747826]
 [0.09541808 0.32450299 0.59335447 0.091128

[[3.66083376e-02 2.63140354e-01 3.84287821e-01 9.63140236e-01
  8.28928577e-01 7.46134203e-01 1.23626219e-02 5.87892264e-01
  6.13462742e-01 4.85601578e-01 6.71125374e-01 1.01500262e-01
  4.64560078e-01 1.23200596e-01 9.77338205e-02 8.22096709e-02
  5.30703427e-01 9.47002170e-01 8.70670676e-01 1.86173308e-01
  1.99532239e-01 8.68835509e-01 5.07067259e-01 9.70506212e-01
  7.32611298e-01 7.24792446e-02 6.44015398e-01 7.14813804e-01]
 [7.41717201e-01 7.39329326e-01 7.64450570e-01 9.99912230e-01
  3.94050563e-01 6.18915214e-01 7.29324651e-01 3.50824906e-02
  7.32742905e-01 2.43681819e-02 3.78780591e-02 2.56545212e-01
  2.88619226e-01 6.41201787e-01 1.68009053e-01 4.59966074e-01
  4.87723048e-01 6.53382201e-01 4.08632199e-01 4.50917344e-02
  3.76311472e-01 6.10274561e-01 5.43108005e-01 8.57335988e-01
  5.36767912e-01 7.48644570e-01 9.62412329e-01 7.12886245e-01]
 [2.13596766e-01 5.43565317e-01 6.27964742e-01 6.78097514e-01
  8.20427353e-01 5.37992621e-01 2.42426666e-01 6.03520115e-02
  9.54

[[0.07367586 0.0393507  0.5550371  0.78139712 0.92849986 0.88224426
  0.92939809 0.4501708  0.96995701 0.35530977 0.87795074 0.81006257
  0.22384263 0.34801472 0.38557924 0.34426966 0.22561822 0.19000067
  0.79261418 0.203711   0.43195973 0.74305648 0.14798287 0.32891698
  0.8127112  0.27228682 0.91409972 0.74811365]
 [0.15309768 0.48855679 0.0611863  0.78770835 0.68024252 0.31700275
  0.52265086 0.87987435 0.62487796 0.70120394 0.50184137 0.77579039
  0.40743128 0.52417993 0.44147858 0.09169874 0.1811569  0.07273898
  0.15493011 0.93888351 0.76516051 0.24985691 0.53976917 0.69997916
  0.02480412 0.79934963 0.46090576 0.53129426]
 [0.26384802 0.36064661 0.77437704 0.18216468 0.55870697 0.72993145
  0.97523971 0.19674241 0.03498223 0.82597092 0.85582212 0.56960306
  0.47443231 0.79812491 0.80608103 0.41390146 0.89007419 0.6958358
  0.68491275 0.63759462 0.61459025 0.25671514 0.00534223 0.47362908
  0.14457146 0.18744173 0.78245206 0.2957234 ]
 [0.99666039 0.64466548 0.59864339 0.3973086

In [10]:
class ProcessingDataFlow(DataFlow):
  def __init__(self, ds):
    self.ds = ds
    
  def reset_state(self):
    self.ds.reset_state()

  def __iter__(self):
    for datapoint in self.ds:
      # do something
      yield new_datapoint

In [11]:
def test_orig(dir, name, augs, batch):
    ds = dataset.ILSVRC12(dir, name, shuffle=True)
    ds = AugmentImageComponent(ds, augs)

    ds = BatchData(ds, batch)
    # ds = PlasmaPutData(ds)
    ds = PrefetchDataZMQ(ds, 50, hwm=80)
    return ds

def test_lmdb_train(db, augs, batch):
    ds = LMDBData(db, shuffle=False)
    ds = LocallyShuffleData(ds, 50000)
    ds = PrefetchData(ds, 5000, 1)
    return ds

    ds = LMDBDataPoint(ds)

    def f(x):
        return cv2.imdecode(x, cv2.IMREAD_COLOR)
    ds = MapDataComponent(ds, f, 0)
    ds = AugmentImageComponent(ds, augs)

    ds = BatchData(ds, batch, use_list=True)
    # ds = PlasmaPutData(ds)
    ds = PrefetchDataZMQ(ds, 40, hwm=80)
    # ds = PlasmaGetData(ds)
    return ds

def test_inference(dir, name, augs, batch=128):
    ds = dataset.ILSVRC12Files(dir, name, shuffle=False, dir_structure='train')

    aug = imgaug.AugmentorList(augs)

    def mapf(dp):
        fname, cls = dp
        im = cv2.imread(fname, cv2.IMREAD_COLOR)
        im = aug.augment(im)
        return im, cls
    ds = ThreadedMapData(ds, 30, mapf, buffer_size=2000, strict=True)
    ds = BatchData(ds, batch)
    ds = PrefetchDataZMQ(ds, 1)
    return ds

In [None]:
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
class MyDataFlow(DataFlow):
  def __iter__(self):
    # load data from somewhere with Python, and yield them
    for k in range(100):
      digit = np.random.rand(28, 28)
      label = np.random.randint(10)
      yield [digit, label]
# df = MyDataFlow()
# df.reset_state()
# for datapoint in df:
#     print(datapoint[0], datapoint[1]
df = MyDataFlow(dir='/my/data', shuffle=True)
# resize the image component of each datapoint
df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128
df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3)