In [6]:
!hostname

966187f2b744


In [7]:
import os
from pathlib import Path

print(os.getcwd())
# home = os.path.expanduser("~")
# home_path = Path(home)
# print(home_path)
project_path = Path("/workspace/mlx-week7")
print(project_path)
os.chdir(project_path)

/workspace/mlx-week7
/workspace/mlx-week7


In [8]:
!pwd

/workspace/mlx-week7


In [9]:
import os

os.getcwd()

import dotenv

dotenv.load_dotenv(project_path / ".env")

True

In [10]:
!which python

/root/miniconda3/envs/mlx-week7/bin/python


In [11]:
import torch
import transformers
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    AutoModel,
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel

In [12]:
import os

assert os.environ["HF_TOKEN"]

In [13]:
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    hf_token=os.environ["HF_TOKEN"],
    cache_dir=project_path / "cache",
)

In [14]:
tokenizer.special_tokens_map  # NOTE

{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}

In [15]:
print(tokenizer.pad_token_id)

None


## Tokenization routines for training and test data generation

In [16]:
tokenizer

LlamaTokenizerFast(name_or_path='mistralai/Mixtral-8x7B-Instruct-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [17]:
CUTOFF_LEN = 2048

INST_START_TOKEN = "[INST]"
INST_END_TOKEN = "[/INST]"
NAV_START_TAG = "[NAV]"
NAV_END_TAG = "[/NAV]"
NL = "\n"
LANGUAGE = "English"
REASON_START_TAG = "[REASON]"
REASON_END_TAG = "[/REASON]"
NUM_INSTRUCTIONS = 9
BOS_TOKEN = "<s>"
assert tokenizer.bos_token == BOS_TOKEN
EOS_TOKEN = "</s>"
assert tokenizer.eos_token == EOS_TOKEN

"""
There is a lot of important config here that fix_tokenizer will set up. IN the end it 
will be like:

LlamaTokenizerFast(name_or_path='mistralai/Mixtral-8x7B-Instruct-v0.1', 
    vocab_size=32000, 
    model_max_length=1000000000000000019884624838656,
    is_fast=True, 
    padding_side='left', 
    truncation_side='right', 
    special_tokens={
        'bos_token': '<s>', 
        'eos_token': '</s>', 
        'unk_token': '<unk>', 
        'pad_token': '[PAD]', 
        'additional_special_tokens': ['[INST]', '[/INST]', '[NAV]', '[/NAV]', '[REASON]', '[/REASON]']
    }, 
    clean_up_tokenization_spaces=False),  
    added_tokens_decoder={
            0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32000: AddedToken("[INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32001: AddedToken("[/INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32000: AddedToken("[INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32001: AddedToken("[/INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32002: AddedToken("[NAV]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32003: AddedToken("[/NAV]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32004: AddedToken("[REASON]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32005: AddedToken("[/REASON]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    }
"""


def fix_tokenizer(tokenizer):
    tokenizer.pad_token = "[PAD]"
    # https://huggingface.co/docs/transformers/en/pad_truncation#padding-and-truncation
    # NOTE: special tokens added by default, padding si
    tokenize = lambda prompt: tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding="max_length",
        # THIS IS IMPORTANT for calling the base tokenizer's build_inputs_with_special_tokens as
        # explained https://huggingface.co/docs/transformers/v4.40.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__
        # We already add it in the prompt preparation so we don't want to add it twice.
        add_special_tokens=False,
    )

    special_tokens_dict = {
        "additional_special_tokens": [
            INST_START_TOKEN,
            INST_END_TOKEN,
            NAV_START_TAG,
            NAV_END_TAG,
            REASON_START_TAG,
            REASON_END_TAG,
        ]
    }
    tokenizer.add_special_tokens(special_tokens_dict)

    return tokenize


