In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict

import numpy as np

import torch
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

In [3]:
MODEL_PATH = 'models/universal-sentence-encoder-multilingual-large-3'

In [4]:
model = tf.saved_model.load(MODEL_PATH)
model

<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x35f72b1f0>

In [5]:
sentence = "Hello, how are you?"
ref = model(sentence)
ref

<tf.Tensor: shape=(1, 512), dtype=float32, numpy=
array([[ 2.77998135e-03,  2.53025256e-02, -1.34592326e-02,
        -6.00954033e-02, -3.57862748e-03,  2.16413625e-02,
         1.99971721e-02, -3.70096639e-02, -2.91313529e-02,
        -1.07441902e-01,  2.73036622e-02, -8.28249753e-02,
         1.23869926e-02,  2.55314838e-02, -3.11783515e-02,
         7.67881498e-02,  5.88866174e-02,  9.00684148e-02,
        -2.02289019e-02,  5.06051891e-02,  4.59435731e-02,
        -9.70152542e-02,  1.44991055e-02,  1.12706097e-02,
        -2.17566569e-03, -3.42217949e-03,  7.06949830e-02,
        -3.22969817e-02, -2.33398993e-02, -2.30249688e-02,
         1.44220600e-02, -1.73140336e-02, -6.50571808e-02,
        -6.18350431e-02,  7.21207708e-02, -6.45894706e-02,
         2.22161412e-02,  8.38220119e-04, -1.77222937e-02,
        -4.84013595e-02,  4.19013053e-02, -1.89781878e-02,
        -1.09328344e-01,  4.43844795e-02,  1.54845491e-02,
         4.25968394e-02,  4.46535647e-02,  7.47142434e-02,
      

In [6]:
mapper = defaultdict(list)
appendix_set = set()
parts_set = set()

for tensor in model.trainable_variables:
    name, appendix = tensor.name.split(':')
    appendix_set.add(appendix)
    name, part = name.rsplit('/', maxsplit=1)
    parts_set.add(part)
    mapper[name].append(tensor.numpy())

assert appendix_set == {'0'}

In [7]:
parts_set

{'part_0',
 'part_1',
 'part_10',
 'part_11',
 'part_12',
 'part_13',
 'part_14',
 'part_15',
 'part_16',
 'part_2',
 'part_3',
 'part_4',
 'part_5',
 'part_6',
 'part_7',
 'part_8',
 'part_9',
 'sharded_0',
 'sharded_1',
 'sharded_10',
 'sharded_11',
 'sharded_12',
 'sharded_13',
 'sharded_14',
 'sharded_15',
 'sharded_16',
 'sharded_2',
 'sharded_3',
 'sharded_4',
 'sharded_5',
 'sharded_6',
 'sharded_7',
 'sharded_8',
 'sharded_9'}

In [8]:
weights = {}

for name, tensors_list in mapper.items():
    if name == 'Embeddings':
        n_matrices = len(tensors_list)
        num_embeddings = n_matrices * tensors_list[0].shape[0]
        embedding_dim = tensors_list[0].shape[1]
        embeddings = np.zeros((num_embeddings, embedding_dim))
        for idx in range(num_embeddings):
            j = idx // n_matrices
            i = idx - j * n_matrices
            embeddings[idx] = tensors_list[i][j]
        weights[name] = embeddings
    else:
        weights[name] = np.concatenate(tensors_list, axis=0)

In [9]:
for name, weight in weights.items():
    print(f'{name}: {weight.shape}')

Embeddings: (128010, 512)
EncoderDNN/DNN/ResidualHidden_0/dense/kernel: (512, 320)
EncoderDNN/DNN/ResidualHidden_1/dense/kernel: (320, 320)
EncoderDNN/DNN/ResidualHidden_1/AdjustDepth/projection/kernel: (512, 320)
EncoderDNN/DNN/ResidualHidden_2/dense/kernel: (320, 512)
EncoderDNN/DNN/ResidualHidden_3/dense/kernel: (512, 512)
EncoderDNN/DNN/ResidualHidden_3/AdjustDepth/projection/kernel: (320, 512)
EncoderTransformer/Transformer/dense/kernel: (512, 512)
EncoderTransformer/Transformer/dense/bias: (512,)
EncoderTransformer/Transformer/SparseTransformerEncode/Layer_0/SelfAttention/layer_prepostprocess/layer_norm/layer_norm_scale: (512,)
EncoderTransformer/Transformer/SparseTransformerEncode/Layer_0/SelfAttention/layer_prepostprocess/layer_norm/layer_norm_bias: (512,)
EncoderTransformer/Transformer/SparseTransformerEncode/Layer_0/SelfAttention/SparseMultiheadAttention/ComputeQKV/compute_q/kernel: (512, 512)
EncoderTransformer/Transformer/SparseTransformerEncode/Layer_0/SelfAttention/Sparse

### Tokenizer

In [10]:
from functools import partial
from src.tokenizer import get_tokenizer, tokenize

tokenizer = get_tokenizer(MODEL_PATH)
tokenize = partial(tokenize, tokenizer=tokenizer)

### Architecture

In [11]:
from src.architecture import MUSE

model_torch = MUSE(
    num_embeddings=weights['Embeddings'].shape[0],
    embedding_dim=weights['Embeddings'].shape[1],
    d_model=512,
    num_heads=8,
)

In [12]:
# Init
with torch.no_grad():
    model_torch.embedding.weight.copy_(torch.from_numpy(weights['Embeddings']))
    
    model_torch.linear.weight.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/dense/kernel'].T))
    model_torch.linear.bias.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/dense/bias']))

In [13]:
# Init
def init_block(block_num):
    block = getattr(model_torch, f'block{block_num}')
    
    with torch.no_grad():
        # ln1.weight
        tensor_weight = block.ln1.weight
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/layer_prepostprocess/layer_norm/layer_norm_scale']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # ln1.bias
        tensor_weight = block.ln1.bias
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/layer_prepostprocess/layer_norm/layer_norm_bias']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # attn.query.weight
        tensor_weight = block.attn.query.weight
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/SparseMultiheadAttention/ComputeQKV/compute_q/kernel'].T
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # attn.query.bias
        tensor_weight = block.attn.query.bias
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/SparseMultiheadAttention/ComputeQKV/compute_q/bias']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # attn.key.weight
        tensor_weight = block.attn.key.weight
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/SparseMultiheadAttention/ComputeQKV/compute_k/kernel'].T
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # attn.key.bias
        tensor_weight = block.attn.key.bias
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/SparseMultiheadAttention/ComputeQKV/compute_k/bias']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # attn.value.weight
        tensor_weight = block.attn.value.weight
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/SparseMultiheadAttention/ComputeQKV/compute_v/kernel'].T
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # attn.value.bias
        tensor_weight = block.attn.value.bias
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/SelfAttention/SparseMultiheadAttention/ComputeQKV/compute_v/bias']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # ln2.weight
        tensor_weight = block.ln2.weight
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/FFN/Layer1/layer_prepostprocess/layer_norm/layer_norm_scale']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # ln2.bias
        tensor_weight = block.ln2.bias
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/FFN/Layer1/layer_prepostprocess/layer_norm/layer_norm_bias']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # linear1.weight
        tensor_weight = block.linear1.weight
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/FFN/Layer1/dense/kernel'].T
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # linear1.bias
        tensor_weight = block.linear1.bias
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/FFN/Layer1/dense/bias']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # linear2.weight
        tensor_weight = block.linear2.weight
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/FFN/Layer2/dense/kernel'].T
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

        # linear2.bias
        tensor_weight = block.linear2.bias
        array = weights[f'EncoderTransformer/Transformer/SparseTransformerEncode/Layer_{block_num}/FFN/Layer2/dense/bias']
        assert tensor_weight.shape == array.shape
        tensor_weight.copy_(torch.from_numpy(array))

init_block(0)
init_block(1)
init_block(2)
init_block(3)
init_block(4)
init_block(5)

In [14]:
# Init
with torch.no_grad():
    model_torch.head.ln.weight.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/layer_prepostprocess/layer_norm/layer_norm_scale']))
    model_torch.head.ln.bias.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/layer_prepostprocess/layer_norm/layer_norm_bias']))

    model_torch.head.linear1.weight.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/AttentionPooling/AttentionHidden/kernel'].T))
    model_torch.head.linear1.bias.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/AttentionPooling/AttentionHidden/bias']))

    model_torch.head.linear2.weight.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/AttentionPooling/AttentionLogits/kernel'].T))
    model_torch.head.linear2.bias.copy_(torch.from_numpy(weights['EncoderTransformer/Transformer/AttentionPooling/AttentionLogits/bias']))

    model_torch.head.tanh_layer.weight.copy_(torch.from_numpy(weights['EncoderTransformer/hidden_layers/tanh_layer_0/dense/kernel'].T))
    model_torch.head.tanh_layer.bias.copy_(torch.from_numpy(weights['EncoderTransformer/hidden_layers/tanh_layer_0/dense/bias']))

In [15]:
model_torch

MUSE(
  (embedding): Embedding(128010, 512)
  (linear): Linear(in_features=512, out_features=512, bias=True)
  (pe): PositionalEncoding()
  (block0): Block(
    (ln1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
    (attn): MultiheadSelfAttention(
      (query): Linear(in_features=512, out_features=512, bias=True)
      (key): Linear(in_features=512, out_features=512, bias=True)
      (value): Linear(in_features=512, out_features=512, bias=True)
    )
    (ln2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
    (linear1): Linear(in_features=512, out_features=2048, bias=True)
    (linear2): Linear(in_features=2048, out_features=512, bias=True)
  )
  (block1): Block(
    (ln1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
    (attn): MultiheadSelfAttention(
      (query): Linear(in_features=512, out_features=512, bias=True)
      (key): Linear(in_features=512, out_features=512, bias=True)
      (value): Linear(in_features=512, out_features=512, bias=True)
    )

### Compare

In [16]:
input_ids = tokenize(sentence)
input_ids

tensor([    1, 17486,     7,   430,    69,    37,    21,     2])

In [17]:
res = model_torch(input_ids)
res

tensor([ 2.7751e-03,  2.5281e-02, -1.3426e-02, -6.0007e-02, -3.5832e-03,
         2.1640e-02,  2.0059e-02, -3.6970e-02, -2.9052e-02, -1.0739e-01,
         2.7322e-02, -8.2878e-02,  1.2412e-02,  2.5597e-02, -3.1173e-02,
         7.6832e-02,  5.8875e-02,  9.0064e-02, -2.0296e-02,  5.0571e-02,
         4.5956e-02, -9.7047e-02,  1.4511e-02,  1.1183e-02, -2.2249e-03,
        -3.4421e-03,  7.0693e-02, -3.2278e-02, -2.3314e-02, -2.2967e-02,
         1.4435e-02, -1.7352e-02, -6.5063e-02, -6.1835e-02,  7.2137e-02,
        -6.4542e-02,  2.2277e-02,  8.1046e-04, -1.7738e-02, -4.8312e-02,
         4.1955e-02, -1.8915e-02, -1.0933e-01,  4.4421e-02,  1.5477e-02,
         4.2578e-02,  4.4713e-02,  7.4773e-02,  2.8420e-02,  3.0054e-02,
        -1.8123e-02,  1.0940e-02, -2.1887e-02,  4.0501e-02, -1.6136e-02,
        -6.7140e-02, -4.6971e-02, -2.9988e-02,  2.4044e-02, -5.0831e-02,
         2.1481e-02, -7.5423e-02, -6.1983e-02,  8.0178e-03,  6.1077e-02,
        -1.4429e-02,  8.0560e-03,  1.6757e-02, -4.1

In [18]:
np.allclose(res.detach().numpy(), ref.numpy(), atol=1e-3)

True

### Multilungual Compare

In [19]:
# Some texts of different lengths.
english_sentences = ["dog", "Puppies are nice.", "I enjoy taking long walks along the beach with my dog."]
italian_sentences = ["cane", "I cuccioli sono carini.", "Mi piace fare lunghe passeggiate lungo la spiaggia con il mio cane."]
japanese_sentences = ["?", "???????", "?????????????????????"]

# Compute embeddings.
en_result = model(english_sentences).numpy()
it_result = model(italian_sentences).numpy()
ja_result = model(japanese_sentences).numpy()

# Compute similarity matrix. Higher score indicates greater similarity.
similarity_matrix_it = np.inner(en_result, it_result)
similarity_matrix_ja = np.inner(en_result, ja_result)

In [20]:
en_result_torch = torch.cat([
    model_torch(tokenize(english_sentences[0])).unsqueeze(0),
    model_torch(tokenize(english_sentences[1])).unsqueeze(0),
    model_torch(tokenize(english_sentences[2])).unsqueeze(0),
]).detach().numpy()

it_result_torch = torch.cat([
    model_torch(tokenize(italian_sentences[0])).unsqueeze(0),
    model_torch(tokenize(italian_sentences[1])).unsqueeze(0),
    model_torch(tokenize(italian_sentences[2])).unsqueeze(0),
]).detach().numpy()

ja_result_torch = torch.cat([
    model_torch(tokenize(japanese_sentences[0])).unsqueeze(0),
    model_torch(tokenize(japanese_sentences[1])).unsqueeze(0),
    model_torch(tokenize(japanese_sentences[2])).unsqueeze(0),
]).detach().numpy()

similarity_matrix_it_torch = np.inner(en_result_torch, it_result_torch)
similarity_matrix_ja_torch = np.inner(en_result_torch, ja_result_torch)

In [21]:
print(np.allclose(en_result, en_result_torch, atol=1e-3))
print(np.allclose(it_result, it_result_torch, atol=1e-3))
print(np.allclose(ja_result, ja_result_torch, atol=1e-3))

print(np.allclose(similarity_matrix_it, similarity_matrix_it_torch, atol=1e-3))
print(np.allclose(similarity_matrix_ja, similarity_matrix_ja_torch, atol=1e-3))

True
True
True
True
True


### Save Model

In [22]:
torch.save(model_torch.state_dict(), 'models/model.pt')