In [1]:
# Dependencies
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as keras_backend
from tensorflow.keras import layers

import random
import sys
import time

from tensorflow.keras import utils
from tensorflow.keras import preprocessing
from os import listdir
from os.path import isdir, join

import numpy as np
import matplotlib.pyplot as plt

# Reproduction
seed = 333
np.random.seed(seed)

### Data processing

In [2]:
from os.path import join
project_path = r"C:\Users\ktub2\Dropbox\family\Kausthubh\UW Madison\Coursework\ECE 539\Project"

In [3]:
omni_path = join(project_path, r"omniglot-processed-train")
omni_train_datasets = dict()
omni_val_datasets = dict()

for name in listdir(omni_path):
    path = join(omni_path, name)
    if isdir(path):
        omni_train_datasets[name] = preprocessing.image_dataset_from_directory(path, label_mode='categorical',
                                                             color_mode='grayscale', batch_size=1000, image_size=(28,28),
                                                             seed=seed, validation_split=0.25, subset="training")
        omni_val_datasets[name] = preprocessing.image_dataset_from_directory(path, label_mode='categorical',
                                                             color_mode='grayscale', batch_size=1000, image_size=(28,28),
                                                             seed=seed, validation_split=0.25, subset="validation")

Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for vali

In [4]:
def dataset_to_tensors(dataset):
    xs = []
    ys = []
    for x, y in dataset:
        xs.extend(x)
        ys.extend(y)
    # xs = np.array(xs)
    # ys = np.array(ys)
    return tf.convert_to_tensor(xs), tf.convert_to_tensor(ys)

xs, ys = dataset_to_tensors(omni_val_datasets['Grantha'])
print(xs.shape, ys.shape)

(100, 28, 28, 1) (100, 20)


In [5]:
omni_train_data = []
omni_train_labels = []
omni_val_data = []
omni_val_labels = []
for i in omni_train_datasets.keys():
    xs,ys = dataset_to_tensors(omni_train_datasets[i])
    omni_train_data.append(xs)
    omni_train_labels.append(ys)
    print('Train dataset converted')
    xs,ys = dataset_to_tensors(omni_val_datasets[i])
    omni_val_data.append(xs)
    omni_val_labels.append(ys)
    print('Val dataset converted')

Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val datase

In [6]:
complete_datasets = []
for i in range(len(omni_train_data)):
    if(len(omni_val_data[i]) == 100):
        complete_datasets.append( (omni_train_data[i],omni_train_labels[i],omni_val_data[i],omni_val_labels[i]) )

### Model definition and relevant functions

