This notebook illustrates some examples of Tensorflow datasets, including tips and gotchas.

In [1]:
import numpy as np
import os
import sys

In [4]:
import tensorflow as tf
tf.__version__


'2.1.0'

In [None]:
sys.path

In [3]:
from IPython.core.display import HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Basics to make sure we understand how TF2 datasets work

In [6]:
# Create some super-simple datasets of integers
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))

In [7]:
# Dump the first items in a dataset
for elem in inc_dataset.take(3):
  print(elem.numpy())


0
1
2


In [9]:
# Define the batch size
batched_dataset = inc_dataset.batch(8)

In [10]:
# Now ask for 15 of those batches.  Note that because there were only 100 elements, we only get 10 batches.
for batch in batched_dataset.take(15):
  print([arr.numpy() for arr in batch])

[0, 1, 2, 3, 4, 5, 6, 7]
[8, 9, 10, 11, 12, 13, 14, 15]
[16, 17, 18, 19, 20, 21, 22, 23]
[24, 25, 26, 27, 28, 29, 30, 31]
[32, 33, 34, 35, 36, 37, 38, 39]
[40, 41, 42, 43, 44, 45, 46, 47]
[48, 49, 50, 51, 52, 53, 54, 55]
[56, 57, 58, 59, 60, 61, 62, 63]
[64, 65, 66, 67, 68, 69, 70, 71]
[72, 73, 74, 75, 76, 77, 78, 79]
[80, 81, 82, 83, 84, 85, 86, 87]
[88, 89, 90, 91, 92, 93, 94, 95]
[96, 97, 98, 99]


In [11]:
# Now try the same thing with the .repeat method.  Notice that now we get 15 batches, and it repeats.
for batch in batched_dataset.repeat().take(15):
  print([arr.numpy() for arr in batch])

[0, 1, 2, 3, 4, 5, 6, 7]
[8, 9, 10, 11, 12, 13, 14, 15]
[16, 17, 18, 19, 20, 21, 22, 23]
[24, 25, 26, 27, 28, 29, 30, 31]
[32, 33, 34, 35, 36, 37, 38, 39]
[40, 41, 42, 43, 44, 45, 46, 47]
[48, 49, 50, 51, 52, 53, 54, 55]
[56, 57, 58, 59, 60, 61, 62, 63]
[64, 65, 66, 67, 68, 69, 70, 71]
[72, 73, 74, 75, 76, 77, 78, 79]
[80, 81, 82, 83, 84, 85, 86, 87]
[88, 89, 90, 91, 92, 93, 94, 95]
[96, 97, 98, 99]
[0, 1, 2, 3, 4, 5, 6, 7]
[8, 9, 10, 11, 12, 13, 14, 15]


## Data augmentation

It's important to understand how non-tensorflow operations work and when they execute.  
Consider two versions of an "augment" function that simply adds a random number to each element
of a dataset containing integers.

In [12]:
# Use numpy to generate the random number
def augment(x):
    result = x + np.round(np.random.uniform(0, 3))
    return result

In [13]:
# Use TF to generate the random number.
def tf_augment(x):
    # We pass in [] as the first argument to uniform() to indicate that we want a scalar back instead of a tensor.
    result = x + tf.round(tf.random.uniform([], 0, 3, dtype=tf.dtypes.int64))
    return result

In [19]:
# Confirm that we get a random number each time we call tf_augment.
for i in range(3):
    print(tf_augment(5))

tf.Tensor(6, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64)


Now let's insert these augment functions into a dataset pipeline.

In [22]:
# When you use the augment() which uses numpy to generate the random offset, you get 
# the same offset applied to all elements drawn from the dataset.  On each iteration you get 
# a different offset, but the same offset is applied within an iteration.  
# THIS PROBABLY ISN'T WHAT YOU WANTED.
for i in range(3):
    print("\nIteration {}".format(i))
    for elt in inc_dataset.map(augment).take(15):
      print(elt.numpy())
    


