<a href="https://colab.research.google.com/github/graehl/awesome-align/blob/master/awesome_align_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AWESOME: Aligning Word Embedding Spaces of Multilingual Encoders

[``awesome-align``](https://github.com/neulab/awesome-align) is a tool that can extract word alignments from multilingual BERT (mBERT) and allows you to fine-tune mBERT on parallel corpora for better alignment quality (see [our paper](https://arxiv.org/abs/2101.08231) for more details).

This is a simple demo of how `awesome-align` extracts word alignments from mBERT.

First, install and import the following packages. (Note that the original `awesome-align` tool does not require the `transformers` package.)

In [66]:
!pwd
!git clone https://github.com/graehl/awesome-align.git || (cd awesome-align && git pull)
!pip install -r awesome-align/requirements.txt
import sys
sys.path.append('/content/awesome-align')
sys.path.append('/content')
from awesome_align import modeling
from awesome_align.configuration_bert import BertConfig
from awesome_align.modeling import BertForMaskedLM
from awesome_align.tokenization_bert import BertTokenizer
from awesome_align.tokenization_utils import PreTrainedTokenizer
from awesome_align.modeling_utils import PreTrainedModel
!pip install transformers
!pip install onnx
!pip install skl2onnx
import torch
import transformers
import itertools
import onnx
from skl2onnx.helpers import onnx_helper


/content
fatal: destination path 'awesome-align' already exists and is not an empty directory.
Already up to date.
Collecting boto3 (from -r awesome-align/requirements.txt (line 5))
  Downloading boto3-1.37.27-py3-none-any.whl.metadata (6.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.2.0->-r awesome-align/requirements.txt (line 2))
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.2.0->-r awesome-align/requirements.txt (line 2))
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.2.0->-r awesome-align/requirements.txt (line 2))
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.2.0->-r awesome-align/requirements.txt (line 2))
  Downloading nvidia_cudnn_cu1

