In [261]:
import json
import argparse
import ssl
import datetime
import os
import math
from pathlib import Path
import numpy as np

import tensorflow as tf
from tensorflow.keras import optimizers

from sklearn.preprocessing import OneHotEncoder

from model import FewShotModel
from data import get_dataset, get_zoo_elephants_images_and_labels, get_support_and_query_sets
from train import get_w_init, my_loss_fn

In [262]:
class Args():
    def __init__(self):
        self.data_dir = '/Users/deepakduggirala/Documents/Elephants-dataset-cropped-png-1024'
        self.n_support = 7
        self.epochs=1
args = Args()

In [3]:
with open('hyperparameters/init.json', 'rb') as f:
        params = json.load(f)

In [263]:
image_paths, image_labels = get_zoo_elephants_images_and_labels(args.data_dir)
support_image_paths, support_labels, query_image_paths, query_labels = get_support_and_query_sets(
    image_paths, image_labels, args.n_support)

enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
support_labels_enc = enc.fit_transform(np.array(support_labels).reshape(-1, 1))
query_labels_enc = enc.transform(np.array(query_labels).reshape(-1, 1))

cache_files = {
    'train': str(Path(args.data_dir) / 'few_shot_train.cache'),
    'val': str(Path(args.data_dir) / 'few_shot_val.cache'),
}

train_ds, N_train, _ = get_dataset(support_image_paths, support_labels_enc,
                                   params,
                                   augment=True,
                                   cache_file=cache_files['train'],
                                   shuffle=True,
                                   batch_size=params['batch_size']['train'])

val_ds, _, _ = get_dataset(query_image_paths, query_labels_enc,
                           params,
                           augment=False,
                           cache_file=cache_files['val'],
                           shuffle=False,
                           batch_size=params['batch_size']['val'])

In [265]:
len(support_image_paths), len(query_image_paths)

(119, 1572)

In [14]:
model_cnt = FewShotModel(params)

w_init = get_w_init(params, model_cnt.base_model, support_image_paths,
                    support_labels, categories=enc.categories_[0])
few_shot_model = model_cnt.get_model(w_init)



In [32]:
few_shot_model.compile(
        optimizer=optimizers.Adam(0.00001),
        loss=my_loss_fn,
        metrics=['accuracy'])

In [33]:
few_shot_model.fit(train_ds,
                       epochs=10,
                       validation_data=val_ds)

Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 

In [15]:
y_pred = few_shot_model.predict(val_ds, verbose=True)



In [31]:
np.mean(np.argmax(query_labels_enc, axis=1) == np.argmax(y_pred, axis=1))

0.6301369863013698

In [102]:
def get_embeddings(image_paths, image_labels, params, n_repeat=1, cache_file=None):
    ds_aug, _, _ = get_dataset(image_paths, image_labels,
                                   params,
                                   augment=True,
                                   cache_file=cache_file,
                                   shuffle=False,
                                   batch_size=32)
    ds_aug = ds_aug.repeat(n_repeat)
    embeddings_aug = model_cnt.base_model.predict(ds_aug, verbose=True)
    
    ls = np.array(image_labels)
    
    return embeddings_aug, np.hstack([ls]*n_repeat)

In [94]:
np.save('embeddings.npy', embs, allow_pickle=False)

In [95]:
np.save('image_labels.npy', ls, allow_pickle=False)

In [96]:
embs_loaded = np.load('embeddings.npy')
ls_loaded = np.load('image_labels.npy')

In [103]:
AUTOTUNE = tf.data.AUTOTUNE

In [None]:
image_paths, image_labels = get_zoo_elephants_images_and_labels(args.data_dir)
support_image_paths, support_labels, query_image_paths, query_labels = get_support_and_query_sets(
    image_paths, image_labels, args.n_support)

enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
support_labels_enc = enc.fit_transform(np.array(support_labels).reshape(-1, 1))
query_labels_enc = enc.transform(np.array(query_labels).reshape(-1, 1))

In [133]:
all_embs = np.load('zoo_embeddings.npy')
all_labels = np.load('zoo_image_labels.npy')

In [206]:
x = all_embs.reshape(-1, 1691, 2048)
y = all_labels.reshape(-1, 1691)

In [105]:
image_paths, image_labels = get_zoo_elephants_images_and_labels(args.data_dir)

In [108]:
from collections import Counter
from data import shuffle

In [109]:
np.random.seed(99)
counts = Counter(image_labels)
shuffled_idxs = {c: shuffle(count) for c, count in counts.items()}

In [114]:
c='1005'
idxs = shuffled_idxs[c]
n_support=5

In [115]:
s_idxs = idxs[:n_support]
q_idxs = idxs[n_support:]

In [207]:
mask = np.array(image_labels) == c

In [213]:
c_image_labels = y[:,mask]
c_embs = x[:, mask, :]

In [None]:
c_image_labels[:,s_idxs].flatten().shape

(300,)

In [227]:
c_embs[:, q_idxs, :].reshape(-1, 2048).shape

(6000, 2048)

In [257]:
def get_support_and_query_sets(X, y, image_labels, n_support, seed=99):
    support_embs = []
    support_labels = []

    query_embs = []
    query_labels = []

    np.random.seed(seed)
    counts = Counter(image_labels)
    shuffled_idxs = {c: shuffle(count) for c, count in counts.items()}

    for c, idxs in shuffled_idxs.items():
        s_idxs = idxs[:n_support]
        q_idxs = idxs[n_support:]

        mask = np.array(image_labels) == c
        c_image_labels = y[:,mask]
        c_embs = X[:, mask, :]

        support_embs.append(c_embs[:, s_idxs, :].reshape(-1, 2048))
        support_labels.append(c_image_labels[:,s_idxs].flatten())
        query_embs.append(c_embs[:, q_idxs, :].reshape(-1, 2048))
        query_labels.append(c_image_labels[:,q_idxs].flatten())

    return np.vstack(support_embs), np.concatenate(support_labels), np.vstack(query_embs), np.concatenate(query_labels)

In [258]:
support_embs, support_labels, query_embs, query_labels = get_support_and_query_sets(x, 
                                                                                    y,
                                                                                    image_labels,
                                                                                    n_support=5)