In [6]:
import json
import argparse
import ssl
import datetime
import os
import math
from pathlib import Path
import numpy as np

import tensorflow as tf
from tensorflow.keras import optimizers

from sklearn.preprocessing import OneHotEncoder

from model import FewShotModel
from data import get_dataset, get_zoo_elephants_images_and_labels, get_support_and_query_sets
from train import get_w_init, my_loss_fn

In [11]:
class Args():
    def __init__(self):
        self.data_dir = '/Users/deepakduggirala/Documents/Elephants-dataset-cropped-png-1024'
        self.n_support = 5
        self.epochs=1
args = Args()

In [3]:
with open('hyperparameters/init.json', 'rb') as f:
        params = json.load(f)

In [4]:
image_paths, image_labels = get_zoo_elephants_images_and_labels(args.data_dir)
support_image_paths, support_labels, query_image_paths, query_labels = get_support_and_query_sets(
    image_paths, image_labels, args.n_support)

enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
support_labels_enc = enc.fit_transform(np.array(support_labels).reshape(-1, 1))
query_labels_enc = enc.transform(np.array(query_labels).reshape(-1, 1))

cache_files = {
    'train': str(Path(args.data_dir) / 'few_shot_train.cache'),
    'val': str(Path(args.data_dir) / 'few_shot_val.cache'),
}

train_ds, N_train, _ = get_dataset(support_image_paths, support_labels_enc,
                                   params,
                                   augment=True,
                                   cache_file=cache_files['train'],
                                   shuffle=True,
                                   batch_size=params['batch_size']['train'])

val_ds, _, _ = get_dataset(query_image_paths, query_labels_enc,
                           params,
                           augment=False,
                           cache_file=cache_files['val'],
                           shuffle=False,
                           batch_size=params['batch_size']['val'])

In [14]:
model_cnt = FewShotModel(params)

w_init = get_w_init(params, model_cnt.base_model, support_image_paths,
                    support_labels, categories=enc.categories_[0])
few_shot_model = model_cnt.get_model(w_init)



In [32]:
few_shot_model.compile(
        optimizer=optimizers.Adam(0.00001),
        loss=my_loss_fn,
        metrics=['accuracy'])

In [33]:
few_shot_model.fit(train_ds,
                       epochs=10,
                       validation_data=val_ds)

Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 

In [15]:
y_pred = few_shot_model.predict(val_ds, verbose=True)



In [31]:
np.mean(np.argmax(query_labels_enc, axis=1) == np.argmax(y_pred, axis=1))

0.6301369863013698