In [1]:
#e
import itertools
import random

In [2]:
bs = 4
num_items = 5
indices = tuple(range(num_items))

full_batches = (num_items) // bs
last_full_batch_index = full_batches*bs
leftover = num_items - last_full_batch_index

it = iter(indices)
batches = list(tuple(itertools.islice(it, bs)) for _ in range(full_batches))

if leftover:
    need_extra = bs - leftover
    last_batch = list(it) + random.sample(indices[:last_full_batch_index], need_extra)
    batches.append(tuple(last_batch))

batches

[(0, 1, 2, 3), (4, 0, 3, 1)]

In [3]:
#e
def chunkify(container, chunk_size):
    num_items = len(container)
    it = iter(container)

    full_chunks = num_items // chunk_size
    leftover = num_items - full_chunks*chunk_size

    chunks = [list(itertools.islice(it, chunk_size)) for _ in range(full_chunks)]
    if leftover: chunks.append(list(it))
    
    return chunks

In [4]:
chunk_size = 3
num_items = 5
indices = list(range(num_items))

list(chunkify(indices, chunk_size))

[[0, 1, 2], [3, 4]]

In [15]:
#e
class SamplerIter:
    def __init__(self, indices, batch_size, shuffle, drop_last):
        self.indices = indices
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

    def __iter__(self):
        if self.shuffle: random.shuffle(self.indices)

        batches = chunkify(self.indices, self.batch_size)
        need_extra = self.batch_size - len(batches[-1])
        if need_extra:
            if self.drop_last:
                batches.pop()
            else:
                batches[-1].extend(random.choices(self.indices, k=need_extra))

        yield from batches

class Sampler:
    def __init__(self, num_items):
        self.num_items = num_items

    def iter(self, batch_size, shuffle=False, drop_last=False):
        indices = list(range(self.num_items))
        return SamplerIter(indices, batch_size, shuffle, drop_last)

In [16]:
s = Sampler(6)

print("default\t\t", list(s.iter(4)))
print("drop\t\t", list(s.iter(4, drop_last=True)))
print("shuffle\t\t", list(s.iter(4, shuffle=True)))
print("shuffle+drop\t", list(s.iter(4, shuffle=True, drop_last=True)))

default		 [[0, 1, 2, 3], [4, 5, 4, 3]]
drop		 [[0, 1, 2, 3]]
shuffle		 [[1, 2, 3, 5], [0, 4, 0, 1]]
shuffle+drop	 [[0, 3, 2, 4]]


In [13]:
s = Sampler(1)
list(s.iter(16))

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

In [17]:
import z_export
z_export.export()

Processing minai_nbs/datasets.ipynb -> minai/minai/datasets.py
  2 cells exported, took 0.001s 
Processing minai_nbs/sampler.ipynb -> minai/minai/sampler.py
  3 cells exported, took 0.000s 
Processing minai_nbs/setup+template.py -> minai/setup.py
  same contents, skipping, took 0.000s
Processing minai_nbs/__init__+template.py -> minai/minai/__init__.py
  same contents, skipping, took 0.000s
Processing minai_nbs/plot.ipynb -> minai/minai/plot.py
  nothing to export, took 0.000s
Processing minai_nbs/mintils.py -> minai/minai/mintils.py
  same contents, skipping, took 0.000s
Processing minai_nbs/data.ipynb -> minai/minai/data.py
  3 cells exported, took 0.000s 

All done... took 0.003s
  lib_name: minai
  author: nblzv
  version: 0.1.1
