# Computational Exercise 7: Transfer Learning

**Please note that (optionally) this assignment may be completed in groups of 2 students.**

**It is recommended that you complete this exercise in Google Colab:**
- otherwise you may encounter errors
- this will allow you to access free GPU resources

---
In this exercise, we'll use transfer learning to train the convolutional neural network (CNN) used in our earlier activity, Inception v3, to classify images of flowers. This process is the one used by Esteva et al. to classify skin lesions.

As in that activity, we'll be importing `tensorflow` as well as `tensorflow_hub`, which will help us load an Inception v3 model that has already been trained (i.e. *pre*-trained) on ImageNet.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_hub as hub

Let's begin by using `tensorflow_hub` to load Inception v3. I found this URL by searching available models on [TensorFlow Hub](https://tfhub.dev). This step may take a minute or two. Once it's complete, we'll have the pre-trained Inception v3 model as a [SavedModel object](https://www.tensorflow.org/guide/saved_model) called `inception_v3`.

In [2]:
INCEPTION_V3_URL = "https://tfhub.dev/google/imagenet/inception_v3/classification/4"
inception_v3 = hub.load(INCEPTION_V3_URL)

In [3]:
# get flowers dataset

data_root = tf.keras.utils.get_file(
  'flower_photos',
  'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)

In [4]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(299, 299),
  batch_size=32
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(299, 299),
  batch_size=32
)

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.


In [None]:
for images, labels in train_ds:
    print(images)
    print(labels)
    break