In [1]:
import tensorflow as tf
import numpy as np
import os 

os.environ["CUDA_VISIBLE_DEVICES"] = ""

## Basics
We can create a basic dataset using from array like objsects (tf tensors, numpy arrays, lists, tuples, even pandas dataframes). The first dimension of the input will be removed and used as the dataset dimension.

In [2]:
dataset = tf.data.Dataset.from_tensor_slices(np.arange(6).reshape((3,2)))
dataset

<TensorSliceDataset shapes: (2,), types: tf.int64>

In [3]:
next(iter(dataset))

<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>

`Dataset` is an iterable so we can loop through its elements by creating an *iterator* and calling the *next* element. This can be easily (and transparently) done using a `for` loop:

In [4]:
for item in dataset:
    print(item)

tf.Tensor([0 1], shape=(2,), dtype=int64)
tf.Tensor([2 3], shape=(2,), dtype=int64)
tf.Tensor([4 5], shape=(2,), dtype=int64)


*Note*: `__iter__` works with python directives, thus only in Eager mode.

The element are tensors, if we are interested in their value only we can use the `as_numpy_iterator()`:

In [5]:
dataset = tf.data.Dataset.range(10) # equivalent to tf.data.Dataset.from_tensor_slices(np.arange(10))
list(dataset.as_numpy_iterator())

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

We can perform some actions on a Dataset, for example using the `shuffle()` method:

In [6]:
shuffled_dt = dataset.shuffle(3)
print(list(shuffled_dt.as_numpy_iterator()))
print(list(shuffled_dt.as_numpy_iterator()))

[1, 3, 0, 2, 5, 6, 8, 4, 9, 7]
[1, 3, 2, 0, 5, 6, 8, 7, 4, 9]


Every time a new iterator is created, the data are shuffled.

*Note*: The transformation does not happen in place but a new `Dataset` is returned.

We can also **batch** the elements of a dataset together:

In [7]:
batched_dt = shuffled_dt.batch(3)
print(list(batched_dt.as_numpy_iterator()))
print(list(batched_dt.as_numpy_iterator()))

[array([0, 1, 3]), array([4, 5, 2]), array([8, 9, 6]), array([7])]
[array([0, 2, 3]), array([5, 1, 4]), array([7, 8, 6]), array([9])]


We can apply a function element-wise with the `map()` method:

In [194]:
def fn(x): return tf.reduce_sum(x)

transformed_dt = batched_dt.map(fn)

print(list(transformed_dt.as_numpy_iterator()))

[3, 13, 20, 9]


With the `cache()` method we cache the data in memory. This is done after the first (complete) iteration:

In [52]:
cached_dataset = shuffled_dt.cache()
print(list(cached_dataset.take(4).as_numpy_iterator())) # Not cached yet
print(list(cached_dataset.take(4).as_numpy_iterator())) # Not cached yet
print(list(cached_dataset.as_numpy_iterator())) # Cached
print(list(cached_dataset.as_numpy_iterator())) # The same!

[1, 2, 3, 0]
[2, 3, 1, 0]
[0, 1, 4, 5, 6, 7, 3, 9, 8, 2]
[0, 1, 4, 5, 6, 7, 3, 9, 8, 2]


So... Shuffle after caching!

# Reading files

Let's create some csv

In [250]:
import os
data_dir = 'data/dataset/'
os.makedirs(data_dir, exist_ok=True)

for i in range(6):
    N_el = 3
    x1 = np.ones(shape=(N_el,1))*i
    x2 = np.arange(N_el).reshape((N_el,1))
    y  = np.random.rand(N_el,1)
    table = np.hstack([x1,x2,y])
    filename = os.path.join(data_dir,"data_%d.csv"%i)
    np.savetxt(filename, table, delimiter=",", header="x1,x2,y")

In [251]:
! ls {data_dir}

data_0.csv data_1.csv data_2.csv data_3.csv data_4.csv data_5.csv


In [253]:
! cat {filename}

# x1,x2,y
5.000000000000000000e+00,0.000000000000000000e+00,1.510211592071286635e-02
5.000000000000000000e+00,1.000000000000000000e+00,6.108007938288444461e-01
5.000000000000000000e+00,2.000000000000000000e+00,3.874759526742046489e-01


### Dataset of file paths
We can create a dataset iterating through the filenames..

In [254]:
filename_dt = tf.data.Dataset.list_files(data_dir+"*")
list(filename_dt.as_numpy_iterator())

[b'data/dataset/data_0.csv',
 b'data/dataset/data_4.csv',
 b'data/dataset/data_2.csv',
 b'data/dataset/data_1.csv',
 b'data/dataset/data_3.csv',
 b'data/dataset/data_5.csv']

### Dataset comprising lines from files
More importantly we can create a dataset that iterates through the lines of a **single file**:

In [255]:
print("Filename:", filename)
single_file_dt = tf.data.TextLineDataset(filename)
list(single_file_dt.as_numpy_iterator())

Filename: data/dataset/data_5.csv


[b'# x1,x2,y',
 b'5.000000000000000000e+00,0.000000000000000000e+00,1.510211592071286635e-02',
 b'5.000000000000000000e+00,1.000000000000000000e+00,6.108007938288444461e-01',
 b'5.000000000000000000e+00,2.000000000000000000e+00,3.874759526742046489e-01']

**Yuck!** There is the header! As it is the first line, we can skip it easily.