Iteration 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

Iteration 1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

Iteration 2
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16


In [23]:
# On the other hand, when you use the augment that uses TF to generate the random numbers,
# different random offsets are used for different elements in the dataset.  
# THIS IS PROBABLY WHAT YOU WANT.
for i in range(3):
    print("\nIteration {}".format(i))
    for elt in inc_dataset.map(tf_augment).take(15):
      print(elt.numpy())


Iteration 0
1
1
2
4
6
6
6
8
10
9
10
13
13
15
16

Iteration 1
1
1
3
5
5
5
7
7
9
9
11
11
14
15
14

Iteration 2
2
1
3
4
6
6
7
9
10
11
12
13
13
14
14


Conclusion: A tensorflow pipeline operates on tensors.  Anything else is evaluated once and held fixed.  
I.e. it's inserted in the graph as a tf.constant.

So the numpy.random.uniform() gets called exactly once, instead of on each element.

In [24]:
# Now let's test what happens if we oversample from the dataset.  
# Here we'll draw 105 elements from the 100-element dataset.  
# (Note that to do this we have to insert a repeat() call.)
# Observe that different random offsets are applied to the last 5 elements than
# the first 5 elements, even though their base value in inc_dataset is the same.
# This tells us that the random values do not repeat.  Which is a good thing.

for elt in inc_dataset.map(tf_augment).repeat().take(105):
    print(elt.numpy())

1
2
2
5
4
6
8
9
8
9
11
13
12
15
15
16
18
19
18
20
22
23
24
25
24
26
26
29
28
31
30
32
34
34
34
36
38
39
38
39
40
42
42
43
46
47
48
48
48
50
52
51
54
53
56
56
57
57
58
60
62
61
63
63
66
66
67
69
70
70
70
71
72
74
76
75
78
77
80
81
80
83
83
83
85
87
88
88
90
90
91
92
94
95
96
96
96
98
99
100
0
2
3
4
6


In [25]:
# Here we see that the same holds true even when you put batches into the picture.  
# When batches repeat, they do so with different random offsets applied to them.
# Again, this is a good thing.
for batch in inc_dataset.map(tf_augment).batch(8).repeat().take(15):
  print([arr.numpy() for arr in batch])

[1, 1, 3, 3, 4, 5, 7, 8]
[9, 11, 12, 13, 12, 13, 16, 16]
[17, 18, 18, 19, 21, 22, 22, 24]
[25, 26, 26, 28, 30, 30, 30, 31]
[32, 35, 34, 35, 37, 38, 39, 40]
[41, 43, 42, 44, 44, 45, 47, 47]
[50, 51, 51, 53, 53, 55, 55, 57]
[58, 57, 60, 60, 62, 61, 63, 65]
[66, 65, 66, 69, 70, 70, 70, 73]
[72, 73, 74, 76, 76, 79, 78, 81]
[80, 83, 84, 83, 86, 86, 88, 87]
[88, 89, 91, 93, 92, 94, 94, 95]
[97, 97, 98, 99]
[2, 2, 4, 5, 6, 6, 6, 7]
[9, 9, 10, 11, 13, 13, 14, 16]


## Shuffling

In [26]:
# Here is how you can set up shuffling so that each each batch is unique.  
# NOTE THAT THIS IS NOT HOW I NORMALLY THINK OF BATCHES.  Also note that this will be
# quite slow in practice if you reshuffle with a buffer that is the size of 100,000 elements.
for batch in dataset.shuffle(20, seed=5).batch(8).repeat().take(15):
  print([arr.numpy() for arr in batch])

