In [1]:
#e
from minai.sampler import chunkify, Sampler, SamplerIter

import threading
import queue
import os
import math

In [2]:
#e
class CollatorCTX:
    def __init__(self, sampler_iter, work_chunk_size, getitem_func, collate_func, max_available_batches):
        self.DEBUG = 0
        self.exit_requested = False

        self.sampler_iter = sampler_iter
        self.work_chunk_size = work_chunk_size
        self.collate_func = collate_func
        self.max_available_batches = max_available_batches

        self.request_batch_event = threading.BoundedSemaphore(self.max_available_batches)
        self.collated_batches = queue.SimpleQueue()

        self.workers = []
        self.getitem_func = getitem_func
        self.indices_queue = queue.SimpleQueue()
        self.results_queue = queue.SimpleQueue()

        for _ in range(self.max_available_batches): self.request_batch_event.acquire()
        

def threadproc_worker(ctx: CollatorCTX):
    while 1:
        indices = ctx.indices_queue.get()
        if indices is None: break
        work_ind, indices = indices

        results = [ctx.getitem_func(i) for i in indices]
        ctx.results_queue.put((work_ind, results))
        del results

def threadproc_collator(ctx: CollatorCTX):
    while 1:
        if ctx.DEBUG: print("batches top")
        ctx.request_batch_event.acquire()
        if ctx.exit_requested: break
        if ctx.DEBUG: print("batches start")

        for batch in ctx.sampler_iter:
            work_chunks = chunkify(batch, ctx.work_chunk_size)
            if ctx.DEBUG: print("will queue", len(work_chunks))
            for ind, work_chunk in enumerate(work_chunks):
                if ctx.DEBUG: print("put", len(work_chunk), work_chunk)
                ctx.indices_queue.put((ind, work_chunk))

            work_chunks_results = []
            for _ in range(len(work_chunks)):
                ind, results = ctx.results_queue.get()
                if ctx.DEBUG: print("got", len(results), results)
                work_chunks_results.append((ind, results))

            sorted_work_chunks_results = []
            work_chunks_results.sort(key=lambda x: x[0])
            for work_chunk_result in work_chunks_results:
                sorted_work_chunks_results.extend(work_chunk_result[1])

            if ctx.DEBUG: print("collating", len(sorted_work_chunks_results), sorted_work_chunks_results)
            ctx.collated_batches.put(ctx.collate_func(sorted_work_chunks_results))
            ctx.request_batch_event.acquire()

        if ctx.DEBUG: print("batches done")
        for _ in range(ctx.max_available_batches-1): ctx.request_batch_event.acquire()
        ctx.collated_batches.put(None)


