In [65]:
import torch
from transformers import AutoTokenizer, AutoModel

In [52]:
src_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en") #sanity -- this may not be a joint vocab
tgt_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")

model = AutoModel.from_pretrained("Helsinki-NLP/opus-mt-es-en", output_attentions=True)

Some weights of the model checkpoint at Helsinki-NLP/opus-mt-es-en were not used when initializing MarianModel: ['final_logits_bias']
- This IS expected if you are initializing MarianModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MarianModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [35]:
# check out vocab tokenization

In [36]:
src_tokenizer.tokenize("compositional")

['▁com', 'posi', 'tional']

In [37]:
src_tokenizer.tokenize("compositiva")

['▁com', 'posi', 'tiva']

In [38]:
src_tokenizer.tokenize("My vocabulary is limited.")

['▁My', '▁', 'voca', 'bula', 'ry', '▁is', '▁limite', 'd', '.']

In [59]:
#sanity check -- use the non-auto versions for the helsinki models
#https://huggingface.co/docs/transformers/main/en/model_doc/marian#transformers.MarianMTModel
from transformers import MarianTokenizer, MarianMTModel

src = "es"  # source language
trg = "en"  # target language

model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)

In [60]:
sample_text = "where is the bus?"
batch = tokenizer([sample_text], return_tensors="pt")

generated_ids = model.generate(**batch)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

'Where is the bus?'

In [97]:
#get embeddings given generated_ids
def get_embed_from_text(text):
    tokens = tokenizer([text], return_tensors="pt")
#     print(tokens)
    weight = model.get_input_embeddings().weight
    embed = []
    for token_id, mask in zip(tokens['input_ids'][0], tokens['attention_mask'][0]):
        this_embed = weight[token_id] * mask
        embed.append(this_embed)
    embed.pop() #last is eos padding -- will be added seemingly regardless (id = 0), so pop it off here
    return torch.stack(embed)

In [106]:
embed_orig = get_embed_from_text("couch")
embed_swap = get_embed_from_text("sofa")

In [107]:
embed_orig

tensor([[ 0.0109,  0.0232, -0.0419,  ..., -0.0266,  0.0284, -0.0354],
        [ 0.0196,  0.0979, -0.0318,  ..., -0.0589, -0.1090, -0.0285]],
       grad_fn=<StackBackward0>)

In [108]:
tokenizer.tokenize("couches")

['▁co', 'uche', 's']

In [109]:
tokenizer.tokenize("sofa")

['▁so', 'fa']

In [111]:
embed_orig.shape

torch.Size([2, 512])

In [149]:
#requires tensors of the same shape -- 1&1, or subword decomps into the same # of units
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
output = cos(embed_orig, embed_swap)
print(output)

tensor([0.4034, 0.5770], grad_fn=<DivBackward0>)


In [147]:
def compute_cos(embed1, embed2):
    #first, average subwords to get 1 embed per word
    embed1_avg = torch.mean(embed1, dim=0).unsqueeze(0)
    embed2_avg = torch.mean(embed2, dim=0).unsqueeze(0)
    
    #second, cosine similarity
    cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
    output = cos(embed1_avg, embed2_avg)
#     print(output)
    return output[0]

In [148]:
embed_orig = get_embed_from_text("couch")
embed_swap = get_embed_from_text("sofa")

compute_cos(embed_orig, embed_swap)

tensor(1., grad_fn=<SelectBackward0>)

In [39]:
# visualize alignments with bertviz

In [32]:
src_sent = "No copies y pegues oraciones de otros sitios."
tgt_sent = "Do not copy-paste sentences from elsewhere."

encoder_input_ids = src_tokenizer(src_sent, return_tensors="pt", add_special_tokens=False).input_ids
decoder_input_ids = tgt_tokenizer(tgt_sent, return_tensors="pt", add_special_tokens=False).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = src_tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tgt_tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

In [33]:
from bertviz import head_view
head_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens = decoder_text
)

<IPython.core.display.Javascript object>

In [151]:
# awesome align -- demo https://colab.research.google.com/drive/1205ubqebM0OsZa1nRgbGJBtitgHqIVv6?usp=sharing#scrollTo=ODwJ_gQ8bnqR

In [152]:
# !pip install transformers==3.1.0 (using 4+ works)
import torch
import transformers
import itertools

In [153]:
model = transformers.BertModel.from_pretrained('bert-base-multilingual-cased')
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-multilingual-cased')

Downloading: 100%|██████████| 625/625 [00:00<00:00, 288kB/s]
Downloading: 100%|██████████| 681M/681M [00:19<00:00, 36.5MB/s] 
Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Downloading: 100%|

In [156]:
# src = 'awesome-align is awesome !'
# tgt = '牛对齐 是 牛 ！'

src = "No copies y pegues oraciones de otros sitios."
tgt = "Do not copy-paste sentences from elsewhere."

In [157]:
# pre-processing
sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt]
wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
sub2word_map_src = []
for i, word_list in enumerate(token_src):
  sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
  sub2word_map_tgt += [i for x in word_list]

# alignment
align_layer = 8
threshold = 1e-3
model.eval()
with torch.no_grad():
  out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
  out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]

  dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))

  softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
  softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)

  softmax_inter = (softmax_srctgt > threshold)*(softmax_tgtsrc > threshold)

align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
align_words = set()
for i, j in align_subwords:
  align_words.add( (sub2word_map_src[i], sub2word_map_tgt[j]) )

# printing
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

for i, j in sorted(align_words):
  print(f'{color.BOLD}{color.BLUE}{sent_src[i]}{color.END}==={color.BOLD}{color.RED}{sent_tgt[j]}{color.END}')

[1m[94mNo[0m===[1m[91mnot[0m
[1m[94mcopies[0m===[1m[91mcopy-paste[0m
[1m[94my[0m===[1m[91mcopy-paste[0m
[1m[94mpegues[0m===[1m[91mcopy-paste[0m
[1m[94moraciones[0m===[1m[91msentences[0m
[1m[94mde[0m===[1m[91mfrom[0m
[1m[94motros[0m===[1m[91melsewhere.[0m
[1m[94msitios.[0m===[1m[91melsewhere.[0m
