In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np

import tensorflow as tf

from tensorflow.keras.layers import Dense

from micron2.clustering.embedding_moco import update_key_model, train_moco

tf.config.set_visible_devices([], 'GPU')
# tf.debugging.set_log_device_placement(True)

In [None]:
help(Dense)

In [None]:
from tensorflow.keras.initializers import (HeNormal, Zeros)
class Net(tf.keras.Model):
    def __init__(self, initializer=HeNormal):
        super(Net, self).__init__()
        self.d1 = Dense(4, 
                        kernel_initializer=initializer,
                        bias_initializer=initializer
                       )
        self.d2 = Dense(2,
                        kernel_initializer=initializer,
                        bias_initializer=initializer
                       )
    
    def __call__(self, x):
        x = self.d1(x)
        x = self.d2(x)
        return x
    
    
model = Net(initializer=HeNormal)
kmodel = Net(initializer=Zeros)

In [None]:
x = np.random.randn(5,8).astype(np.float32)
_ = model(x)
_ = kmodel(x)

In [None]:
model.trainable_variables

In [None]:
# These should be zero
kmodel.trainable_variables

In [None]:
# Apply momentum update
update_key_model(model, kmodel)

In [None]:
# These should be model.variable * (1 - 0.999)
kmodel.trainable_variables

In [None]:
class MoCoQueue:
    def __init__(self, max_len=8):
        self.Q = []
        self.max_len = max_len
    
    def __len__(self):
        return len(self.Q)
        
    def enqueue(self, z):
        self.Q.append(z)
        if len(self) > self.max_len:
            self.dequeue()
    
    def dequeue(self):
        """ Remove the oldest item in queue """
        _ = self.Q.pop(0)
    
    def getqueue(self):
        return tf.concat(self.Q, axis=0)
    
    
    
def fake_encoder(batch_size=16):
    return tf.constant(np.random.randn(batch_size,5))



def moco_loss(q_feat, key_feat, queue, batch_size=1, temp=1.):
    ## https://github.com/ppwwyyxx/moco.tensorflow/blob/master/main_moco.py
    # loss
    l_pos = tf.reshape(tf.einsum('nc,nc->n', q_feat, key_feat), (-1, 1))  # nx1
    l_neg = tf.einsum('nc,kc->nk', q_feat, queue)  # nxK
    logits = tf.concat([l_pos, l_neg], axis=1)  # nx(1+k)
    logits = logits * (1 / temp)
    labels = tf.zeros(batch_size, dtype=tf.int64)  # n
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    loss = tf.reduce_mean(loss, name='xentropy-loss')
    return loss

In [None]:
k_history = MoCoQueue(max_len=8)

n_queue_batches = 4
for n in range(n_queue_batches):
    k = fake_encoder()
    k_history.enqueue(k)
    
print(len(k_history))

In [None]:
batch_size = 16

for _ in range(20):
    q = fake_encoder(batch_size)
    k = fake_encoder(batch_size)

    l = moco_loss(q, k, k_history.getqueue(), batch_size=batch_size)

    k_history.enqueue(k)
    # k_history.dequeue()
    print(l, len(k_history))