In [1]:
from transformers import XLMRobertaTokenizer, RobertaTokenizer, AutoTokenizer
from itertools import chain
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
en_sentence: str = "Hello my good friend! It is good to see you again."
zh_sentence: str = "你好，我的朋友！很高兴再次见到你。"  # translation of en sentence

tok = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
tok2 = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")
tok3 = AutoTokenizer.from_pretrained("bert-base-cased")

In [3]:
sep_token: str = "<SEP>"

In [4]:
en_line: list[str] = en_sentence.strip().split()
zh_line: list[str] = zh_sentence.strip().split()

In [5]:
# XLM Tokenizer adds "__" at start of each tokenized word for some reason
# en_tokens: list[list[str]] = [tok.tokenize(word) for word in en_line]
zh_tokens: list[list[str]] = [tok.tokenize(word) for word in zh_line]


In [6]:
# XLM tokenizer is the only one that can process chinese tokens. Roberta will tokenize into strange characters, while Deberta cannot even recognise Chinese characters.
zh_tokens

[['▁', '你好', ',', '我', '的朋友', '!', '很高兴', '再次', '见到', '你', '。']]

In [225]:
# enumerate will process sentence word by word as opposed to char by char
wbwexamples = []
for i, unit in enumerate(en_line):
    wbwexamples.append(
        en_line[:i] + [sep_token] + [en_line[i]] + [sep_token] + en_line[i + 1 :]
    )
print(wbwexamples)  ## lists of lists


