# v1 embedding model

In [1]:
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
keras.__version__

'3.5.0'

In [2]:
import numpy as np
import jax.numpy as jnp
import jax

from keras.layers import Input, Dense, Conv2D, GlobalMaxPooling2D
from keras.layers import Layer, BatchNormalization, Activation
from keras.models import Model

from jax import jit, value_and_grad, vmap

In [27]:
class Opts:
    height_width = 64
    batch_size = 5          # B
    num_classes = 6         # C
    num_egs_per_class = 4   # N
    embedding_dim = 128     # E
    
opts = Opts()

In [28]:
from data import ConstrastiveExamples

c_egs = ConstrastiveExamples(
    root_dir='data/reference_egs',
    obj_ids=["061","135","182",  # x3 red
             "111","153","198",  # x3 green
             "000","017","019"], # x3 blue
)
ds = c_egs.dataset(batch_size=opts.batch_size,
                   objs_per_batch=opts.num_egs_per_class)
for x, y in ds:
    print(x.shape, y)

x = jnp.array(x)

self.label_idx_to_str {0: '061', 1: '135', 2: '182', 3: '111', 4: '153', 5: '198', 6: '000', 7: '017', 8: '019'}
(5, 8, 64, 64, 3) tf.Tensor(
[[5 5 7 7 0 0 1 1]
 [8 8 1 1 0 0 3 3]
 [6 6 7 7 5 5 4 4]
 [6 6 3 3 5 5 7 7]
 [5 5 3 3 1 1 4 4]], shape=(5, 8), dtype=int64)


2024-10-19 11:45:44.887051: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [29]:
def conv_bn_relu(filters, y):
  y = Conv2D(filters=filters, strides=2, kernel_size=3, activation=None, padding='same')(y)
  y = BatchNormalization()(y)
  return Activation('relu')(y)

input = Input((opts.height_width, opts.height_width, 3))
y = conv_bn_relu(filters=16, y=input)
y = conv_bn_relu(filters=32, y=y)
y = conv_bn_relu(filters=64, y=y)
y = conv_bn_relu(filters=128, y=y)
y = GlobalMaxPooling2D()(y)  # (B, E)

class L2Normalisation(Layer):
    def call(self, x):
        norm = jnp.linalg.norm(x, axis=-1, keepdims=True)
        return x / norm

# embed, with normalisation
embeddings = Dense(
    opts.embedding_dim,
    use_bias=False,
    kernel_initializer=keras.initializers.TruncatedNormal(),
    name='embeddings')(y)  # (B, E)
embeddings = L2Normalisation()(embeddings)

embedding_model = Model(input, embeddings)

embedding_model.summary()

In [30]:
params = embedding_model.trainable_variables
nt_params = embedding_model.non_trainable_variables

params, nt_params, [ntp.shape for ntp in nt_params]

