In [1]:
# Dependencies
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as keras_backend
from tensorflow.keras import layers

import random
import sys
import time

from tensorflow.keras import utils
from tensorflow.keras import preprocessing
from os import listdir
from os.path import isdir, join

import numpy as np
import matplotlib.pyplot as plt

# Reproduction
seed = 333
np.random.seed(seed)

### Data processing

In [2]:
from os.path import join
project_path = r"C:\Users\ktub2\Dropbox\family\Kausthubh\UW Madison\Coursework\ECE 539\Project"

In [3]:
omni_path = join(project_path, r"omniglot-processed-train")
omni_train_datasets = dict()
omni_val_datasets = dict()

for name in listdir(omni_path):
    path = join(omni_path, name)
    if isdir(path):
        omni_train_datasets[name] = preprocessing.image_dataset_from_directory(path, label_mode='categorical',
                                                             color_mode='grayscale', batch_size=1000, image_size=(28,28),
                                                             seed=seed, validation_split=0.25, subset="training")
        omni_val_datasets[name] = preprocessing.image_dataset_from_directory(path, label_mode='categorical',
                                                             color_mode='grayscale', batch_size=1000, image_size=(28,28),
                                                             seed=seed, validation_split=0.25, subset="validation")

Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for validation.
Found 400 files belonging to 20 classes.
Using 300 files for training.
Found 400 files belonging to 20 classes.
Using 100 files for vali

In [4]:
def dataset_to_tensors(dataset):
    xs = []
    ys = []
    for x, y in dataset:
        xs.extend(x)
        ys.extend(y)
    # xs = np.array(xs)
    # ys = np.array(ys)
    return tf.convert_to_tensor(xs), tf.convert_to_tensor(ys)

xs, ys = dataset_to_tensors(omni_val_datasets['Grantha'])
print(xs.shape, ys.shape)

(100, 28, 28, 1) (100, 20)


In [5]:
omni_train_data = []
omni_train_labels = []
omni_val_data = []
omni_val_labels = []
for i in omni_train_datasets.keys():
    xs,ys = dataset_to_tensors(omni_train_datasets[i])
    omni_train_data.append(xs)
    omni_train_labels.append(ys)
    print('Train dataset converted')
    xs,ys = dataset_to_tensors(omni_val_datasets[i])
    omni_val_data.append(xs)
    omni_val_labels.append(ys)
    print('Val dataset converted')

Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val dataset converted
Train dataset converted
Val datase

In [9]:
complete_datasets = []
for i in range(len(omni_train_data)):
    if(len(omni_val_data[i]) == 100):
        complete_datasets.append( (omni_train_data[i],omni_train_labels[i],omni_val_data[i],omni_val_labels[i]) )

### Model definition and relevant functions

In [70]:
def create_model():
    relu_initializer = tf.keras.initializers.HeNormal()
    softmax_initializer = tf.keras.initializers.GlorotNormal()
    
    inputs = keras.Input(shape=(28,28,1))
    for i in range(4):
        x = layers.Conv2D(64, (3,3), (2,2), kernel_initializer=relu_initializer, bias_initializer='zeros', 
                          activation='relu', padding='same')(inputs)
        x = layers.BatchNormalization()(x)
#         x = layers.MaxPool2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(20, kernel_initializer=softmax_initializer)(x)
    outputs = layers.Softmax()(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs, name='maml_model')
    return model

In [71]:
def compute_loss(model, x, y, loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True)):
    logits = model(x)
    loss = loss_fn(y, logits)
    return loss

### Meta Learning

In [68]:
lr = 0.001
alpha = 0.05
num_tasks = 2
epochs = 20
optimizer = keras.optimizers.Adam(learning_rate=lr)

model = create_model()

