# Explore Shared Embedding

## Imports

In [8]:
import sys
import os

import numpy as np

root_path = os.path.abspath(os.path.join('..'))
if root_path not in sys.path:
    sys.path.append(root_path)

from src.models.shared_embedding import SharedEmbedding

## Overview

In order to get a better feel for operations going on inside the `SharedEmbedding` let's first print the dimensions of the shared embeddings input, weights, and outputs using shared_embedding_dimensions for the following parameter values:

* `vocab_size=10`
* `d_model=4`

> Note: The input to the embedding layer will already have gone through the data pipeline process above and so in order to replicate such a dataset we can simply create a array of vectors of a set length.

> Note: The embedding weights are subject to learning during the model training process. The weights illustrated below serve as initial values, which, after just a single forward pass through the SharedEmbedding layer, do not hold meaningful representations. Throughout training, these weights will iteratively adjust to encapsulate more useful and semantically relevant information, driven by the minimization of the model's loss function.

## Initialize SharedEmbedding

In [6]:
shared_embedding = SharedEmbedding(vocab_size=10, d_model=4)

## Print input and ouput for a batch containing 1 input sequence

### Input to Shared Embedding

In [28]:
dummy_sentence = np.array([[3, 1, 9, 8, 7]])

print(f"Dimensions (batch_size, seq_length): {dummy_sentence.shape}")
print(f"{dummy_sentence}")

Dimensions (batch_size, seq_length): (1, 5)
[[3 1 9 8 7]]


## Output of Shared Embedding

In [31]:
embedded_output = shared_embedding(dummy_sentence)

print(f"Dimensions (batch_size, seq_length): {embedded_output.shape}")
print(f"{embedded_output}")

Dimensions (batch_size, seq_length): (1, 5, 4)
[[[-0.09908666  0.06959189 -0.05134678  0.06398902]
  [ 0.07289138  0.05813029 -0.00244436 -0.00594692]
  [-0.05870557  0.08982063  0.0197406   0.04603086]
  [-0.09961352  0.01636253  0.01097564  0.04978543]
  [-0.02301206 -0.01816873  0.03561945 -0.0379561 ]]]


## Print input and ouput for a batch containing 3 input sequences

### Input to Shared Embedding

In [34]:
dummy_sentences = np.array([
        [3, 1, 9, 8, 7],
        [0, 2, 3, 4, 5],
        [6, 7, 1, 8, 9]
])

print(f"Dimensions (batch_size, seq_length): {dummy_sentences.shape}")
print(f"{dummy_sentences}")

Dimensions (batch_size, seq_length): (3, 5)
[[3 1 9 8 7]
 [0 2 3 4 5]
 [6 7 1 8 9]]


### Output of Shared Embedding

In [35]:
embedded_output = shared_embedding(dummy_sentences)

print(f"Dimensions (batch_size, seq_length): {embedded_output.shape}")
print(f"{embedded_output}")

Dimensions (batch_size, seq_length): (3, 5, 4)
[[[-0.09908666  0.06959189 -0.05134678  0.06398902]
  [ 0.07289138  0.05813029 -0.00244436 -0.00594692]
  [-0.05870557  0.08982063  0.0197406   0.04603086]
  [-0.09961352  0.01636253  0.01097564  0.04978543]
  [-0.02301206 -0.01816873  0.03561945 -0.0379561 ]]

 [[ 0.04770365 -0.00831404  0.09543822  0.00717072]
  [-0.03777781  0.03604748 -0.03230224  0.09779359]
  [-0.09908666  0.06959189 -0.05134678  0.06398902]
  [-0.04405236 -0.08592837  0.02742467  0.09911964]
  [-0.06826501  0.09861846  0.02513096 -0.02877231]]

 [[ 0.02805883 -0.08670046  0.00149784  0.05752458]
  [-0.02301206 -0.01816873  0.03561945 -0.0379561 ]
  [ 0.07289138  0.05813029 -0.00244436 -0.00594692]
  [-0.09961352  0.01636253  0.01097564  0.04978543]
  [-0.05870557  0.08982063  0.0197406   0.04603086]]]


## Print the embedding maxtrix weights for the above

In [36]:
embedding_weights = shared_embedding.get_embedding()
print("Dimensions (vocab_size, d_model): {}".format(embedding_weights.shape))
print(f"{embedding_weights}")

Dimensions (vocab_size, d_model): (10, 4)
[[ 0.04770365 -0.00831404  0.09543822  0.00717072]
 [ 0.07289138  0.05813029 -0.00244436 -0.00594692]
 [-0.03777781  0.03604748 -0.03230224  0.09779359]
 [-0.09908666  0.06959189 -0.05134678  0.06398902]
 [-0.04405236 -0.08592837  0.02742467  0.09911964]
 [-0.06826501  0.09861846  0.02513096 -0.02877231]
 [ 0.02805883 -0.08670046  0.00149784  0.05752458]
 [-0.02301206 -0.01816873  0.03561945 -0.0379561 ]
 [-0.09961352  0.01636253  0.01097564  0.04978543]
 [-0.05870557  0.08982063  0.0197406   0.04603086]]
