# v3 embedding model

* includes two generators, ContrastiveExamples & SceneExamples for two inputs

In [1]:
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras

In [2]:
from typing import List, Tuple

import numpy as np

import jax.numpy as jnp
from jax import vmap, jit, value_and_grad, nn

import optax

from data import ObjIdsHelper, ContrastiveExamples, SceneExamples
from models.models import construct_embedding_model

In [141]:
class Opts:
    num_batches = 1             # effective epoch length
    
    obj_height_width = 64    
    num_obj_references = 5     # N number of reference examples given for each object
    num_focus_objs = 3         # C total number of classes used for contrasting & focus in scene
    obj_filter_sizes = [4, 8]  
    obj_embedding_dim = 64     # E dim for obj reference embeddings
    
    scene_height_width = 640    
    scene_filter_sizes = [4, 8]
    scene_feature_dim = 96     # F dim for scene features

    classifier_filter_sizes = [4, 8]
    
    learning_rate = 1e-4

opts = Opts()

def shapes(debug_str, list_of_variables):
    return f"{debug_str} ({len(list_of_variables)}) {[v.shape for v in list_of_variables]}"

In [30]:
# from models.models import conv_bn_relu

from keras.layers import Input, Dense, Conv2D, GlobalMaxPooling2D, Reshape
from keras.layers import Layer, BatchNormalization, Activation, Concatenate
from keras.models import Model
from keras.initializers import TruncatedNormal, Constant

def conv_bn_relu(filters, y, name, one_by_one=False):   
    if one_by_one:
        main = Conv2D(
            filters=filters, strides=1, kernel_size=1,
            activation=None, padding='same',
            name=f"{name}_conv1x1")(y)
    else:
        main = Conv2D(
            filters=filters, strides=2, kernel_size=3,
            activation=None, padding='same',
            name=f"{name}_conv")(y)
        
    main = BatchNormalization(name=f"{name}_bn")(main)
    main = Activation('relu', name=f"{name}_relu")(main)

    # TODO add this residual back in at end when scaling up
    # branch = Conv2D(
    #     filters=filters, strides=1, kernel_size=3,
    #     activation=None, padding='same')(main)
    # branch = BatchNormalization()(branch)
    # branch = Activation('relu')(branch)
    
    return main #+ branch

class L2Normalisation(Layer):
    def call(self, x):
        norm = jnp.linalg.norm(x, axis=-1, keepdims=True)
        return x / norm

class Tiling(Layer):    
    def __init__(self, grid_size, name):
        super().__init__(name=name)
        self.grid_size = grid_size        
    def call(self, x):
        return jnp.tile(
            x[:,None,None,:],
            (1, self.grid_size, self.grid_size, 1))

## embedding model

In [36]:
def construct_embedding_model(
        height_width: int,
        filter_sizes: List[int],
        embedding_dim: int
        ):

    input = Input((height_width, height_width, 3))

    y = input
    for i, f in enumerate(filter_sizes):
        y = conv_bn_relu(filters=f, y=y, name=f"obj_e_{i}")
    y = GlobalMaxPooling2D(name='obj_e_gp')(y)  # (B, E)

    # embed, with normalisation
    embeddings = Dense(
        embedding_dim,
        use_bias=False,
        kernel_initializer=TruncatedNormal(),
        name='obj_embeddings')(y)  # (B, E)
    embeddings = L2Normalisation(name='obj_e_l2')(embeddings)

    return Model(input, embeddings)

In [37]:
embedding_model = construct_embedding_model(
    opts.obj_height_width, opts.obj_filter_sizes, opts.obj_embedding_dim)
embedding_model.summary()

## scene model