for epoch in range(epochs):
    all_meta_gradients = []
    total_loss = 0
    i = 0
    
    for task in random.sample(complete_datasets,num_tasks):
        train_data, train_labels, test_data, test_labels = task
        
        with tf.GradientTape() as test_tape:
            with tf.GradientTape() as train_tape:
                train_loss = compute_loss(model, train_data, train_labels)
            inner_gradients = train_tape.gradient(train_loss, model.trainable_weights)
            
            model_copy = create_model()
            model_copy.set_weights(model.get_weights())
            k = 0
            for i in range(len(model_copy.layers)):
                if(model_copy.layers[i].trainable):
                    if (hasattr(model_copy.layers[i], 'kernel')):
                        model_copy.layers[i].kernel = tf.subtract(model.layers[i].kernel,
                                        tf.multiply(alpha, inner_gradients[k]))
                        model_copy.layers[i].bias = tf.subtract(model.layers[i].bias,
                                        tf.multiply(alpha, inner_gradients[k+1]))
                        k += 2
                    elif (hasattr(model_copy.layers[i], 'gamma')):
                        model_copy.layers[i].gamma = tf.subtract(model.layers[i].gamma,
                                        tf.multiply(alpha, inner_gradients[k]))
                        model_copy.layers[i].beta = tf.subtract(model.layers[i].beta,
                                        tf.multiply(alpha, inner_gradients[k+1]))
                        k += 2
            print('.',end='')
            
            test_loss = compute_loss(model_copy, test_data, test_labels)
            total_loss += tf.reduce_sum(test_loss)
            i += 1
            
        meta_gradients = test_tape.gradient(test_loss, model.trainable_weights)
        all_meta_gradients.append(meta_gradients)
      
    print('')
    print('Gradient check: '+str(all_meta_gradients[0][0][0,0,0,0].numpy()))
    avg_meta_gradients = all_meta_gradients[0]
    for i in range(1,len(all_meta_gradients)):
        for j in range(len(all_meta_gradients[i])):
            avg_meta_gradients[j] = avg_meta_gradients[j] + all_meta_gradients[i][j]
    avg_meta_gradients = [avg_meta_gradients[i]/num_tasks for i in range(len(avg_meta_gradients))]
    optimizer.apply_gradients(zip(avg_meta_gradients, model.trainable_weights))
    
    print('Meta Update: Epoch Number '+str(epoch))
    print('Avg. Loss: '+str(total_loss/i))

..
Gradient check: 0.0
Meta Update: Epoch Number 0
Avg. Loss: tf.Tensor(5.996309, shape=(), dtype=float32)
..
Gradient check: -1.3927067e-35
Meta Update: Epoch Number 1
Avg. Loss: tf.Tensor(5.996309, shape=(), dtype=float32)
..
Gradient check: 0.0
Meta Update: Epoch Number 2
Avg. Loss: tf.Tensor(5.996309, shape=(), dtype=float32)
..
Gradient check: 0.0
Meta Update: Epoch Number 3
Avg. Loss: tf.Tensor(5.996159, shape=(), dtype=float32)
..
Gradient check: 0.0
Meta Update: Epoch Number 4
Avg. Loss: tf.Tensor(6.0963087, shape=(), dtype=float32)
..
Gradient check: 0.0
Meta Update: Epoch Number 5
Avg. Loss: tf.Tensor(6.0963087, shape=(), dtype=float32)
..

KeyboardInterrupt: 

In [72]:
lr = 1
epochs = 20
optimizer = keras.optimizers.Adam(learning_rate=lr)

model = create_model()

for task in random.sample(complete_datasets,1):
    train_data, train_labels, test_data, test_labels = task
    for epoch in range(epochs):
        
        for i in range(len(train_data)):
            dat = tf.reshape(train_data[i], [1,28,28,1])
            labl = tf.reshape(train_labels[i], [1,20])
            with tf.GradientTape() as train_tape:
                train_loss = compute_loss(model, dat, labl)
            inner_gradients = train_tape.gradient(train_loss, model.trainable_weights)
            optimizer.apply_gradients(zip(inner_gradients, model.trainable_weights))
            print(train_loss)
            print('.')

tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(2.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(2.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=flo

tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(2.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=flo

tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(2.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(2.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(2.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=flo

tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(2.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=flo

tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=float32)
.
tf.Tensor(3.0781543, shape=(), dtype=flo

KeyboardInterrupt: 

In [46]:
train_data.shape

TensorShape([300, 28, 28, 1])

In [47]:
train_data[0].shape

TensorShape([28, 28, 1])

In [64]:
model = create_model()
meow = tf.reshape(train_data[1], [1,28,28,1])
model(meow)

<tf.Tensor: shape=(1, 20), dtype=float32, numpy=
array([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 1.0552893e-13, 1.7753738e-11, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],
      dtype=float32)>

In [65]:
tf.reshape(train_labels[0], [1,20])

<tf.Tensor: shape=(1, 20), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0.]], dtype=float32)>