In [1]:
import numpy as np

In [71]:
class ClassWeightedBatchSampler:

    def __init__(self, class_weights, class_idxs, batch_size, n_batches):
        self.class_weights = class_weights
        self.class_idxs = [CircularList(idx) for idx in class_idxs]
        self.batch_size = batch_size
        self.n_batches = n_batches

        self.n_classes = len(self.class_weights)
        self.class_sizes = np.asarray([int(batch_size * w) for w in self.class_weights])

    def _get_batch(self, start_idxs):
        selected = []
        for c, size in enumerate(self.class_sizes):
            selected.extend(self.class_idxs[c][start_idxs[c]:start_idxs[c] + size])
        np.random.shuffle(selected)
        return selected

    def __iter__(self):
        [cidx.shuffle() for cidx in self.class_idxs]
        start_idxs = np.zeros(self.n_classes, dtype=int)
        for bidx in range(self.n_batches):
            yield self._get_batch(start_idxs)
            start_idxs += self.class_sizes

    def __len__(self):
        return self.n_batches


class CircularList:
    """
    Applies modulo function to indexing.
    """
    def __init__(self, items):
        self._items = items
        self._mod = len(self._items)

    def shuffle(self):
        np.random.shuffle(self._items)

    def __getitem__(self, key):
        if isinstance(key, slice):
            return [self[i] for i in range(key.start, key.stop)]
        return self._items[key % self._mod]


In [72]:
weights = [10 / 20, 4 / 20, 6 / 20]
idxs  = [
    np.arange(20),
    [21, 22, 23],
    np.arange(30, 40)
]
print(weights, idxs)

[0.5, 0.2, 0.3] [array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19]), [21, 22, 23], array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])]


In [73]:
sampler = ClassWeightedBatchSampler(weights, idxs, 20, 2)
len(sampler)

2

In [74]:
sampler.class_sizes

array([10,  4,  6])

In [75]:
for batch in sampler:
    print(batch)

[38, 39, 22, 23, 23, 18, 1, 7, 34, 21, 14, 0, 12, 2, 37, 33, 6, 3, 17, 32]
[39, 31, 13, 10, 11, 23, 4, 15, 36, 21, 5, 22, 35, 37, 8, 9, 30, 19, 16, 21]


In [69]:
class CircularList:

    def __init__(self, items):
        self._items = items
        self._mod = len(self._items)

    def __getitem__(self, key):
        if isinstance(key, slice):
            return [self[i] for i in range(key.start, key.stop)]
        return self._items[key % self._mod]

In [50]:
cl = CircularList([1, 2, 3])

In [51]:
cl[1]

2

In [52]:
cl[4]

2

In [53]:
cl[5]

3

In [54]:
cl[2:10]

[3, 1, 2, 3, 1, 2, 3, 1]

In [55]:
a = np.zeros((3))
a

array([0., 0., 0.])

In [57]:
a += np.array([1, 2, 3])
a

array([2., 4., 6.])