In [1]:
from peleenet import PeleeNet
import tensorflow as tf
import tensorflow
import numpy as np
from PIL import Image
import pickle
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#from tensorflow.contrib.tpu.python.tpu import keras_support

In [2]:
def generator(X, y, batch_size, use_augmentation, shuffle, scale):
    if use_augmentation:
        base_gen = ImageDataGenerator(
            horizontal_flip=True,
            width_shift_range=4.0/32.0,
            height_shift_range=4.0/32.0)
    else:
        base_gen = ImageDataGenerator()
    for X_base, y_base in base_gen.flow(X, y, batch_size=batch_size, shuffle=shuffle):
        if scale != 1:
            X_batch = np.zeros((X_base.shape[0], X_base.shape[1]*scale,
                                X_base.shape[2]*scale, X_base.shape[3]), np.float32)
            for i in range(X_base.shape[0]):
                with Image.fromarray(X_base[i].astype(np.uint8)) as img:
                    img = img.resize((X_base.shape[1]*scale, X_base.shape[2]*scale), Image.LANCZOS)
                    X_batch[i] = np.asarray(img, np.float32) / 255.0
        else:
            X_batch = X_base / 255.0
        yield X_batch, y_base

In [3]:
def lr_scheduler(epoch):
    x = 0.4
    if epoch >= 70: x /= 5.0
    if epoch >= 120: x /= 5.0
    if epoch >= 170: x /= 5.0
    return x

In [6]:
def train(use_augmentation, use_stem_block):
    (X_train, y_train), (X_test, y_test) = tensorflow.keras.datasets.cifar10.load_data()
    y_train = tensorflow.keras.utils.to_categorical(y_train)
    y_test = tensorflow.keras.utils.to_categorical(y_test)

    # generator
    batch_size = 512
    scale = 7 if use_stem_block else 1
    train_gen = generator(X_train, y_train, batch_size=batch_size,
                          use_augmentation=use_augmentation, shuffle=True, scale=scale)
    test_gen = generator(X_test, y_test, batch_size=1000,
                         use_augmentation=False, shuffle=False, scale=scale)
    
    # network
    input_shape = (224,224,3) if use_stem_block else (32,32,3)
    model = PeleeNet(input_shape=(224,224,3),use_stem_block=True, num_init_channel=3, k=32, block_config=[3,4,8,6], out_layers = [128,256,512,704],bottleneck_width=[2,2,4,4], n_classes=10)
    model.compile(tensorflow.keras.optimizers.SGD(0.4, 0.9), "categorical_crossentropy", ["acc"])

    scheduler = tensorflow.keras.callbacks.LearningRateScheduler(lr_scheduler)
    hist = tensorflow.keras.callbacks.History()

    model.fit_generator(train_gen, steps_per_epoch=X_train.shape[0]//batch_size,
                        validation_data=test_gen, validation_steps=X_test.shape[0]//1000,
                        callbacks=[scheduler, hist], epochs=1, max_queue_size=1)
    
    history = hist.history
    with open(f"pelee_aug_{use_augmentation}_stem_{use_stem_block}.pkl", "wb") as fp:
        pickle.dump(history, fp)

In [None]:
train(True,True)

