In [1]:
#e
import itertools
import random

In [2]:
bs = 4
num_items = 5
indices = tuple(range(num_items))

full_batches = (num_items) // bs
last_full_batch_index = full_batches*bs
leftover = num_items - last_full_batch_index

it = iter(indices)
batches = list(tuple(itertools.islice(it, bs)) for _ in range(full_batches))

if leftover:
    need_extra = bs - leftover
    last_batch = list(it) + random.sample(indices[:last_full_batch_index], need_extra)
    batches.append(tuple(last_batch))

batches

[(0, 1, 2, 3), (4, 1, 0, 3)]

In [3]:
#e
def chunkify_calc_sizes(container, chunk_size):
    num_items = len(container)
    full_chunks = num_items // chunk_size
    leftover = num_items - full_chunks*chunk_size
    return full_chunks, leftover

def chunkify(container, chunk_size):
    #if len(container) <= chunk_size: return [container]
    
    full_chunks, leftover = chunkify_calc_sizes(container, chunk_size)

    it = iter(container)
    chunks = [list(itertools.islice(it, chunk_size)) for _ in range(full_chunks)]
    if leftover: chunks.append(list(it))
    
    return chunks

In [4]:
chunk_size = 3
num_items = 5
indices = list(range(num_items))

list(chunkify(indices, chunk_size))

[[0, 1, 2], [3, 4]]

In [5]:
#e
class SIOPTS: # SamplerIterOpts
    def __init__(self, batch_size=64, shuffle=False, drop_last=False):
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

    def __repr__(self):
        return f"SIO(batch_size={self.batch_size}, shuffle={self.shuffle}, drop_last={self.drop_last})"

class SamplerIter:
    def __init__(self, indices, sampler_iter_opts: SIOPTS):
        self.indices = indices
        self.opts = sampler_iter_opts

        full_chunks, leftover = chunkify_calc_sizes(self.indices, self.opts.batch_size)
        self.num_batches = full_chunks + (not self.opts.drop_last)*bool(leftover)

    def __repr__(self):
        return f"SamplerIter(num_batches={self.num_batches}, opts={self.opts})"

    def __iter__(self):
        if self.opts.shuffle: random.shuffle(self.indices)

        batches = chunkify(self.indices, self.opts.batch_size)
        need_extra = self.opts.batch_size - len(batches[-1])
        if need_extra:
            if self.opts.drop_last:
                batches.pop()
            else:
                batches[-1].extend(random.choices(self.indices, k=need_extra))

        yield from batches

class Sampler:
    def __init__(self, num_items):
        self.num_items = num_items

    def iter(self, sampler_iter_opts=None):
        sampler_iter_opts = sampler_iter_opts or SIOPTS()
        indices = list(range(self.num_items))
        return SamplerIter(indices, sampler_iter_opts)

In [6]:
s = Sampler(6)

In [7]:
print(vars(s.iter(SIOPTS(4))))
print("default\t\t", list(s.iter(SIOPTS(4))))
print("drop\t\t", list(s.iter(SIOPTS(4, drop_last=True))))
print("shuffle\t\t", list(s.iter(SIOPTS(4, shuffle=True))))
print("shuffle+drop\t", list(s.iter(SIOPTS(4, shuffle=True, drop_last=True))))

{'indices': [0, 1, 2, 3, 4, 5], 'opts': SIO(batch_size=4, shuffle=False, drop_last=False), 'num_batches': 2}
default		 [[0, 1, 2, 3], [4, 5, 3, 2]]
drop		 [[0, 1, 2, 3]]
shuffle		 [[5, 3, 2, 0], [1, 4, 1, 4]]
shuffle+drop	 [[1, 0, 3, 2]]


In [10]:
s = Sampler(1)
list(s.iter(SIOPTS(16)))

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

In [11]:
import z_export
z_export.export()

Processing minai_nbs/datasets.ipynb -> minai/minai/datasets.py  |  same contents, skipping, took 0.001s
Processing minai_nbs/sampler.ipynb -> minai/minai/sampler.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/setup+template.py -> minai/setup.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/__init__+template.py -> minai/minai/__init__.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/plot.ipynb -> minai/minai/plot.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/mintils.py -> minai/minai/mintils.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/data.ipynb -> minai/minai/data.py  |  same contents, skipping, took 0.000s

All done... took 0.003s
  lib_name: minai
  author: nblzv
  version: 0.1.1
