In [None]:
%matplotlib inline

Use Imagenet to tag your own photos
===

In this lab, we'll use an existing network (Xception) to tag our own photos.

Load an image
---

We'll use the Keras interface to the Pillow library to load up images. But then we'll have to manipulate them by hand.

Our following models will use values between $-1$ and $1$ instead of $0$ and $255$ for pixels.

- check what `x` looks like: `numpy`'s `.shape` attribute and a simple `print` call will do wonders!

- transform `x` so that its values shift from $[0, 255]$ to $[-1, 1]$.

In [None]:
from keras.preprocessing import image
import numpy as np

def process_img(path):
    img = image.load_img(path, target_size=(299, 299))
    img = image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img /= 255
    img -= 0.5
    img *= 2
    return img

def process_imgs(paths):
    return np.concatenate([process_img(path) for path in paths])

Import an existing model
---

Keras applications contains a lot of high performing models. In this lab, we'll use the Xception model.

- load the Xception network in the model variable.

In [None]:
from keras.applications.xception import Xception

model = Xception(weights='imagenet')

Plot the model
---

- plot the model thanks to Keras awesome visualization facilities :)

In [None]:
from keras.utils import plot_model

plot_model(model, to_file='xception-architecture.png')


Predit classes
---

We can now finally predict classes from our images thanks to our model! To do so, we'll use the helper function `decode_predictions` from `imagenet_utils` of Keras.

- define the predict_list function so that it takes as input preprocessed images and output a list of top 3 classes names.

In [None]:
from keras.applications.imagenet_utils import decode_predictions


def predict_list(imgs):
    preds = model.predict(imgs)
    preds = decode_predictions(preds, top=3)
    preds = [[name for (_, name, _) in top] for top in preds]
    return preds

imgs = process_imgs(['pics/bird.jpg', 'pics/car.jpg', 'pics/dogs.jpg'])

predict_list(imgs)

In [None]:
from keras.applications.xception import Xception
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.optimizers import Adam
import os
import numpy as np



def load_data():
    def list_dir(d):
        xs = os.listdir(d)
        return [os.path.join(d, x) for x in xs]
    def make_targets(x, y_class):
        y = np.zeros((x.shape[0], 2))
        y[:, y_class] = 1
        return y
    def make_dataset_part(d, y_class):
        x_train = process_imgs(list_dir(os.path.join(d, 'train')))
        x_test = process_imgs(list_dir(os.path.join(d, 'test')))
        return ((x_train, make_targets(x_train, y_class)),
                (x_test, make_targets(x_test, y_class)))
    def glue_parts(parts):
        (x_train, y_train), (x_test, y_test) = parts[0]
        for (x_tr, y_tr), (x_te, y_te) in parts[1:]:
            x_train = np.concatenate([x_train, x_tr], axis=0)
            y_train = np.concatenate([y_train, y_tr], axis=0)
            x_test = np.concatenate([x_test, x_te], axis=0)
            y_test = np.concatenate([y_test, y_te], axis=0)
        return (x_train, y_train), (x_test, y_test)
    part_bront = make_dataset_part('pics/brontosaurus', 0)
    part_steg = make_dataset_part('pics/stegosaurus', 1)
    return glue_parts([part_bront, part_steg])

base_model = Xception(weights='imagenet', include_top = False)
plot_model(base_model, to_file='xception-architecture-stripped.png')
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(64, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
for layer in base_model.layers:
    layer.trainable = False
adam = Adam(lr=0.001)
model.compile(optimizer=adam,
              loss='categorical_crossentropy',
              metrics=['accuracy'])
(x_train, y_train), (x_test, y_test) = load_data()
model.fit(x_train,
          y_train,
          epochs=3,
          batch_size=8,
          validation_data=(x_test, y_test))
