<a href="https://colab.research.google.com/github/ClementWalter/Keras-FewShotLearning/blob/master/notebooks/omniglot/basic_siamese_nets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from pathlib import Path
from unittest.mock import patch

import imgaug.augmenters as iaa
import pandas as pd
from tensorflow.python.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.python.keras import Model

!pip install git+https://github.com/ClementWalter/Keras-FewShotLearning.git
from keras_fsl.datasets import omniglot
from keras_fsl.models import SiameseNets
from keras_fsl.sequences import (
    DeterministicSequence,
    RandomBalancedPairsSequence,
)
from keras_fsl.utils import patch_len, default_workers
# prevent issue with multiprocessing and long sequences, see https://github.com/keras-team/keras/issues/13226
patch_fit_generator = patch(
    'tensorflow.keras.Model.fit_generator',
    side_effect=default_workers(patch_len(Model.fit_generator)),
)
patch_fit_generator.start()

Collecting git+https://github.com/ClementWalter/Keras-FewShotLearning.git
  Cloning https://github.com/ClementWalter/Keras-FewShotLearning.git to /tmp/pip-req-build-nu7gs0w_
  Running command git clone -q https://github.com/ClementWalter/Keras-FewShotLearning.git /tmp/pip-req-build-nu7gs0w_
Building wheels for collected packages: keras-fsl
  Building wheel for keras-fsl (setup.py) ... [?25l[?25hdone
  Created wheel for keras-fsl: filename=keras_fsl-0.0.1-cp36-none-any.whl size=17546 sha256=161dd22b89d1f1ecd8e85bf8740efd1ef8e9124a54d39ef57085229f492eac92
  Stored in directory: /tmp/pip-ephem-wheel-cache-2m1z5pyi/wheels/b2/b1/19/f3d9c95093ddc56d2c447747668d7a97f482a1ca2d9f1da22b
Successfully built keras-fsl
Installing collected packages: keras-fsl
Successfully installed keras-fsl-0.0.1


<MagicMock name='fit_generator' id='140508297199400'>

In [3]:
train_set, test_set = omniglot.load_data()

Downloading data from https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip
Downloading data from https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip


In [2]:
train_set = train_set.assign(label=lambda df: df.alphabet + '_' + df.label)
test_set = test_set.assign(label=lambda df: df.alphabet + '_' + df.label)

NameError: ignored

In [0]:
siamese_nets = SiameseNets()
val_set = train_set.sample(frac=0.3, replace=False)
train_set = train_set.loc[lambda df: ~df.index.isin(val_set.index)]
callbacks = [TensorBoard(), ModelCheckpoint('logs/proto_nets/best_weights.h5')]
(Path('logs') / 'proto_nets').mkdir(parents=True, exist_ok=True)
preprocessing = iaa.Sequential([
    iaa.Affine(
        translate_percent={'x': (-0.2, 0.2), 'y': (-0.2, 0.2)},
        rotate=(-10, 10),
        shear=(-0.8, 1.2),
    )
])
train_sequence = RandomBalancedPairsSequence(train_set, preprocessing=preprocessing, batch_size=16)
val_sequence = RandomBalancedPairsSequence(val_set, batch_size=16)

siamese_nets.compile(optimizer='Adam', loss='binary_crossentropy')
Model.fit_generator(  # to use patched fit_generator, see first cell
    siamese_nets,
    train_sequence,
    validation_data=val_sequence,
    callbacks=callbacks,
    epochs=100,
    steps_per_epoch=1000,
    validation_steps=200,
    use_multiprocessing=True,
)

In [0]:
encoder = siamese_nets.get_layer('branch_model')
head_model = siamese_nets.get_layer('head_model')
test_sequence = DeterministicSequence(test_set, batch_size=16)
embeddings = encoder.predict_generator(test_sequence, verbose=1)

k_shot = 1
n_way = 5
support = (
    test_set
    .loc[lambda df: df.label.isin(test_set.label.drop_duplicates().sample(n_way))]
    .groupby('label')
    .apply(lambda group: group.sample(k_shot).drop('label', axis=1))
    .reset_index('label')
)
query = (
    test_set
    .loc[lambda df: df.label.isin(support.label.unique())]
    .loc[lambda df: ~df.index.isin(support.index)]
    .loc[lambda df: df.index.repeat(k_shot * n_way)]
    .reset_index()
)
support = (
    support
    .loc[lambda df: pd.np.tile(df.index, len(query) // (k_shot * n_way))]
    .reset_index()
)
predictions = (
    pd.concat([
        query,
        pd.DataFrame(head_model.predict([embeddings[query['index']], embeddings[support['index']]]), columns=['score']),
        support.add_suffix('_support'),
    ], axis=1)
)
confusion_matrix = (
    predictions
    .groupby(query.columns.to_list())
    .apply(lambda group: group.nlargest(1, columns='score').label_support)
    .reset_index()
    .pivot_table(
        values='image_name',
        index='label_support',
        columns='label',
        aggfunc='count',
        margins=True,
        fill_value=0,
    )
    .assign(precision=lambda df: pd.np.diag(df)[:-1] / df.All[:-1])
    .T.assign(recall=lambda df: pd.np.diag(df)[:-1] / df.All[:-2]).T
    .assign(f1=lambda df: 2 * df.precision * df.loc['recall'] / (df.precision + df.loc['recall']))
)