In [107]:
def construct_scene_model(
    scene_height_width: int,
    scene_filter_sizes: List[int],
    scene_feature_dim: int,
    expected_obj_embedding_dim: int,
    classifier_filter_sizes: List[int]
    ):

    # scene backbone
    scene_input = Input((scene_height_width, scene_height_width, 3), name='scene_input')
    y = scene_input
    for i, f in enumerate(scene_filter_sizes):
        y = conv_bn_relu(filters=f, y=y, name=f"scene_{i}")
        
    # final feature layer ( projection, no relu )
    scene_features = Dense(
        scene_feature_dim,
        use_bias=False, activation=None,
        kernel_initializer=TruncatedNormal(),        
        name='scene_features')(y)  # (B, F)    

    # input branch from obj_embeddings
    obj_embedding_input = Input((expected_obj_embedding_dim,), name='obj_embedding_inp')

    # tile the embeddings to match the spatial size of the features 
    # from the scene backbone
    grid_size = scene_features.shape[-2]  # assume square, dangerous?        
    tiled_obj_embeddings = Tiling(grid_size, name='tiled_obj_emb')(obj_embedding_input)

    # combine the two sets of features
    obj_scene_features = Concatenate(axis=-1)([scene_features, tiled_obj_embeddings])

    # add classifier ( logits )
    classifier = obj_scene_features
    for i, f in enumerate(scene_filter_sizes):
        classifier = conv_bn_relu(filters=f, y=classifier, 
                                  name=f"classifier_{i}", one_by_one=True)
    classifier = Dense(1, name='classifier')(classifier)

    return Model(inputs=[scene_input, obj_embedding_input], 
                 outputs=classifier)

scene_model = construct_scene_model(
    scene_height_width=opts.scene_height_width,
    scene_filter_sizes=[4, 8, 16, 32, 64, 64], #opts.scene_filter_sizes,
    scene_feature_dim=opts.scene_feature_dim,
    expected_obj_embedding_dim=opts.obj_embedding_dim,
    classifier_filter_sizes=[8, 16] #opts.classifier_filter_sizes
)
scene_model.summary()

# run example stateless
B = 3
eg_embedding_batch = np.ones((B, opts.obj_embedding_dim))
eg_scene_batch = np.ones((B, opts.scene_height_width, opts.scene_height_width, 3))
classifier_out, s_nt_params = scene_model.stateless_call(
    scene_model.trainable_variables,
    scene_model.non_trainable_variables,
    [eg_scene_batch, eg_embedding_batch])
print('classifier_out', classifier_out.shape)
print(shapes('s_nt_params', s_nt_params))

classifier_out (3, 10, 10, 1)
s_nt_params [(4,), (4,), (8,), (8,), (16,), (16,), (32,), (32,), (64,), (64,), (64,), (64,), (4,), (4,), (8,), (8,), (16,), (16,), (32,), (32,), (64,), (64,), (64,), (64,)]


## dataset

In [86]:
import copy

obj_ids_helper = ObjIdsHelper(
    root_dir='data/train/reference_patches/',
    obj_ids=["061", "135","182",  # x3 red
             "111", "153","198",  # x3 green
             "000", "017","019"], # x3 blue
    seed=123
)

obj_egs = ContrastiveExamples(obj_ids_helper)
obj_ds = obj_egs.dataset(num_batches=opts.num_batches,
                   num_obj_references=opts.num_obj_references,
                   num_contrastive_examples=opts.num_focus_objs)

scene_egs = SceneExamples(
    obj_ids_helper=obj_ids_helper,
    grid_size=10,
    num_other_objs=4,
    instances_per_obj=3,
    seed=123)
scene_ds = scene_egs.dataset(
    num_batches=opts.num_batches,
    num_focus_objects=opts.num_focus_objs)

for (obj_x, _obj_y), (scene_x, scene_y_true) in zip(obj_ds, scene_ds):
    obj_x = jnp.array(obj_x)    
    scene_x = jnp.array(scene_x)
    scene_y_true = jnp.array(scene_y_true)    
    print('obj_ x', obj_x.shape)
    print('scene_ x', scene_x.shape, 'y_true', scene_y_true.shape)

scene._example_generator batch=0 obj_id=017 ( label = 7 ) 
scene._example_generator batch=0 obj_id=135 ( label = 1 ) 
scene._example_generator batch=0 obj_id=198 ( label = 5 ) 
obj_ x (3, 2, 5, 64, 64, 3)
scene_ x  (3, 640, 640, 3) y_true (3, 10, 10)