[array([ 5, 12, 17, 22, 13, 11, 23, 15]), array([ -5, -12, -17, -22, -13, -11, -23, -15])]
[array([21, 20,  2, 10, 29,  4,  9,  0]), array([-21, -20,  -2, -10, -29,  -4,  -9,   0])]
[array([32, 27, 30,  1, 36, 35, 33, 24]), array([-32, -27, -30,  -1, -36, -35, -33, -24])]
[array([31,  6,  7,  8, 43, 44, 42, 19]), array([-31,  -6,  -7,  -8, -43, -44, -42, -19])]
[array([ 3, 28, 48, 53, 18, 41, 26, 16]), array([ -3, -28, -48, -53, -18, -41, -26, -16])]
[array([38, 14, 50, 62, 60, 47, 34, 64]), array([-38, -14, -50, -62, -60, -47, -34, -64])]
[array([56, 55, 57, 58, 71, 37, 49, 65]), array([-56, -55, -57, -58, -71, -37, -49, -65])]
[array([25, 66, 69, 63, 78, 70, 54, 82]), array([-25, -66, -69, -63, -78, -70, -54, -82])]
[array([40, 73, 80, 85, 77, 67, 83, 75]), array([-40, -73, -80, -85, -77, -67, -83, -75])]
[array([74, 90, 79, 92, 86, 96, 84, 46]), array([-74, -90, -79, -92, -86, -96, -84, -46])]
[array([95, 98, 89, 91, 51, 45, 59, 94]), array([-95, -98, -89, -91, -51, -45, -59, -94])]

In [27]:
# Here is how you set up shuffling so that each batch is the same but the 
# order of the batches is random. THIS IS HOW I USUALLY THINK OF THINGS.
for batch in dataset.batch(8).shuffle(20,seed=5).repeat().take(15):
  print([arr.numpy() for arr in batch])

