# Settings

In [1]:
from keras.layers import Input, Conv2D, Lambda, merge, Dense, Flatten,MaxPooling2D
from keras.models import Model, Sequential
from keras.regularizers import l2
from keras import backend as K
from keras.optimizers import SGD,Adam
from keras.losses import binary_crossentropy
import numpy.random as rnd
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
%matplotlib inline

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


# Load data

In [2]:
import pickle
import boto3
from io import BytesIO

In [13]:
s3 = boto3.resource('s3')
data_subsets = ['train', 'val']
data = {}

for name in data_subsets:
    with BytesIO() as files:
        path = "omniglot_images/" +name+ ".pickle"
        s3.Bucket("research-paper-omniglot-data").download_fileobj(path, files)
        files.seek(0)    # move back to the beginning after writing
        (X,c) = pickle.load(files)
        data[name] = X

In [4]:
path = '../../omniglot_images/'
data_subsets = ["train", "val", "test"]

data = {}
categories = {}
info = {}
        
for name in data_subsets:
    file_path = os.path.join(path, name + ".pickle")
    print("loading data from {}".format(file_path))
    with open(file_path,"rb") as f:
        (X,c) = pickle.load(f)
        data[name] = X
        categories[name] = c

loading data from ../../omniglot_images/train.pickle


FileNotFoundError: [Errno 2] No such file or directory: '../../omniglot_images/train.pickle'

In [5]:
def create_data(size, s='train'):
    #get train data and shape
    X=data[s]
    n_classes, n_examples, w, h = X.shape
    
    #initialize 2 empty arrays for the input size in a list
    pairs=[np.zeros((size, h, w,1)) for i in range(2)]
    
    #initialize vector for the targets
    targets=np.zeros((size,1))
    
    for x in range(size):
        #randomly sample one class (character)
        category = rnd.choice(n_classes,1,replace=False)
        #randomly sample one example from class (1-20 characters)
        idx_1 = rnd.randint(0, n_examples)
        pairs[0][x,:,:,:] = X[category, idx_1].reshape(w, h, 1)
        #randomly sample again one example from class and add last class with modulo
        # ..to ensure not same class pairs are created
        idx_2 = (idx_1 + rnd.randint(0, n_examples)) % n_examples
        #pick images of different class for 1st half and same class for 2nd half
        if x >= size // 2:
            category_2 = category
            targets[x] = 1
        else: 
        #add a random number to the category modulo n classes to ensure 2nd image has
        # ..different category
            idx_2 = rnd.randint(0, n_examples) 
            category_2 = (category + rnd.randint(1,n_classes)) % n_classes
            targets[x] = 0
        pairs[1][x,:,:,:] = X[category_2,idx_2].reshape(w, h,1)
        
    return pairs, targets

In [5]:
train_set, train_labels = create_train_data(10000)
#val_set, val_labels = create_train_data(10000)

# Create graph

In [6]:
def W_init(shape,name=None):
    """Initialize weights as in paper"""
    values = rnd.normal(loc=0,scale=1e-2,size=shape)
    return K.variable(values,name=name)

In [7]:
def b_init(shape,name=None):
    """Initialize bias as in paper"""
    values=rnd.normal(loc=0.5,scale=1e-2,size=shape)
    return K.variable(values,name=name)

In [8]:
input_shape = (105, 105, 1)
left_input = Input(input_shape)
right_input = Input(input_shape)
#build convnet to use in each siamese 'leg'
convnet = Sequential()
convnet.add(Conv2D(64,(10,10),activation='relu',input_shape=input_shape,
                   kernel_initializer=W_init,kernel_regularizer=l2(2e-4)))
convnet.add(MaxPooling2D())
convnet.add(Conv2D(128,(7,7),activation='relu',
                   kernel_regularizer=l2(2e-4),kernel_initializer=W_init,bias_initializer=b_init))
