In [1]:
#e
import threading
import queue
import os
import math

import datasets as hfds

from minai.sampler import chunkify, Sampler, SamplerIter, SIO
from minai.datasets import SimpleDataset

In [2]:
#e
class CMTO: # CollatorMTOpts
    def __init__(self, 
                 sampler_iter: SamplerIter = None, 
                 getitem_func=None, 
                 collate_func=None, 
                 num_workers=os.cpu_count(), 
                 max_available_batches=2, 
                 chunk_size_per_thread=4,
                 is_hf_ds=False):
        
        self.sampler_iter = sampler_iter
        self.getitem_func = getitem_func
        self.collate_func = collate_func
        self.num_workers = num_workers
        self.max_available_batches = max_available_batches
        self.chunk_size_per_thread = chunk_size_per_thread
        self.is_hf_ds = is_hf_ds

        # Extra special flags, gotta be set very manually
        self.COLLATOR_DEBUG = False
        self.WORKERS_DEBUG = False

    def __repr__(self):
        return f"CMTO({self.sampler_iter.opts},\n"\
            f"  getitem_func={self.getitem_func.__qualname__},\n"\
            f"  collate_func={self.collate_func.__qualname__},\n"\
            f"  num_workers={self.num_workers},\n"\
            f"  max_available_batches={self.max_available_batches},\n"\
            f"  chunk_size_per_thread={self.chunk_size_per_thread},\n"\
            f"  is_hf_ds={self.is_hf_ds})"


class CollatorCTX: # Internal
    def __init__(self, opts: CMTO, work_chunk_size):
        self.DEBUG = opts.WORKERS_DEBUG

        self.sampler_iter = opts.sampler_iter
        self.getitem_func = opts.getitem_func
        self.collate_func = opts.collate_func
        self.max_available_batches = opts.max_available_batches
        self.is_hf_ds = opts.is_hf_ds
        self.work_chunk_size = work_chunk_size

        self.workers = []
        self.indices_queue = queue.SimpleQueue()
        self.results_queue = queue.SimpleQueue()
        
        self.request_batch_event = threading.BoundedSemaphore(self.max_available_batches)
        self.collated_batches = queue.SimpleQueue()
        self.exit_requested = False

        for _ in range(self.max_available_batches): self.request_batch_event.acquire()


def threadproc_worker(ctx: CollatorCTX):
    while 1:
        indices = ctx.indices_queue.get()
        if indices is None: break
        work_ind, indices = indices

        if ctx.is_hf_ds:
            results = ctx.getitem_func(indices)
        else:
            results = [ctx.getitem_func(i) for i in indices]

        ctx.results_queue.put((work_ind, results))
        del results

def threadproc_collator(ctx: CollatorCTX):
    while 1:
        if ctx.DEBUG: print("batches top")
        
        ctx.request_batch_event.acquire()
        if ctx.exit_requested: break
        
        if ctx.DEBUG: print("batches start")
        for batch in ctx.sampler_iter:
            work_chunks = chunkify(batch, ctx.work_chunk_size)
            if ctx.DEBUG: print("will queue", len(work_chunks))
            for ind, work_chunk in enumerate(work_chunks):
                if ctx.DEBUG: print(ind, "put", len(work_chunk), work_chunk)
                ctx.indices_queue.put((ind, work_chunk))

            work_chunks_results = []
            for _ in range(len(work_chunks)):
                ind, results = ctx.results_queue.get()
                if ctx.DEBUG: print(ind, "got", len(results), results)
                work_chunks_results.append((ind, results))

            work_chunks_results.sort(key=lambda x: x[0])

            sorted_work_chunks_results = []
            if ctx.is_hf_ds: # SOA
                sorted_work_chunks_results = [x[1] for x in work_chunks_results]
            else:
                for work_chunk_result in work_chunks_results: # AOS
                    sorted_work_chunks_results.extend(work_chunk_result[1])
            
            del work_chunks_results

            ctx.collated_batches.put(ctx.collate_func(sorted_work_chunks_results))
            ctx.request_batch_event.acquire()
            if ctx.exit_requested: break # Double break

        if ctx.exit_requested: break # Double break
        if ctx.DEBUG: print("batches done")
        for _ in range(ctx.max_available_batches-1): ctx.request_batch_event.acquire()
        ctx.collated_batches.put(None)

    #if ctx.DEBUG: print("collator exit")