[array([64, 65, 66, 67, 68, 69, 70, 71]), array([-64, -65, -66, -67, -68, -69, -70, -71])]
[array([96, 97, 98, 99]), array([-96, -97, -98, -99])]
[array([88, 89, 90, 91, 92, 93, 94, 95]), array([-88, -89, -90, -91, -92, -93, -94, -95])]
[array([ 8,  9, 10, 11, 12, 13, 14, 15]), array([ -8,  -9, -10, -11, -12, -13, -14, -15])]
[array([80, 81, 82, 83, 84, 85, 86, 87]), array([-80, -81, -82, -83, -84, -85, -86, -87])]
[array([56, 57, 58, 59, 60, 61, 62, 63]), array([-56, -57, -58, -59, -60, -61, -62, -63])]
[array([24, 25, 26, 27, 28, 29, 30, 31]), array([-24, -25, -26, -27, -28, -29, -30, -31])]
[array([40, 41, 42, 43, 44, 45, 46, 47]), array([-40, -41, -42, -43, -44, -45, -46, -47])]
[array([16, 17, 18, 19, 20, 21, 22, 23]), array([-16, -17, -18, -19, -20, -21, -22, -23])]
[array([48, 49, 50, 51, 52, 53, 54, 55]), array([-48, -49, -50, -51, -52, -53, -54, -55])]
[array([32, 33, 34, 35, 36, 37, 38, 39]), array([-32, -33, -34, -35, -36, -37, -38, -39])]
[array([0, 1, 2, 3, 4, 5, 6, 7]), a

In [28]:
# If you want to shuffle before you create batches, e.g. to create train/test splits, you 
# want to be sure that those batches don't get redone after each epoch.  If they do,
# then elements that were in training batches on epoch 1 may end up in validation batches on epoch 2.
# To avoid this, use the reshuffle_each_iteration flag
for batch in dataset.shuffle(20, seed=5, reshuffle_each_iteration=False).batch(8).repeat().take(15):
  print([arr.numpy() for arr in batch])

[array([11, 20,  6, 22, 16, 19, 15, 24]), array([-11, -20,  -6, -22, -16, -19, -15, -24])]
[array([ 7, 28,  5, 27,  1, 14, 10, 17]), array([ -7, -28,  -5, -27,  -1, -14, -10, -17])]
[array([25, 18, 32,  9,  4, 40,  0, 39]), array([-25, -18, -32,  -9,  -4, -40,   0, -39])]
[array([ 3, 31, 26, 37, 34, 29, 35, 13]), array([ -3, -31, -26, -37, -34, -29, -35, -13])]
[array([33, 45, 51,  8, 48, 38, 30,  2]), array([-33, -45, -51,  -8, -48, -38, -30,  -2])]
[array([46, 55, 41, 47, 63, 50, 42, 57]), array([-46, -55, -41, -47, -63, -50, -42, -57])]
[array([61, 52, 36, 21, 64, 70, 12, 72]), array([-61, -52, -36, -21, -64, -70, -12, -72])]
[array([23, 66, 59, 56, 44, 79, 77, 78]), array([-23, -66, -59, -56, -44, -79, -77, -78])]
[array([76, 49, 85, 65, 81, 74, 67, 86]), array([-76, -49, -85, -65, -81, -74, -67, -86])]
[array([75, 69, 93, 73, 95, 62, 90, 94]), array([-75, -69, -93, -73, -95, -62, -90, -94])]
[array([83, 88, 53, 92, 82, 43, 98, 71]), array([-83, -88, -53, -92, -82, -43, -98, -71])]

## A crude way to do train/val/test splits

I don't really recommend this except if you're seriously pressed for time.  
A far better way is to use either scikit-learn's `test_train_split` or 
`tf.data.experimental.sample_from_datasets` or `tf.data.experimental.rejection_resample`.

In [33]:
num_train = 75
num_val = 10
num_test = 15
train_ds = dataset.take(num_train)

In [32]:
nontrain_ds = dataset.skip(num_train)

In [35]:
for elem in train_ds:
    print([arr.numpy() for arr in elem])
    

[0, 0]
[1, -1]
[2, -2]
[3, -3]
[4, -4]
[5, -5]
[6, -6]
[7, -7]
[8, -8]
[9, -9]
[10, -10]
[11, -11]
[12, -12]
[13, -13]
[14, -14]
[15, -15]
[16, -16]
[17, -17]
[18, -18]
[19, -19]
[20, -20]
[21, -21]
[22, -22]
[23, -23]
[24, -24]
[25, -25]
[26, -26]
[27, -27]
[28, -28]
[29, -29]
[30, -30]
[31, -31]
[32, -32]
[33, -33]
[34, -34]
[35, -35]
[36, -36]
[37, -37]
[38, -38]
[39, -39]
[40, -40]
[41, -41]
[42, -42]
[43, -43]
[44, -44]
[45, -45]
[46, -46]
[47, -47]
[48, -48]
[49, -49]
[50, -50]
[51, -51]
[52, -52]
[53, -53]
[54, -54]
[55, -55]
[56, -56]
[57, -57]
[58, -58]
[59, -59]
[60, -60]
[61, -61]
[62, -62]
[63, -63]
[64, -64]
[65, -65]
[66, -66]
[67, -67]
[68, -68]
[69, -69]
[70, -70]
[71, -71]
[72, -72]
[73, -73]
[74, -74]


In [36]:
val_ds = nontrain_ds.take(num_val)
for elem in val_ds:
    print([arr.numpy() for arr in elem])

[75, -75]
[76, -76]
[77, -77]
[78, -78]
[79, -79]
[80, -80]
[81, -81]
[82, -82]
[83, -83]
[84, -84]


In [37]:
test_ds = nontrain_ds.skip(num_val)
for elem in test_ds:
    print([arr.numpy() for arr in elem])

[85, -85]
[86, -86]
[87, -87]
[88, -88]
[89, -89]
[90, -90]
[91, -91]
[92, -92]
[93, -93]
[94, -94]
[95, -95]
[96, -96]
[97, -97]
[98, -98]
[99, -99]
