<a href="https://colab.research.google.com/github/graehl/awesome-align/blob/master/awesome_align_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AWESOME: Aligning Word Embedding Spaces of Multilingual Encoders

[``awesome-align``](https://github.com/neulab/awesome-align) is a tool that can extract word alignments from multilingual BERT (mBERT) and allows you to fine-tune mBERT on parallel corpora for better alignment quality (see [our paper](https://arxiv.org/abs/2101.08231) for more details).

This is a simple demo of how `awesome-align` extracts word alignments from mBERT.

First, install and import the following packages. (Note that the original `awesome-align` tool does not require the `transformers` package.)

In [46]:
!pwd
!git clone https://github.com/graehl/awesome-align.git || (cd awesome-align && git pull)
USE_AWESOME_ALIGN = True

/content
fatal: destination path 'awesome-align' already exists and is not an empty directory.
Already up to date.


In [47]:
!pip install onnx
!pip install onnxruntime
!pip install -r awesome-align/requirements.txt
import sys
sys.path.append('/content/awesome-align')
sys.path.append('/content')
import torch
import itertools

!pip install transformers






In [48]:
# printing
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

def print_align(align_words, desc=''):
    print(f"{desc} {len(align_words)} links {align_words} for '{src}' to '{tgt}'")
    for x in align_words:
      i, j = x
      print(f'{color.BOLD}{color.BLUE}{sent_src[i]}{color.END}==={color.BOLD}{color.RED}{sent_tgt[j]}{color.END}')

Load the multilingual BERT model and its tokenizer.

In [49]:
model_name_or_path='bert-base-multilingual-cased'
model_name = model_name_or_path.split('/')[-1]

import transformers
import awesome_align
from awesome_align import modeling
from awesome_align.configuration_bert import BertConfig
from awesome_align.modeling import BertForMaskedLM
from awesome_align.tokenization_bert import BertTokenizer
from awesome_align.tokenization_utils import PreTrainedTokenizer
from awesome_align.modeling_utils import PreTrainedModel

def init_model_and_tokenizer(
    model_name_or_path,
    config_name = None,
    cache_dir = None,
    tokenizer_name = None,
):
  config_class, model_class, tokenizer_class = BertConfig, BertForMaskedLM, BertTokenizer
  if config_name:
      config = config_class.from_pretrained(config_name, cache_dir=cache_dir)
  elif model_name_or_path:
      config = config_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  else:
      config = config_class()

  if tokenizer_name:
      tokenizer = tokenizer_class.from_pretrained(tokenizer_name, cache_dir=cache_dir)
  elif model_name_or_path:
      tokenizer = tokenizer_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  else:
      raise ValueError(
          "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
          "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
      )

  # pad is actually always 0
  modeling.PAD_ID = tokenizer.pad_token_id
  modeling.CLS_ID = tokenizer.cls_token_id
  modeling.SEP_ID = tokenizer.sep_token_id

  if model_name_or_path:
      model = model_class.from_pretrained(
          model_name_or_path,
          from_tf=bool(".ckpt" in model_name_or_path),
          config=config,
          cache_dir=cache_dir,
      )
  else:
      model = model_class(config=config)

  return model, tokenizer


# True caused, in export to onnx, `Boolean value of Tensor with more than one value is ambiguous`, but we fixed with ExportNthLayer wrapper
if USE_AWESOME_ALIGN:
  model, tokenizer = init_model_and_tokenizer(model_name_or_path)
else:
  model, tokenizer = transformers.AutoModel.from_pretrained(model_name_or_path), transformers.AutoTokenizer.from_pretrained(model_name_or_path)


Input *tokenized* source and target sentences.

In [50]:

import onnx
model.eval()
# just sets mode of model, probably doesn't need to be under no_grad

def extend_mask(attention_mask, dtype=torch.float32):
    if attention_mask.dim() == 3:
        extended_attention_mask = attention_mask[:, None, :, :]
    elif attention_mask.dim() == 2:
        extended_attention_mask = attention_mask[:, None, None, :]
    else:
        raise ValueError(
             "Wrong shape for input_ids or attention_mask"
        )
    extended_attention_mask = extended_attention_mask.to(dtype=dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    return extended_attention_mask

def guess_dtype(model):
  if hasattr(model, 'get_parameter_dtype'):
    return model.get_parameter_dtype()
  elif hasattr(model, 'parameters'):
    return next(model.parameters()).dtype
  else:
    return torch.float32

def make_ones_mask(ids):
  shape = ids.size()
  device = ids.device
  attention_mask = torch.ones(shape, device=device)
  attention_mask[ids==0] = 0
  return attention_mask

def make_extended_mask(ids, dtype=torch.float32):
  attention_mask = make_ones_mask(ids)
  return extend_mask(attention_mask, dtype)

class ExportNthLayer(torch.nn.Module):
    def __init__(self, base_model, align_layer_max=8):
        super().__init__()
        e = base_model.bert if hasattr(base_model, 'bert') else base_model
        self.bert = e
        self.embeddings = e.embeddings
        # For BERT, num_hidden_layers is in config
        self.config = e.config
        self.num_layers = min(e.config.num_hidden_layers, align_layer_max)
        e = e.encoder if hasattr(e, 'encoder') else e
        self.encoder = e
        self.layer = e.layer[:self.num_layers]
        print(f'{self.layer}')

    def forward(self, ids, attention_mask=None, position_ids=None):
      shape = ids.size()
      device = ids.device
      if attention_mask is None:
        attention_mask = make_ones_mask(ids)

      # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
      # ourselves in which case we just need to make it broadcastable to all heads.
      extended_attention_mask = extend_mask(attention_mask, guess_dtype(self.bert))
      input_shape = ids.size()
      token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
      hidden_states = self.embeddings(ids, token_type_ids=token_type_ids, position_ids=position_ids)

      if self.layer is not None:
        for i, layer in enumerate(self.layer):
          hidden_states = layer(hidden_states, attention_mask=extended_attention_mask)
        return hidden_states
      else:
        return self.bert(ids, attention_mask)

USE_ONNX_OPSET=18 #12-13 fails checker (shape missing). 14 passes checker. 16 has improved bert perf. 18 has best (ok for onnxruntime 1.15 which we use on linux)
def to_onnx(model, onnx_file_path, inputs=['input_ids', 'attention_mask'], outputs=['output'], dynamic=True, batch=True, align_layer=None, opset_version=USE_ONNX_OPSET, return_tensor_names=True):
  captions = {0 : 'batch_size', 1: 'sequence_length'} if batch else {0 : 'sequence_length'}
  dynamic_axes = {}
  if dynamic:
    for k in inputs:
      dynamic_axes[k] = captions
    for k in outputs:
      dynamic_axes[k] = captions

  # Create dummy input data
  batch_size = 1
  sequence_length = 128
  dims = (batch_size, sequence_length) if batch else (sequence_length,)
  inputs_ones = tuple(torch.ones(dims) if x != 'input_ids' else torch.randint(0, model.config.vocab_size, dims) for x in inputs)

  hasbert = hasattr(model, 'bert')
  #print(f'hasbert={hasbert}')
  model = model.bert if hasbert else model
  # TODO: figure out how to do first nth encoder layers for non-awesome-align bert
  model = ExportNthLayer(model, align_layer) if USE_AWESOME_ALIGN and align_layer is not None else model
  # Export the model to ONNX
  torch.onnx.export(
      model,
      inputs_ones, #(input_ids, attention_mask),
      onnx_file_path,
      export_params=True,
      opset_version=opset_version,
      do_constant_folding=True,
      input_names = inputs,
      output_names = outputs,
      dynamic_axes=dynamic_axes,
  )

  if return_tensor_names:
    om = onnx.load(onnx_file_path)

    print('initializers: ...')
    for node in om.graph.initializer[-19:]:
      print(f'{node.name}')
    return list(x.output for x in om.graph.node)
  else:
    return f"Model exported to {onnx_file_path}"


DO_ONNX_EXPORT=True

def onnxpathm(x):
  return f'{model_name}-nlayer{x}.onnx'
align_layer_max=10
onnxpath=onnxpathm(align_layer_max)
if DO_ONNX_EXPORT:
    #onnxruntime.python.tools.transformers.export_onnx_model_from_pt(...)
  if False and USE_AWESOME_ALIGN:
    !CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/content/awesome-align python /content/awesome-align/run_align.py --model_name_or_path=bert-base-multilingual-cased --output_onnx=$onnxpath --max_layer=$align_layer_max
  else:
    names = to_onnx(model, onnxpath, align_layer=align_layer_max)
    print(f'... ({len(names)})')
    for x in names[-199:]: print(str(x))
  !du -h $onnxpath


ModuleList(
  (0-9): 10 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
initializers: ...
onnx::MatMul_1359
onnx::MatMul_1360

In [51]:
COPY_TO_DRIVE=False
def cpdrive(onnxpath):
 if COPY_TO_DRIVE:
  from google.colab import drive
  driveroot = '/content/gdrive'
  drive.mount(driveroot, force_remount=True)
  drivedir=f'{driveroot}/MyDrive'
  subdir=f'{drivedir}/awesome'
  !mkdir -p $subdir
  onnxname=onnxpath.split('/')[-1]
  onnxto=f'{subdir}/{onnxname}'
  print(onnxto)
  print(onnxpath)
  !cp $onnxpath $onnxto
  !du -h $onnxto
cpdrive(onnxpath)

In [52]:
#src = 'I bought a new car because I was going through a midlife crisis .'
tgt = 'Я купил новую тачку , потому что я переживал кризис среднего возраста .'
tgt = 'Compré un auto nuevo porque estaba pasando por una crisis de la mediana edad .'

src = 'Hello world'
tgt = 'Salut le monde'
src = 'I love you'
tgt = "Je t ' aime"
psrc = ''
ssrc = ''
ptgt = ''
stgt = ''
if False:
  psrc = 'He said : " '
  ssrc = ' " .'
  ptgt = "Il a dit : « "
  stgt = " » ."
src = f'{psrc}{src}{ssrc}'
tgt = f'{ptgt}{tgt}{stgt}'
srctgt = f'{src} ||| {tgt}'
fpar = 'srctgt.txt'
with open(fpar, 'w') as f:
  f.write(srctgt)
if False:
  !rm align.txt
  !CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/content/awesome-align python /content/awesome-align/run_align.py --output_file=align.txt --model_name_or_path="$model_name_or_path" --data_file="$fpar" --extraction='softmax' --softmax_threshold=1e-3 --batch_size=32
  !cat align.txt

# pre-processing
def wstok(x): return x.strip().split()
def subwords(xs): return [tokenizer.tokenize(x) for x in xs]
def ids(xs): return [tokenizer.convert_tokens_to_ids(x) for x in xs]
sent_src, sent_tgt = wstok(src), wstok(tgt)
token_src, token_tgt = subwords(sent_src), subwords(sent_tgt)

wid_src, wid_tgt = ids(token_src), ids(token_tgt)
#def tokenizer_max_len(tokenizer): return tokenizer.max_len_single_sentence if hasattr(tokenizer, 'max_len_single_sentence') else tokenizer.model_max_length
maxlenkw = {}
if hasattr(tokenizer, 'model_max_length'):
  maxlenkw['model_max_length'] = tokenizer.model_max_length
  maxlenkw['truncation'] = True
  maxlenkw['padding'] = True
else:
  maxlenkw['max_length'] = tokenizer.max_len

def ids_for_model(ids, model, tokenizer): return tokenizer.prepare_for_model(list(itertools.chain(*ids)), return_tensors='pt', **maxlenkw)['input_ids']
print(f'wid {len(wid_src)} x {len(wid_tgt)}')
ids_src, ids_tgt = ids_for_model(wid_src, model, tokenizer), ids_for_model(wid_tgt, model, tokenizer)
#assert ids_src[0] == tokenizer.cls_token_id
#assert ids_src[-1] == tokenizer.sep_token_id
#assert ids_tgt[0] == tokenizer.cls_token_id
#assert ids_tgt[-1] == tokenizer.sep_token_id
print(f'{token_src}')
print(f'{ids_src}')
print(f'{token_tgt}')
print(f'{ids_tgt}')
sub2word_map_src = []
for i, word_list in enumerate(token_src):
  sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
  sub2word_map_tgt += [i for x in word_list]

wid 3 x 4
[['I'], ['love'], ['you']]
tensor([[  101,   146, 16138, 13028,   102]])
[['Je'], ['t'], ["'"], ['aime']]
tensor([[  101, 13796,   188,   112, 62691,   102]])


In [53]:
# pre-processing
def wstok(x): return x.strip().split()
def subwords(xs): return [tokenizer.tokenize(x) for x in xs]
def ids(xs): return [tokenizer.convert_tokens_to_ids(x) for x in xs]
sent_src, sent_tgt = wstok(src), wstok(tgt)
token_src, token_tgt = subwords(sent_src), subwords(sent_tgt)

wid_src, wid_tgt = ids(token_src), ids(token_tgt)
#def tokenizer_max_len(tokenizer): return tokenizer.max_len_single_sentence if hasattr(tokenizer, 'max_len_single_sentence') else tokenizer.model_max_length
maxlenkw = {}
if hasattr(tokenizer, 'model_max_length'):
  maxlenkw['model_max_length'] = tokenizer.model_max_length
  maxlenkw['truncation'] = True
else:
  maxlenkw['max_length'] = tokenizer.max_len

def ids_for_model(ids, model, tokenizer): return tokenizer.prepare_for_model(list(itertools.chain(*ids)), return_tensors='pt', **maxlenkw)['input_ids']
print(f'wid {len(wid_src)} x {len(wid_tgt)}')
ids_src, ids_tgt = ids_for_model(wid_src, model, tokenizer), ids_for_model(wid_tgt, model, tokenizer)
print(f'{token_src}')
print(f'{ids_src}')
print(f'{token_tgt}')
print(f'{ids_tgt}')
sub2word_map_src = []
for i, word_list in enumerate(token_src):
  sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
  sub2word_map_tgt += [i for x in word_list]

wid 3 x 4
[['I'], ['love'], ['you']]
tensor([[  101,   146, 16138, 13028,   102]])
[['Je'], ['t'], ["'"], ['aime']]
tensor([[  101, 13796,   188,   112, 62691,   102]])


In [54]:
def encoder_output_layer(x, opset_version=18):
  prefix = '' if USE_AWESOME_ALIGN else '/encoder'
  n = 'LayerNormalization_output_0' if opset_version >= 18 else 'Add_1_output_0'
  return f'{prefix}/layer.7/output/LayerNorm/{n}'
  # /encoder/layer.10/output/LayerNorm/LayerNormalization_output_0

layername = encoder_output_layer(7, USE_ONNX_OPSET)
print(layername)


/layer.7/output/LayerNorm/LayerNormalization_output_0


In [55]:


import onnxruntime as ort
import onnxruntime as rt

align_layer_max = 10
onnxpath = onnxpathm(align_layer_max)
!ls -l $onnxpath

def onnx_inputs(path, inputs=None):
  if True or inputs is None:
    return [x.name for x in onnx.load(path).graph.input]
  return inputs

import onnx

def modify_onnx_outputs(path, onnxpathout, outputs, inputs=None, checker=True):
  onnx.utils.extract_model(path, onnxpathout, onnx_inputs(path, inputs), outputs)
  if checker:
    onnx.checker.check_model(onnxpathout)
  return onnxpathout


def onnxmask(ids):
  return (ids != 0).to(torch.float32)
  #return make_extended_mask(ids)[0, 0, :, :]


def alignpairs(out_src, out_tgt, sub2word_map_src, sub2word_map_tgt, threshold, debug=False):
      dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))
      print(f'#src={out_src.size()} {out_src}')
      print(f'#tgt={out_tgt.size()} {out_tgt}')
      print(f'#prod={dot_prod.size()} {dot_prod}')
      softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
      softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)

      # tryalso entmax15(dot_prod, dim=...)? also TODO: before softmax mask off cls sep pad tokens
      srctgt = softmax_srctgt > threshold
      tgtsrc = softmax_tgtsrc > threshold
      softmax_inter = srctgt * tgtsrc
      if debug:
        print(f'> {threshold}:\n {softmax_inter}')

      align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
      align_words = set()
      for xyz in align_subwords:
        i, j = xyz[-2], xyz[-1]
        #print(f'subword: {i}-{j}')
        align_words.add( (sub2word_map_src[i], sub2word_map_tgt[j]) )
      return sorted(list(align_words))

sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL


def onnx_word_encs(session, ids_src):
  if len(ids_src.size()) < 2:
    ids_src = ids_src.reshape(1, -1)
  msrc = onnxmask(ids_src)
  print(f'ids #ids={ids_src.size()} {ids_src}')
  print(f'mask (# {msrc.size()}): {msrc}')
  osrc = torch.tensor(session.run(output_names, {input_names[0]: ids_src.numpy(), input_names[1]: msrc.numpy()})[0][:,1:-1,:])
  print(f'output #in={ids_src.size()} #out={osrc.size()} {osrc}')
  return osrc

#'output',
for use_output in [layername]:
  print(use_output)

  session_path = onnxpath
  if use_output is not None and use_output != 'output':
    onnxpathout = f'{onnxpath}.out.{use_output.replace("/","_")}'
    assert onnxpathout != onnxpath
    print(onnxpathout)
    session_path = modify_onnx_outputs(onnxpath, onnxpathout, outputs=[use_output], inputs=['input_ids', 'attention_mask'], checker=USE_ONNX_OPSET>13)
    cpdrive(onnxpathout)

  !du -h $session_path
  session = ort.InferenceSession(session_path, sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

  input_names = [x.name for x in session.get_inputs()]
  output_names = [x.name for x in session.get_outputs()]
  print(input_names)
  print(output_names)

  osrc = onnx_word_encs(session, ids_src)
  print(f'osrc:{osrc}')
  osrc = onnx_word_encs(session, ids_src)
  print(f'osrc:{osrc}')
  otgt = onnx_word_encs(session, ids_tgt)
  print(f'otgt:{otgt}')
  otgt = onnx_word_encs(session, ids_tgt)
  print(f'otgt:{otgt}')
  del session
  print(f'{osrc.size()} x {otgt.size()}')

  # , 1e-8
  for threshold in [1e-3, 1e-3]:
    align_words = alignpairs(osrc, otgt, sub2word_map_src, sub2word_map_tgt, threshold, debug=True)
    # 13-14 2-1 10-8 0-0 5-4 4-2 12-9 11-12 3-3 8-6 1-0 7-5 9-7
    # [(0, 0), (1, 0), (2, 1), (3, 3), (4, 2), (5, 4), (7, 5), (8, 6), (9, 7), (10, 8), (12, 9), (13, 14)]
    print_align(align_words, f'onnx {use_output} {threshold}')

-rw-r--r-- 1 root root 652498330 May  1 01:22 bert-base-multilingual-cased-nlayer10.onnx
/layer.7/output/LayerNorm/LayerNormalization_output_0
bert-base-multilingual-cased-nlayer10.onnx.out._layer.7_output_LayerNorm_LayerNormalization_output_0
569M	bert-base-multilingual-cased-nlayer10.onnx.out._layer.7_output_LayerNorm_LayerNormalization_output_0




['input_ids', 'attention_mask']
['/layer.7/output/LayerNorm/LayerNormalization_output_0']
ids #ids=torch.Size([1, 5]) tensor([[  101,   146, 16138, 13028,   102]])
mask (# torch.Size([1, 5])): tensor([[1., 1., 1., 1., 1.]])
output #osrc=torch.Size([1, 3, 768]) tensor([[[ 0.8144,  0.2974, -0.0447,  ...,  1.3720,  0.3257,  0.3161],
         [ 0.9315,  0.7958, -1.1756,  ...,  1.5676,  0.7013, -0.3384],
         [ 0.7285,  1.1777, -0.7042,  ...,  0.8818,  0.0295, -0.2367]]])