2024-10-26 10:49:48.164659: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


## composite model

In [163]:
loss_weights = { 'constrastive': 1.0, 'scene': 100.0 }

def mean_embeddings(params, nt_params, x):
    # x (N, H, W, 3)
    embeddings, nt_params = embedding_model.stateless_call(
        params, nt_params, x, training=True)  # (N, E)
    # average over N
    embeddings = jnp.mean(embeddings, axis=0)  # (E)
    # (re) L2 normalise
    embeddings /= jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
    return embeddings, nt_params  # (E,)

def main_diagonal_softmax_cross_entropy(logits):
    # cross entropy assuming "labels" are just (0, 1, 2, ...) i.e. where
    # one_hot mask for log_softmax ends up just being the main diagonal
    return -jnp.sum(jnp.diag(nn.log_softmax(logits)))


def forward(params, nt_params, obj_x, scene_x):
    # obj_x    (C, 2, N, oHW, oHW, 3)
    # scene_x  (C, sHW, sHW, 3)
    # scene_y  (C, G, G, 1)
    
    e_params, s_params = params
    e_nt_params, s_nt_params = nt_params

    # first run obj reference branch
    
    # first flatten obj_x to single 2C "batch" over N to get common batch norm stats
    # TODO: how are these stats skewed w.r.t to fact we'll call over N during inference
    C = obj_x.shape[0]
    nhwc = obj_x.shape[-4:]
    obj_x = obj_x.reshape((-1, *nhwc))  # (2C, N, oHW, oHW, 3)

    # run through mean embeddings which reduces over N
    # ( and average non trainables )
    v_mean_embeddings = vmap(mean_embeddings, in_axes=(None, None, 0))
    obj_embeddings, e_nt_params = v_mean_embeddings(
        e_params, e_nt_params, obj_x)  # (2C, E)
    e_nt_params = [jnp.mean(p, axis=0) for p in e_nt_params]

    # reshape back to split anchors and positives
    obj_embeddings = obj_embeddings.reshape((C, 2, -1))  # (C, 2, E)
    anchors = obj_embeddings[:,0]
    positives = obj_embeddings[:,1]
    #print('anchors', anchors.shape)

    # second; run scene branch runs ( with just anchors for obj references )
    
    # classifier_out (C, G, G, 1) ( logits )
    classifier_out, s_nt_params = scene_model.stateless_call(
        s_params, s_nt_params, [scene_x, anchors], training=True)
    #print('classifier_out', classifier_out.shape)
    #print(shapes('s_nt_params', s_nt_params))

    nt_params = e_nt_params, s_nt_params
    return anchors, positives, classifier_out, nt_params

def calculate_individual_losses(params, nt_params, obj_x, scene_x, scene_y_true):
    # obj_x    (C, 2, N, oHW, oHW, 3)
    # scene_x  (C, sHW, sHW, 3)
    # scene_y  (C, G, G, 1)

    # run forward through two networks
    anchors, positives, classifier_out, nt_params = forward(
        params, nt_params, obj_x, scene_x)
        
    # calculate contrastive loss from obj embeddings
    gram_ish_matrix = jnp.einsum('ae,be->ab', anchors, positives)
    metric_losses = main_diagonal_softmax_cross_entropy(logits=gram_ish_matrix)
    metric_loss = jnp.mean(metric_losses)

    # calculate classifier loss is binary cross entropy ( mean across all instances )
    scene_losses = optax.losses.sigmoid_binary_cross_entropy(
        logits=classifier_out.flatten(),
        labels=scene_y_true.flatten())
    scene_loss = jnp.mean(scene_losses)

    # return losses ( with nt_params updated from forward call )
    return metric_loss, scene_loss, nt_params

def calculate_single_loss(params, nt_params, obj_x, scene_x, scene_y_true):
    metric_loss, scene_loss, nt_params = calculate_individual_losses(
        params, nt_params, obj_x, scene_x, scene_y_true)
    loss = (loss_weights['constrastive']) * metric_loss + (loss_weights['scene'] * scene_loss)
    return loss,  nt_params