([<KerasVariable shape=(3, 3, 3, 16), dtype=float32, path=conv2d_12/kernel>,
  <KerasVariable shape=(16,), dtype=float32, path=conv2d_12/bias>,
  <KerasVariable shape=(16,), dtype=float32, path=batch_normalization_12/gamma>,
  <KerasVariable shape=(16,), dtype=float32, path=batch_normalization_12/beta>,
  <KerasVariable shape=(3, 3, 16, 32), dtype=float32, path=conv2d_13/kernel>,
  <KerasVariable shape=(32,), dtype=float32, path=conv2d_13/bias>,
  <KerasVariable shape=(32,), dtype=float32, path=batch_normalization_13/gamma>,
  <KerasVariable shape=(32,), dtype=float32, path=batch_normalization_13/beta>,
  <KerasVariable shape=(3, 3, 32, 64), dtype=float32, path=conv2d_14/kernel>,
  <KerasVariable shape=(64,), dtype=float32, path=conv2d_14/bias>,
  <KerasVariable shape=(64,), dtype=float32, path=batch_normalization_14/gamma>,
  <KerasVariable shape=(64,), dtype=float32, path=batch_normalization_14/beta>,
  <KerasVariable shape=(3, 3, 64, 128), dtype=float32, path=conv2d_15/kernel>,
  <K

In [31]:
print("x[0]", x[0].shape)
embeddings = embedding_model(x[0])
embeddings.shape, jnp.linalg.norm(embeddings, axis=1)

x[0] (8, 64, 64, 3)


((8, 128),
 Array([0.99999994, 1.        , 0.99999994, 1.0000001 , 1.        ,
        1.        , 1.        , 1.        ], dtype=float32))

In [38]:
# the model sees "a batch" as the set of (anchor, positive) pairs
# whereas x is a batch of these.

embeddings, nt_params_2 = embedding_model.stateless_call(params, nt_params, x[0], training=True)

print("e shape", embeddings.shape)
print("e norms", jnp.linalg.norm(embeddings, axis=-1))
print("ntps", [p.shape for p in nt_params_2])

e shape (8, 128)
e norms [1.         1.         1.         0.99999994 1.         1.
 1.         1.        ]
ntps [(16,), (16,), (32,), (32,), (64,), (64,), (128,), (128,)]


In [43]:
# so to use the model batched we actually need to vmap it first 

def training_call(x):
    return embedding_model.stateless_call(params, nt_params, x, training=True)

training_call = vmap(training_call)
embeddings, nt_params_2 = training_call(x)

print("e shape", embeddings.shape)
print("e norms", jnp.linalg.norm(embeddings, axis=-1))
print("ntps", [p.shape for p in nt_params_2])

e shape (5, 8, 128)
e norms [[0.99999994 1.         1.         0.99999994 1.         1.
  1.         1.        ]
 [0.99999994 1.0000001  1.         1.         1.         1.
  1.         1.        ]
 [1.         1.         1.         0.99999994 1.         1.
  0.99999994 1.        ]
 [0.99999994 1.         1.         1.         1.         0.99999994
  1.         1.        ]
 [1.         0.99999994 1.         1.         0.99999994 1.
  1.         1.        ]]
ntps [(5, 16), (5, 16), (5, 32), (5, 32), (5, 64), (5, 64), (5, 128), (5, 128)]


In [44]:
# but note that the nt_params returned have been vectorised too
# i.e. they are [(B, p1), (B, p2), ...] instead of [(p1,), (p2,), ...]
# so, we need to aggreate them,

nt_params_2 = [jnp.mean(p, axis=0) for p in nt_params_2]
print("ntps", [p.shape for p in nt_params_2])

ntps [(16,), (16,), (32,), (32,), (64,), (64,), (128,), (128,)]


In [46]:
# now onto the constrastive loss,
# again first for a 'batch' of examples

def main_diagonal_softmax_cross_entropy(logits):
    # cross entropy assuming "labels" are just (0, 1, 2, ...) i.e. where
    # one_hot mask for log_softmax ends up just being the main diagonal
    return -jnp.sum(jnp.diag(jax.nn.log_softmax(logits)))
    
def constrastive_loss(params, nt_params, x):
    embeddings, nt_params = embedding_model.stateless_call(params, nt_params, x, training=True)
    embeddings = embeddings.reshape((opts.num_egs_per_class, 2, opts.embedding_dim))
    anchors = embeddings[:, 0]
    positives = embeddings[:, 1]
    gram_ish_matrix = jnp.einsum('ae,be->ab', anchors, positives)
    xent = main_diagonal_softmax_cross_entropy(logits=gram_ish_matrix)
    return jnp.mean(xent), nt_params


loss, nt_params_2 = constrastive_loss(params, nt_params, x[0])
loss, [ntp.shape for ntp in nt_params_2]

(Array(5.4682417, dtype=float32),
 [(16,), (16,), (32,), (32,), (64,), (64,), (128,), (128,)])

In [47]:
# as before we can vectorise the loss 
# takes (B, 2C, H, W, 3)
constrastive_loss_v = vmap(constrastive_loss, in_axes=[None, None, 0])

# and run over all of x 
per_eg_loss, nt_params_2 = constrastive_loss_v(params, nt_params, x)  # (N)

print('per_eg_loss', per_eg_loss)
print('ntp', [ntp.shape for ntp in nt_params_2])


per_eg_loss [5.4682417 5.4851623 5.500038  5.525987  5.538123 ]
ntp [(5, 16), (5, 16), (5, 32), (5, 32), (5, 64), (5, 64), (5, 128), (5, 128)]


In [48]:
# and, as before, we need to aggregate everythging

avg_loss = jnp.mean(per_eg_loss)
nt_params_2 = [jnp.mean(p, axis=0) for p in nt_params_2]

print('avg_loss', avg_loss)
print('ntp', [ntp.shape for ntp in nt_params_2])

avg_loss 5.5035105
ntp [(16,), (16,), (32,), (32,), (64,), (64,), (128,), (128,)]
