In [None]:
!pip install imgaug --upgrade
# !pip install albumentations --upgrade

In [2]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
# Mount Drive if on Colab
on_colab = True
if on_colab:
  from google.colab import drive
  drive.mount('/content/drive', force_remount=True)
  
  FOLDERNAME = 'chow chow 231n/trainer/'
  assert FOLDERNAME is not None, "[!] Enter the foldername."

  import sys
  sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

  %cd drive/My\ Drive/$FOLDERNAME/

In [68]:
import util
import model

import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import numpy as np
import imgaug as ia
import sklearn

from os.path import isfile 
from datetime import datetime
import pickle
import pathlib

In [74]:
# Define paths to data
d = '../data/'
dirs = {
    'imgs_val':     d + 'val_data/',
    'imgs_test':    d + "test_data/",
    'lmdb_val':  d + 'val_lmdb',
    'lmdb_test': d + 'test_lmdb'
}

files = {
    'val_pkl':      d + "lmdb_val_out.pkl",
    'test_pkl':     d + "lmdb_test_out.pkl",
    'classes_pkl':  d + "classes1M.pkl",
    'rvocab_pkl':   d + "vocab.txt",
    'pt_pkl':       d + 'partition.pkl', # image ids
    'lb_pkl':       d + 'labels.pkl', # labels
    'ingr_pkl':     d + 'ingredients.pkl', # ingredient names
    # 'cl_pkl':       d + 'classes.pkl', # class names
}

In [71]:
# Read LMDB into pkl
# TODO: add train dir, out file
if not (isfile(files['val_pkl']) and isfile(files['test_pkl'])):
    print("LMDBs processing into pkls")
    
    lmdb_val_out = util.read_lmdb(dirs['lmdb_val'])
    pickle.dump(lmdb_val_out, open(files['val_pkl'], 'wb'))
    
    lmdb_test_out = util.read_lmdb(dirs['lmdb_test'])
    pickle.dump(lmdb_test_out, open(files['test_pkl'], 'wb'))
else:
    print("LMDBs already saved as pkls")


LMDBs already saved as pkls


In [77]:
# Load data
if not (isfile(files['pt_pkl']) and isfile(files['lb_pkl']) and isfile(files['ingr_pkl'])):
    print("Datasets processing for the 1st time into pkls")
    partition, labels, ingrs = util.load_data(dirs, files)
    
    pickle.dump(partition, open(files['pt_pkl'], "wb"))
    pickle.dump(labels, open(files['lb_pkl'], "wb"))
    pickle.dump(ingrs, open(files['ingr_pkl'], "wb"))
    print("Datasets loaded")
else:
    partition = pickle.load(open(files['pt_pkl'], "rb"))
    labels = pickle.load(open(files['lb_pkl'], "rb"))
    ingrs = pickle.load(open(files['ingr_pkl'], "rb"))
    print("Datasets loaded from pkls")


Datasets loaded from pkls


In [81]:
# Create data generators
data_params = {'dim': (224,224),
              'batch_size': 2,
              'n_ingrs': len(ingrs),
              'n_channels': 3,
              'shuffle': True}

image_dir_tr = dirs['imgs_val']
image_dir_val = dirs['imgs_test']

gen_tr = model.DataGenerator(partition['train'], labels, image_dir_tr, **data_params)
gen_val = model.DataGenerator(partition['validation'], labels, image_dir_val, **data_params)

In [None]:
# Check generator and fetching speed
X, y = gen_tr.__getitem__(0)
for i, img in enumerate(X):
    ia.imshow(img)
    print(y[i])
    break

X, y = gen_val.__getitem__(0)
for i, img in enumerate(X):
    ia.imshow(img)
    print(y[i])
    break

In [134]:
# Name experiment
trial_name = 'test'
today = datetime.now().strftime('%m_%d')
trial_ct = len(list(pathlib.Path('logs/').rglob(today+'*/'))) + 1
exp_name =  '{}_{}_{}'.format(today, trial_ct, trial_name)
logdir = 'logs/' + exp_name
print('Running experiment ', exp_name)

# Stage 1 training
# Choose parameters for initializing the last layer
init_params = {
    'lr': 1e-2,
    'input_shape': (*data_params['dim'], data_params['n_channels']),
    'num_ingrs': len(ingrs),
    'eps': 10, 
}

# Stage 2 training
# Choose parameters for transfer learning on the pre-trained model
best_init_name = '{}_{}_{}'.format('08_27', '2', 'testtest') # select stage 1 weights to use
tl_params = {
    'lr': 1e-5,
    'num_unfreeze': 4, # num layers to unfreeze from top of pre-trained model
    'init_path': '{}/{}/{}'.format('logs', best_init_name, 'best_epoch_model.h5'),
    'eps': 100,
}

# Load model
is_init = True # choose stage of training
if is_init:
    print('Loaded stage 1 model')
    nn = model.create_init_model(init_params)
    eps = init_params['eps']
else:
    print('Loaded stage 2 model')
    nn = model.create_tl_model(tl_params)
    eps = tl_params['eps']

Running experiment  08_27_7_test
Loaded stage 1 model


In [114]:
# Define callbacks
lr_decay_cb = tf.keras.callbacks.LearningRateScheduler(
    lambda epoch: lr + 0.02 * (0.5 ** (1 + epoch)),
    verbose=True)

tensorboard_cb = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

class MetricsHistory(Callback):
    def __init__(self):
        self.best_score = -1
    
    def on_epoch_end(self, epoch, logs={}):
        tr_score = logs['f1_ml']
        val_score = logs['val_f1_ml']

        print('\n\nTrain F1: {:2f} \nVal F1:   {:2f}\n\n'.format(tr_score, val_score))
        
        if val_score > self.best_score:
            print("\nBetter validation score! Saving model ...")
            if not os.path.exists(logdir):
                os.makedirs(logdir)
            self.model.save(logdir + '/best_epoch_model.h5')
            self.best_score = val_score

metrics_cb = MetricsHistory()

In [None]:
# Visualize training
%tensorboard --logdir logs/

In [None]:
# Train
history = nn.fit(gen_tr, epochs=eps, validation_data=gen_val, 
                 callbacks=[tensorboard_cb, metrics_cb])

# use_multiprocessing=True, workers=6, 
# lr_decay_cb, validation_steps=val_steps, 

In [None]:
# Export
nn.save(logdir + "/last_epoch_model/")

In [None]:
# Evaluate predictions
X_s, y_s = gen_val.__getitem__(1)
for i, pred in enumerate(nn.predict(X_s)):    
    truth = y_s[i]
    img = X_s[i]
    
    true_class_id = np.where(truth == 1)[0]
    pred_class_id = np.where(pred > 1e-9)[0]
    
    print("Image ", i)
    print("True: ", [ingrs[x] for x in true_class_id])
    print("Predicted: ", [ingrs[x] for x in pred_class_id])
    print()
    print(true_class_id)
    print(pred_class_id)
    print("Predicted logit: ", pred[pred_class_id])
    ia.imshow(img)