def calculate_gradients(params, nt_params, obj_x, scene_x, scene_y_true):
    # obj_x    (C, 2, N, oHW, oHW, 3)
    # scene_x  (C, sHW, sHW, 3)
    # scene_y  (C, G, G, 1)
    grad_fn = value_and_grad(calculate_single_loss, has_aux=True)    
    (loss, nt_params), grads = grad_fn(
        params, nt_params, obj_x, scene_x, scene_y_true)
    return (loss, nt_params), grads

opt = optax.adam(learning_rate=opts.learning_rate)

def train_step(
    params, nt_params, opt_state, 
    obj_x, scene_x, scene_y_true):

    # calculate gradients
    (loss, nt_params), grads = calculate_gradients(
        params, nt_params, obj_x, scene_x, scene_y_true)
    
    # # this is bit clumsy; because params was passed to grad call 
    # # with _all_ params (including non trainables, we get back 
    # # grads w.r.t to the non trainables ( which
    # # will be zero and can be ignored... )
    # e_params, _, s_params, _ = params
    # e_grads, _, s_grads, _ = grads

    # calculate updates from optimiser
    updates, opt_state = opt.update(grads, opt_state, params)

    # apply updates to get new params
    params = optax.apply_updates(params, updates)

    # return
    return params, nt_params, opt_state, loss
    

print('initial params')
e_params = embedding_model.trainable_variables
e_nt_params = embedding_model.non_trainable_variables
s_params = scene_model.trainable_variables
s_nt_params = scene_model.non_trainable_variables
print(shapes('e_params', e_params))
print(shapes('e_nt_params', e_nt_params))
print(shapes('s_params', s_params))
print(shapes('s_nt_params', s_nt_params))

# package up trainable and non trainables in tuples
params = e_params, s_params
nt_params = e_nt_params, s_nt_params

# optimser will run against both
opt_state = opt.init(params)

train_step = jit(train_step)

for e in range(1000):
    params, nt_params, opt_state, loss = train_step(
        params, nt_params, opt_state,
        obj_x, scene_x, scene_y_true)
    if e % 50 == 0:
        metric_loss, scene_loss, _ = calculate_individual_losses(
            params, nt_params, obj_x, scene_x, scene_y_true)
        print('metric_loss', metric_loss, 'scene_loss', scene_loss)



# (loss, e_nt_params, s_nt_params), grads = jit(calculate_gradients)(
#     params, opt_state,
#     obj_x, scene_x, scene_y_true,)

# print('from grad call...')
# print('loss', loss)
# print(shapes('e_nt_params', e_nt_params))
# print(shapes('s_nt_params', s_nt_params))

# # bit clumsy; side effect of passing all params in...

# e_grads = grads[0]
# s_grads = grads[2]
# for i in range(4):
#     print('grads', i, len(grads[i]))    


initial params
e_params (9) [(3, 3, 3, 4), (4,), (4,), (4,), (3, 3, 4, 8), (8,), (8,), (8,), (8, 64)]
e_nt_params (4) [(4,), (4,), (8,), (8,)]
s_params (51) [(3, 3, 3, 4), (4,), (4,), (4,), (3, 3, 4, 8), (8,), (8,), (8,), (3, 3, 8, 16), (16,), (16,), (16,), (3, 3, 16, 32), (32,), (32,), (32,), (3, 3, 32, 64), (64,), (64,), (64,), (3, 3, 64, 64), (64,), (64,), (64,), (64, 96), (1, 1, 160, 4), (4,), (4,), (4,), (1, 1, 4, 8), (8,), (8,), (8,), (1, 1, 8, 16), (16,), (16,), (16,), (1, 1, 16, 32), (32,), (32,), (32,), (1, 1, 32, 64), (64,), (64,), (64,), (1, 1, 64, 64), (64,), (64,), (64,), (64, 1), (1,)]
s_nt_params (24) [(4,), (4,), (8,), (8,), (16,), (16,), (32,), (32,), (64,), (64,), (64,), (64,), (4,), (4,), (8,), (8,), (16,), (16,), (32,), (32,), (64,), (64,), (64,), (64,)]
metric_loss 3.2671108 scene_loss 0.55294204
metric_loss 3.2582855 scene_loss 0.22210969
metric_loss 3.241908 scene_loss 0.14593804
metric_loss 3.214454 scene_loss 0.10628454
metric_loss 3.161618 scene_loss 0.0818252