[['<SEP>', 'Hello', '<SEP>', 'my', 'good', 'friend!', 'It', 'is', 'good', 'to', 'see', 'you', 'again.'], ['Hello', '<SEP>', 'my', '<SEP>', 'good', 'friend!', 'It', 'is', 'good', 'to', 'see', 'you', 'again.'], ['Hello', 'my', '<SEP>', 'good', '<SEP>', 'friend!', 'It', 'is', 'good', 'to', 'see', 'you', 'again.'], ['Hello', 'my', 'good', '<SEP>', 'friend!', '<SEP>', 'It', 'is', 'good', 'to', 'see', 'you', 'again.'], ['Hello', 'my', 'good', 'friend!', '<SEP>', 'It', '<SEP>', 'is', 'good', 'to', 'see', 'you', 'again.'], ['Hello', 'my', 'good', 'friend!', 'It', '<SEP>', 'is', '<SEP>', 'good', 'to', 'see', 'you', 'again.'], ['Hello', 'my', 'good', 'friend!', 'It', 'is', '<SEP>', 'good', '<SEP>', 'to', 'see', 'you', 'again.'], ['Hello', 'my', 'good', 'friend!', 'It', 'is', 'good', '<SEP>', 'to', '<SEP>', 'see', 'you', 'again.'], ['Hello', 'my', 'good', 'friend!', 'It', 'is', 'good', 'to', '<SEP>', 'see', '<SEP>', 'you', 'again.'], ['Hello', 'my', 'good', 'friend!', 'It', 'is', 'good', 'to', 's

In [226]:
for example in wbwexamples:
    print(example)

['<SEP>', 'Hello', '<SEP>', 'my', 'good', 'friend!', 'It', 'is', 'good', 'to', 'see', 'you', 'again.']
['Hello', '<SEP>', 'my', '<SEP>', 'good', 'friend!', 'It', 'is', 'good', 'to', 'see', 'you', 'again.']
['Hello', 'my', '<SEP>', 'good', '<SEP>', 'friend!', 'It', 'is', 'good', 'to', 'see', 'you', 'again.']
['Hello', 'my', 'good', '<SEP>', 'friend!', '<SEP>', 'It', 'is', 'good', 'to', 'see', 'you', 'again.']
['Hello', 'my', 'good', 'friend!', '<SEP>', 'It', '<SEP>', 'is', 'good', 'to', 'see', 'you', 'again.']
['Hello', 'my', 'good', 'friend!', 'It', '<SEP>', 'is', '<SEP>', 'good', 'to', 'see', 'you', 'again.']
['Hello', 'my', 'good', 'friend!', 'It', 'is', '<SEP>', 'good', '<SEP>', 'to', 'see', 'you', 'again.']
['Hello', 'my', 'good', 'friend!', 'It', 'is', 'good', '<SEP>', 'to', '<SEP>', 'see', 'you', 'again.']
['Hello', 'my', 'good', 'friend!', 'It', 'is', 'good', 'to', '<SEP>', 'see', '<SEP>', 'you', 'again.']
['Hello', 'my', 'good', 'friend!', 'It', 'is', 'good', 'to', 'see', '<SEP

In [227]:
line_len: int = len(en_line)
for example in wbwexamples:
    assert line_len == len(example) - 2

In [228]:
assert len(wbwexamples) == len(en_line)

In [229]:
en_tokens: list[list[list[str]]] = [
    [tok.tokenize(word) for word in example] for example in wbwexamples
]
zh_tokens: list[list[str]] = [tok.tokenize(word) for word in zh_sentence]

In [230]:
type(en_tokens), type(en_tokens[0]), type(en_tokens[0][0]), type(en_tokens[0][0][0])

(list, list, list, str)

In [231]:
type(zh_tokens), type(zh_tokens[0]), type(zh_tokens[0][0])

(list, list, str)

In [232]:
# for i in range(len(en_tokens)):
#     print(f"Element {i} of highest-level list")
#     for j in range(len(en_tokens[i])):
#         print(f"Element {i, j}, {type(en_tokens[i][j])} of {type(en_tokens[i][j][0])}")
#         print(en_tokens[i][j])

In [233]:
# for i in range(len(zh_tokens)):
#     print(f"Element {i}, {type(zh_tokens[i])} of {type(zh_tokens[i][0])}")
#     print(zh_tokens[i])

In [234]:
en_w2id_1: list[list[list[str]]] = [
    [tok.convert_tokens_to_ids(token) for token in tokens] for tokens in en_tokens
]

In [235]:
# en_w2id_1

In [236]:
en_w2id_2: list[list[list[int]]] = [
    [[tok.convert_tokens_to_ids(token) for token in word] for word in sent]
    for sent in en_tokens
]

In [257]:
en_w2id_2

[[[4426, 294, 21290, 2740],
  [35378],
  [4426, 294, 21290, 2740],
  [759],
  [4127],
  [34391, 38],
  [1650],
  [83],
  [4127],
  [47],
  [1957],
  [398],
  [13438, 5]],
 [[35378],
  [4426, 294, 21290, 2740],
  [759],
  [4426, 294, 21290, 2740],
  [4127],
  [34391, 38],
  [1650],
  [83],
  [4127],
  [47],
  [1957],
  [398],
  [13438, 5]],
 [[35378],
  [759],
  [4426, 294, 21290, 2740],
  [4127],
  [4426, 294, 21290, 2740],
  [34391, 38],
  [1650],
  [83],
  [4127],
  [47],
  [1957],
  [398],
  [13438, 5]],
 [[35378],
  [759],
  [4127],
  [4426, 294, 21290, 2740],
  [34391, 38],
  [4426, 294, 21290, 2740],
  [1650],
  [83],
  [4127],
  [47],
  [1957],
  [398],
  [13438, 5]],
 [[35378],
  [759],
  [4127],
  [34391, 38],
  [4426, 294, 21290, 2740],
  [1650],
  [4426, 294, 21290, 2740],
  [83],
  [4127],
  [47],
  [1957],
  [398],
  [13438, 5]],
 [[35378],
  [759],
  [4127],
  [34391, 38],
  [1650],
  [4426, 294, 21290, 2740],
  [83],
  [4426, 294, 21290, 2740],
  [4127],
  [47],
  [1957]

Sanity check: why are the lists equivalent? Something to do with transformers tokenizers?

In [238]:
assert en_w2id_1 == en_w2id_2

In [239]:
zh_w2id_1: list[list[int]] = [
    [tok.convert_tokens_to_ids(token) for token in tokens] for tokens in zh_tokens
]

In [240]:
zh_w2id_2: list[list[int]] = [tok.convert_tokens_to_ids(token) for token in zh_tokens]

In [241]:
assert zh_w2id_1 == zh_w2id_2

In [242]:
chained_en: list[list[int]] = [list(chain(*ids)) for ids in en_w2id_2]

In [243]:
en_input_ids: list[list[int]] = [
    tok.prepare_for_model(list(chain(*word_ids)), truncation=True, max_length=256)[
        "input_ids"
    ]
    for word_ids in en_w2id_2
]

In [260]:
# en_input_ids

In [245]:
en_att_mask = [
    tok.prepare_for_model(list(chain(*word_ids)), truncation=True, max_length=256)[
        "attention_mask"
    ]
    for word_ids in en_w2id_2
]

In [246]:
en_att_mask

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]

In [247]:
en_input_ids[0]

[0,
 4426,
 294,
 21290,
 2740,
 35378,
 4426,
 294,
 21290,
 2740,
 759,
 4127,
 34391,
 38,
 1650,
 83,
 4127,
 47,
 1957,
 398,
 13438,
 5,
 2]

In [262]:
en_input_ids_pt: torch.Tensor = torch.Tensor(
    [
        tok.prepare_for_model(list(chain(*word_ids)), truncation=True, max_length=256)[
            "input_ids"
        ]
        for word_ids in en_w2id_2
    ]
)

# torch.ones_like(en_input_ids_pt[1])
en_input_ids_pt

tensor([[0.0000e+00, 4.4260e+03, 2.9400e+02, 2.1290e+04, 2.7400e+03, 3.5378e+04,
         4.4260e+03, 2.9400e+02, 2.1290e+04, 2.7400e+03, 7.5900e+02, 4.1270e+03,
         3.4391e+04, 3.8000e+01, 1.6500e+03, 8.3000e+01, 4.1270e+03, 4.7000e+01,
         1.9570e+03, 3.9800e+02, 1.3438e+04, 5.0000e+00, 2.0000e+00],
        [0.0000e+00, 3.5378e+04, 4.4260e+03, 2.9400e+02, 2.1290e+04, 2.7400e+03,
         7.5900e+02, 4.4260e+03, 2.9400e+02, 2.1290e+04, 2.7400e+03, 4.1270e+03,
         3.4391e+04, 3.8000e+01, 1.6500e+03, 8.3000e+01, 4.1270e+03, 4.7000e+01,
         1.9570e+03, 3.9800e+02, 1.3438e+04, 5.0000e+00, 2.0000e+00],
        [0.0000e+00, 3.5378e+04, 7.5900e+02, 4.4260e+03, 2.9400e+02, 2.1290e+04,
         2.7400e+03, 4.1270e+03, 4.4260e+03, 2.9400e+02, 2.1290e+04, 2.7400e+03,
         3.4391e+04, 3.8000e+01, 1.6500e+03, 8.3000e+01, 4.1270e+03, 4.7000e+01,
         1.9570e+03, 3.9800e+02, 1.3438e+04, 5.0000e+00, 2.0000e+00],
        [0.0000e+00, 3.5378e+04, 7.5900e+02, 4.1270e+03, 4.42

In [None]:
zh_input_ids: list[int] = tok.prepare_for_model(
    list(chain(*zh_w2id_2)), truncation=True, max_length=256
)["input_ids"]

In [250]:
zh_input_ids

[0,
 73675,
 6,
 1322,
 6,
 4,
 13129,
 6,
 43,
 6,
 182529,
 6,
 16157,
 711,
 6,
 2165,
 6,
 1395,
 6,
 29738,
 6,
 2058,
 6,
 4465,
 6,
 11415,
 6,
 789,
 73675,
 6,
 30,
 2]

In [251]:
len(en_input_ids)

11

In [252]:
en_input_ids_pt.shape

torch.Size([11, 23])

In [None]:
zh_input_ids: torch.Tensor = tok.prepare_for_model(
    list(chain(*zh_w2id_2)), truncation=True, max_length=256, return_tensors="pt"
)["input_ids"][1:]

In [254]:
zh_input_ids_pt = zh_input_ids.repeat(len(en_input_ids), 1)
zh_input_ids_pt.shape

torch.Size([11, 31])

In [255]:
zh_bpe2word = []
# recall: zh_tokens is list[list[str]]
for k, word_list in enumerate(zh_tokens):
    # word_list is list[str]
    zh_bpe2word += [k for _ in word_list]

In [256]:
zh_bpe2word

[0,
 1,
 1,
 2,
 2,
 3,
 4,
 4,
 5,
 5,
 6,
 6,
 7,
 8,
 8,
 9,
 9,
 10,
 10,
 11,
 11,
 12,
 12,
 13,
 13,
 14,
 14,
 15,
 16,
 16]

In [263]:
en_input_ids[0]

[0,
 4426,
 294,
 21290,
 2740,
 35378,
 4426,
 294,
 21290,
 2740,
 759,
 4127,
 34391,
 38,
 1650,
 83,
 4127,
 47,
 1957,
 398,
 13438,
 5,
 2]

In [268]:
en_input_ids_pt.shape

torch.Size([11, 23])