In [14]:
import numpy as np
slices = True
augment_data = False

def extract_slice(x):
    return x

def form_correct_shape_array(x):
    return np.hstack(x)

def iterate_balanced_minibatches_multiclass(inputs, targets, full_batchsize, shuffle=True):

    assert len(inputs) == len(targets)

    all_targets = np.unique(targets)

    idxs = {this_targ: np.where(targets==this_targ)[0]
            for this_targ in all_targets}

    if shuffle:
        for key in idxs.keys():
            np.random.shuffle(idxs[key])

    per_class_batchsize = full_batchsize / len(all_targets)

    # find the largest class - this will define the epoch size
    examples_in_epoch = max([len(x) for _, x in idxs.iteritems()])

    # in each batch, new data from largest class is provided
    # data from other class is reused once it runs out
    for start_idx in range(0, examples_in_epoch - per_class_batchsize + 1, per_class_batchsize):

        # get indices for each of the excerpts, wrapping back to the beginning...
        excerpts = []
        for target, this_target_idxs in idxs.iteritems():
            excerpts.append(np.take(
                this_target_idxs, np.arange(start_idx, start_idx + per_class_batchsize), mode='wrap'))

        # reform the full balanced inputs and output
        full_idxs = np.hstack(excerpts)

        if slices:
            # take a single random slice from each of the training examples
            these_spectrograms = [inputs[xx] for xx in full_idxs]
            Xs = [extract_slice(x) for x in these_spectrograms]
            if augment_data:
                Xs = map(augment_slice, Xs)
            yield form_correct_shape_array(Xs), targets[full_idxs]
        else:
            if augment_data:
                Xs = map(augment_slice, inputs[full_idxs])
            else:
                Xs = inputs[full_idxs]
            yield Xs, targets[full_idxs]


In [15]:
inputs = np.arange(100)
targets = np.linspace(0, 10, 100).astype(np.int32)

print inputs.shape, targets.shape

x_iter = []
y_iter = []
for x, y in iterate_balanced_minibatches_multiclass(inputs, targets, 30):
    x_iter.append(x)
    y_iter.append(y)

(100,) (100,)


In [19]:
XX = np.hstack(x_iter)
YY = np.hstack(y_iter)
print XX, YY

print np.unique(XX).shape, inputs.shape
print np.unique(YY).shape, np.bincount(YY)

[ 9  5 17 14 28 29 37 39 40 46 57 52 61 65 71 77 82 80 92 96 99 99  7  6 10
 11 25 21 31 38 47 42 59 55 62 63 78 74 86 81 95 90 99 99  4  2 18 13 27 23
 30 36 44 41 53 50 64 66 76 73 89 83 98 93 99 99  3  8 12 15 24 20 35 32 45
 48 51 56 68 69 72 75 85 87 97 91 99 99  0  1 16 19 22 26 33 34 43 49 54 58
 60 67 79 70 88 84 94 92 99 99] [ 0  0  1  1  2  2  3  3  4  4  5  5  6  6  7  7  8  8  9  9 10 10  0  0  1
  1  2  2  3  3  4  4  5  5  6  6  7  7  8  8  9  9 10 10  0  0  1  1  2  2
  3  3  4  4  5  5  6  6  7  7  8  8  9  9 10 10  0  0  1  1  2  2  3  3  4
  4  5  5  6  6  7  7  8  8  9  9 10 10  0  0  1  1  2  2  3  3  4  4  5  5
  6  6  7  7  8  8  9  9 10 10]
(100,) (100,)
(11,) [10 10 10 10 10 10 10 10 10 10 10]