# ( old stuff below for reference )

In [110]:
# y_true = np.array([0,0,1,1], dtype=float)
# print('y_true', y_true.shape, np.squeeze(y_true))
# y_pred_logits = np.array([-10,10,-10,10], dtype=float)
# print('y_pred_logits', y_pred_logits.shape, y_pred_logits)

# from optax.losses import sigmoid_binary_cross_entropy
# losses = sigmoid_binary_cross_entropy(logits=y_pred_logits, labels=y_true)
# print('losses', losses.shape, losses)

x = jnp.array([[0,0],[1,0]])
x.flatten()


    

Array([0, 0, 1, 0], dtype=int32)

In [88]:
y_pred = np.random.uniform(size=scene_y.shape)
y_pred.shape

(3, 10, 10)

In [10]:
# plotting debug

from PIL import Image

def collage(pil_imgs, rows, cols):
    n = len(pil_imgs) 
    if n != rows * cols:
        raise Exception()
    img_h, img_w = pil_imgs[0].size    
    collage = Image.new('RGB', (cols*img_h, rows*img_w))
    for i in range(n):
        pc, pr = i%rows, i//rows
        collage.paste(pil_imgs[i], (pr*img_h, pc*img_w))
    return collage
    
def to_pil_img(a):
    a = np.array(a*255, dtype=np.uint8)
    return Image.fromarray(a)

In [11]:
from data import ContrastiveExamples, load_fname

# start with simple case of  R, G, B examples
c_egs = ContrastiveExamples(
    root_dir='data/train/reference_patches/',
    obj_ids=["061", # "135","182",  # x3 red
             "111", # "153","198",  # x3 green
             "000", #"017","019", # x3 blue
            ]
)
ds = c_egs.dataset(num_batches=1,                                   # epoch length
                   batch_size=opts.batch_size,                      # B
                   num_obj_references=opts.num_obj_references,      # N
                   num_contrastive_objs=opts.num_contrastive_objs)  # C
for x, y in ds:
    break

x = jnp.array(x)
print(x.shape, y)

TypeError: ContrastiveExamples.__init__() got an unexpected keyword argument 'root_dir'

In [None]:
# recall each element of outer batch is xN examples of anchor/positive pairs
collage(
    [to_pil_img(x[0,0,0]), 
     to_pil_img(x[0,0,1])],
    rows=1, cols=2)

In [None]:
# recall each element of outer batch is xN examples of anchor/positive pairs
collage(
    [to_pil_img(x[0,1,0]), 
     to_pil_img(x[0,1,1])],
    rows=1, cols=2)

In [None]:
# recall X.shapee (B=2, N=4, C=6, 64, 64, 3)
# to batch from right to left though, we want (B=2, C=6, N=4, 64, 64, 3)
x = jnp.transpose(x, (0,2,1,3,4,5))
print(x.shape)


model is simple enough embedding model

output is L2 normalised embedding ( so dot products can be used for sims and xent contrastive )

In [None]:
from models.models import construct_embedding_model

embedding_model = construct_embedding_model(
    height_width=64,
    filter_sizes=[16,32,64,128],
    embedding_dim=128)

In [None]:

params = embedding_model.trainable_variables
nt_params = embedding_model.non_trainable_variables

def mean_embeddings(params, nt_params, x):
    # x (N, H, W, 3)
    embeddings, nt_params = embedding_model.stateless_call(params, nt_params, x, training=True)  # (N, E)
    # average over N
    embeddings = jnp.mean(embeddings, axis=0)  # (E)
    # (re) L2 normalise
    embeddings /= jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
    return embeddings, nt_params

