# Deep Learning for Galaxy Zoo challenge with Tensorflow and TPUs

[Galaxy Zoo](https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge) was a machine learning competition on Kaggle wherein competitors had to build a model that could predict galaxy morphology features based on its `JPG` image. In a [previous notebook](https://colab.research.google.com/drive/1i6ghXgyQPcyLn5Q9-c7QbIsMEpY4vqQQ), we used the `fastai` library to build a `ResNet18` model for this task. Due to the number of training images and the complexity of the model, we could not fit it for all the images. We instead had to rely on a small sample of images (about 10%) in order to train a model in reasonable amount of time.

With this notebook, the aim is to see if a combination of the TensorFlow package and *Tensor Processing Units* (TPUs) (specialised hardware built for machine learning by Google) can help us train `ResNet18` or even deeper models on the entire training dataset in a reasonable amount of time.

The `fastai` package is built on top of another deep learning package called *PyTorch*. Thus, with this notebook, we will be exploring some TensorFlow syntax (Keras to be specific) and understanding how to configure a notebook to run on TPUs.

## Step 0: Prerequisites
Let us begin with by mounting our Google Drive (where the data is stored) and loading the required packages.

In [None]:
from google.colab import files, drive
import sys

drive.mount("/content/gdrive", force_remount=True)
root_dir = "/content/gdrive/My Drive/Colab Notebooks/fastai_2022/data/galaxy-zoo-the-galaxy-challenge"
sys.path.append(root_dir)

Let us now do a quick check to ensure that we are able to access the contents of this directory in the notebook.

In [None]:
import os
os.listdir(root_dir)

In [None]:
from pathlib import Path
from sklearn.model_selection import train_test_split
from tensorflow.data import Dataset

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

## Step 1: Load data

We know that the training images are in the directory `images_training_rev1`. Let us begin by creating a `tf.data.Dataset` object that can access all the training images. 

`tf.data.Dataset` is an API that allows us to write input data pipelines effectively. It provides a convenient way for multiple functions/operations like:
- Create a data source from input data.
- Apply transformations to preprocess the data.
- Iterate over the data to process its elements.

This API can handle data of multiple types like images, text, etc.

In [None]:
root_dir: Path = Path(root_dir)
training_img_dir: Path = Path(root_dir/"images_training_rev1")

In [None]:
list_ds: Dataset = tf.data.Dataset.list_files(f"{str(training_img_dir)}/*.jpg", shuffle=False)

Let us shuffle the images. By doing so, we can evaluate the performance of our model as there should not be a significant discrepancy in its results with different shuffled images. With `reshuffle_each_iteration` set to `False`, we prevent shuffling with every epoch of our model training.

In [None]:
image_count: int = len(list(training_img_dir.glob("*.jpg")))
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)

Now we list the paths of first five images.

In [None]:
for f in list_ds.take(5):
    print(f.numpy())

Let us split the training images into training and validation sets.

In [None]:
val_size: int = int(image_count * val_factor)
train_ds: Dataset = list_ds.skip(val_size)
valid_ds: Dataset = list_ds.take(val_size)

In [None]:
for f in train_ds.take(5):
    print(f.numpy())

We know that the outputs for these images are available as columns in `training_solutions_rev1.csv`. Let us first load and peek at this table.

In [None]:
training_outputs: pd.DataFrame = pd.read_csv(root_dir/"training_solutions_rev1.csv")
training_outputs.head()

The values in the `GalaxyID` column match the filenames. So, let us modify these values to full paths for the corresponding files.

In [None]:
training_outputs = (training_outputs.assign(GalaxyImageFile=lambda x: [f"{str(training_img_dir)}/{id}.jpg" for id in x['GalaxyID'].to_list()])
                    .drop("GalaxyID", axis=1))
training_outputs.head()

Next, we split the data into training and validation sets.