def generate_train_prompt(user_query):
    sys_msg = "Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using {NAV_START_TAG} and {NAV_END_TAG} tags as well as the reason tags:  {REASON_START_TAG} and {REASON_END_TAG}:as shown in the examples that follow:"
    p = (
        f"{BOS_TOKEN} {INST_START_TOKEN} "
        + sys_msg
        + "\n"
        + user_query["chunk"].strip()
        + f" {INST_END_TOKEN} "
        + user_query["navs"].strip()
        + f" {EOS_TOKEN}"
    )
    return p


def generate_test_prompt(user_query):
    sys_msg = f"Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using {NAV_START_TAG} and {NAV_END_TAG} tags as well as the reason tags: {REASON_START_TAG} and {REASON_END_TAG}:"
    p = (
        f"{BOS_TOKEN} {INST_START_TOKEN} "
        + sys_msg
        + "\n"
        + user_query["chunk"].strip()
        + f" {INST_END_TOKEN} "
    )
    return p


tokenize = fix_tokenizer(tokenizer)

## Lets make sure we can encode a test sentence and decode it back the way we want it.

In [18]:
test_prompt = generate_train_prompt(
    {
        "chunk": "Test me",
        "navs": f"{NAV_START_TAG} test nav {NAV_END_TAG}{REASON_START_TAG} some reason {REASON_END_TAG}",
    }
)
print(test_prompt)
test_tokens = tokenize(test_prompt)
print(f"# tokens with padded ids {len(test_tokens['input_ids'])}")
input_ids = torch.tensor(test_tokens["input_ids"])
attn_mask = torch.tensor(test_tokens["attention_mask"])

unmasked_text = torch.masked_select(input_ids, attn_mask.bool())
print(unmasked_text)
# Now lets decode the text
print(".")
print(tokenizer.decode(unmasked_text, add_special_tokens=False))