In [7]:
def create_model():
    relu_initializer = tf.keras.initializers.HeNormal()
    softmax_initializer = tf.keras.initializers.GlorotNormal()
    
    inputs = keras.Input(shape=(28,28,1))
    for i in range(4):
        x = layers.Conv2D(64, (3,3), kernel_initializer=relu_initializer, bias_initializer='zeros', 
                          activation='relu', padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPool2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(20, kernel_initializer=softmax_initializer)(x)
    outputs = layers.Softmax()(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs, name='maml_model')
    return model

In [8]:
def compute_loss(model, x, y, loss_fn=keras.losses.CategoricalCrossentropy()):
    logits = model(x)
    loss = loss_fn(y, logits)
    return loss

### Meta Learning

In [16]:
lr = 0.001
alpha = 0.1
num_tasks = 15
epochs = 1000
inner_epochs = 3
i = 0
optimizer_inner = keras.optimizers.SGD(learning_rate=alpha)
optimizer_outer = keras.optimizers.Adam(learning_rate=lr)

model = create_model()
model.compile()

for epoch in range(epochs):
    all_meta_gradients = []
    total_loss = 0
    
    for task in random.sample(complete_datasets,num_tasks):
        train_data, train_labels, test_data, test_labels = task
        
        with tf.GradientTape() as test_tape:
            model_copy = create_model()
            model_copy.set_weights(model.get_weights())
            model_copy.compile(optimizer=optimizer_inner, loss=tf.keras.losses.CategoricalCrossentropy())
            inner_history = model_copy.fit(train_data, train_labels, epochs=inner_epochs, verbose=0)
            test_loss = compute_loss(model_copy, test_data, test_labels)
            total_loss += test_loss
            i += 1
        print('.',end='')
        meta_gradients = test_tape.gradient(test_loss, model_copy.trainable_weights)
        all_meta_gradients.append(meta_gradients)
      
    print('')
    print('Gradient check: '+str(all_meta_gradients[0][0][0,0,0,0].numpy()))
    sum_meta_gradients = all_meta_gradients[0]
    for i in range(1,len(all_meta_gradients)):
        for j in range(len(all_meta_gradients[i])):
            sum_meta_gradients[j] = sum_meta_gradients[j] + all_meta_gradients[i][j]
    optimizer_outer.apply_gradients(zip(sum_meta_gradients, model.trainable_weights))
    
    print('Meta Update: Epoch Number '+str(epoch))
    print('Avg. Loss: '+str(total_loss/(i+1)))

...............
Gradient check: 0.027422095
Meta Update: Epoch Number 0
Avg. Loss: tf.Tensor(14.2548895, shape=(), dtype=float32)
...............
Gradient check: 0.00023497517
Meta Update: Epoch Number 1
Avg. Loss: tf.Tensor(14.134669, shape=(), dtype=float32)
...............
Gradient check: 0.21215045
Meta Update: Epoch Number 2
Avg. Loss: tf.Tensor(14.164075, shape=(), dtype=float32)
...............
Gradient check: 0.11242764
Meta Update: Epoch Number 3
Avg. Loss: tf.Tensor(13.886563, shape=(), dtype=float32)
...............
Gradient check: 0.098595545
Meta Update: Epoch Number 4
Avg. Loss: tf.Tensor(13.903209, shape=(), dtype=float32)
...............
Gradient check: -0.017359382
Meta Update: Epoch Number 5
Avg. Loss: tf.Tensor(13.942461, shape=(), dtype=float32)
...............
Gradient check: 0.0
Meta Update: Epoch Number 6
Avg. Loss: tf.Tensor(13.764863, shape=(), dtype=float32)
...............
Gradient check: 0.025217641
Meta Update: Epoch Number 7
Avg. Loss: tf.Tensor(14.257921,

...............
Gradient check: 0.0
Meta Update: Epoch Number 64
Avg. Loss: tf.Tensor(13.813412, shape=(), dtype=float32)
...............
Gradient check: 0.07400082
Meta Update: Epoch Number 65
Avg. Loss: tf.Tensor(13.980189, shape=(), dtype=float32)
...............
Gradient check: -0.029282533
Meta Update: Epoch Number 66
Avg. Loss: tf.Tensor(13.754109, shape=(), dtype=float32)
...............
Gradient check: 0.05635599
Meta Update: Epoch Number 67
Avg. Loss: tf.Tensor(13.740127, shape=(), dtype=float32)
...............
Gradient check: 0.04662149
Meta Update: Epoch Number 68
Avg. Loss: tf.Tensor(13.729827, shape=(), dtype=float32)
...............
Gradient check: 0.09538406
Meta Update: Epoch Number 69
Avg. Loss: tf.Tensor(13.943166, shape=(), dtype=float32)
...............
Gradient check: 0.11516012
Meta Update: Epoch Number 70
Avg. Loss: tf.Tensor(13.357179, shape=(), dtype=float32)
...............
Gradient check: 0.00023881206
Meta Update: Epoch Number 71
Avg. Loss: tf.Tensor(13.481

...............
Gradient check: 0.027903926
Meta Update: Epoch Number 128
Avg. Loss: tf.Tensor(12.551792, shape=(), dtype=float32)
...............
Gradient check: -1.7033793e-05
Meta Update: Epoch Number 129
Avg. Loss: tf.Tensor(11.641217, shape=(), dtype=float32)
...............
Gradient check: -0.023108598
Meta Update: Epoch Number 130
Avg. Loss: tf.Tensor(10.862893, shape=(), dtype=float32)
...............
Gradient check: -0.03413912
Meta Update: Epoch Number 131
Avg. Loss: tf.Tensor(12.598458, shape=(), dtype=float32)
...............
Gradient check: -0.002644211
Meta Update: Epoch Number 132
Avg. Loss: tf.Tensor(12.017406, shape=(), dtype=float32)
...............
Gradient check: 0.04387481
Meta Update: Epoch Number 133
Avg. Loss: tf.Tensor(11.6700535, shape=(), dtype=float32)
...............
Gradient check: -0.013389387
Meta Update: Epoch Number 134
Avg. Loss: tf.Tensor(12.972873, shape=(), dtype=float32)
...............
Gradient check: -0.00042616017
Meta Update: Epoch Number 135


...............
Gradient check: 0.0038110262
Meta Update: Epoch Number 191
Avg. Loss: tf.Tensor(9.727454, shape=(), dtype=float32)
...............
Gradient check: 0.23437859
Meta Update: Epoch Number 192
Avg. Loss: tf.Tensor(9.57741, shape=(), dtype=float32)
...............
Gradient check: -0.20363198
Meta Update: Epoch Number 193
Avg. Loss: tf.Tensor(9.789383, shape=(), dtype=float32)
...............
Gradient check: -0.001863813
Meta Update: Epoch Number 194
Avg. Loss: tf.Tensor(9.222288, shape=(), dtype=float32)
...............
Gradient check: 0.15988407
Meta Update: Epoch Number 195
Avg. Loss: tf.Tensor(8.814728, shape=(), dtype=float32)
...............
Gradient check: 0.031027708
Meta Update: Epoch Number 196
Avg. Loss: tf.Tensor(9.06058, shape=(), dtype=float32)
...............
Gradient check: 0.0
Meta Update: Epoch Number 197
Avg. Loss: tf.Tensor(11.509184, shape=(), dtype=float32)
...............
Gradient check: -0.13522142
Meta Update: Epoch Number 198
Avg. Loss: tf.Tensor(10.4

...............
Gradient check: -0.06597738
Meta Update: Epoch Number 255
Avg. Loss: tf.Tensor(9.224429, shape=(), dtype=float32)
...............
Gradient check: 0.0850343
Meta Update: Epoch Number 256
Avg. Loss: tf.Tensor(8.50832, shape=(), dtype=float32)
...............
Gradient check: -0.05938028
Meta Update: Epoch Number 257
Avg. Loss: tf.Tensor(7.955424, shape=(), dtype=float32)
...............
Gradient check: -0.1376755
Meta Update: Epoch Number 258
Avg. Loss: tf.Tensor(9.601846, shape=(), dtype=float32)
...............
Gradient check: -0.105711944
Meta Update: Epoch Number 259
Avg. Loss: tf.Tensor(8.667279, shape=(), dtype=float32)
...............
Gradient check: -1.1022784e-05
Meta Update: Epoch Number 260
Avg. Loss: tf.Tensor(8.066893, shape=(), dtype=float32)
...............
Gradient check: -0.21784297
Meta Update: Epoch Number 261
Avg. Loss: tf.Tensor(8.906478, shape=(), dtype=float32)
...............
Gradient check: -0.14379759
Meta Update: Epoch Number 262
Avg. Loss: tf.Te

...............
Gradient check: -0.017673219
Meta Update: Epoch Number 319
Avg. Loss: tf.Tensor(13.707903, shape=(), dtype=float32)
...............
Gradient check: 0.0
Meta Update: Epoch Number 320
Avg. Loss: tf.Tensor(14.261331, shape=(), dtype=float32)
.......

KeyboardInterrupt: 

## Fine-tuning

In [17]:
omni_test_path = join(project_path, r"omniglot-processed-test")
omni_test_train_datasets = dict()
omni_test_val_datasets = dict()

for name in listdir(omni_test_path):
    path = join(omni_test_path, name)
    if isdir(path):
        omni_test_train_datasets[name] = preprocessing.image_dataset_from_directory(path, label_mode='categorical',
                                                             color_mode='grayscale', batch_size=1000, image_size=(28,28),
                                                             seed=seed, validation_split=0.75, subset="training")
        omni_test_val_datasets[name] = preprocessing.image_dataset_from_directory(path, label_mode='categorical',
                                                             color_mode='grayscale', batch_size=1000, image_size=(28,28),
                                                             seed=seed, validation_split=0.75, subset="validation")

Found 400 files belonging to 20 classes.
Using 100 files for training.
Found 400 files belonging to 20 classes.
Using 300 files for validation.
Found 400 files belonging to 20 classes.
Using 100 files for training.
Found 400 files belonging to 20 classes.
Using 300 files for validation.
Found 400 files belonging to 20 classes.
Using 100 files for training.
Found 400 files belonging to 20 classes.
Using 300 files for validation.


In [18]:
omni_test_train_data = []
omni_test_train_labels = []
omni_test_val_data = []
omni_test_val_labels = []

for i in omni_test_train_datasets.keys():
    xs,ys = dataset_to_tensors(omni_test_train_datasets[i])
    omni_test_train_data.append(xs)
    omni_test_train_labels.append(ys)
    print('Test train dataset converted')
    xs,ys = dataset_to_tensors(omni_test_val_datasets[i])
    omni_test_val_data.append(xs)
    omni_test_val_labels.append(ys)
    print('Test val dataset converted')

Test train dataset converted
Test val dataset converted
Test train dataset converted
Test val dataset converted
Test train dataset converted
Test val dataset converted


In [19]:
greek_fine_tuning_data = omni_test_train_data[0]
greek_fine_tuning_labels = omni_test_train_labels[0]

greek_val_data = omni_test_val_data[0]
greek_val_labels = omni_test_val_labels[0]

model_greek = create_model()
model_greek.set_weights(model.get_weights())

In [20]:
model_greek.compile(optimizer=keras.optimizers.Adam(learning_rate=0.005),
                    loss=tf.keras.losses.CategoricalCrossentropy(), metrics='accuracy')
model_greek.fit(greek_fine_tuning_data, greek_fine_tuning_labels, epochs=200, verbose=0)

<tensorflow.python.keras.callbacks.History at 0x2a6e0651f70>

In [21]:
model_greek.evaluate(greek_val_data, greek_val_labels)



[29.5891170501709, 0.34333333373069763]

In [22]:
model_greek_meow = create_model()
model_greek_meow.compile(optimizer=keras.optimizers.Adam(learning_rate=0.005),
                    loss=tf.keras.losses.CategoricalCrossentropy(), metrics='accuracy')
model_greek_meow.fit(greek_fine_tuning_data, greek_fine_tuning_labels, epochs=200, verbose=0)

<tensorflow.python.keras.callbacks.History at 0x2a6c04c4160>

In [23]:
model_greek_meow.evaluate(greek_val_data, greek_val_labels)



[22.00643539428711, 0.3233333230018616]