In [None]:
val_factor: float = 0.2
x_train: list
y_train: np.ndarray
x_valid: list
y_valid: np.ndarray
x_train, x_valid, y_train, y_valid = train_test_split(
    training_outputs.loc[:, "GalaxyImageFile"].to_list(), 
    training_outputs.drop("GalaxyImageFile", axis=1).values, 
    test_size=val_factor, random_state=0)

Here `x_{train|test}` holds Galaxy IDs that correspond to image file names and `y_{train|test}` holds outputs for each image. Let us now create tensors for the IDs.

In [None]:
train_id_ds: Dataset = tf.data.Dataset.from_tensor_slices(x_train)
valid_id_ds: Dataset = tf.data.Dataset.from_tensor_slices(x_valid)

Now, let us define a function to load the images. Along with it, let us define the desired dimensions of the image to be used for training the model and the number of images to process in each batch.

In [None]:
batch_size: int = 32
img_height: int = 180
img_width: int = 180

In [None]:
def decode_img(img: tf.Tensor) -> tf.Tensor:
    # Convert the compressed string to a 3D uint8 tensor.
    img = tf.io.decode_jpeg(img, channels=3)
    # Resize the image to the desired size. The output is a 3-D float Tensor
    # of shape `[new_height, new_width, channels]`.
    return tf.image.resize(img, [img_height, img_width])

In [None]:
def load_image(fp: tf.Tensor) -> tf.Tensor:
    file_contents: tf.Tensor = tf.io.read_file(fp)
    img: tf.Tensor = decode_img(file_contents)
    return img

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

In [None]:
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_img_ds: Dataset = train_id_ds.map(load_image, num_parallel_calls=AUTOTUNE)
valid_img_ds: Dataset = valid_id_ds.map(load_image, num_parallel_calls=AUTOTUNE)

Next, we create tensors of the outputs.

In [None]:
train_output_ds: Dataset = tf.data.Dataset.from_tensor_slices(tf.cast(y_train, tf.float32))
valid_output_ds: Dataset = tf.data.Dataset.from_tensor_slices(tf.cast(y_valid, tf.float32))

Finally, let us combine the images and outputs into `(image, outputs)` pairs.

In [None]:
train_ds: Dataset = tf.data.Dataset.zip((train_img_ds, train_output_ds))
valid_ds: Dataset = tf.data.Dataset.zip((valid_img_ds, valid_output_ds))

Let us print some information about one of the input-output pairs.

In [None]:
for image, output in train_ds.take(1):
    print("Image shape: ", image.numpy().shape)
    print("Label: ", output.numpy())

*Note*: We tried to load the data using the code structure and syntax provided in the [Tensorflow tutorial](https://www.tensorflow.org/tutorials/load_data/images#using_tfdata_for_finer_control) but as the value passed by `map()` in each iteration is a *symbolic tensor*, we could not use it fetch the outputs from a dataframe directly.

Next, we configure the dataset for improved performance i.e.:
- Ensure it is well-shuffled
- Make it batched
- Ensure that batches are available as soon as possible.

In [None]:
def configure_for_performance(ds: Dataset) -> Dataset:
    ds = ds.cache()
    ds = ds.shuffle(buffer_size=1000)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds

In [None]:
train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(valid_ds)

As a final check, let us see one of the training images with its outputs.

In [None]:
image_batch, label_batch = next(iter(train_ds))


plt.figure(figsize=(10, 10))
for i in range(1):
    print(f"Output: {label_batch[i].numpy().tolist()}")
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image_batch[i].numpy().astype("uint8"))
    plt.axis("off")

## References
- [Load and preprocess images: Using tf.data for finer control](https://www.tensorflow.org/tutorials/load_data/images#using_tfdata_for_finer_control)
- [Get string value from a tensor in `Dataset.map()`](https://stackoverflow.com/questions/56122670/how-to-get-string-value-out-of-tf-tensor-which-dtype-is-string)
- [galaxy_zoo_Xception](https://www.kaggle.com/code/hironobukawaguchi/galaxy-zoo-xception)