embeddings, nt_params_2 = mean_embeddings(params, nt_params, x[0,0])

print("e shape", embeddings.shape)
print("e norms", jnp.linalg.norm(embeddings, axis=-1))
print(shapes('ntps', nt_params_2))

In [None]:
embeddings, nt_params_2 =  vmap(mean_embeddings, in_axes=(None, None, 0))(params, nt_params, x[0])
nt_params_2 = [jnp.mean(p, axis=0) for p in nt_params_2]

print("e shape", embeddings.shape)
print("e norms", jnp.linalg.norm(embeddings, axis=-1))
print(shapes('ntps', nt_params_2))

In [None]:
# define the constrastive loss based on the 'batch' of 2N examples ( N pairs )

def main_diagonal_softmax_cross_entropy(logits):
    # cross entropy assuming "labels" are just (0, 1, 2, ...) i.e. where
    # one_hot mask for log_softmax ends up just being the main diagonal
    return -jnp.sum(jnp.diag(jax.nn.log_softmax(logits)))

def mean_embeddings(params, nt_params, x):
    # x (N, H, W, 3)
    embeddings, nt_params = embedding_model.stateless_call(params, nt_params, x, training=True)  # (N, E)
    # average over N
    embeddings = jnp.mean(embeddings, axis=0)  # (E)
    # (re) L2 normalise
    embeddings /= jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
    return embeddings, nt_params
    
def constrastive_loss(params, nt_params, x):
    # x (2C, N, H, W, 3)
    embeddings, nt_params = vmap(mean_embeddings, in_axes=(None, None, 0))(params, nt_params, x)
    nt_params = [jnp.mean(p, axis=0) for p in nt_params]
    # embeddings (2C, E)    
    embeddings = embeddings.reshape((-1, 2, opts.embedding_dim))  # (C, 2, E)
    anchors = embeddings[:, 0]
    positives = embeddings[:, 1]
#    print('anchors', anchors.shape, 'positives', positives.shape)
    gram_ish_matrix = jnp.einsum('ae,be->ab', anchors, positives)
    xent = main_diagonal_softmax_cross_entropy(logits=gram_ish_matrix)
    return jnp.mean(xent), nt_params

def batch_constrastive_loss(params, nt_params, x):
    losses, nt_params = vmap(constrastive_loss, in_axes=(None, None, 0))(params, nt_params, x)
    # x (B, 2C, N, H, W, 3)
    nt_params = [jnp.mean(p, axis=0) for p in nt_params]
    return jnp.mean(losses), nt_params


loss, nt_params_2 = batch_constrastive_loss(params, nt_params, x)
print('loss', loss)
print(shapes('ntps', nt_params_2))

In [None]:
# define gradients and a simple training loop

def calculate_gradients(params, nt_params, x):
    # x (2C,H,W,3)
    grad_fn = value_and_grad(constrastive_loss, has_aux=True)    
    (loss, nt_params), grads = grad_fn(params, nt_params, x)
    return (loss, nt_params), grads

opt = optax.adam(learning_rate=opts.learning_rate)

def train_step(params, nt_params, opt_state, x):
    (loss, nt_params), grads = calculate_gradients(params, nt_params, x)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, nt_params, opt_state, loss

embedding_model = construct_embedding_model()

params = embedding_model.trainable_variables
nt_params = embedding_model.non_trainable_variables
opt_state = opt.init(params)

for e in range(1000):
    params, nt_params, opt_state, loss = jit(train_step)(params, nt_params, opt_state, x[0])
    if e % 100 == 0:
        print(loss)


In [None]:
# test against batch
embeddings, _ = embedding_model.stateless_call(params, nt_params, x[0], training=False)
embeddings.shape

# looks good (0,1) (2,3) (4,5) all pair well ( and others are -0.5 )
jnp.around(jnp.dot(embeddings, embeddings.T), 2)

next we get things working on a batch of these examples