convnet.add(MaxPooling2D())
convnet.add(Conv2D(128,(4,4),activation='relu',kernel_initializer=W_init,kernel_regularizer=l2(2e-4),bias_initializer=b_init))
convnet.add(MaxPooling2D())
convnet.add(Conv2D(256,(4,4),activation='relu',kernel_initializer=W_init,kernel_regularizer=l2(2e-4),bias_initializer=b_init))
convnet.add(Flatten())
convnet.add(Dense(4096,activation="sigmoid",kernel_regularizer=l2(1e-3),kernel_initializer=W_init,bias_initializer=b_init))

#call the convnet Sequential model on each of the input tensors so params will be shared
encoded_l = convnet(left_input)
encoded_r = convnet(right_input)
#layer to merge two encoded inputs with the l1 distance between them
L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
#call this layer on list of two input tensors.
L1_distance = L1_layer([encoded_l, encoded_r])
prediction = Dense(1,activation='sigmoid',bias_initializer=b_init)(L1_distance)
siamese_net = Model(inputs=[left_input,right_input],outputs=prediction)

optimizer = Adam(0.00006)
#//TODO: get layerwise learning rates and momentum annealing scheme described in paperworking
siamese_net.compile(loss="binary_crossentropy",optimizer=optimizer)

siamese_net.count_params()

38951745

# Training

In [9]:
def generate_oneshot_set(N, s='val'):
    """Create pairs of test image, support set for testing N way one-shot learning. """
    X=data[s]
    n_classes, n_examples, w, h = X.shape
    indices = rnd.randint(0,n_examples,size=(N,))
    categories = rnd.choice(range(n_classes),size=(N,),replace=False)            
    true_category = categories[0]
    ex1, ex2 = rnd.choice(n_examples,replace=False,size=(2,))
    test_image = np.asarray([X[true_category,ex1,:,:]]*N).reshape(N, w, h,1)
    support_set = X[categories,indices,:,:]
    support_set[0,:,:] = X[true_category,ex2]
    support_set = support_set.reshape(N, w, h,1)
    targets = np.zeros((N,))
    targets[0] = 1
    targets, test_image, support_set = shuffle(targets, test_image, support_set)
    pairs = [test_image,support_set]

    return pairs, targets

In [10]:
def prediction(probs, labels):
    probs_rounded = np.round_(probs)
    output = np.equal(probs_rounded, labels)
    accuracy = np.mean(output.astype(int))
    
    return(accuracy)

This is just as fast as the Tensorflow implementation

In [None]:
#Training loop
print("!")
evaluate_every = 50 # interval for evaluating on one-shot tasks
batch_size = 32
n_iter = 90000
N_way = 20 # how many classes for testing one-shot tasks>
n_val = 250 #how mahy one-shot tasks to validate on?

print("training")
for i in range(1, n_iter):
    #batch_x1, batch_x2, batch_y = shuffle(train_set[0],train_set[1], train_labels, n_samples = batch_size)
    #loss=siamese_net.train_on_batch([batch_x1, batch_x2],batch_y)
    batch_x, batch_y = create_data(batch_size)
    loss=siamese_net.train_on_batch(batch_x,batch_y)
    if i % evaluate_every == 0:
        train_x, train_y = create_data(10)
        val_x, val_y = create_data(10, s='val')
        prob_train = siamese_net.predict(train_x)
        prob_val = siamese_net.predict(val_x)
        print(prob_train)
        acc_train = prediction(prob_train, train_y)
        acc_val = prediction(prob_val, val_y)
        
        print('acc train:', acc_train)
        print('acc val:', acc_val)
        print('batch iteration:', i)
        print('loss:', loss)
        print("evaluating")
        n_correct = 0
        for j in range(n_val):
            pairs, labels = generate_oneshot_set(N_way, s='val')
            probs = siamese_net.predict(pairs)
            if np.argmax(probs) == np.argmax(labels):
                    n_correct+=1
        percent_correct = (n_correct / n_val)
        print('Validation accuracy:', percent_correct)