class CollatorMT:
    def __init__(self, collatormt_opts: CMTO):
        self.DEBUG = collatormt_opts.COLLATOR_DEBUG
        self.opts = collatormt_opts

        batch_size = self.opts.sampler_iter.opts.batch_size
        chunk_size_per_thread = self.opts.chunk_size_per_thread
        num_workers = self.opts.num_workers
        work_chunk_size = max(chunk_size_per_thread, 
                              batch_size // (num_workers * chunk_size_per_thread))

        new_num_workers = min(max(1, math.ceil(batch_size / work_chunk_size)), num_workers)
        if new_num_workers != num_workers:
            print(f"Number of workers reduced from {num_workers} to {new_num_workers}, since "\
                  f"num_workers*work_chunk_size > batch_size ({num_workers}*{work_chunk_size} > {batch_size})")
            self.opts.num_workers = new_num_workers

        self.ctx = CollatorCTX(self.opts, work_chunk_size)

        threading.Thread(target=threadproc_collator, args=(self.ctx,)).start()
        for _ in range(self.opts.num_workers): threading.Thread(target=threadproc_worker, args=(self.ctx,)).start()

    def __del__(self):
        self.ctx.exit_requested = True
        self.ctx.request_batch_event.release()

        for _ in range(self.opts.num_workers): self.ctx.indices_queue.put(None)

    def __iter__(self):
        self.ctx.request_batch_event.release(self.ctx.max_available_batches)
        
        while 1:
            if self.DEBUG: print("-> iter request")
            collated = self.ctx.collated_batches.get()
            if collated is None: 
                if self.DEBUG: print("-> iter done")
                break

            if self.DEBUG: print("-> iter got")

            yield collated
            self.ctx.request_batch_event.release()

In [3]:
#e
def simple_collate_func(results):
    xs = [r[0] for r in results]
    ys = [r[1] for r in results]
    return xs, ys

In [4]:
import time

def ds_getitem(i):
    time.sleep(0.1)
    return (i, i**2)

sampler = Sampler(14)
batch_size = 13
collator = CollatorMT(CMTO(sampler.iter(SIO(batch_size)), ds_getitem, simple_collate_func, max_available_batches=2))

for collated in collator:
    print("--------",collated)

print("agane")

for collated in collator:
    print("--------",collated)

Number of workers reduced from 15 to 4, since num_workers*work_chunk_size > batch_size (15*4 > 13)
-------- ([10, 4, 9, 12, 11, 1, 8, 0, 3, 13, 2, 6, 5], [100, 16, 81, 144, 121, 1, 64, 0, 9, 169, 4, 36, 25])
-------- ([7, 2, 1, 3, 0, 4, 11, 2, 13, 13, 11, 13, 0], [49, 4, 1, 9, 0, 16, 121, 4, 169, 169, 121, 169, 0])
agane
-------- ([1, 7, 10, 4, 8, 12, 13, 11, 2, 3, 5, 9, 0], [1, 49, 100, 16, 64, 144, 169, 121, 4, 9, 25, 81, 0])
-------- ([6, 9, 5, 4, 7, 6, 10, 0, 2, 4, 0, 12, 13], [36, 81, 25, 16, 49, 36, 100, 0, 4, 16, 0, 144, 169])


In [5]:
#e
def hf_first_ds(dsd):
    return next(iter(dsd.values()))

class HFCollate:
    def __init__(self, ds: hfds.Dataset):
        if type(ds) is hfds.DatasetDict: ds = hf_first_ds(ds)
        self.features = ds.features.keys()
        self.features_len = len(self.features)

    def __call__(self, results):
        collated = [[] for _ in range(self.features_len)]
        for result in results:
            for i, feature in enumerate(self.features):
                collated[i].extend(result[feature])
        return collated
    
    def __repr__(self):
        return f"CollateHF(features={list(self.features)})"

In [6]:
import minai.datasets as minds
fashion_mnist = minds.hf_load(minds.HF_DATASETS.FASHION_MNIST)

Found cached dataset fashion_mnist (/home/nblzv/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48)


  0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
c = HFCollate(fashion_mnist)
print(c)
print(c([fashion_mnist["train"][[0, 1]], fashion_mnist["train"][[2]]]))

CollateHF(features=['image', 'label'])
[[<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7FCF8BA269D0>, <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7FCF8BAF1ED0>, <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7FCF88C30D50>], [9, 0, 0]]


In [8]:
#e
class DataLoader:
    def __init__(self, dataset, collatormt_opts:CMTO):
        self.dataset = dataset
        self.collator = CollatorMT(collatormt_opts)

    @classmethod
    def simple(cls, simple_ds: SimpleDataset, 
               sampler_iter_opts: SIO = None, 
               collatormt_opts: CMTO = None):
        sampler_iter_opts = sampler_iter_opts or SIO()
        collatormt_opts = collatormt_opts or CMTO()

        opts = collatormt_opts
        opts.sampler_iter = Sampler(len(simple_ds)).iter(sampler_iter_opts)
        opts.getitem_func = simple_ds.__getitem__
        opts.collate_func = simple_collate_func
        return cls(simple_ds, opts)

    @classmethod
    def hf(cls, hf_ds: hfds.Dataset,
           sampler_iter_opts: SIO = None, 
           collatormt_opts: CMTO = None):
        assert type(hf_ds) is hfds.Dataset, "Dataset expected (not DatasetDict)"
        sampler_iter_opts = sampler_iter_opts or SIO()
        collatormt_opts = collatormt_opts or CMTO()

        opts = collatormt_opts
        opts.sampler_iter = Sampler(len(hf_ds)).iter(sampler_iter_opts)
        opts.getitem_func = hf_ds.__getitem__
        opts.collate_func = HFCollate(hf_ds).__call__
        opts.is_hf_ds = True
        return cls(hf_ds, opts)
        
    def __iter__(self):
        yield from self.collator

    def __repr__(self):
        ctmo = self.collator.opts.__repr__()
        ctmo = ctmo.replace("\n", "\n  ")

        return f"DataLoader(ds={self.dataset},\n  {ctmo})"

In [9]:
ds = SimpleDataset(list(range(100)), list(range(0, -100, -1)))
dl = DataLoader.simple(ds, SIO(16, False))
ds, dl

Number of workers reduced from 15 to 4, since num_workers*work_chunk_size > batch_size (15*4 > 16)


(SimpleDataset(len=100, xs=int, ys=int),
 DataLoader(ds=SimpleDataset(len=100, xs=int, ys=int),
   CMTO(SIO(batch_size=16, shuffle=False, drop_last=False),
     getitem_func=SimpleDataset.__getitem__,
     collate_func=simple_collate_func,
     num_workers=4,
     max_available_batches=2,
     chunk_size_per_thread=4,
     is_hf_ds=False)))

In [10]:
for _ in range(2):
    for xs, ys in dl:
        print(xs, ys)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] [0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15]
