# Images

In [None]:
import pathlib
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url, 
                                   fname='flower_photos', 
                                   untar=True)
data_dir = pathlib.Path(data_dir)

In [None]:
image_count = len(list(data_dir.glob('*/*.jpg')))
image_count

### Loading images

In [None]:
dataset = tf.data.Dataset.list_files(str(data_dir/'*/*'))

In [None]:
for f in dataset.take(5):
  print(f.numpy())

In [None]:
def load_image(path):
    img_height = 180
    img_width = 180
    binary_format = tf.io.read_file(path)
    image = tf.image.decode_jpeg(binary_format, channels=3)
    return tf.image.resize(image, [img_height, img_width])

In [None]:
dataset = dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache().shuffle(buffer_size=1000) # cache only if the dataset fits in memory
dataset = dataset.batch(2)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
for f in dataset.take(5):
  print(f.numpy().shape)

In [None]:
images = next(iter(dataset))
images.shape

### Filters

Filters are 3-dimensional tensors. Tensorflow stores the different filter weights for a given pixel and channel in the last dimension. Therefore, the structure of a tensor of filters is:

```python
[rows, columns, channels, filters]
```

where channels are the filters in the input thensor for a given layer.

In [None]:
hfilter = tf.stack([tf.stack([tf.zeros(3), tf.ones(3), tf.zeros(3)]) for _ in range(3)])
hfilter

In [None]:
vfilter = tf.transpose(hfilter, [0, 2, 1])
vfilter

Given that the values of each filter (for a concrete pixel and channel) are in the last axis, we are goint to stack both filters in the last axis.

In [None]:
filters = tf.stack([hfilter, vfilter], axis=-1)
filters.shape

In [None]:
outputs = tf.nn.conv2d(images, filters, strides=1, padding="SAME")

In [None]:
plt.figure(figsize=(20,60))
ax = plt.subplot(1, 3, 1)
plt.axis("off")
plt.imshow(images[1].numpy().astype("uint8"))
for i in range(2):
  ax = plt.subplot(1, 3, i + 2)
  plt.imshow(outputs[1, :, :, i], cmap="gray")
  plt.axis("off")

### Pooling

In [None]:
outputs = tf.nn.max_pool(images, ksize=(1,2,2,1), strides=(1,2,2,1), padding='SAME')
images.shape, outputs.shape

In [None]:
plt.figure(figsize=(8, 8))
for i in range(2):
  ax = plt.subplot(2, 2, i*2 + 1)
  plt.imshow(images[i, :, :, i], cmap="gray")
  plt.axis("off")
  ax = plt.subplot(2, 2, i*2 + 2)
  plt.imshow(outputs[i, :, :, i], cmap="gray")
  plt.axis("off")

###  Depthwise pooling

Pooling along all the channels for each pixel.

In [None]:
outputs = tf.nn.max_pool(images, ksize=(1,1,1,3), strides=(1,1,1,3), padding='SAME')
images.shape, outputs.shape

In [None]:
plt.figure(figsize=(8, 8))
for i in range(2):
  ax = plt.subplot(2, 2, i*2 + 1)
  plt.imshow(images[i, :, :, i], cmap="gray")
  plt.axis("off")
  ax = plt.subplot(2, 2, i*2 + 2)
  plt.imshow(outputs[i, :, :, 0], cmap="gray")
  plt.axis("off")