# v1.5 embedding model

* v1 reproducing constrastive loss for metric learning; (2C, H, W, 3)
* v1.5 vectorising that model to operate on average of N embeddings (N, 2C, H, W, 3)

In [None]:
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras

In [None]:
import jax.numpy as jnp
import jax
from jax import jit, value_and_grad, vmap

import optax

from keras.layers import Input, Dense, Conv2D, GlobalMaxPooling2D
from keras.layers import Layer, BatchNormalization, Activation, Dropout
from keras.models import Model

import numpy as np
np.set_printoptions(precision=5, threshold=10000, suppress=True, linewidth=10000)

print('keras', keras.__version__, 
      'jax', jax.__version__, 
      'optax', optax.__version__)

In [None]:
class Opts:
    height_width = 64
    batch_size = 2            # B (outer batch size
    num_obj_references = 4    # N number of reference examples given for each object
    num_contrastive_objs = 3  # C total number of classes used for contrasting
    embedding_dim = 64        # E embedding dim
    learning_rate = 1e-4
    
opts = Opts()

def shapes(debug_str, list_of_variables):
    return f"{debug_str} {[v.shape for v in list_of_variables]}"

In [None]:
# plotting debug

from PIL import Image

def collage(pil_imgs, rows, cols):
    n = len(pil_imgs) 
    if n != rows * cols:
        raise Exception()
    img_h, img_w = pil_imgs[0].size    
    collage = Image.new('RGB', (cols*img_h, rows*img_w))
    for i in range(n):
        pc, pr = i%rows, i//rows
        collage.paste(pil_imgs[i], (pr*img_h, pc*img_w))
    return collage
    
def to_pil_img(a):
    a = np.array(a*255, dtype=np.uint8)
    return Image.fromarray(a)

In [None]:
from data import ContrastiveExamples, load_fname

# start with simple case of  R, G, B examples
c_egs = ContrastiveExamples(
    root_dir='data/train/reference_patches/',
    obj_ids=["061", # "135","182",  # x3 red
             "111", # "153","198",  # x3 green
             "000", #"017","019", # x3 blue
            ]
)
ds = c_egs.dataset(num_batches=1,                                   # epoch length
                   batch_size=opts.batch_size,                      # B
                   num_obj_references=opts.num_obj_references,      # N
                   num_contrastive_objs=opts.num_contrastive_objs)  # C
for x, y in ds:
    break

x = jnp.array(x)
print(x.shape, y)

In [None]:
# recall each element of outer batch is xN examples of anchor/positive pairs
collage(
    [to_pil_img(x[0,0,0]), 
     to_pil_img(x[0,0,1])],
    rows=1, cols=2)

In [None]:
# recall each element of outer batch is xN examples of anchor/positive pairs
collage(
    [to_pil_img(x[0,1,0]), 
     to_pil_img(x[0,1,1])],
    rows=1, cols=2)

In [None]:
# recall X.shapee (B=2, N=4, C=6, 64, 64, 3)
# to batch from right to left though, we want (B=2, C=6, N=4, 64, 64, 3)
x = jnp.transpose(x, (0,2,1,3,4,5))
print(x.shape)


model is simple enough embedding model

output is L2 normalised embedding ( so dot products can be used for sims and xent contrastive )

In [None]:
from models.models import construct_embedding_model

embedding_model = construct_embedding_model(
    height_width=64,
    filter_sizes=[16,32,64,128],
    embedding_dim=128)

In [None]:

params = embedding_model.trainable_variables
nt_params = embedding_model.non_trainable_variables

def mean_embeddings(params, nt_params, x):
    # x (N, H, W, 3)
    embeddings, nt_params = embedding_model.stateless_call(params, nt_params, x, training=True)  # (N, E)
    # average over N
    embeddings = jnp.mean(embeddings, axis=0)  # (E)
    # (re) L2 normalise
    embeddings /= jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
    return embeddings, nt_params

embeddings, nt_params_2 = mean_embeddings(params, nt_params, x[0,0])

print("e shape", embeddings.shape)
print("e norms", jnp.linalg.norm(embeddings, axis=-1))
print(shapes('ntps', nt_params_2))

In [None]:
embeddings, nt_params_2 =  vmap(mean_embeddings, in_axes=(None, None, 0))(params, nt_params, x[0])
nt_params_2 = [jnp.mean(p, axis=0) for p in nt_params_2]

print("e shape", embeddings.shape)
print("e norms", jnp.linalg.norm(embeddings, axis=-1))
print(shapes('ntps', nt_params_2))

