In [1]:
import torch
from transformers import BertTokenizer

In [2]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

orig_text = [
    "I like bananas.",
    "Yesterday the mailman came by!",
    "Do you enjoy cookies?",
]
edit_text = [
    "Do you?",
    "He delivered a mystery package.",
    "My grandma just baked some!",
]

In [3]:
MAX_LEN = 20

In [4]:
def truncate_and_merge(tokens_a, tokens_b, max_lenth):
    i = len(tokens_a)
    j = len(tokens_b)
    while True:
        total_length = i + j
        if i <= 0 or j <= 0:
            raise ValueError("Concat Error. One of the String Len is 0")
        if total_length <= max_lenth:
            break
        if i > j:
            i = i - 1
        else:
            j = j - 1

    print(type(tokens_a))
    print(tokens_a.narrow(0, 0, i))
    print(tokens_b.narrow(0, 0, j))
    return torch.cat((tokens_a[:i], tokens_b[:j]), 0)

In [5]:
sent_1 = orig_text[0]
sent_1_tokens = tokenizer.encode_plus(
        sent_1,
        add_special_tokens=False,
        padding=False,
        return_tensors="pt",
    )
sent_2 = edit_text[0]
sent_2_tokens = tokenizer.encode_plus(
        sent_2,
        add_special_tokens=False,
        padding=False,
        return_tensors="pt",
    )

In [6]:
# sent_1_tokens["input_ids"].size(1)

In [7]:
# x = truncate_and_merge(sent_1_tokens, sent_2_tokens, 5)
max_length = MAX_LEN - 3
i = sent_1_tokens["input_ids"].size(1)
j = sent_2_tokens["input_ids"].size(1)
while True:
    total_length = i + j
    if i <= 0 or j <= 0:
        raise ValueError("Concat Error. One of the String Len is 0")
    if total_length <= max_length:
        break
    if i > j:
        i = i - 1
    else:
        j = j - 1
total_padding_required = MAX_LEN - (total_length + 3)

In [8]:
i, j, total_padding_required

(4, 3, 10)

In [9]:
# dir(sent_1_tokens)
for k, v in sent_1_tokens.items():
    sent_1_tokens[k] = sent_1_tokens[k][:, :i]

    # dir(sent_1_tokens)
for k, v in sent_2_tokens.items():
    sent_2_tokens[k] = sent_2_tokens[k][:, :j]

In [10]:
sent_1_tokens

{'input_ids': tensor([[ 1045,  2066, 26191,  1012]]), 'token_type_ids': tensor([[0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

In [11]:
sent_2_tokens

{'input_ids': tensor([[2079, 2017, 1029]]), 'token_type_ids': tensor([[0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1]])}

In [12]:
# For Special Tokens [CLS]=[101] and [SEP]=[102]:
# token_type_ids=0
# attention_mask=1
# For Padding Token [PAD]=0:
# token_type_ids=0
# attention_mask=1

In [13]:
concatenated_tokens = {}
zero_t = torch.tensor([[0]])
one_t = torch.tensor([[1]])
cls_t = torch.tensor([[101]])
sep_t = torch.tensor([[102]])
pad_tokens = torch.tensor([[0]*total_padding_required])
for k, v in sent_1_tokens.items():
    if k == "input_ids":
        concatenated_tokens[k] = torch.cat((cls_t, sent_1_tokens[k], sep_t, sent_2_tokens[k], sep_t,pad_tokens), axis=1)
    if k == "token_type_ids":
        concatenated_tokens[k] = torch.cat((zero_t, sent_1_tokens[k], zero_t, sent_2_tokens[k], zero_t,pad_tokens), axis=1)
    if k == "attention_mask":
        concatenated_tokens[k] = torch.cat((one_t, sent_1_tokens[k], one_t, sent_2_tokens[k], one_t,pad_tokens), axis=1)

In [14]:
concatenated_tokens

{'input_ids': tensor([[  101,  1045,  2066, 26191,  1012,   102,  2079,  2017,  1029,   102,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [18]:
concatenated_tokens["input_ids"].size(1), concatenated_tokens["token_type_ids"].size(1), concatenated_tokens["attention_mask"].size(1)

(20, 20, 20)