# Siamese Network

In [2]:
import trax
from trax import layers as tl
import trax.fastmath.numpy as np
import numpy

# Setting random seeds
trax.supervised.trainer_lib.init_random_number_generators(10)
numpy.random.seed(10)



INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 




## L2 Normalization
- Define an L2 normalization to a tensor
- To build custom loss function, it is expected that the tensors received are normalized

In [3]:
def normalize(x): 
    return x / np.sqrt(np.sum(x * x, axis=-1, keepdims=True))
#   return x / np.linalg.norm(x, axis=-1, keepdims=True)

In [7]:
import handcalcs.render
from math import sqrt
x =10

In [17]:
%%render
L_2 = x / (sqrt(sum(x ^ 2)))

TypeError: 'int' object is not iterable

<IPython.core.display.Latex object>

In [18]:
tensor = numpy.random.random((2, 5))
print(f'The tensor is of type: {type(tensor)}\n\nAnd looks like this:\n\n {tensor}')

The tensor is of type: <class 'numpy.ndarray'>

And looks like this:

 [[0.77132064 0.02075195 0.63364823 0.74880388 0.49850701]
 [0.22479665 0.19806286 0.76053071 0.16911084 0.08833981]]


In [19]:
norm_tensor = normalize(tensor)
print(f'The normalized tensor is of type: {type(norm_tensor)}\n\nAnd looks like this:\n\n {norm_tensor}')

The normalized tensor is of type: <class 'jax.interpreters.xla.DeviceArray'>

And looks like this:

 [[0.57393795 0.01544148 0.4714962  0.55718327 0.37093794]
 [0.26781026 0.23596111 0.9060541  0.20146926 0.10524315]]


## Siamese Model
- An LSTM model using `Serial` combinator layer
- A `Parallel` combinator to create the Siamese model

In [20]:
vocab_size = 500
model_dimension = 128

LSTM = tl.Serial(
    tl.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
    tl.LSTM(model_dimension),
    tl.Mean(axis=1),
    tl.Fn('Normalize', lambda x: normalize(x))
)

# Use the Parallel combinator to create a Siamese model
Siamese = tl.Parallel(LSTM, LSTM)

In [22]:
def show_layers(model, layer_prefix):
    print(f'Total Layers: {len(model.sublayers)}\n')
    for i in range(len(model.sublayers)):
        print("============")
        print(f'{layer_prefix}_{i}: {model.sublayers[i]}\n')
        
print('Siamese Model: \n')
show_layers(Siamese, 'Parallel.sublayers')

print('Detail of LSTM models: \n')
show_layers(LSTM, 'Serial.sublayers')

Siamese Model: 

Total Layers: 2

Parallel.sublayers_0: Serial[
  Embedding_500_128
  LSTM_128
  Mean
  Normalize
]

Parallel.sublayers_1: Serial[
  Embedding_500_128
  LSTM_128
  Mean
  Normalize
]

Detail of LSTM models: 

Total Layers: 4

Serial.sublayers_0: Embedding_500_128

Serial.sublayers_1: LSTM_128

Serial.sublayers_2: Mean

Serial.sublayers_3: Normalize

