<a href="https://colab.research.google.com/github/graehl/awesome-align/blob/master/awesome_align_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AWESOME: Aligning Word Embedding Spaces of Multilingual Encoders

[``awesome-align``](https://github.com/neulab/awesome-align) is a tool that can extract word alignments from multilingual BERT (mBERT) and allows you to fine-tune mBERT on parallel corpora for better alignment quality (see [our paper](https://arxiv.org/abs/2101.08231) for more details).

This is a simple demo of how `awesome-align` extracts word alignments from mBERT.

First, install and import the following packages. (Note that the original `awesome-align` tool does not require the `transformers` package.)

In [1]:
!pwd
!git clone https://github.com/graehl/awesome-align.git || (cd awesome-align && git pull)


/content
Cloning into 'awesome-align'...
remote: Enumerating objects: 369, done.[K
remote: Counting objects: 100% (168/168), done.[K
remote: Compressing objects: 100% (80/80), done.[K
remote: Total 369 (delta 129), reused 98 (delta 88), pack-reused 201 (from 1)[K
Receiving objects: 100% (369/369), 584.76 KiB | 2.54 MiB/s, done.
Resolving deltas: 100% (218/218), done.


In [2]:
!pip install onnxruntime
!pip install -r awesome-align/requirements.txt
import sys
sys.path.append('/content/awesome-align')
sys.path.append('/content')
!pip install skl2onnx
import torch
import itertools
from skl2onnx.helpers import onnx_helper
!pip install transformers




Collecting onnxruntime
  Downloading onnxruntime-1.21.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.21.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m34.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m850.9 kB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected pa



In [52]:
# printing
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

def print_align(align_words, desc=''):
    print(f"{desc} {len(align_words)} links {align_words} for '{src}' to '{tgt}'")
    for x in align_words:
      i, j = x
      print(f'{color.BOLD}{color.BLUE}{sent_src[i]}{color.END}==={color.BOLD}{color.RED}{sent_tgt[j]}{color.END}')

Load the multilingual BERT model and its tokenizer.

In [4]:
model_name_or_path='bert-base-multilingual-cased'

import transformers
import awesome_align
from awesome_align import modeling
from awesome_align.configuration_bert import BertConfig
from awesome_align.modeling import BertForMaskedLM
from awesome_align.tokenization_bert import BertTokenizer
from awesome_align.tokenization_utils import PreTrainedTokenizer
from awesome_align.modeling_utils import PreTrainedModel

def init_model_and_tokenizer(
    model_name_or_path,
    config_name = None,
    cache_dir = None,
    tokenizer_name = None,
):
  config_class, model_class, tokenizer_class = BertConfig, BertForMaskedLM, BertTokenizer
  if config_name:
      config = config_class.from_pretrained(config_name, cache_dir=cache_dir)
  elif model_name_or_path:
      config = config_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  else:
      config = config_class()

  if tokenizer_name:
      tokenizer = tokenizer_class.from_pretrained(tokenizer_name, cache_dir=cache_dir)
  elif model_name_or_path:
      tokenizer = tokenizer_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  else:
      raise ValueError(
          "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
          "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
      )

  # pad is actually always 0
  modeling.PAD_ID = tokenizer.pad_token_id
  modeling.CLS_ID = tokenizer.cls_token_id
  modeling.SEP_ID = tokenizer.sep_token_id

  if model_name_or_path:
      model = model_class.from_pretrained(
          model_name_or_path,
          from_tf=bool(".ckpt" in model_name_or_path),
          config=config,
          cache_dir=cache_dir,
      )
  else:
      model = model_class(config=config)

  return model, tokenizer

USE_AWESOME_ALIGN = True
# True caused, in export to onnx, `Boolean value of Tensor with more than one value is ambiguous`, but we fixed with ExportNthLayer wrapper
if USE_AWESOME_ALIGN:
  model, tokenizer = init_model_and_tokenizer(model_name_or_path)
else:
  model, tokenizer = transformers.AutoModel.from_pretrained(model_name_or_path), transformers.AutoTokenizer.from_pretrained(model_name_or_path)


Downloading:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/714M [00:00<?, ?B/s]

Input *tokenized* source and target sentences.

In [5]:
src = 'I bought a new car because I was going through a midlife crisis .'
tgt = 'Я купил новую тачку , потому что я переживал кризис среднего возраста .'
tgt = 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
srctgt = f'{src} ||| {tgt}'
fpar = 'srctgt.txt'
with open(fpar, 'w') as f:
  f.write(srctgt)
if True:
  !rm align.txt
  !CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/content/awesome-align python /content/awesome-align/run_align.py --output_file=align.txt --model_name_or_path="$model_name_or_path" --data_file="$fpar" --extraction='softmax' --softmax_threshold=1e-3 --batch_size=32
!cat align.txt

rm: cannot remove 'align.txt': No such file or directory
2025-04-05 13:21:29.840113: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743859289.896570    1125 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743859289.908989    1125 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Loading the dataset...
Extracting: 1it [00:00,  1.80it/s]
13-14 2-1 10-8 0-0 5-4 4-2 12-9 11-12 3-3 8-6 1-0 7-5 9-7


In [72]:


model.eval()
# just sets mode of model, probably doesn't need to be under no_grad

def extend_mask(attention_mask, dtype=torch.float32):
    if attention_mask.dim() == 3:
        extended_attention_mask = attention_mask[:, None, :, :]
    elif attention_mask.dim() == 2:
        extended_attention_mask = attention_mask[:, None, None, :]
    else:
        raise ValueError(
             "Wrong shape for input_ids or attention_mask"
        )
    extended_attention_mask = extended_attention_mask.to(dtype=dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    return extended_attention_mask

def guess_dtype(model):
  if hasattr(model, 'get_parameter_dtype'):
    return model.get_parameter_dtype()
  elif hasattr(model, 'parameters'):
    return next(model.parameters()).dtype
  else:
    return torch.float32

def make_ones_mask(ids):
  shape = ids.size()
  device = ids.device
  attention_mask = torch.ones(shape, device=device)
  attention_mask[ids==0] = 0
  return attention_mask

def make_extended_mask(ids, dtype=torch.float32):
  attention_mask = make_ones_mask(ids)
  return extend_mask(attention_mask, dtype)

class ExportNthLayer(torch.nn.Module):
    def __init__(self, base_model, align_layer_max=8):
        super().__init__()
        e = base_model.bert if hasattr(base_model, 'bert') else base_model
        self.bert = e
        self.embeddings = e.embeddings
        # For BERT, num_hidden_layers is in config
        self.config = e.config
        self.num_layers = min(e.config.num_hidden_layers, align_layer_max)
        e = e.encoder if hasattr(e, 'encoder') else e
        self.encoder = e
        self.layer = e.layer[:self.num_layers]
        print(f'{self.layer}')

    def forward(self, ids, attention_mask=None, position_ids=None):
      shape = ids.size()
      device = ids.device
      if attention_mask is None:
        attention_mask = make_ones_mask(ids)

      # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
      # ourselves in which case we just need to make it broadcastable to all heads.
      extended_attention_mask = extend_mask(attention_mask, guess_dtype(self.bert))
      input_shape = ids.size()
      token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
      hidden_states = self.embeddings(ids, token_type_ids=token_type_ids, position_ids=position_ids)

      if self.layer is not None:
        for i, layer in enumerate(self.layer):
          hidden_states = layer(hidden_states, attention_mask=extended_attention_mask)
        return hidden_states
      else:
        return self.bert(ids, attention_mask)

def to_onnx(model, onnx_file_path, inputs=['input_ids', 'attention_mask'], outputs=['output'], dynamic=True, batch=True, align_layer=None, opset_version=14, return_tensor_names=True):
  captions = {0 : 'batch_size', 1: 'sequence_length'} if batch else {0 : 'sequence_length'}
  dynamic_axes = {}
  if dynamic:
    for k in inputs:
      dynamic_axes[k] = captions
    for k in outputs:
      dynamic_axes[k] = captions

  # Create dummy input data
  batch_size = 1
  sequence_length = 128
  dims = (batch_size, sequence_length) if batch else (sequence_length,)
  inputs_ones = tuple(torch.ones(dims) if x != 'input_ids' else torch.randint(0, model.config.vocab_size, dims) for x in inputs)

  hasbert = hasattr(model, 'bert')
  print(f'hasbert={hasbert}')
  model = model.bert if hasbert else model
  # TODO: figure out how to do first nth encoder layers for non-awesome-align bert
  model = ExportNthLayer(model, align_layer) if USE_AWESOME_ALIGN and align_layer is not None else model
  # Export the model to ONNX
  torch.onnx.export(
      model,
      inputs_ones, #(input_ids, attention_mask),
      onnx_file_path,
      export_params=True,
      opset_version=opset_version,
      do_constant_folding=True,
      input_names = inputs,
      output_names = outputs,
      dynamic_axes=dynamic_axes,
  )

  if return_tensor_names:
    om = onnx_helper.load_onnx_model(onnx_file_path)

    print('initializers')
    for node in om.graph.initializer:
      print(f'{node.name}')
    return list(onnx_helper.enumerate_model_node_outputs(om))
  else:
    return f"Model exported to {onnx_file_path}"


In [73]:

DO_ONNX_EXPORT=True
align_layer_max=9
onnxpath="model.onnx"
if DO_ONNX_EXPORT:
    #onnxruntime.python.tools.transformers.export_onnx_model_from_pt(...)
  if False and USE_AWESOME_ALIGN:
    !CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/content/awesome-align python /content/awesome-align/run_align.py --model_name_or_path=bert-base-multilingual-cased --output_onnx=$onnxpath --max_layer=$align_layer_max
  else:
    for x in to_onnx(model, onnxpath, align_layer=align_layer_max): print(str(x))


hasbert=True
ModuleList(
  (0-8): 9 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
initializers
bert.embeddings.word_embeddi

In [77]:
COPY_TO_DRIVE=False
if COPY_TO_DRIVE:
  from google.colab import drive
  driveroot = '/content/gdrive'
  drive.mount(driveroot, force_remount=True)
  onnxto=f'{driveroot}/mbert-cased-layer{align_layer_max}.onnx'
  print(onnxto)
  print(onnxpath)
  !cp $onnxpath $onnxto


In [7]:
# pre-processing
def wstok(x): return x.strip().split()
def subwords(xs): return [tokenizer.tokenize(x) for x in xs]
def ids(xs): return [tokenizer.convert_tokens_to_ids(x) for x in xs]
sent_src, sent_tgt = wstok(src), wstok(tgt)
token_src, token_tgt = subwords(sent_src), subwords(sent_tgt)
wid_src, wid_tgt = ids(token_src), ids(token_tgt)
#def tokenizer_max_len(tokenizer): return tokenizer.max_len_single_sentence if hasattr(tokenizer, 'max_len_single_sentence') else tokenizer.model_max_length
maxlenkw = {}
if hasattr(tokenizer, 'model_max_length'):
  maxlenkw['model_max_length'] = tokenizer.model_max_length
  maxlenkw['truncation'] = True
else:
  maxlenkw['max_length'] = tokenizer.max_len

def ids_for_model(ids, model, tokenizer): return tokenizer.prepare_for_model(list(itertools.chain(*ids)), return_tensors='pt', **maxlenkw)['input_ids']
print(f'wid {len(wid_src)} x {len(wid_tgt)}')
ids_src, ids_tgt = ids_for_model(wid_src, model, tokenizer), ids_for_model(wid_tgt, model, tokenizer)
print(f'{ids_src}')
print(f'{ids_tgt}')
sub2word_map_src = []
for i, word_list in enumerate(token_src):
  sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
  sub2word_map_tgt += [i for x in word_list]

wid 14 x 15
tensor([[  101,   146, 28870,   169, 10751, 13000, 12373,   146, 10134, 19090,
         11222,   169, 15607, 57156, 22859,   119,   102]])
tensor([[  101, 16680, 52302, 10333, 10119, 18257, 15249, 16348, 14645, 46481,
         10183, 10153, 22859, 10104, 10109, 16689, 22757,   119,   102]])


In [57]:
!ls -l $onnxpath
import onnxruntime as ort
import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(onnxpath)


-rw-r--r-- 1 root root 624160622 Apr  5 13:22 model.onnx


In [60]:
input_names = [x.name for x in session.get_inputs()]
output_names = [x.name for x in session.get_outputs()]
print(input_names)
print(output_names)

['input_ids', 'attention_mask']
['output']


In [61]:
import pdb
def onnxmask(ids):
  #return (ids != 0).to(torch.float32)
  return make_extended_mask(ids)[0, 0, :, :]
def onnx_word_encs(session, ids_src):
  msrc = onnxmask(ids_src)
  osrc = session.run(output_names, {input_names[0]: ids_src.numpy(), input_names[1]: msrc.numpy()})[0][:,1:-1,:]
  return torch.tensor(osrc)

osrc = onnx_word_encs(session, ids_src)
otgt = onnx_word_encs(session, ids_tgt)

def alignpairs(out_src, out_tgt, sub2word_map_src, sub2word_map_tgt, threshold):
      dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))

      softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
      softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)
      # tryalso entmax15(dot_prod, dim=...)? also TODO: before softmax mask off cls sep pad tokens

      softmax_inter = (softmax_srctgt > threshold)*(softmax_tgtsrc > threshold)

      align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
      align_words = set()
      for xyz in align_subwords:
        i, j = xyz[-2], xyz[-1]
        #print(f'subword: {i}-{j}')
        align_words.add( (sub2word_map_src[i], sub2word_map_tgt[j]) )
      return sorted(list(align_words))

align_words = alignpairs(osrc, otgt, sub2word_map_src, sub2word_map_tgt, 1e-3)
# 13-14 2-1 10-8 0-0 5-4 4-2 12-9 11-12 3-3 8-6 1-0 7-5 9-7
# [(0, 0), (1, 0), (2, 1), (3, 3), (4, 2), (5, 4), (7, 5), (8, 6), (9, 7), (10, 8), (12, 9), (13, 14)]
print_align(align_words, f'onnx')

onnx 12 links [(0, 0), (1, 0), (2, 1), (3, 3), (4, 2), (5, 4), (7, 5), (8, 6), (9, 7), (10, 8), (12, 9), (13, 14)] for 'I bought a new car because I was going through a midlife crisis .' to 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
[1m[94mI[0m===[1m[91mCompré[0m
[1m[94mbought[0m===[1m[91mCompré[0m
[1m[94ma[0m===[1m[91mun[0m
[1m[94mnew[0m===[1m[91mnuevo[0m
[1m[94mcar[0m===[1m[91mauto[0m
[1m[94mbecause[0m===[1m[91mporque[0m
[1m[94mwas[0m===[1m[91mestaba[0m
[1m[94mgoing[0m===[1m[91mpasando[0m
[1m[94mthrough[0m===[1m[91mpor[0m
[1m[94ma[0m===[1m[91muna[0m
[1m[94mcrisis[0m===[1m[91mcrisis[0m
[1m[94m.[0m===[1m[91m.[0m


In [62]:



# alignment

def sents_without_startend(batch): return batch[:, 1:-1]
if USE_AWESOME_ALIGN:
  def hiddens(model, ids, align_layer):
    return sents_without_startend(model.bert(ids, align_layer=align_layer, attention_mask=(ids!=0)))
else:
  def hidden(model, ids): return model(ids.unsqueeze(0), output_hidden_states=True)[2]
  def hiddens(model, ids, align_layer):
    return sents_without_startend(hidden(model, ids)[align_layer])


DECODE_AWESOME_ALIGN=False

for align_layer in range(max(0,align_layer_max - 2),align_layer_max+1):
 last_align = None
 threshold = 1e-3
 for it in range(6):
  if DECODE_AWESOME_ALIGN and USE_AWESOME_ALIGN:
    # get_aligned_word takes a batch.
    # print(f'{len(ids_src[0])} x {len(ids_tgt[0])}')
    align_words = model.get_aligned_word(ids_src, ids_tgt, (sub2word_map_src,), (sub2word_map_tgt,), 'cpu', len(ids_src), len(ids_tgt), align_layer, 'softmax', threshold, True)[0]
  else:
    with torch.no_grad():
      out_src = hiddens(model, ids_src, align_layer)
      out_tgt = hiddens(model, ids_tgt, align_layer)
      #pdb.set_trace()
      #out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
      #out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
      align_words = alignpairs(out_src, out_tgt, sub2word_map_src, sub2word_map_tgt, threshold)

  align_words = sorted(list(align_words))
  if align_words != last_align:
    print_align(align_words,desc = f' (layer {align_layer} > {threshold:.3g})')

  last_align = align_words
  threshold = threshold * 1e-1

 (layer 7 > 0.001) 11 links [(1, 0), (2, 1), (3, 3), (4, 2), (5, 4), (7, 5), (8, 6), (9, 7), (10, 8), (12, 9), (13, 14)] for 'I bought a new car because I was going through a midlife crisis .' to 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
[1m[94mbought[0m===[1m[91mCompré[0m
[1m[94ma[0m===[1m[91mun[0m
[1m[94mnew[0m===[1m[91mnuevo[0m
[1m[94mcar[0m===[1m[91mauto[0m
[1m[94mbecause[0m===[1m[91mporque[0m
[1m[94mwas[0m===[1m[91mestaba[0m
[1m[94mgoing[0m===[1m[91mpasando[0m
[1m[94mthrough[0m===[1m[91mpor[0m
[1m[94ma[0m===[1m[91muna[0m
[1m[94mcrisis[0m===[1m[91mcrisis[0m
[1m[94m.[0m===[1m[91m.[0m
 (layer 7 > 1e-06) 12 links [(1, 0), (2, 1), (3, 3), (4, 2), (5, 4), (7, 5), (8, 6), (9, 7), (10, 8), (11, 12), (12, 9), (13, 14)] for 'I bought a new car because I was going through a midlife crisis .' to 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
[1m[94mbough