!
training
[[0.21489626]
 [0.02116116]
 [0.56986755]
 [0.2681466 ]
 [0.32567844]
 [0.6789337 ]
 [0.54133695]
 [0.6069878 ]
 [0.65379405]
 [0.69988036]]
acc train: 0.9
acc val: 0.6
batch iteration: 50
loss: 3.649458
evaluating
Validation accuracy: 0.164
[[0.0717481 ]
 [0.24659957]
 [0.4208979 ]
 [0.3138581 ]
 [0.00544414]
 [0.31875396]
 [0.4252426 ]
 [0.71478903]
 [0.62507814]
 [0.78646296]]
acc train: 0.8
acc val: 0.8
batch iteration: 100
loss: 3.0759401
evaluating
Validation accuracy: 0.236
[[0.04436253]
 [0.6144382 ]
 [0.47510242]
 [0.49111122]
 [0.26286408]
 [0.59563845]
 [0.6277725 ]
 [0.8159533 ]
 [0.86841345]
 [0.6668283 ]]
acc train: 0.9
acc val: 0.8
batch iteration: 150
loss: 2.796114
evaluating
Validation accuracy: 0.244
[[0.1353483 ]
 [0.27533314]
 [0.70290995]
 [0.00791663]
 [0.08165593]
 [0.69338185]
 [0.85048085]
 [0.8181452 ]
 [0.6259519 ]
 [0.64523166]]
