# Experiment with CIFAR 10
## Centralized learning for baseline

In [None]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import pickle
import time
import datetime

import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(0)
import tensorflow_hub as hub
import numpy as np
import pandas as pd

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
labelnames = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
labelnames

In [None]:
dataset_root = os.path.abspath(os.path.expanduser('dataset-cifar10'))
dataset_root = os.path.join(dataset_root, 'center')
dataset_root

In [None]:
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
target_size = (224, 224, 3)
datasets = dict()
for subset in ['train', 'test']:
    path = os.path.join(dataset_root, subset)
    datasets[subset] = image_generator.flow_from_directory(path, classes=labelnames,
                                                           target_size=target_size[:-1], shuffle=True,
                                                           follow_links=True)
datasets

In [None]:
mobilenet_v2 = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4'
inception_v3 = 'https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4'
feature_extractor_model = mobilenet_v2

In [None]:
feature_extractor_layer = hub.KerasLayer(feature_extractor_model, input_shape=target_size, trainable=True)
feature_extractor_layer

In [None]:
model = tf.keras.Sequential([
    feature_extractor_layer,
    tf.keras.layers.Dense(len(labelnames))
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()

In [None]:
epochs = 120
returns = model.fit(datasets['train'],
                    validation_data=datasets['test'],
                    epochs=epochs)

In [None]:
today = datetime.datetime.now().strftime('%y%m%d%H%M%S')
with open(f'result-cifar10-centralizedlearning-{today}.pkl', 'wb') as f:
    pickle.dump(returns.history, f)