In [1]:
import json

with open('tr_token_mapping.json') as f:
    token_list = json.load(f)

In [2]:
len(token_list)

30158

In [5]:
import torch

from transformers import Gemma2Model

gemma_model = Gemma2Model.from_pretrained("google/gemma-2-2b-it")
gemma_model

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,),

In [8]:
# create an empty tensor to store the embeddings of the tokens shape (len(token_list), gemma_model.embed_tokens.weight.shape[1])
gemma_embeddings = gemma_model.embed_tokens.weight
embeddings = torch.zeros(len(token_list), gemma_embeddings.shape[1])
embeddings.shape

torch.Size([30158, 2304])

In [21]:
token_list[0], gemma_embeddings[58714]

({'tr_token': 'salavat',
  'tr_token_id': 21105,
  'llama_token_ids': [19776, 402, 266],
  'gemma2_token_ids': [7871, 58714]},
 tensor([ 0.0039,  0.0043,  0.0183,  ...,  0.0139, -0.0064,  0.0596],
        grad_fn=<SelectBackward0>))

In [27]:
e1 = gemma_embeddings[7871]
e2 = gemma_embeddings[58714]

e1 = e1 + e2
average = e1 / 2
average

tensor([-0.0067, -0.0076,  0.0065,  ...,  0.0170, -0.0135,  0.0192],
       grad_fn=<DivBackward0>)

In [28]:
# for each token in the token_list, get the corresponding embedding from the gemma model and store it in the embeddings tensor
# if there is more than one token in the token_list that maps to the same index, average the embeddings

for token_map in token_list:
    index = token_map['tr_token_id']
    gemma2_token_ids = token_map['gemma2_token_ids']
    embedding = gemma_embeddings[gemma2_token_ids[0]]
    sum_embedding = embedding
    for gemma2_token_id in gemma2_token_ids[1:]:
        embedding = embedding + gemma_embeddings[gemma2_token_id]
    if len(gemma2_token_ids) > 1:
        embedding = embedding / len(gemma2_token_ids)        
    embeddings[index] = embedding

embeddings[0]

tensor([-0.0002, -0.0059,  0.0222,  ...,  0.0152, -0.0074, -0.0119],
       grad_fn=<SelectBackward0>)

In [30]:
count_of_zero_embeddings = 0
for i in range(len(embeddings)):
    if torch.all(embeddings[i] == 0):
        count_of_zero_embeddings += 1
count_of_zero_embeddings

0

In [31]:
# save the embeddings tensor to a file
torch.save(embeddings, 'tr_gemma2_embeddings.pt')