acc train: 0.9
acc val: 0.8
batch iteration: 200
loss: 2.275612
evaluating
Validation accuracy: 0.316
[[4.5272237e-01]


Validation accuracy: 0.528
[[5.6223962e-02]
 [1.9561036e-03]
 [1.4719598e-05]
 [2.3180945e-02]
 [1.0427052e-01]
 [7.7750629e-01]
 [3.0298769e-01]
 [9.5694965e-01]
 [8.2337838e-01]
 [9.7120273e-01]]
acc train: 0.9
acc val: 0.8
batch iteration: 1650
loss: 0.63395834
evaluating
Validation accuracy: 0.644
[[4.9269732e-02]
 [1.7255105e-01]
 [8.7776099e-04]
 [9.1932202e-03]
 [2.6057083e-03]
 [9.7254366e-01]
 [9.8771685e-01]
 [6.3550341e-01]
 [8.8006121e-01]
 [9.3286467e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 1700
loss: 0.7407397
evaluating
Validation accuracy: 0.576
[[0.0618976 ]
 [0.01376141]
 [0.01051199]
 [0.00093557]
 [0.00670638]
 [0.75248355]
 [0.30210048]
 [0.16843945]
 [0.9285243 ]
 [0.7393047 ]]
acc train: 0.8
acc val: 0.9
batch iteration: 1750
loss: 0.6963006
evaluating
Validation accuracy: 0.624
[[2.0215558e-02]
 [2.4758643e-03]
 [3.6093742e-03]
 [7.2707109e-02]
 [9.4644882e-04]
 [9.7014767e-01]
 [8.0788541e-01]
 [8.3331412e-01]
 [5.6159008e-01]
 [8.5687923e-01]]
acc tr

Validation accuracy: 0.628
[[1.9461707e-04]
 [4.2605773e-02]
 [3.3370338e-03]
 [8.0299914e-02]
 [1.9484435e-04]
 [6.6305315e-01]
 [7.8821868e-01]
 [2.9930511e-01]
 [9.6735507e-01]
 [9.2071491e-01]]
acc train: 0.9
acc val: 0.7
batch iteration: 3200
loss: 0.5163076
evaluating
Validation accuracy: 0.568
[[4.0160466e-02]
 [3.6977747e-04]
 [2.7056376e-04]
 [1.3725085e-03]
 [9.4842577e-01]
 [7.7872556e-01]
 [9.8080415e-01]
 [9.0445727e-01]
 [9.7583073e-01]
 [9.7795850e-01]]
acc train: 0.9
acc val: 0.8
batch iteration: 3250
loss: 0.47671762
evaluating
Validation accuracy: 0.588
[[2.5924503e-05]
 [9.0287381e-01]
 [8.4486049e-01]
 [5.2395249e-03]
 [1.0904301e-01]
 [9.7092396e-01]
 [9.9680871e-01]
 [9.9336386e-01]
 [9.9199206e-01]
 [9.9692082e-01]]
acc train: 0.8
acc val: 1.0
batch iteration: 3300
loss: 0.41720355
evaluating
Validation accuracy: 0.612
[[3.8724190e-03]
 [1.4645311e-01]
 [7.6019545e-03]
 [4.0358089e-02]
 [3.0041983e-05]
 [8.1005448e-01]
 [9.8468000e-01]
 [9.7379106e-01]
 [8.919091

Validation accuracy: 0.656
[[0.9047847 ]
 [0.01862606]
 [0.19844948]
 [0.0088355 ]
 [0.00373808]
 [0.653395  ]
 [0.9383089 ]
 [0.96812683]
 [0.653395  ]
 [0.98257345]]
acc train: 0.9
acc val: 0.8
batch iteration: 4700
loss: 0.36721918
evaluating
Validation accuracy: 0.632
[[1.1419986e-02]
 [1.7042981e-03]
 [5.1342277e-04]
 [8.7053549e-01]
 [8.7371771e-04]
 [9.6824229e-01]
 [9.9687177e-01]
 [9.6034038e-01]
 [9.9365211e-01]
 [9.2767501e-01]]
acc train: 0.9
acc val: 1.0
batch iteration: 4750
loss: 0.27956054
evaluating
Validation accuracy: 0.672
[[7.9674060e-03]
 [6.9128914e-04]
 [1.5703518e-02]
 [4.7984443e-04]
 [3.5713192e-02]
 [9.7276431e-01]
 [9.9118292e-01]
 [6.5396398e-01]
 [8.4426528e-01]
 [9.9876356e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 4800
loss: 0.36091867
evaluating
Validation accuracy: 0.628
[[8.0586586e-05]
 [1.1268719e-01]
 [2.0767772e-03]
 [6.0776751e-05]
 [4.2586122e-04]
 [9.3382573e-01]
 [9.9913687e-01]
 [9.5109159e-01]
 [9.8744965e-01]
 [9.6964866e-01]]
acc 

Validation accuracy: 0.696
[[4.6743494e-03]
 [3.2884310e-04]
 [2.8018188e-04]
 [3.2179093e-03]
 [4.2272359e-04]
 [8.8351601e-01]
 [8.1809610e-01]
 [6.0894454e-01]
 [9.6130389e-01]
 [9.9626285e-01]]
acc train: 1.0
acc val: 1.0
batch iteration: 6200
loss: 0.25119647
evaluating
Validation accuracy: 0.692
[[0.00393814]
 [0.0017015 ]
 [0.00972249]
 [0.00859185]
 [0.00517758]
 [0.98776466]
 [0.99959713]
 [0.9780211 ]
 [0.99266875]
 [0.9912266 ]]
acc train: 1.0
acc val: 0.7
batch iteration: 6250
loss: 0.33151373
evaluating
Validation accuracy: 0.692
[[2.9087878e-05]
 [9.0518072e-02]
 [3.1074348e-03]
 [3.8215000e-04]
 [2.5670096e-01]
 [9.9888009e-01]
 [9.9284345e-01]
 [9.9847656e-01]
 [6.0642332e-02]
 [9.9317873e-01]]
acc train: 0.9
acc val: 0.8
batch iteration: 6300
loss: 0.3176959
evaluating
Validation accuracy: 0.7
[[1.5918064e-04]
 [1.9666133e-02]
 [2.4077367e-06]
 [5.7037332e-04]
 [1.7171090e-03]
 [8.2455355e-01]
 [9.7043073e-01]
 [9.4392812e-01]
 [3.8635680e-01]
 [9.9204093e-01]]
acc tra

Validation accuracy: 0.7
[[2.0875914e-05]
 [5.1109400e-03]
 [1.5565116e-03]
 [5.0503784e-03]
 [7.8984076e-04]
 [9.9448496e-01]
 [9.8753268e-01]
 [9.9760240e-01]
 [8.5458630e-01]
 [9.5869398e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 7700
loss: 0.26955992
evaluating
Validation accuracy: 0.676
[[3.3656728e-05]
 [3.2821029e-02]
 [2.3837697e-03]
 [9.6886009e-03]
 [2.1529256e-04]
 [9.8051161e-01]
 [7.6015276e-01]
 [9.9744058e-01]
 [9.6403861e-01]
 [6.0718775e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 7750
loss: 0.18456413
evaluating
Validation accuracy: 0.72
[[1.7499700e-03]
 [1.1814588e-04]
 [9.9541945e-03]
 [2.7100094e-02]
 [3.6976917e-04]
 [6.7181873e-01]
 [9.7242910e-01]
 [6.7181873e-01]
 [9.5086867e-01]
 [9.8247093e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 7800
loss: 0.24142885
evaluating
Validation accuracy: 0.74
[[1.2832205e-06]
 [1.4453249e-04]
 [1.0987907e-02]
 [3.8948376e-04]
 [6.8888990e-06]
 [9.3175137e-01]
 [2.1856219e-01]
 [9.5874888e-01]
 [9.9903560e-

Validation accuracy: 0.656
[[1.0086796e-02]
 [9.7343157e-04]
 [3.1284623e-02]
 [5.8477336e-01]
 [2.4074088e-04]
 [9.9757224e-01]
 [4.7500321e-01]
 [6.8010718e-01]
 [9.9544537e-01]
 [5.7827932e-01]]
acc train: 0.8
acc val: 0.9
batch iteration: 9200
loss: 0.28938228
evaluating
Validation accuracy: 0.72
[[1.31911775e-02]
 [3.26849637e-03]
 [6.67206012e-04]
 [6.63877930e-04]
 [1.28380386e-02]
 [6.80418074e-01]
 [9.93609846e-01]
 [9.96009707e-01]
 [8.03697407e-01]
 [9.98237014e-01]]
acc train: 1.0
acc val: 1.0
batch iteration: 9250
loss: 0.23075733
evaluating
Validation accuracy: 0.764
[[5.5617217e-02]
 [3.1213520e-05]
 [1.0652120e-06]
 [7.0039481e-05]
 [7.8743738e-05]
 [9.5782214e-01]
 [9.4635350e-01]
 [6.3837987e-01]
 [9.8402530e-01]
 [9.9632084e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 9300
loss: 0.22248365
evaluating
Validation accuracy: 0.748
[[2.6477283e-05]
 [3.1102706e-02]
 [8.5864734e-08]
 [3.0031326e-05]
 [3.7638565e-06]
 [9.9510819e-01]
 [9.5226294e-01]
 [9.4724375e-01]


Validation accuracy: 0.7
[[8.9422613e-04]
 [1.1290879e-02]
 [7.1652001e-03]
 [1.3083753e-01]
 [3.1805341e-03]
 [9.8802525e-01]
 [9.5462471e-01]
 [9.9872005e-01]
 [9.6144241e-01]
 [9.6613556e-01]]
acc train: 1.0
acc val: 0.6
batch iteration: 10700
loss: 0.28311926
evaluating
Validation accuracy: 0.7
[[1.6270037e-03]
 [1.5975662e-04]
 [2.6833673e-04]
 [2.1707025e-01]
 [1.2099503e-03]
 [9.9398947e-01]
 [9.9792105e-01]
 [9.5810360e-01]
 [5.9605175e-01]
 [9.8595178e-01]]
acc train: 1.0
acc val: 1.0
batch iteration: 10750
loss: 0.21377988
evaluating
Validation accuracy: 0.764
[[3.2070898e-03]
 [2.9063531e-06]
 [2.3082538e-02]
 [1.0528718e-05]
 [1.6332999e-03]
 [9.9794298e-01]
 [9.9177295e-01]
 [9.7726029e-01]
 [9.8477316e-01]
 [8.8499898e-01]]
acc train: 1.0
acc val: 0.8
batch iteration: 10800
loss: 0.2594343
evaluating
Validation accuracy: 0.76
[[9.83002596e-04]
 [1.16863595e-02]
 [7.11618077e-06]
 [1.08169865e-04]
 [6.82021433e-04]
 [6.08325541e-01]
 [9.84966934e-01]
 [9.75184560e-01]
 [9.

Validation accuracy: 0.768
[[2.6803002e-05]
 [8.9677572e-03]
 [7.1724375e-05]
 [1.2878515e-04]
 [1.3256865e-05]
 [9.8897552e-01]
 [9.6725935e-01]
 [9.9423122e-01]
 [9.9749267e-01]
 [9.9856597e-01]]
acc train: 1.0
acc val: 0.7
batch iteration: 12200
loss: 0.24581066
evaluating
Validation accuracy: 0.756
[[3.8191953e-03]
 [3.7086412e-02]
 [3.6423247e-02]
 [1.7514714e-05]
 [4.5280780e-05]
 [9.9033701e-01]
 [8.8241643e-01]
 [9.5414305e-01]
 [9.8782325e-01]
 [6.9705349e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 12250
loss: 0.19227369
evaluating
Validation accuracy: 0.776
[[1.6110685e-01]
 [9.7112643e-06]
 [5.6325318e-04]
 [4.5649838e-05]
 [3.0998340e-01]
 [9.8218322e-01]
 [7.8074098e-01]
 [9.7229034e-01]
 [7.9746759e-01]
 [9.4514352e-01]]
acc train: 1.0
acc val: 1.0
batch iteration: 12300
loss: 0.15882261
evaluating
Validation accuracy: 0.708
[[1.2886726e-07]
 [5.5008858e-01]
 [6.3524456e-03]
 [6.2729763e-05]
 [3.4583132e-05]
 [8.1088692e-01]
 [9.8684525e-01]
 [9.9974066e-01]
 [9.88

Validation accuracy: 0.712
[[1.9407987e-04]
 [4.1371160e-03]
 [7.5332619e-06]
 [2.6327302e-06]
 [2.2980319e-04]
 [6.1630917e-01]
 [9.9905759e-01]
 [9.9708015e-01]
 [6.4274293e-01]
 [9.9758887e-01]]
acc train: 1.0
acc val: 0.8
batch iteration: 13700
loss: 0.25042173
evaluating
Validation accuracy: 0.796
[[2.9411034e-03]
 [2.2718627e-03]
 [3.7601585e-03]
 [6.0438118e-03]
 [1.9082936e-04]
 [9.7920632e-01]
 [9.9720359e-01]
 [9.9952006e-01]
 [9.8787284e-01]
 [8.9828420e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 13750
loss: 0.21016519
evaluating
Validation accuracy: 0.736
[[4.7784591e-01]
 [1.4803201e-04]
 [4.8026357e-02]
 [2.2807251e-07]
 [5.0716599e-06]
 [9.4913852e-01]
 [9.8199791e-01]
 [9.9569219e-01]
 [9.9591345e-01]
 [9.8969942e-01]]
acc train: 1.0
acc val: 0.9
batch iteration: 13800
loss: 0.1925904
evaluating
Validation accuracy: 0.768
[[3.2225583e-04]
 [2.4254735e-04]
 [9.6792282e-05]
 [6.3530874e-06]
 [5.2100499e-03]
 [9.7988242e-01]
 [9.9916553e-01]
 [9.7964418e-01]
 [9.528