[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] [-16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]
[32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] [-32, -33, -34, -35, -36, -37, -38, -39, -40, -41, -42, -43, -44, -45, -46, -47]
[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] [-48, -49, -50, -51, -52, -53, -54, -55, -56, -57, -58, -59, -60, -61, -62, -63]
[64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] [-64, -65, -66, -67, -68, -69, -70, -71, -72, -73, -74, -75, -76, -77, -78, -79]
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95] [-80, -81, -82, -83, -84, -85, -86, -87, -88, -89, -90, -91, -92, -93, -94, -95]
[96, 97, 98, 99, 17, 92, 70, 38, 9, 60, 82, 68, 44, 61, 9, 42] [-96, -97, -98, -99, -17, -92, -70, -38, -9, -60, -82, -68, -44, -61, -9, -42]
[0,

In [11]:
dl = DataLoader.hf(fashion_mnist["train"], SIO(9, False), CMTO(max_available_batches=1))
dl, next(iter(dl.collator))

Number of workers reduced from 15 to 3, since num_workers*work_chunk_size > batch_size (15*4 > 9)


(DataLoader(ds=Dataset({
     features: ['image', 'label'],
     num_rows: 60000
 }),
   CMTO(SIO(batch_size=9, shuffle=False, drop_last=False),
     getitem_func=Dataset.__getitem__,
     collate_func=HFCollate.__call__,
     num_workers=3,
     max_available_batches=1,
     chunk_size_per_thread=4,
     is_hf_ds=True)),
 [[<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
   <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>],
  [9, 0, 0, 3, 0, 2, 7, 2, 5]])

In [12]:
import z_export
z_export.export()

Processing minai_nbs/datasets.ipynb -> minai/minai/datasets.py  |  same contents, skipping, took 0.001s
Processing minai_nbs/sampler.ipynb -> minai/minai/sampler.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/setup+template.py -> minai/setup.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/__init__+template.py -> minai/minai/__init__.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/plot.ipynb -> minai/minai/plot.py  |  nothing to export, took 0.000s
Processing minai_nbs/mintils.py -> minai/minai/mintils.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/data.ipynb -> minai/minai/data.py  |  5 cells exported, took 0.000s 

All done... took 0.002s
  lib_name: minai
  author: nblzv
  version: 0.1.1