osrc:tensor([[[ 0.8144,  0.2974, -0.0447,  ...,  1.3720,  0.3257,  0.3161],
         [ 0.9315,  0.7958, -1.1756,  ...,  1.5676,  0.7013, -0.3384],
         [ 0.7285,  1.1777, -0.7042,  ...,  0.8818,  0.0295, -0.2367]]])
ids #ids=torch.Size([1, 5]) tensor([[  101,   146, 16138, 13028,   102]])
mask (# torch.Size([1, 5])): tensor([[1., 1., 1., 1., 1.]])
output #osrc=torch.Size([1, 3, 768]) tensor([[[ 0.8144,  0.2974, -0.0447,  ...,  1.3720,  0.3257,  0.3161],
         [ 0.9315,  0.7958, -1.1756,  ...,  1.5676,  0.7013, -

In [56]:



# alignment

def sents_without_startend(batch): return batch[:, 1:-1]
if False or USE_AWESOME_ALIGN:
  def hiddens(model, ids, align_layer):
    return sents_without_startend(model.bert(ids, align_layer=align_layer, attention_mask=(ids!=0)))
else:
  def hidden(model, ids): return model(ids.unsqueeze(0), output_hidden_states=True)[2]
  def hiddens(model, ids, align_layer):
    return sents_without_startend(hidden(model, ids)[align_layer])


DECODE_AWESOME_ALIGN=False
USE_GAW=DECODE_AWESOME_ALIGN and USE_AWESOME_ALIGN
print(f'USE_AWESOME_ALIGN={USE_AWESOME_ALIGN} decode:{DECODE_AWESOME_ALIGN} USE_GAW:{USE_GAW}')
for align_layer in range(max(0,align_layer_max - 3),align_layer_max+1):
 last_align = None
 threshold = 1e-3
 with torch.no_grad():
   if not USE_GAW:
        out_src = hiddens(model, ids_src, align_layer)
        out_tgt = hiddens(model, ids_tgt, align_layer)
        print(f'{align_layer} src: {out_src.size()} {out_src}\n tgt: {out_tgt.size()} {out_tgt}'
   for it in range(6):
    for repeat in range(2):
      if USE_GAW:
        # get_aligned_word takes a batch.
        # print(f'{len(ids_src[0])} x {len(ids_tgt[0])}')
        align_words = model.get_aligned_word(ids_src, ids_tgt, (sub2word_map_src,), (sub2word_map_tgt,), 'cpu', len(ids_src), len(ids_tgt), align_layer, 'softmax', threshold, True)[0]
      else:
          #pdb.set_trace()
          #out_src = hiddens(model, ids_src, align_layer)
          #out_tgt = hiddens(model, ids_tgt, align_layer)
          #out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
          #out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
          align_words = alignpairs(out_src, out_tgt, sub2word_map_src, sub2word_map_tgt, threshold)

      align_words = sorted(list(align_words))
      if align_words != last_align:
        print_align(align_words,desc = f' (layer {align_layer} > {threshold:.3g})')
      print(f". {threshold} {repeat}\n")

      last_align = align_words
      threshold = threshold * 1e-1

USE_AWESOME_ALIGN=True decode:False USE_GAW:False
#src=torch.Size([1, 3, 768]) tensor([[[ 0.9197, -0.0453,  1.0196,  ...,  1.9355, -0.0070,  0.6302],
         [ 0.6683, -0.0320, -0.1091,  ...,  2.2523,  0.6345,  0.3268],
         [ 0.3858,  0.3171, -0.4968,  ...,  0.9330, -0.0409,  0.3151]]])
#tgt=torch.Size([1, 4, 768]) tensor([[[ 0.2208,  0.5041, -0.5499,  ...,  1.8713,  0.3883,  0.8834],
         [ 0.4861,  0.4773, -0.6770,  ...,  0.5610,  0.3776,  1.5606],
         [ 0.0660,  0.1112,  0.7822,  ...,  0.2562,  0.7453,  2.0105],
         [-0.0746,  0.7246, -0.7382,  ...,  0.9855,  0.4589,  0.1077]]])
#prod=torch.Size([1, 3, 4]) tensor([[[463.5422, 415.5837, 373.3694, 393.3855],
         [350.9561, 418.6241, 359.9503, 413.1902],
         [375.4362, 395.8914, 359.2048, 440.2242]]])
 (layer 7 > 0.001) 3 links [(0, 0), (1, 1), (2, 3)] for 'I love you' to 'Je t ' aime'
[1m[94mI[0m===[1m[91mJe[0m
[1m[94mlove[0m===[1m[91mt[0m
[1m[94myou[0m===[1m[91maime[0m
. 0.001 0

#src=t