In [1]:
import numpy as np
from nems.models.base import DataSet

input = {'x': np.random.rand(1000, 18)}
d = DataSet(input)

In [2]:
new_d = d.prepend_samples()
new_d.inputs['x'].shape

(1, 1000, 18)

In [3]:
# Shouldn't change
d.inputs['x'].shape

(1000, 18)

In [4]:
np.shares_memory(d.inputs['x'], new_d.inputs['x'])

True

In [5]:
sd = new_d.squeeze_samples()
sd.inputs['x'].shape

(1000, 18)

In [6]:
input2 = {'y': np.random.rand(1, 1000, 18), 'z': np.random.rand(10, 1000, 1)}
d2 = DataSet(input2)
for k, v in d2.inputs.items():
    print(f'{k}: {v.shape}')

y: (1, 1000, 18)
z: (10, 1000, 1)


In [7]:
d3 = d2.as_broadcasted_samples()
for k, v in d3.inputs.items():
    print(f'{k}: {v.shape}')

y: (10, 1000, 18)
z: (10, 1000, 1)


In [8]:
np.shares_memory(d3.inputs['y'][5,...], d2.inputs['y'])

True

In [9]:
def print_shapes(d, batch_size):
    batches = d.as_batches(batch_size)
    for i, batch in enumerate(batches):
        for k, v in batch.inputs.items():
            print(f'batch {i}| {k}: {v.shape}')
        print('='*30)

        samples = batch.as_samples()
        for j, sample in enumerate(samples):
            for k, v in sample.inputs.items():
                print(f'    sample {j}| {k}: {v.shape}')
            print('    ' + '-'*26)

In [10]:
print_shapes(d3, batch_size=2)

batch 0| y: (2, 1000, 18)
batch 0| z: (2, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
batch 1| y: (2, 1000, 18)
batch 1| z: (2, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
batch 2| y: (2, 1000, 18)
batch 2| z: (2, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
batch 3| y: (2, 1000, 18)
batch 3| z: (2, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
batch 4| y: (2, 1000, 18)
batch 4| z: (2, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
  

In [11]:
print_shapes(d3, batch_size=4)

batch 0| y: (2, 1000, 18)
batch 0| z: (2, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
batch 1| y: (4, 1000, 18)
batch 1| z: (4, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
    sample 2| y: (1000, 18)
    sample 2| z: (1000, 1)
    --------------------------
    sample 3| y: (1000, 18)
    sample 3| z: (1000, 1)
    --------------------------
batch 2| y: (4, 1000, 18)
batch 2| z: (4, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
    sample 2| y: (1000, 18)
    sample 2| z: (1000, 1)
    --------------------------
    sample 3| y: (1000, 18)
    sample 3| z: (1000, 1)
    --------------

In [12]:
print_shapes(d3, batch_size=None)

batch 0| y: (10, 1000, 18)
batch 0| z: (10, 1000, 1)
    sample 0| y: (1000, 18)
    sample 0| z: (1000, 1)
    --------------------------
    sample 1| y: (1000, 18)
    sample 1| z: (1000, 1)
    --------------------------
    sample 2| y: (1000, 18)
    sample 2| z: (1000, 1)
    --------------------------
    sample 3| y: (1000, 18)
    sample 3| z: (1000, 1)
    --------------------------
    sample 4| y: (1000, 18)
    sample 4| z: (1000, 1)
    --------------------------
    sample 5| y: (1000, 18)
    sample 5| z: (1000, 1)
    --------------------------
    sample 6| y: (1000, 18)
    sample 6| z: (1000, 1)
    --------------------------
    sample 7| y: (1000, 18)
    sample 7| z: (1000, 1)
    --------------------------
    sample 8| y: (1000, 18)
    sample 8| z: (1000, 1)
    --------------------------
    sample 9| y: (1000, 18)
    sample 9| z: (1000, 1)
    --------------------------
