## Imports

In [None]:
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import pickle

## Load data

In [None]:
[test_set_raw, valid_set_raw, train_set_raw], info = tfds.load("tf_flowers", split=["train[:10%]", "train[10%:25%]", "train[25%:]"], as_supervised=True, with_info=True)

## Pictures

In [None]:
class_names = info.features["label"].names
n_classes = info.features["label"].num_classes
dataset_size = info.splits["train"].num_examples

plt.figure(figsize=(12, 8))
index = 0
sample_images = train_set_raw.take(9)
for image, label in sample_images:
    index += 1
    plt.subplot(3, 3, index)
    plt.imshow(image)
    plt.title("Class: {}".format(class_names[label]))
    plt.axis("off")
    plt.show(block=False)

## 2.2 CNN

### 2.2.1 Scaling

In [None]:
def preprocess(image, label):
    resized_image = tf.image.resize(image, [224, 224])
    return resized_image, label

In [None]:
batch_size = 32
train_set = train_set_raw.map(preprocess).shuffle(dataset_size).batch(batch_size).prefetch(1)
valid_set = valid_set_raw.map(preprocess).batch(batch_size).prefetch(1)
test_set = test_set_raw.map(preprocess).batch(batch_size).prefetch(1)

In [None]:
plt.figure(figsize=(8, 8))
sample_batch = train_set.take(1)
print(sample_batch)
for X_batch, y_batch in sample_batch:
    for index in range(12):
        plt.subplot(3, 4, index + 1)
        plt.imshow(X_batch[index]/255.0)
        plt.title("Class: {}".format(class_names[y_batch[index]]))
        plt.axis("off")
        plt.show()

### 2.2.2 Model

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Rescaling(1./255),
    tf.keras.layers.Conv2D(filters=96, kernel_size=7, padding="same", activation="relu"),
    tf.keras.layers.MaxPool2D(pool_size=2),
    tf.keras.layers.Conv2D(filters=256, kernel_size=5, padding="same", activation="relu"),
    tf.keras.layers.MaxPool2D(pool_size=2),
    
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(256, activation="relu"),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(n_classes, activation="softmax")
])

In [None]:
model.compile(loss="sparse_categorical_crossentropy", optimizer=tf.keras.optimizers.SGD(lr=0.01), metrics=["accuracy"])
model.fit(train_set, epochs=10, validation_data=valid_set)

In [None]:
eval_tuple = (model.evaluate(train_set), model.evaluate(valid_set), model.evaluate(test_set))
print(model.summary())

In [None]:
with open('simple_cnn_acc.pkl', 'wb') as f:
    pickle.dump(eval_tuple, f)
    

## 2.3 Transfer learning

## 2.3.1 Prepare data

In [None]:
def preprocess(image, label):
    resized_image = tf.image.resize(image, [224, 224])
    final_image = tf.keras.applications.xception.preprocess_input(resized_image)
    return final_image, label

In [None]:
batch_size = 32

train_set = train_set_raw.map(preprocess).shuffle(dataset_size).batch(batch_size).prefetch(1)
valid_set = valid_set_raw.map(preprocess).batch(batch_size).prefetch(1)
test_set = test_set_raw.map(preprocess).batch(batch_size).prefetch(1)

In [None]:
plt.figure(figsize=(8, 8))
sample_batch = train_set.take(1)
for X_batch, y_batch in sample_batch:
    for index in range(12):
        plt.subplot(3, 4, index + 1)
        plt.imshow(X_batch[index] / 2 + 0.5)
        plt.title("Class: {}".format(class_names[y_batch[index]]))
        plt.axis("off")
plt.show()


### 2.3.2 Model

In [None]:
base_model = tf.keras.applications.xception.Xception(weights="imagenet", include_top=False)
# tf.keras.utils.plot_model(base_model)


Korzystając z API funkcyjnego Keras dodaj warstwy:  
• uśredniającą wartości wszystkich „pikseli”,  
• wyjściową, gęstą, odpowiednią dla problemu.

In [None]:
avg = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(n_classes, activation="softmax")(avg)
model = tf.keras.Model(inputs=base_model.input, outputs=output)


Learning

In [None]:
for layer in base_model.layers:
    layer.trainable = False

model.compile(loss="sparse_categorical_crossentropy", optimizer=tf.keras.optimizers.SGD(lr=0.2), metrics=["accuracy"])
model.fit(train_set, epochs=5, validation_data=valid_set)

In [None]:
base_model.trainable = True
model.compile(loss="sparse_categorical_crossentropy", optimizer=tf.keras.optimizers.SGD(lr=0.01, momentum=0.9), metrics=["accuracy"])
model.fit(train_set, epochs=10, validation_data=valid_set)

In [None]:
# save to pkl
eval_tuple = (model.evaluate(train_set), model.evaluate(valid_set), model.evaluate(test_set))
print(model.summary())

with open('xception_acc.pkl', 'wb') as f:
    pickle.dump(eval_tuple, f)