In [256]:
no_header_dt = single_file_dt.skip(1)
list(no_header_dt.as_numpy_iterator())

[b'5.000000000000000000e+00,0.000000000000000000e+00,1.510211592071286635e-02',
 b'5.000000000000000000e+00,1.000000000000000000e+00,6.108007938288444461e-01',
 b'5.000000000000000000e+00,2.000000000000000000e+00,3.874759526742046489e-01']

With the same API we can create a dataset iterating through lines from multiple files:

In [257]:
multi_file_dt = tf.data.TextLineDataset(filename_dt)
list(multi_file_dt.as_numpy_iterator())[:6]

[b'# x1,x2,y',
 b'1.000000000000000000e+00,0.000000000000000000e+00,1.761952109868600846e-01',
 b'1.000000000000000000e+00,1.000000000000000000e+00,2.937346730679679663e-02',
 b'1.000000000000000000e+00,2.000000000000000000e+00,3.545499146055860473e-01',
 b'# x1,x2,y',
 b'4.000000000000000000e+00,0.000000000000000000e+00,4.467986475199151597e-01']

This time the header is repeated several time across the dataset. To remove it we can filter out all lines starting with "#". 

In [258]:
no_header_dt = multi_file_dt.filter(lambda line: tf.not_equal(tf.strings.substr(line, 0, 1), "#"))
list(no_header_dt.as_numpy_iterator())[:7]

[b'1.000000000000000000e+00,0.000000000000000000e+00,1.761952109868600846e-01',
 b'1.000000000000000000e+00,1.000000000000000000e+00,2.937346730679679663e-02',
 b'1.000000000000000000e+00,2.000000000000000000e+00,3.545499146055860473e-01',
 b'2.000000000000000000e+00,0.000000000000000000e+00,1.222223120351877412e-01',
 b'2.000000000000000000e+00,1.000000000000000000e+00,8.895071154196134700e-01',
 b'2.000000000000000000e+00,2.000000000000000000e+00,6.112288628589750417e-01',
 b'3.000000000000000000e+00,0.000000000000000000e+00,6.146688086690119679e-01']

### Interleave

In the previous dataset the lines were iterated sequentially for each file. If were dealing with huge files this behaviour could affect our ability to randomise the content of the dataset. Using `interleave()` instead we can iterate through groups of files, take a block of lines from each file in the group, and stack them together.

In [259]:
interleave_dt = filename_dt.interleave(lambda x: tf.data.TextLineDataset(x).skip(1),
                                       cycle_length=3,
                                       block_length=2).shuffle(5)
list(interleave_dt.as_numpy_iterator())

[b'3.000000000000000000e+00,0.000000000000000000e+00,6.146688086690119679e-01',
 b'5.000000000000000000e+00,1.000000000000000000e+00,6.108007938288444461e-01',
 b'5.000000000000000000e+00,2.000000000000000000e+00,3.874759526742046489e-01',
 b'2.000000000000000000e+00,1.000000000000000000e+00,8.895071154196134700e-01',
 b'3.000000000000000000e+00,2.000000000000000000e+00,9.307465018518346067e-01',
 b'5.000000000000000000e+00,0.000000000000000000e+00,1.510211592071286635e-02',
 b'2.000000000000000000e+00,2.000000000000000000e+00,6.112288628589750417e-01',
 b'2.000000000000000000e+00,0.000000000000000000e+00,1.222223120351877412e-01',
 b'4.000000000000000000e+00,0.000000000000000000e+00,4.467986475199151597e-01',
 b'0.000000000000000000e+00,0.000000000000000000e+00,1.400779891641089625e-01',
 b'1.000000000000000000e+00,1.000000000000000000e+00,2.937346730679679663e-02',
 b'1.000000000000000000e+00,0.000000000000000000e+00,1.761952109868600846e-01',
 b'0.000000000000000000e+00,1.0000000000

### Decoding bytes
To read the above lines and decode the numeric data we should map the `tf.io.decode_csv()` function:

In [260]:
def decode_record(line):
    record = tf.io.decode_csv(line, record_defaults=[0.,0.,0.])
    x = tf.stack(record[:-1])
    y = tf.stack(record[-1])
    return x, y

decoded_dt = interleave_dt.map(decode_record)
list(decoded_dt.as_numpy_iterator())

[(array([2., 1.], dtype=float32), 0.8895071),
 (array([4., 1.], dtype=float32), 0.884114),
 (array([1., 0.], dtype=float32), 0.1761952),
 (array([2., 0.], dtype=float32), 0.12222231),
 (array([4., 2.], dtype=float32), 0.22054629),
 (array([4., 0.], dtype=float32), 0.44679865),
 (array([1., 2.], dtype=float32), 0.3545499),
 (array([0., 0.], dtype=float32), 0.140078),
 (array([2., 2.], dtype=float32), 0.6112289),
 (array([1., 1.], dtype=float32), 0.029373467),
 (array([5., 0.], dtype=float32), 0.015102115),
 (array([0., 2.], dtype=float32), 0.30262077),
 (array([5., 2.], dtype=float32), 0.38747597),
 (array([3., 2.], dtype=float32), 0.9307465),
 (array([3., 1.], dtype=float32), 0.8132554),
 (array([0., 1.], dtype=float32), 0.15995905),
 (array([5., 1.], dtype=float32), 0.6108008),
 (array([3., 0.], dtype=float32), 0.6146688)]