In [None]:
# define the constrastive loss based on the 'batch' of 2N examples ( N pairs )

def main_diagonal_softmax_cross_entropy(logits):
    # cross entropy assuming "labels" are just (0, 1, 2, ...) i.e. where
    # one_hot mask for log_softmax ends up just being the main diagonal
    return -jnp.sum(jnp.diag(jax.nn.log_softmax(logits)))

def mean_embeddings(params, nt_params, x):
    # x (N, H, W, 3)
    embeddings, nt_params = embedding_model.stateless_call(params, nt_params, x, training=True)  # (N, E)
    # average over N
    embeddings = jnp.mean(embeddings, axis=0)  # (E)
    # (re) L2 normalise
    embeddings /= jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
    return embeddings, nt_params
    
def constrastive_loss(params, nt_params, x):
    # x (2C, N, H, W, 3)
    embeddings, nt_params = vmap(mean_embeddings, in_axes=(None, None, 0))(params, nt_params, x)
    nt_params = [jnp.mean(p, axis=0) for p in nt_params]
    # embeddings (2C, E)    
    embeddings = embeddings.reshape((-1, 2, opts.embedding_dim))  # (C, 2, E)
    anchors = embeddings[:, 0]
    positives = embeddings[:, 1]
#    print('anchors', anchors.shape, 'positives', positives.shape)
    gram_ish_matrix = jnp.einsum('ae,be->ab', anchors, positives)
    xent = main_diagonal_softmax_cross_entropy(logits=gram_ish_matrix)
    return jnp.mean(xent), nt_params

def batch_constrastive_loss(params, nt_params, x):
    losses, nt_params = vmap(constrastive_loss, in_axes=(None, None, 0))(params, nt_params, x)
    # x (B, 2C, N, H, W, 3)
    nt_params = [jnp.mean(p, axis=0) for p in nt_params]
    return jnp.mean(losses), nt_params


loss, nt_params_2 = batch_constrastive_loss(params, nt_params, x)
print('loss', loss)
print(shapes('ntps', nt_params_2))

In [None]:
# define gradients and a simple training loop

def calculate_gradients(params, nt_params, x):
    # x (2C,H,W,3)
    grad_fn = value_and_grad(constrastive_loss, has_aux=True)    
    (loss, nt_params), grads = grad_fn(params, nt_params, x)
    return (loss, nt_params), grads

opt = optax.adam(learning_rate=opts.learning_rate)

def train_step(params, nt_params, opt_state, x):
    (loss, nt_params), grads = calculate_gradients(params, nt_params, x)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, nt_params, opt_state, loss

embedding_model = construct_embedding_model()

params = embedding_model.trainable_variables
nt_params = embedding_model.non_trainable_variables
opt_state = opt.init(params)

for e in range(1000):
    params, nt_params, opt_state, loss = jit(train_step)(params, nt_params, opt_state, x[0])
    if e % 100 == 0:
        print(loss)


In [None]:
# test against batch
embeddings, _ = embedding_model.stateless_call(params, nt_params, x[0], training=False)
embeddings.shape

# looks good (0,1) (2,3) (4,5) all pair well ( and others are -0.5 )
jnp.around(jnp.dot(embeddings, embeddings.T), 2)

next we get things working on a batch of these examples

In [None]:
# to use the model batched we vmap it first 

def training_call(x):
    return embedding_model.stateless_call(params, nt_params, x, training=True)

training_call = vmap(training_call)
embeddings, nt_params_2 = training_call(x)

print("e shape", embeddings.shape)
print("e norms", jnp.linalg.norm(embeddings, axis=-1))
print(shapes('ntps_2', nt_params_2))

In [None]:
# but note that the nt_params returned have been vectorised too
# i.e. they are [(B, p1), (B, p2), ...] instead of [(p1,), (p2,), ...]
# so, we need to aggreate them,

nt_params_2 = [jnp.mean(p, axis=0) for p in nt_params_2]
print(shapes('ntps', nt_params_2))

In [None]:
# as before we can vectorise the loss 
# takes (B, 2C, H, W, 3)
constrastive_loss_v = vmap(constrastive_loss, in_axes=[None, None, 0])

# and run over all of x 
loss_v, nt_params_2 = constrastive_loss_v(params, nt_params, x)  # (N)

print("before inner batch aggregation...")
print('loss_v', loss_v)
print(shapes('ntps', nt_params_2))

loss = jnp.mean(loss_v)
nt_params_2 = [jnp.mean(p, axis=0) for p in nt_params_2]

print("after inner batch aggregation...")
print('loss', loss)
print(shapes('ntps', nt_params_2))