class CollatorMT:
    def __init__(self, sampler_iter: SamplerIter, getitem_func, collate_func, *, num_workers=os.cpu_count(), max_available_batches=2, chunk_size_per_thread=4):
        self.DEBUG = 0

        work_chunk_size = max(4, sampler_iter.batch_size // (num_workers * chunk_size_per_thread))
        if work_chunk_size * num_workers > sampler_iter.batch_size:
            new_num_workers = max(1, math.ceil(sampler_iter.batch_size / work_chunk_size))
            print(f"Number of workers reduced from {num_workers} to {new_num_workers}, since num_workers*work_chunk_size > batch_size ({num_workers}*{work_chunk_size} > {sampler_iter.batch_size})")
            num_workers = new_num_workers

        self.num_workers = num_workers
        self.ctx = CollatorCTX(sampler_iter, work_chunk_size, getitem_func, collate_func, max_available_batches)

        threading.Thread(target=threadproc_collator, args=(self.ctx,)).start()
        for _ in range(num_workers): threading.Thread(target=threadproc_worker, args=(self.ctx,)).start()

    def __del__(self):
        self.ctx.exit_requested = True
        self.ctx.request_batch_event.release()

        for _ in range(self.num_workers): self.ctx.indices_queue.put(None)

    def __iter__(self):
        self.ctx.request_batch_event.release(self.ctx.max_available_batches)
        
        while 1:
            if self.DEBUG: print("-> iter request")
            collated = self.ctx.collated_batches.get()
            if collated is None: 
                if self.DEBUG: print("-> iter done")
                break

            if self.DEBUG: print("-> iter got")

            yield collated
            self.ctx.request_batch_event.release()
            

In [4]:
import time

def ds_getitem(i):
    time.sleep(0.1)
    return (i, i**2)

def collate(results):
    xs = [r[0] for r in results]
    ys = [r[1] for r in results]
    return xs, ys

sampler = Sampler(14)
batch_size = 13
collator = CollatorMT(sampler.iter(batch_size), ds_getitem, collate, max_available_batches=2)

for collated in collator:
    print("--------",collated)

print("agane")

for collated in collator:
    print("--------",collated)

Number of workers reduced from 15 to 4, since num_workers*work_chunk_size > batch_size (15*4 > 13)
-------- ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144])
-------- ([13, 12, 4, 2, 12, 5, 3, 3, 3, 7, 11, 7, 12], [169, 144, 16, 4, 144, 25, 9, 9, 9, 49, 121, 49, 144])
agane
-------- ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144])
-------- ([13, 6, 0, 1, 12, 9, 12, 12, 6, 3, 9, 3, 0], [169, 36, 0, 1, 144, 81, 144, 144, 36, 9, 81, 9, 0])


In [5]:
#e
class Dataset:
    def __init__(self, xs, ys):
        self.xs = xs
        self.ys = ys
        assert len(xs) == len(ys)

    def __len__(self):
        return len(self.xs)
    
    def __getitem__(self, i):
        assert type(i) is int
        return self.xs[i], self.ys[i]
    
class DataLoader:
    def __init__(self, dataset, batch_size, shuffle, collate_func, **kwargs):
        self.sampler_iter = Sampler(len(dataset)).iter(batch_size, shuffle)
        self.collator = CollatorMT(self.sampler_iter, dataset.__getitem__, collate_func, **kwargs)

    def __iter__(self):
        yield from self.collator

In [7]:
ds = Dataset(list(range(100)), list(range(0, -100, -1)))
dl = DataLoader(ds, 16, False, collate)

for _ in range(2):
    for xs, ys in dl:
        print(xs, ys)


Number of workers reduced from 15 to 4, since num_workers*work_chunk_size > batch_size (15*4 > 16)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] [0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15]
[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] [-16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]
[32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] [-32, -33, -34, -35, -36, -37, -38, -39, -40, -41, -42, -43, -44, -45, -46, -47]
[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] [-48, -49, -50, -51, -52, -53, -54, -55, -56, -57, -58, -59, -60, -61, -62, -63]
[64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] [-64, -65, -66, -67, -68, -69, -70, -71, -72, -73, -74, -75, -76, -77, -78, -79]
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95] [-80, -81, -82, -83, -84, -85, -86, -87, -88, -89, -90, -91, -92, -93, -94, -95]
[96, 97, 98, 99, 13, 12, 52, 52, 39, 32, 12, 2

In [9]:
import z_export
z_export.export()

Processing minai_nbs/datasets.ipynb -> minai/minai/datasets.py
  same contents, skipping, took 0.001s
Processing minai_nbs/sampler.ipynb -> minai/minai/sampler.py
  same contents, skipping, took 0.000s
Processing minai_nbs/setup+template.py -> minai/setup.py
  same contents, skipping, took 0.000s
Processing minai_nbs/__init__+template.py -> minai/minai/__init__.py
  same contents, skipping, took 0.000s
Processing minai_nbs/plot.ipynb -> minai/minai/plot.py
  nothing to export, took 0.000s
Processing minai_nbs/mintils.py -> minai/minai/mintils.py
  same contents, skipping, took 0.000s
Processing minai_nbs/data.ipynb -> minai/minai/data.py
  same contents, skipping, took 0.000s

All done... took 0.003s
  lib_name: minai
  author: nblzv
  version: 0.1.1