<s> [INST] Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using {NAV_START_TAG} and {NAV_END_TAG} tags as well as the reason tags:  {REASON_START_TAG} and {REASON_END_TAG}:as shown in the examples that follow:
Test me [/INST] [NAV] test nav [/NAV][REASON] some reason [/REASON] </s>
# tokens with padded ids 2048
tensor([    1,   259, 32000, 28705, 12628,   272,  2245,   477,   264,  9863,
          288,  3884,  1820, 18063,   264,  2948,  7103,  1059,   272,  2990,
          302,  4222, 28725,  9131,  5099,   302,   272,  2245,   369,  6685,
         2948, 18132, 11382,   297,   378,  1413,   371,  3384, 28790, 28730,
        12241, 28730, 12137, 28752,   304,   371,  3384, 28790, 28730,  5000,
        28730, 12137, 28752, 12944,   390,  1162,   390,   272,  2611, 12944,
        28747, 28705,   371,   896,  2109,   832, 28730, 12241, 28730, 12137,
        28752

In [19]:
# Now for test tokenization
test_prompt = generate_test_prompt(
    {
        "chunk": "Test me",
    }
)
print(test_prompt)
test_tokens = tokenize(test_prompt)
print(f"# tokens with padded ids {len(test_tokens['input_ids'])}")
input_ids = torch.tensor(test_tokens["input_ids"])
attn_mask = torch.tensor(test_tokens["attention_mask"])

unmasked_text = torch.masked_select(input_ids, attn_mask.bool())
print(unmasked_text)
# Now lets decode the text
print(".")
print(tokenizer.decode(unmasked_text, add_special_tokens=False))

<s> [INST] Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using [NAV] and [/NAV] tags as well as the reason tags: [REASON] and [/REASON]:
Test me [/INST] 
# tokens with padded ids 2048
tensor([    1,   259, 32000, 28705, 12628,   272,  2245,   477,   264,  9863,
          288,  3884,  1820, 18063,   264,  2948,  7103,  1059,   272,  2990,
          302,  4222, 28725,  9131,  5099,   302,   272,  2245,   369,  6685,
         2948, 18132, 11382,   297,   378,  1413, 28705, 32002, 28705,   304,
        28705, 32003, 28705, 12944,   390,  1162,   390,   272,  2611, 12944,
        28747, 28705, 32004, 28705,   304, 28705, 32005,   714,    13,  1963,
          528, 28705, 32001,   259])
.
<s>  [INST]  Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instruction

# Load the dataset for walking tours

In [20]:
import json

serialized_data_gpt_four = json.load(open("gpt_four_annotations.json"))
serialized_data_gpt_four_test = json.load(open("gpt_four_annotations_test.json"))
serialized_data_gpt_four[0], serialized_data_gpt_four_test[0]

({'chunk': 'Our route starts at the south side of Victoria Park, at Bonner Hall Bridge – best accessed by getting the Tube to Bethnal Green, a ten-minute walk away. The first portion of this route offers clues about the nature of east London before the arrival of the Regent’s Canal, taking in several buildings that predate the waterways, and whose fates were inexorably changed by its arrival. Before moving on, take a moment to reflect on Victoria Park. In the eighteenth century, this was all open pasture, interspersed with the odd brick kiln and market garden. The one notable feature was Bonner Hall, so called after the sixteenth-century bishop of London Edmund Bonner. All this was to change in the nineteenth century. As London expanded, calls for public parks grew; in 1840, Queen Victoria was presented with a petition signed by 30,000 residents. The Crown estate purchased 218 acres in the area and, over the next few years, converted it into Victoria Park. The park shares a family rese

In [21]:
max(
    [
        len(s["chunk"].strip().split()) + len(s["navs"].strip().split())
        for s in serialized_data_gpt_four
    ]
)

673

# Create the model
[ ] Train for a bit on this dataset
[ ] Save the model.
[ ] Test on some canned held out data from bermondsey street

In [22]:
# del model
# del trainer
import gc

gc.collect()
torch.cuda.empty_cache()

In [23]:
model_save_dir = project_path / "mixtral-moe-lora-instruct-walking-tour-london"

In [24]:
model_save_dir

PosixPath('/workspace/mlx-week7/mixtral-moe-lora-instruct-walking-tour-london')

### Change this flag to reload a pretrained model checkpoint.

In [25]:
load_pretrained = True

In [26]:
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir=str(project_path / "cache"),
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

## Resize the model based on the tokenizer

In [27]:
model.resize_token_embeddings(len(tokenizer))

Embedding(32006, 4096)

### Alternatively load the PEFT model from adapter
See: https://huggingface.co/docs/transformers/en/peft#load-a-peft-adapter

In [28]:
# if load_pretrained:
    # This can't work here. Cause the model has a changed shape due to fine tuning we did. 
    # To acually load it like this we would have to save our additional model on github and then 
    # use it
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_save_dir / "best_saved_model",
    #     load_in_4bit=True,
    #     torch_dtype=torch.float16,
    #     device_map="auto",
    #     cache_dir=project_path / "cache",
    # )

    # But Even this is not quite right, we need to load a PEFT model using the LoraConfig
    # and add the saved adapter
    # model.load_adapter(model_save_dir / "best_saved_model")

In [29]:
model

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32006, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_

## Now tokenize the dataset

In [30]:
from datasets import Dataset

walking_tour_dataset = Dataset.from_list(serialized_data_gpt_four)
walking_tour_dataset_test = Dataset.from_list(serialized_data_gpt_four_test)

In [31]:
walking_tour_dataset_tokens = walking_tour_dataset.shuffle().map(
    lambda x: tokenize(generate_train_prompt(x)), remove_columns=["chunk", "navs"]
)
walking_tour_dataset_tokens_test = walking_tour_dataset_test.shuffle().map(
    lambda x: tokenize(generate_test_prompt(x)), remove_columns=["chunk"]
)

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

In [32]:
type(walking_tour_dataset_tokens[0]["input_ids"])

list

## get the PEFT model

In [33]:
# Prepare model for k-bit training
from peft import inject_adapter_in_model
model = prepare_model_for_kbit_training(model)
# model = prepare_model_for_kbit_training(model)
LORA_R = 8
LORA_ALPHA = 2 * LORA_R
LORA_DROPOUT = 0.1
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["w1", "w2", "w3"],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
if not load_pretrained:
    model = get_peft_model(model, config)
else:
    # Hmm a shame the in place thing did not work. The issue is that thought it modifies the model 
    # it does not change its type. 
    # inject_adapter_in_model(config, model, str(model_save_dir / "best_saved_model"))
    model = get_peft_model(model, config)


NameError: name 'inject_adapter_in_model' is not defined

In [34]:
from peft import inject_adapter_in_model
inject_adapter_in_model(config, model, str(model_save_dir / "checkpoint-74"))


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MixtralForCausalLM(
      (model): MixtralModel(
        (embed_tokens): Embedding(32006, 4096)
        (layers): ModuleList(
          (0-31): 32 x MixtralDecoderLayer(
            (self_attn): MixtralSdpaAttention(
              (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): MixtralRotaryEmbedding()
            )
            (block_sparse_moe): MixtralSparseMoeBlock(
              (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
              (experts): ModuleList(
                (0-7): 8 x MixtralBlockSparseTop2MLP(
                  (w1): lora.Linear4bit(
                    (base_layer): Linear4bit(in_feat

In [None]:
model

# Start the training loop

In [36]:
trainer = Trainer(
    model=model,
    train_dataset=walking_tour_dataset_tokens,
    eval_dataset=walking_tour_dataset_tokens_test,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,

        resume_from_checkpoint=str(model_save_dir),

        num_train_epochs=20,
        learning_rate=1e-4,
        logging_steps=2,
        optim="adamw_torch",
        save_total_limit=5,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        overwrite_output_dir=True,
        output_dir=str(model_save_dir),
        load_best_model_at_end=True,
        # resume_from_checkpoint=str(model_save_dir)
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)


model.config.use_cache = False
#

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [38]:
trainer.train(resume_from_checkpoint=str(model_save_dir/ "checkpoint-74"))



Epoch,Training Loss,Validation Loss
9,1.2283,2.391732
10,0.9941,2.380301
11,0.7543,2.452753
13,0.6589,2.572316


config.json:   0%|          | 0.00/720 [00:00<?, ?B/s]



KeyboardInterrupt: 

In [None]:
model_save_dir

## Save a pretrained model
See https://huggingface.co/docs/transformers/v4.40.0/en/main_classes/model#transformers.PreTrainedModel.save_pretrained


In [None]:
trainer.save_model(model_save_dir / "best_saved_model")
# model.save_model(model_save_dir)

In [None]:
trainer.save_state()

In [None]:
import os

os.getcwd()

# Load a model from a checkpoint

In [None]:
project_path

In [None]:
model

In [None]:
del model
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
trainer.train(
    resume_from_checkpoint="mixtral-moe-lora-instruct-walking-tour-london/checkpoint-35"
)

In [None]:
trainer.save_model("mixtral-moe-lora-instruct-walking-tour-london")

# Inference

In [None]:
import json

test_dataset_chunks = json.load(open("./test_dataset.json"))

In [None]:
from datasets import Dataset

test_dataset = Dataset.from_list(test_dataset_chunks)
test_dataset[0:2]

In [None]:
# test_prompt_gen = lambda text: "<s> [INST]" + sys_msg +"\n"+ user_query["chunk"] + "[/INST]" +  user_query["navs"] + "</s>

## The prompt

In [None]:
test_prompt = generate_test_prompt(test_dataset_chunks[0])
print(test_prompt)

## Different tokenization strategies for inference


###  Tokenize using encode method

In [None]:
encoded = tokenizer.encode(test_prompt, add_special_tokens=False)
print(encoded[0:100])
print(tokenizer.decode(encoded))

In [None]:
test_prompt = """<s>  [INST]  Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV>  and </NAV>  tags as well as the reason tags: <REASON>  </REASON>  as shown in the examples that follow
This walk starts at one of the most famous landmarks in Britain: Tower
Bridge."""

In [None]:
test_prompt

In [None]:
test_tokens = tokenize(test_prompt)
print(len(test_tokens["input_ids"]))
test_input_ids = torch.tensor(test_tokens["input_ids"])
test_attn_mask = torch.tensor(test_tokens["attention_mask"])
print(len(test_input_ids), len(test_attn_mask))
unmasked_text = torch.masked_select(test_input_ids, test_attn_mask.bool())
print(unmasked_text[0:100])
# Now lets decode the text
print(".")
print(tokenizer.decode(unmasked_text, add_special_tokens=False))

In [None]:
test_tokens = tokenizer(
    [test_prompt],
    padding="max_length",
    max_length=200,
    truncation=True,
    return_tensors="pt",
    add_special_tokens=False,
).to("cuda")

In [None]:
from tqdm import tqdm


def generate_text(text, max_length=50):
    # Start with initial encoded text
    input_ids = tokenizer.encode(text, return_tensors="pt")
    # Generate text
    for _ in tqdm(range(max_length)):
        with torch.no_grad():
            outputs = model(input_ids=input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        # Append the newly generated token ID to the existing input_ids tensor
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        # Print the updated text at each step
        # print("Updated text:", tokenizer.decode(input_ids.squeeze()))
        # print(tokenizer.decode(next_token_id))
        # # Optional: Stop on specific conditions, e.g., end of sentence
        # 1f next_token_id in tokenizer.encode(l'.', '?'

    # Return the final generated text
    return tokenizer.decode(input_ids.squeeze(), skip_special_tokens=False)

In [None]:
test_prompt

In [None]:
print(test_dataset_chunks[0]["chunk"])

In [None]:
op = generate_text(test_dataset_chunks[0]["chunk"])

In [None]:
print(op)

In [None]:
next_token_ids = output.logits[:, -1, :]
torch.argmax(next_token_ids, dim=-1, keepdim=True)

In [None]:
test_prompt

In [None]:
input_with_no_padding = tokenizer(
    test_prompt + f"{NAV_START_TAG}",
    add_special_tokens=False,
)
input_with_no_padding["input_ids"]

In [None]:
input_ids = torch.tensor(input_with_no_padding["input_ids"]).unsqueeze(0)

In [None]:
# Generate text
for _ in tqdm(range(200)):
    with torch.no_grad():
        input_ids = model.generate(
            input_ids,
            max_new_tokens=1,
            do_sample=False,
            temperature=0.1,
            pad_token_id=tokenizer.pad_token_id,
            num_return_sequences=1,
        )
        token = tokenizer.decode(input_ids[:, -1].item())
        if token == NAV_END_TAG:
            input_ids = torch.cat(
                (
                    input_ids,
                    torch.tensor(
                        [
                            tokenizer.convert_tokens_to_ids(REASON_START_TAG),
                        ]
                    ).unsqueeze(0),
                ),
                dim=-1,
            )
        elif token == REASON_END_TAG:
            input_ids = torch.cat(
                (
                    input_ids,
                    torch.tensor(
                        [
                            tokenizer.convert_tokens_to_ids(NAV_START_TAG),
                        ]
                    ).unsqueeze(0),
                ),
                dim=-1,
            )
        elif token == EOS_TOKEN:
            break
        # next_token_logits = outputs.logits[:, -1, :]
        # next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
    # Append the newly generated token ID to the existing input_ids tensor

    # input_ids = torch.cat([input_ids, outputs], dim=-1)

In [None]:
# Now lets do what daniel taught me
print(tokenizer.decode(input_ids.squeeze()))

In [None]:
torch.tensor([tokenizer.convert_tokens_to_ids(NAV_END_TAG)]).unsqueeze(0)

In [None]:
torch.cat(
    (
        input_ids,
        torch.tensor(
            [
                tokenizer.convert_tokens_to_ids(NAV_END_TAG),
                tokenizer.convert_tokens_to_ids(REASON_START_TAG),
            ]
        ).unsqueeze(0),
    ),
    dim=-1,
)

In [None]:
for i in range(-1, -200, -1):
    if tokenizer.decode(input_ids[:, i].item()) == NAV_END_TAG:
        print("..")
    print(tokenizer.decode(input_ids[:, i].item()), end=" ")

In [None]:
input_ids.shape