In [None]:
# we can calculate grads just as before but we must call a function that includes the loss 
# aggregation ( i.e. grads only are applicable for a scalar loss ) so wrap the vmap
# and jnp.mean in one function.

def constrastive_loss_v(params, nt_params, x):
    # vectorise function as normal
    loss_fn_v = vmap(constrastive_loss, in_axes=[None, None, 0])
    # call returning vectorised result
    loss_v, nt_params_v = loss_fn_v(params, nt_params, x)
    # aggregate mean over both loss and nt_params for return
    # TODO: what does this do for rng seeds?
    loss = jnp.mean(loss_v)
    nt_params = [jnp.mean(p, axis=0) for p in nt_params_v]    
    return loss, nt_params

(loss, nt_params_2), grads = jit(value_and_grad(constrastive_loss_v, has_aux=True))(params, nt_params, x)

print('loss', loss)
print(shapes('grads', grads))
print(shapes('ntps', nt_params_2))


In [None]:
# stitch together into a training loop

def main_diagonal_softmax_cross_entropy(logits):
    # cross entropy assuming "labels" are just (0, 1, 2, ...) i.e. where
    # one_hot mask for log_softmax ends up just being the main diagonal
    return -jnp.sum(jnp.diag(jax.nn.log_softmax(logits)))
    
def constrastive_loss(params, nt_params, x):
    # x (2C,H,W,3)
    embeddings, nt_params = embedding_model.stateless_call(params, nt_params, x, training=True)
    embeddings = embeddings.reshape((opts.num_egs_per_class, 2, opts.embedding_dim))
    anchors = embeddings[:, 0]
    positives = embeddings[:, 1]
    gram_ish_matrix = jnp.einsum('ae,be->ab', anchors, positives)
    xent = main_diagonal_softmax_cross_entropy(logits=gram_ish_matrix)
    return jnp.mean(xent), nt_params

def constrastive_loss_v(params, nt_params, x):
    # x (B,2C,H,W,3)
    loss_fn_v = vmap(constrastive_loss, in_axes=[None, None, 0])
    loss_v, nt_params_v = loss_fn_v(params, nt_params, x)
    loss = jnp.mean(loss_v)
    return loss, nt_params_v

def calculate_gradients(params, nt_params, x):
    # x (B,2C,H,W,3)
    grad_fn = value_and_grad(constrastive_loss_v, has_aux=True)    
    (loss, nt_params_v), grads = grad_fn(params, nt_params, x)
    return (loss, nt_params_v), grads

opt = optax.adam(learning_rate=opts.learning_rate)

def train_step(params, nt_params, opt_state, x):
    (loss, nt_params_v), grads = calculate_gradients(params, nt_params, x)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    nt_params = [jnp.mean(p, axis=0) for p in nt_params_v]
    return params, nt_params, opt_state, loss

embedding_model = construct_embedding_model()

params = embedding_model.trainable_variables
nt_params = embedding_model.non_trainable_variables
opt_state = opt.init(params)

for epoch in range(200):
    params, nt_params, opt_state, loss = jit(train_step)(params, nt_params, opt_state, x)
    if epoch % 20 == 0:
        print('e', epoch, 'loss', loss)


In [None]:
# try a couple of the examples from the batch
# each looks good

for i in range(3):
    print("-"*10, i)
    embeddings, _ = embedding_model.stateless_call(params, nt_params, x[i], training=False)
    print(jnp.around(jnp.dot(embeddings, embeddings.T), 2))

In [None]:
embeddings, _ = embedding_model.stateless_call(params, nt_params, x[0], training=False)
print(jnp.around(jnp.dot(embeddings, embeddings.T), 2))

In [None]:
for variable, value in zip(embedding_model.trainable_variables, params):
    variable.assign(value)
for variable, value in zip(embedding_model.non_trainable_variables, nt_params):
    variable.assign(value)

In [None]:
embeddings = embedding_model(x[0], training=False)
print(jnp.around(jnp.dot(embeddings, embeddings.T), 2))

In [None]:
import pickle
with open('test.pkl', 'wb') as f:
    pickle.dump(embedding_model.get_weights(), f)

In [None]:
embedding_model_2 = construct_embedding_model()

with open('test.pkl', 'rb') as f:
    reloaded_weights = pickle.load(f)
embedding_model_2.set_weights(reloaded_weights)

embeddings = embedding_model_2(x[0], training=False)
print(jnp.around(jnp.dot(embeddings, embeddings.T), 2))

In [None]:
embedding_model.get_config()