[31mERROR: Operation cancelled by user[0m[31m
[0mTraceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/pip/_internal/cli/base_command.py", line 179, in exc_logging_wrapper
    status = run_func(*args)
             ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pip/_internal/cli/req_command.py", line 67, in wrapper
    return func(self, options, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pip/_internal/commands/install.py", line 447, in run
^C


In [55]:
# printing
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'


Load the multilingual BERT model and its tokenizer.

In [56]:
bertmodel='bert-base-multilingual-cased'
model = transformers.AutoModel.from_pretrained(bertmodel)
tokenizer = transformers.AutoTokenizer.from_pretrained(bertmodel)

Input *tokenized* source and target sentences.

In [57]:
src = 'I bought a new car because I was going through a midlife crisis .'
tgt = 'Я купил новую тачку , потому что я переживал кризис среднего возраста .'
tgt = 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'

Run the model and print the resulting alignments.

In [58]:
import pdb
# pre-processing
sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt]
wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
def ids_for_model(ids, tokenizer): return tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids']
ids_src, ids_tgt = ids_for_model(ids, src), ids_for_model(ids, tgt)

sub2word_map_src = []
for i, word_list in enumerate(token_src):
  sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
  sub2word_map_tgt += [i for x in word_list]


model.eval() # just sets mode of model, probably doesn't need to be under no_grad

def to_onnx(model, onnx_file_path, inputs=['input_ids', 'attention_mask'], outputs=['output'], dynamic=True, batch=True, opset_version=14, return_tensor_names=True):
  captions = {0 : 'batch_size', 1: 'sequence_length'} if batch else {0 : 'sequence_length'}
  dynamic_axes = {}
  if dynamic:
    for k in inputs:
      dynamic_axes[k] = captions
    for k in outputs:
      dynamic_axes[k] = captions

  # Create dummy input data
  batch_size = 1
  sequence_length = 128
  dims = (batch_size, sequence_length) if batch else (sequence_length,)
  #input_ids = torch.randint(0, model.config.vocab_size, dims)
  #attention_mask = torch.ones(dims)
  inputs_ones = tuple(torch.ones(dims) if x != 'input_ids' else torch.randint(0, model.config.vocab_size, dims) for x in inputs)
  # Define the path for the ONNX file
  onnx_file_path = "automodel.onnx"

  # Export the model to ONNX
  torch.onnx.export(
      model,
      inputs_ones, #(input_ids, attention_mask),
      onnx_file_path,
      export_params=True,
      opset_version=opset_version,
      do_constant_folding=True,
      input_names = inputs,
      output_names = outputs,
      dynamic_axes=dynamic_axes,
  )

  if return_tensor_names:
    om = onnx_helper.load_onnx_model(onnx_file_path)
    return list(onnx_helper.enumerate_model_node_outputs(om))
  else:
    return f"Model exported to {onnx_file_path}"

for x in to_onnx(model,"model.onnx"): print(str(x))


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


/Constant_output_0
/Shape_output_0
/Constant_1_output_0
/Gather_output_0
/Shape_1_output_0
/Constant_2_output_0
/Gather_1_output_0
onnx::Slice_209
/Constant_3_output_0
/Constant_4_output_0
/Constant_5_output_0
/Unsqueeze_output_0
/Constant_6_output_0
/Slice_output_0
/Constant_7_output_0
/Unsqueeze_1_output_0
/Constant_8_output_0
/Unsqueeze_2_output_0
/Concat_output_0
/Constant_9_output_0
/Reshape_output_0
/Shape_2_output_0
/ConstantOfShape_output_0
/Constant_10_output_0
/Mul_output_0
/Equal_output_0
/Where_output_0
/Expand_output_0
onnx::Slice_233
/embeddings/Constant_output_0
/embeddings/Constant_1_output_0
/embeddings/Constant_2_output_0
/embeddings/Unsqueeze_output_0
/embeddings/Constant_3_output_0
/embeddings/Slice_output_0
/embeddings/word_embeddings/Gather_output_0
/embeddings/token_type_embeddings/Gather_output_0
/embeddings/Add_output_0
/embeddings/position_embeddings/Gather_output_0
/embeddings/Add_1_output_0
/embeddings/LayerNorm/ReduceMean_output_0
/embeddings/LayerNorm/Sub_

In [59]:

# alignment

def sent_without_startend(batch, sent=0): return batch[sent, 1:-1]
def alignvec(batch, align_layer=8, sent=0): return sent_without_startend(batch[align_layer], sent=sent)
def hidden(model, ids): return model(ids.unsqueeze(0), output_hidden_states=True)[2]

for align_layer in range(7,12):
 last_align = None
 threshold = 1e-1
 for it in range(6):
  threshold = threshold * 1e-2
  with torch.no_grad():
    hidden_src = hidden(model, ids_src)
    hidden_tgt = hidden(model, ids_tgt)
    #pdb.set_trace()
    out_src = alignvec(hidden_src, align_layer) #model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
    out_tgt = alignvec(hidden_tgt, align_layer) #model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]

    dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))

    softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
    softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)

    softmax_inter = (softmax_srctgt > threshold)*(softmax_tgtsrc > threshold)

  align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
  align_words = set()
  for i, j in align_subwords:
    align_words.add( (sub2word_map_src[i], sub2word_map_tgt[j]) )
  align_words = sorted(align_words)
  if align_words != last_align:
    print(f" (layer {align_layer} > {threshold:.3g}) {len(align_words)} links for '{src}' to '{tgt}'")
    for i, j in align_words:
      print(f'{color.BOLD}{color.BLUE}{sent_src[i]}{color.END}==={color.BOLD}{color.RED}{sent_tgt[j]}{color.END}')
  last_align = align_words

 (layer 7 > 0.001) 11 links for 'I bought a new car because I was going through a midlife crisis .' to 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
[1m[94mbought[0m===[1m[91mCompré[0m
[1m[94ma[0m===[1m[91mun[0m
[1m[94mnew[0m===[1m[91mnuevo[0m
[1m[94mcar[0m===[1m[91mauto[0m
[1m[94mbecause[0m===[1m[91mporque[0m
[1m[94mwas[0m===[1m[91mestaba[0m
[1m[94mgoing[0m===[1m[91mpasando[0m
[1m[94mthrough[0m===[1m[91mpor[0m
[1m[94ma[0m===[1m[91muna[0m
[1m[94mcrisis[0m===[1m[91mcrisis[0m
[1m[94m.[0m===[1m[91m.[0m
 (layer 7 > 1e-07) 12 links for 'I bought a new car because I was going through a midlife crisis .' to 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
[1m[94mbought[0m===[1m[91mCompré[0m
[1m[94ma[0m===[1m[91mun[0m
[1m[94mnew[0m===[1m[91mnuevo[0m
[1m[94mcar[0m===[1m[91mauto[0m
[1m[94mbecause[0m===[1m[91mporque[0m
[1m[94mwas[0m=