<a href="https://colab.research.google.com/github/graehl/awesome-align/blob/master/awesome_align_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AWESOME: Aligning Word Embedding Spaces of Multilingual Encoders

[``awesome-align``](https://github.com/neulab/awesome-align) is a tool that can extract word alignments from multilingual BERT (mBERT) and allows you to fine-tune mBERT on parallel corpora for better alignment quality (see [our paper](https://arxiv.org/abs/2101.08231) for more details).

This is a simple demo of how `awesome-align` extracts word alignments from mBERT.

First, install and import the following packages. (Note that the original `awesome-align` tool does not require the `transformers` package.)

In [87]:
!pwd
!git clone https://github.com/graehl/awesome-align.git || (cd awesome-align && git pull)


/content
fatal: destination path 'awesome-align' already exists and is not an empty directory.
remote: Enumerating objects: 11, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 7 (delta 4), reused 0 (delta 0), pack-reused 0 (from 0)[K
Unpacking objects: 100% (7/7), 9.51 KiB | 1.19 MiB/s, done.
From https://github.com/graehl/awesome-align
   68c1ff2..7128d0a  master     -> origin/master
Updating 68c1ff2..7128d0a
Fast-forward
 awesome_align/modeling.py |    2 [32m+[m[31m-[m
 awesome_align_demo.ipynb  | 2098 [32m++++++[m[31m-----------------------------------------------------------[m
 2 files changed, 193 insertions(+), 1907 deletions(-)


In [None]:
!pip install -r awesome-align/requirements.txt
import sys
sys.path.append('/content/awesome-align')
sys.path.append('/content')

!pip install transformers
!pip install onnx
!pip install skl2onnx
import torch
import itertools
import onnx
from skl2onnx.helpers import onnx_helper


In [19]:
# printing
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'


Load the multilingual BERT model and its tokenizer.

In [88]:
model_name_or_path='bert-base-multilingual-cased'

import transformers

from awesome_align import modeling
from awesome_align.configuration_bert import BertConfig
from awesome_align.modeling import BertForMaskedLM
from awesome_align.tokenization_bert import BertTokenizer
from awesome_align.tokenization_utils import PreTrainedTokenizer
from awesome_align.modeling_utils import PreTrainedModel

def init_model_and_tokenizer(
    model_name_or_path,
    config_name = None,
    cache_dir = None,
    tokenizer_name = None,
):
  config_class, model_class, tokenizer_class = BertConfig, BertForMaskedLM, BertTokenizer
  if config_name:
      config = config_class.from_pretrained(config_name, cache_dir=cache_dir)
  elif model_name_or_path:
      config = config_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  else:
      config = config_class()

  if tokenizer_name:
      tokenizer = tokenizer_class.from_pretrained(tokenizer_name, cache_dir=cache_dir)
  elif model_name_or_path:
      tokenizer = tokenizer_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  else:
      raise ValueError(
          "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
          "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
      )

  modeling.PAD_ID = tokenizer.pad_token_id
  modeling.CLS_ID = tokenizer.cls_token_id
  modeling.SEP_ID = tokenizer.sep_token_id

  if model_name_or_path:
      model = model_class.from_pretrained(
          model_name_or_path,
          from_tf=bool(".ckpt" in model_name_or_path),
          config=config,
          cache_dir=cache_dir,
      )
  else:
      model = model_class(config=config)

  return model, tokenizer

USE_AWESOME_ALIGN = True
# True causes, in export to onnx, `Boolean value of Tensor with more than one value is ambiguous`
if USE_AWESOME_ALIGN:
  model, tokenizer = init_model_and_tokenizer(model_name_or_path)
else:
  model, tokenizer = transformers.AutoModel.from_pretrained(model_name_or_path), transformers.AutoTokenizer.from_pretrained(model_name_or_path)


Input *tokenized* source and target sentences.

In [104]:
src = 'I bought a new car because I was going through a midlife crisis .'
tgt = 'Я купил новую тачку , потому что я переживал кризис среднего возраста .'
tgt = 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
srctgt = f'{src} ||| {tgt}'
fpar = 'srctgt.txt'
with open(fpar, 'w') as f:
  f.write(srctgt)
if False:
  !rm align.txt
  !CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/content/awesome-align python /content/awesome-align/run_align.py --output_file=align.txt --model_name_or_path="$model_name_or_path" --data_file="$fpar" --extraction='softmax' --softmax_threshold=1e-3 --batch_size=32
!cat align.txt

13-14 2-1 10-8 0-0 5-4 4-2 12-9 11-12 3-3 8-6 1-0 7-5 9-7


In [81]:


model.eval()
# just sets mode of model, probably doesn't need to be under no_grad

class ExportHidden(torch.nn.Module):
    def __init__(self, base_model, align_layer=8):
        super().__init__()
        self.base_model = base_model
        # For BERT, num_hidden_layers is in config
        self.num_layers = base_model.config.num_hidden_layers
        self.align_layer = align_layer

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        alignkw = {}
        if self.align_layer is not None:
          alignkw['align_layer'] = self.align_layer
        # Run the base model with output_hidden_states=True
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True,
            **alignkw
            )
        hidden_states = outputs[2]
        return hidden_states[self.align_layer] if self.align_layer is not None else hidden_states

def to_onnx(model, onnx_file_path, inputs=['input_ids', 'attention_mask'], outputs=['output'], dynamic=True, batch=True, align_layer=None, opset_version=14, return_tensor_names=True):
  captions = {0 : 'batch_size', 1: 'sequence_length'} if batch else {0 : 'sequence_length'}
  dynamic_axes = {}
  if dynamic:
    for k in inputs:
      dynamic_axes[k] = captions
    for k in outputs:
      dynamic_axes[k] = captions

  # Create dummy input data
  batch_size = 1
  sequence_length = 128
  dims = (batch_size, sequence_length) if batch else (sequence_length,)
  inputs_ones = tuple(torch.ones(dims) if x != 'input_ids' else torch.randint(0, model.config.vocab_size, dims) for x in inputs)

  hasbert = hasattr(model, 'bert')
  print(f'hasbert={hasbert}')
  model = model.bert if hasbert else model
  #model = ExportHidden(model, align_layer) if align_layer is not None else model
  # Export the model to ONNX
  torch.onnx.export(
      model,
      inputs_ones, #(input_ids, attention_mask),
      onnx_file_path,
      export_params=True,
      opset_version=opset_version,
      do_constant_folding=True,
      input_names = inputs,
      output_names = outputs,
      dynamic_axes=dynamic_axes,
  )

  if return_tensor_names:
    om = onnx_helper.load_onnx_model(onnx_file_path)
    return list(onnx_helper.enumerate_model_node_outputs(om))
  else:
    return f"Model exported to {onnx_file_path}"

DO_ONNX_EXPORT=False
if DO_ONNX_EXPORT:
  for x in to_onnx(model, "model.onnx"): print(str(x))


Run the model and print the resulting alignments.

In [103]:
import pdb
# pre-processing
def wstok(x): return x.strip().split()
def subwords(xs): return [tokenizer.tokenize(x) for x in xs]
def ids(xs): return [tokenizer.convert_tokens_to_ids(x) for x in xs]
sent_src, sent_tgt = wstok(src), wstok(tgt)
token_src, token_tgt = subwords(sent_src), subwords(sent_tgt)
wid_src, wid_tgt = ids(token_src), ids(token_tgt)
#def tokenizer_max_len(tokenizer): return tokenizer.max_len_single_sentence if hasattr(tokenizer, 'max_len_single_sentence') else tokenizer.model_max_length
maxlenkw = {}
if hasattr(tokenizer, 'model_max_length'):
  maxlenkw['model_max_length'] = tokenizer.model_max_length
  maxlenkw['truncation'] = True
else:
  maxlenkw['max_length'] = tokenizer.max_len

def ids_for_model(ids, model, tokenizer): return tokenizer.prepare_for_model(list(itertools.chain(*ids)), return_tensors='pt', **maxlenkw)['input_ids']
print(f'wid {len(wid_src)} x {len(wid_tgt)}')
ids_src, ids_tgt = ids_for_model(wid_src, model, tokenizer), ids_for_model(wid_tgt, model, tokenizer)
print(f'ids {len(ids_src[0])} x {len(ids_tgt[0])}')
print(f'{ids_src}')
print(f'{ids_tgt}')
sub2word_map_src = []
for i, word_list in enumerate(token_src):
  sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
  sub2word_map_tgt += [i for x in word_list]


# alignment

def sent_without_startend(batch, sent=0): return batch[sent, 1:-1]
if USE_AWESOME_ALIGN:
  def hiddens(model, ids, align_layer):
    return model.bert(ids, align_layer=align_layer, attention_mask=(ids!=0))[:, 1:-1]
else:
  def alignvec(batch, align_layer=8, sent=0): return sent_without_startend(batch[align_layer], sent=sent)
  def hidden(model, ids): return model(ids.unsqueeze(0), output_hidden_states=True)[2]
  def hiddens(model, ids, align_layer):
    return alignvec(hidden(model, ids), align_layer)

for align_layer in range(8,9):
 last_align = None
 threshold = 1e-3
 for it in range(6):

  if USE_AWESOME_ALIGN:
    # get_aligned_word handles a batch.
    print(f'{len(ids_src)} x {len(ids_tgt)}')
    align_words = model.get_aligned_word(ids_src, ids_tgt, (sub2word_map_src,), (sub2word_map_tgt,), 'cpu', len(ids_src), len(ids_tgt), align_layer, 'softmax', threshold, True)[0]
  else:
    with torch.no_grad():
      out_src = hiddens(model, ids_src, align_layer)
      out_tgt = hiddens(model, ids_tgt, align_layer)
      #pdb.set_trace()
      #out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
      #out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]

      dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))

      softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
      softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)
      # tryalso entmax15(dot_prod, dim=...)? also TODO: before softmax mask off cls sep pad tokens

      softmax_inter = (softmax_srctgt > threshold)*(softmax_tgtsrc > threshold)

      align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
      align_words = set()
      for i, j in align_subwords:
        align_words.add( (sub2word_map_src[i], sub2word_map_tgt[j]) )
  align_words = sorted(list(align_words))
  if align_words != last_align:
    print(f" (layer {align_layer} > {threshold:.3g}) {len(align_words)} links {align_words} for '{src}' to '{tgt}'")
    for x in align_words:
      i, j = x
      print(f'{color.BOLD}{color.BLUE}{sent_src[i]}{color.END}==={color.BOLD}{color.RED}{sent_tgt[j]}{color.END}')
  last_align = align_words
  threshold = threshold * 1e-1

wid 14 x 15
ids 17 x 19
tensor([[  101,   146, 28870,   169, 10751, 13000, 12373,   146, 10134, 19090,
         11222,   169, 15607, 57156, 22859,   119,   102]])
tensor([[  101, 16680, 52302, 10333, 10119, 18257, 15249, 16348, 14645, 46481,
         10183, 10153, 22859, 10104, 10109, 16689, 22757,   119,   102]])
1 x 1
 (layer 8 > 0.001) 13 links [(0, 0), (1, 0), (2, 1), (3, 3), (4, 2), (5, 4), (7, 5), (8, 6), (9, 7), (10, 8), (11, 12), (12, 9), (13, 14)] for 'I bought a new car because I was going through a midlife crisis .' to 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'
[1m[94mI[0m===[1m[91mCompré[0m
[1m[94mbought[0m===[1m[91mCompré[0m
[1m[94ma[0m===[1m[91mun[0m
[1m[94mnew[0m===[1m[91mnuevo[0m
[1m[94mcar[0m===[1m[91mauto[0m
[1m[94mbecause[0m===[1m[91mporque[0m
[1m[94mwas[0m===[1m[91mestaba[0m
[1m[94mgoing[0m===[1m[91mpasando[0m
[1m[94mthrough[0m===[1m[91mpor[0m
[1m[94ma[0m===[1m[91muna[0