In [5]:
!hostname

e87c14e3184b


In [1]:
import os
from pathlib import Path

print(os.getcwd())
# home = os.path.expanduser("~")
# home_path = Path(home)
# print(home_path)
project_path = Path("/workspace/mlx-week7")
# project_path = Path(os.path.expanduser("~") + "/workspace/mlx-week7")
print(project_path)
os.chdir(project_path)

/workspace/mlx-week7
/workspace/mlx-week7


In [2]:
!pwd

/workspace/mlx-week7


In [3]:
import os

os.getcwd()

import dotenv

dotenv.load_dotenv(project_path / ".env")

True

In [4]:
import torch
import transformers
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    AutoModel,
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel

In [6]:
import os

assert os.environ["HF_TOKEN"]

In [30]:
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    hf_token=os.environ["HF_TOKEN"],
    cache_dir=project_path / "cache",
)

In [31]:
tokenizer.special_tokens_map  # NOTE

{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}

In [8]:
print(tokenizer.pad_token_id)

None


## Tokenization routines for training and test data generation

In [8]:
tokenizer

LlamaTokenizerFast(name_or_path='mistralai/Mixtral-8x7B-Instruct-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [9]:
INST_START_TOKEN = "[INST]"
INST_END_TOKEN = "[/INST]"
NAV_START_TAG = "[NAV]"
NAV_END_TAG = "[/NAV]"
NL = "\n"
LANGUAGE = "English"
REASON_START_TAG = "[REASON]"
REASON_END_TAG = "[/REASON]"
NUM_INSTRUCTIONS = 9
BOS_TOKEN = "<s>"

In [32]:
CUTOFF_LEN = 2048


assert tokenizer.bos_token == BOS_TOKEN
EOS_TOKEN = "</s>"
assert tokenizer.eos_token == EOS_TOKEN

"""
There is a lot of important config here that fix_tokenizer will set up. IN the end it 
will be like:

LlamaTokenizerFast(name_or_path='mistralai/Mixtral-8x7B-Instruct-v0.1', 
    vocab_size=32000, 
    model_max_length=1000000000000000019884624838656,
    is_fast=True, 
    padding_side='left', 
    truncation_side='right', 
    special_tokens={
        'bos_token': '<s>', 
        'eos_token': '</s>', 
        'unk_token': '<unk>', 
        'pad_token': '[PAD]', 
        'additional_special_tokens': ['[INST]', '[/INST]', '[NAV]', '[/NAV]', '[REASON]', '[/REASON]']
    }, 
    clean_up_tokenization_spaces=False),  
    added_tokens_decoder={
            0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32000: AddedToken("[INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32001: AddedToken("[/INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32000: AddedToken("[INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32001: AddedToken("[/INST]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32002: AddedToken("[NAV]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32003: AddedToken("[/NAV]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32004: AddedToken("[REASON]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
            32005: AddedToken("[/REASON]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    }
"""


def fix_tokenizer(tokenizer):
    tokenizer.pad_token = "[PAD]"
    # https://huggingface.co/docs/transformers/en/pad_truncation#padding-and-truncation
    # NOTE: special tokens added by default, padding si
    tokenize = lambda prompt: tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding="max_length",
        # THIS IS IMPORTANT for calling the base tokenizer's build_inputs_with_special_tokens as
        # explained https://huggingface.co/docs/transformers/v4.40.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__
        # We already add it in the prompt preparation so we don't want to add it twice.
        add_special_tokens=False,
    )

    special_tokens_dict = {
        "additional_special_tokens": [
            # INST_START_TOKEN,
            # INST_END_TOKEN,
            NAV_START_TAG,
            NAV_END_TAG,
            REASON_START_TAG,
            REASON_END_TAG,
        ]
    }
    # tokenizer.add_special_tokens(special_tokens_dict)

    return tokenize


def generate_train_prompt(user_query):
    sys_msg = "Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using {NAV_START_TAG} and {NAV_END_TAG} tags as well as the reason tags:  {REASON_START_TAG} and {REASON_END_TAG}:as shown in the examples that follow:"
    p = (
        f"{BOS_TOKEN} {INST_START_TOKEN} "
        + sys_msg
        + "\n"
        + user_query["chunk"].strip()
        + f" {INST_END_TOKEN} "
        + user_query["navs"].strip()
        + f" {EOS_TOKEN}"
    )
    return p


def generate_test_prompt(user_query, sys_msg):
    p = (
        f"{BOS_TOKEN} {INST_START_TOKEN} "
        + sys_msg
        + "\n"
        + user_query["chunk"].strip()
        + f" {INST_END_TOKEN} "
    )
    return p


tokenize = fix_tokenizer(tokenizer)

In [33]:
tokenizer.special_tokens_map

{'bos_token': '<s>',
 'eos_token': '</s>',
 'unk_token': '<unk>',
 'pad_token': '[PAD]'}

# The prompt

In [34]:
NAV_TYPES = (
    """
T1. Navigation instructions always imply to walk one way or another with phrases.

One example of T1
```
One of the more charming sections stands round the corner from Tooley Street, in St Thomas Street and Crucifix Lane.
```
Another example is:
```
Fitting, then, that at the time the neighbouring 191 Bermondsey Street was the rectory for St Mary Magdalen church, which stands a little down the road.
```

""",
    """
T2. Navigation instructions are of different lengths and range over several
sentences. Make sure you capture the entirity of these sentence within the `Nav
Tags`.
""",
    """
T4. Often the instructions mix North, South East, West along with the other Types. These are valid and should be included in the `Nav Tags`. 

An example is:
```
From the south-western end of Shad Thames, go west along Tooley Street and cross Tower Bridge Road.
```
""",
    """
T5. Often the instructions include traveling across other streets, buildings or
junctions. These should be included in the `Nav Tags`. An example is:
```
Leave the churchyard by the gate on Abbey Street, which runs perpendicular to Bermondsey Street.
```
""",
    """

T3. *If and ONLY if* you find a valid navigation in T1,2,4,5,6 above, you may have a visual markers or features specified with the navigation instructions that can help a user but are useless in of themselves. Include these in those `Nav Tags`.

So an example of combination of T2 and T3:
```
The route to take, though, is the thoroughfare that intersects with Wheler
Street.  This is Quaker Street.  On the north side is a row of gabled former
railway warehouses dating from the late nineteenth century and now gutted and
being converted into an economy hotel; on the south side is an early
twentieth-century block of industrial dwellings, portions of the former Truman’s
Brewery, and a large interwar public-housing block named Wheler House.
```
The `gabled former railway warehouses dating from the late nineteenth century and now gutted and
being converted into an economy hotel` is the snippet that is a visual marker and should be included in the `Nav Tag`.

""",
)
START_TAG = "[NAV]"
END_TAG = "[/NAV]"
NUM_TYPES = len(NAV_TYPES)
NL = "\n"
LANGUAGE = "English"
REASON_START_TAG = "[REASON]"
REASON_END_TAG = "[/REASON]"
NUM_INSTRUCTIONS = 9
EXAMPLES = [
    f"""
    Example 1:
    ```
    From the south-western end of Shad Thames, go west along Tooley Street and cross
    Tower Bridge Road. This portion of the walk takes us through the mercantile hub
    of Victorian Bermondsey – from London Bridge station, via the warehouses of
    Bermondsey Street district, to the centre of London’s leather industry. Our
    first port of call is the doleful wastes of Potters Field Park, now a windswept
    and unlovely public space including a somewhat trampled lawn and ‘amphitheatre’.

    ```

    Output 1:
    ```
    {START_TAG}From the south-western end of Shad Thames, go west along Tooley Street and cross
    Tower Bridge Road.{END_TAG}  
    {REASON_START_TAG} The annotation was done  because it seems to have directions (N, S, E, W) and T5 as well where "cross Tower Bridge" follows the pattern of crossing a street. {REASON_END_TAG}
    {START_TAG} Our first port of call is the doleful wastes of Potters Field Park, {END_TAG} 
    {REASON_START_TAG}  The  annotation was done because it has an implicit walk direction to follow and stop at Potters field. {REASON_END_TAG}
    ```
    Note how the each NAV tag was followed by its REASON tag.
    Note that the second sentence in the Example 1: above:
        > This portion of the walk takes us through the mercantile hub
        of Victorian Bermondsey – from London Bridge station, via the 
        warehouses of Bermondsey Street district, to the centre of London’s leather industry.

    does not follow type T3 close enough and thus does not have the `Nav Tags`. 
    """,
    f"""
    Example 2:
    ```
    Walk south-west along More London Place, a geometrical and not unpleasing sliver
    of a passage, which leads back to Tooley Street and the remains of a more
    vigorous and muscular world. Tooley Street was a great mercantile thoroughfare
    in the nineteenth century, lined with ware-houses, offices and railway
    structures related to London Bridge station which sits, at high level,
    immediately to its south.
    ```
    Output 2:
    ```
    {START_TAG}Walk south-west along More London Place, a geometrical and not unpleasing sliver
    of a passage, which leads back to Tooley Street and the remains of a more
    vigorous and muscular world.{END_TAG}
    {REASON_START_TAG} This was annotated because it asks you to walk south-west towards tooly steet and 
    there is also a visual feature: `unpleasing sliver of a passage`  to guide you there{REASON_END_TAG}
    ```
    Note that sentence 2 in the Example text is not telling you how to get to Tooley Street but instead is descibing something about offices
    and railway structures and the London Bridge station; Even though it says "immediately to its south." its something to 
    be seen but not a navigation instruction.
""",
]

NEGATIVE_EXAMPLES = [
    # Examples that will have no valid `Nav Tags`, May or may not need to be included.
    f"""

    Example text 1:
    ```
    When faced with a bomb packed with explosives, the relatively thin brickwork is
    horribly vulnerable and easily penetrated; on the night of 25 October 1940,
    during the Blitz, a bomb crashed through the roof a little to the east of this
    spot, at the intersection of Tanner Street and Druid Street. Seventy-seven of
    the people sheltering inside were killed.
    ```
    """,
    f"""

    Example text 2: 
    ```
    Samuel Beazley, born 1786, was one of the most famed and productive individuals
    in London theatre, writing nearly a hundred plays and designing and enlarging a
    number of theatres, including two notable structures of the 1830s: the long-lost
    neoclassical City of London Theatre in Norton Folgate, Spitalfields, and the
    cast-iron Ionic colonnade that still embellishes the Drury Lane Theatre in
    Covent Garden.
    ```

    In BOTH the above examples the Output is empty because they match no navigation Types in the text satisfying 
    the T1-T{NUM_TYPES} types mentioned. SKIP TEXT SNIPPETS and don't add delimiters to them.
    """,
]

PROMPT = f"""
You are an expert at annotating {LANGUAGE} natural language text I give you with tags per my instructions below. 


# Context
I will give you some text from a Walking tour book referring to a specific route through the city of London. 
The task is to annotate parts of the text that describe specific navigation instructions.


# Instructions:
I want you to follow these instructions to do this:
1. Place a start and end tag {START_TAG}, {END_TAG} delimiting the navigation instruction. lets call these `Nav Tags` for future reference.

2.  Below are the Types of navigation tags that you need to use to annotate the text I provide.  The types are defined between the `----` delims for clarity:
----
{NL.join(NAV_TYPES)}
----

Lets call them `T1`, `T2`... `TN` for N navigation types - above we have N={NUM_TYPES}  -  I will use these to refer in the examples below.

3. The types of navigation instructions are not mutually exclusive. You may find multiple Types(T1-T{NUM_TYPES}) in a single navigation instruction.

4. If the text only talks about a place but not how to get there or what to do there, it is not a navigation instruction and should not be annotated.

5. There are likely several of `Nav Tags` in the text so try your best to find them all.

6. Output the EXACT text that is between the `Nav Tags` including any punctuation, capitalization, and line breaks. DO NOT SUMMARIZE IT.

7. You MUST include a reason for each `Nav Tag` in the response be one of the {NUM_TYPES} types of navigation instructions and that the text between the `Nav Tags` satisfies that criteria.

8. Make sure each NAV tag delimited text is followed by its individual REASON tag. So if you output, say, N {START_TAG}...{END_TAG} pairs you should have N {REASON_START_TAG}...{REASON_END_TAG} pairs as well.

9. When you phrase the reason text don't mention T1, T2.. T{NUM_TYPES} instead just expand those reasons inline.

Make sure 200% that you follow **all** the {NUM_INSTRUCTIONS} instructions above to the letter. 


 # Examples:
    Here are two examples of the text I will provide and the expected output and some exaplanation. 
    Note that the below examples and the outpouts are surrounded by ``` delimiters and the rest is clarifications:
===============START SECTION OF INPUT EXAMPLES AND EXPECTED OUTPUT====================
 {NL.join(EXAMPLES)}
===============END SECTION OF INPUT EXAMPLES AND EXPECTED OUTPUT====================
# Given the above annotated the following text
{{PASSAGE}}

"""

## Lets make sure we can encode a test sentence and decode it back the way we want it.

In [21]:
test_prompt = generate_test_prompt(
    {
        "chunk": "Test me",
        "navs": f"{NAV_START_TAG} test nav {NAV_END_TAG}{REASON_START_TAG} some reason {REASON_END_TAG}",
    },
    "Ssystem prompt",
)
print(test_prompt)
test_tokens = tokenize(test_prompt)
print(f"# tokens with padded ids {len(test_tokens['input_ids'])}")
input_ids = torch.tensor(test_tokens["input_ids"])
attn_mask = torch.tensor(test_tokens["attention_mask"])

unmasked_text = torch.masked_select(input_ids, attn_mask.bool())
print(unmasked_text)
# Now lets decode the text
print(".")
print(tokenizer.decode(unmasked_text, add_special_tokens=False))

<s> [INST] Ssystem prompt
Test me [/INST] 
# tokens with padded ids 2048
tensor([    1, 28705,   733, 16289, 28793,   318,  6574, 11510,    13,  1963,
          528,   733, 28748, 16289, 28793, 28705])
.
<s>  [INST] Ssystem prompt
Test me [/INST] 


In [22]:
# Now for test tokenization
test_prompt = generate_test_prompt(
    {
        "chunk": "Test me",
    },
    "System Prompt",
)
print(test_prompt)
test_tokens = tokenize(test_prompt)
print(f"# tokens with padded ids {len(test_tokens['input_ids'])}")
input_ids = torch.tensor(test_tokens["input_ids"])
attn_mask = torch.tensor(test_tokens["attention_mask"])

unmasked_text = torch.masked_select(input_ids, attn_mask.bool())
print(unmasked_text)
# Now lets decode the text
print(".")
print(tokenizer.decode(unmasked_text, add_special_tokens=False))

<s> [INST] System Prompt
Test me [/INST] 
# tokens with padded ids 2048
tensor([    1, 28705,   733, 16289, 28793,  2135, 12948,   447,    13,  1963,
          528,   733, 28748, 16289, 28793, 28705])
.
<s>  [INST] System Prompt
Test me [/INST] 


# Load the dataset for walking tours

In [22]:
import json

serialized_data_gpt_four = json.load(open("./gpt_four_annotations.json"))
serialized_data_gpt_four_test = json.load(
    open("./gpt_four_annotations_test.json")
)
serialized_data_gpt_four[0], serialized_data_gpt_four_test[0]

({'chunk': 'Our route starts at the south side of Victoria Park, at Bonner Hall Bridge – best accessed by getting the Tube to Bethnal Green, a ten-minute walk away. The first portion of this route offers clues about the nature of east London before the arrival of the Regent’s Canal, taking in several buildings that predate the waterways, and whose fates were inexorably changed by its arrival. Before moving on, take a moment to reflect on Victoria Park. In the eighteenth century, this was all open pasture, interspersed with the odd brick kiln and market garden. The one notable feature was Bonner Hall, so called after the sixteenth-century bishop of London Edmund Bonner. All this was to change in the nineteenth century. As London expanded, calls for public parks grew; in 1840, Queen Victoria was presented with a petition signed by 30,000 residents. The Crown estate purchased 218 acres in the area and, over the next few years, converted it into Victoria Park. The park shares a family rese

In [23]:
def generate_baselining_prompt(sys_msg, user_query):
    p = (
        f"{BOS_TOKEN} {INST_START_TOKEN} "
        + sys_msg.format(PASSAGE=user_query.strip())
        + f" {INST_END_TOKEN} "
    )
    return p

1. Create some passages with the prompt above as demonstration: Add INST tags to it.

Chat template: ```<s>[INST] Instruction [/INST] Model answer</s>[INST] Follow-up instruction [/INST]```

My prompt would be ```<s> [INST] SYSTEM_PROMPT(all the examples everything) and an example passage.[/INST]```

2. Lets test that.

In [24]:
# Using the prompt as a system message we want to see how well mistral can performn on a few test passages I have
baseline_example1 = generate_baselining_prompt(
    PROMPT, serialized_data_gpt_four[0]["chunk"]
)  # answer is serialized_data_gpt_four[0]['navs']
print(baseline_example1)

<s> [INST] 
You are an expert at annotating English natural language text I give you with tags per my instructions below. 


# Context
I will give you some text from a Walking tour book referring to a specific route through the city of London. 
The task is to annotate parts of the text that describe specific navigation instructions.


# Instructions:
I want you to follow these instructions to do this:
1. Place a start and end tag [NAV], [/NAV] delimiting the navigation instruction. lets call these `Nav Tags` for future reference.

2.  Below are the Types of navigation tags that you need to use to annotate the text I provide.  The types are defined between the `----` delims for clarity:
----

T1. Navigation instructions always imply to walk one way or another with phrases.

One example of T1
```
One of the more charming sections stands round the corner from Tooley Street, in St Thomas Street and Crucifix Lane.
```
Another example is:
```
Fitting, then, that at the time the neighbouring 

In [14]:
max(
    [
        len(s["chunk"].strip().split()) + len(s["navs"].strip().split())
        for s in serialized_data_gpt_four
    ]
)

673

# Create the model
[ ] Train for a bit on this dataset
[ ] Save the model.
[ ] Test on some canned held out data from bermondsey street

In [15]:
# del model
# del trainer
import gc

gc.collect()
torch.cuda.empty_cache()

In [16]:
model_save_dir = project_path / "mixtral-moe-lora-instruct-walking-tour-london"

In [17]:
model_save_dir

PosixPath('/workspace/mlx-week7/mixtral-moe-lora-instruct-walking-tour-london')

### Change this flag to reload a pretrained model checkpoint.

In [18]:
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir=str(project_path / "cache"),
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [25]:
model

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_

In [40]:
# Inference 
encoded_text = tokenizer(baseline_example1, return_tensors="pt", add_special_tokens=False)
model_inputs = encoded_text.to("cuda")
# model_inputs 



In [43]:
output = model.generate(**model_inputs,
                            max_length=4000,
                            use_cache=True,
                            early_stopping=True,
                            bos_token_id=model.config.bos_token_id,
                            eos_token_id=model.config.eos_token_id,
                            pad_token_id=model.config.eos_token_id,
                            temperature=0.1,
                            do_sample=True)



In [47]:
print(tokenizer.batch_decode(output)[0])


<s>  [INST] 
You are an expert at annotating English natural language text I give you with tags per my instructions below. 


# Context
I will give you some text from a Walking tour book referring to a specific route through the city of London. 
The task is to annotate parts of the text that describe specific navigation instructions.


# Instructions:
I want you to follow these instructions to do this:
1. Place a start and end tag [NAV], [/NAV] delimiting the navigation instruction. lets call these `Nav Tags` for future reference.

2.  Below are the Types of navigation tags that you need to use to annotate the text I provide.  The types are defined between the `----` delims for clarity:
----

T1. Navigation instructions always imply to walk one way or another with phrases.

One example of T1
```
One of the more charming sections stands round the corner from Tooley Street, in St Thomas Street and Crucifix Lane.
```
Another example is:
```
Fitting, then, that at the time the neighbouring

In [49]:
print(serialized_data_gpt_four[0]['navs'])

[NAV] Our route starts at the south side of Victoria Park, at Bonner Hall Bridge – best accessed by getting the Tube to Bethnal Green, a ten-minute walk away. [/NAV] [REASON] This was annotated because it provides specific instructions on where to start the route and how to get there, implying a direction to walk. [/REASON] [NAV] For those ready to plunge into Georgian industrial architecture, however, it is best to walk to the western end of the park and through the canal gate on to the towpath. [/NAV] [REASON] This was annotated because it provides specific instructions on where to walk next, indicating a direction to follow. [/REASON] 


## get the PEFT model

In [None]:
load_pretrained = False

In [26]:
# Prepare model for k-bit training
from peft import inject_adapter_in_model

model = prepare_model_for_kbit_training(model)
# model = prepare_model_for_kbit_training(model)
LORA_R = 8
LORA_ALPHA = 2 * LORA_R
LORA_DROPOUT = 0.1
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["w1", "w2", "w3"],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
if not load_pretrained:
    model = get_peft_model(model, config)
else:
    # Hmm a shame the in place thing did not work. The issue is that thought it modifies the model
    # it does not change its type.
    # inject_adapter_in_model(config, model, str(model_save_dir / "best_saved_model"))
    model = get_peft_model(model, config)
    inject_adapter_in_model(config, model, str(model_save_dir / "best_saved_model"))

# Start the training loop

In [33]:
trainer = Trainer(
    model=model,
    train_dataset=walking_tour_dataset_tokens,
    eval_dataset=walking_tour_dataset_tokens_test,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        # To resume from training you need to increaswe the number of epochs.
        # https://github.com/huggingface/transformers/blob/8c12690cecbb97e187861e386f7a0ac790e4236c/src/transformers/trainer.py#L2069
        num_train_epochs=10,
        learning_rate=1e-4,
        logging_steps=2,
        optim="adamw_torch",
        save_total_limit=5,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        overwrite_output_dir=True,
        output_dir=str(model_save_dir),
        load_best_model_at_end=True,
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)


model.config.use_cache = False
# trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [29]:
model_save_dir

PosixPath('/workspace/mlx-week7/mixtral-moe-lora-instruct-walking-tour-london')

## Save a pretrained model
See https://huggingface.co/docs/transformers/v4.40.0/en/main_classes/model#transformers.PreTrainedModel.save_pretrained


In [30]:
trainer.save_model(model_save_dir / "best_saved_model")
# model.save_model(model_save_dir)

config.json:   0%|          | 0.00/720 [00:00<?, ?B/s]



In [31]:
trainer.save_state()

In [87]:
import os

os.getcwd()

'/root/mlx_week_7'

# Load a model from a checkpoint

In [20]:
project_path

PosixPath('/workspace/mlx-week7')

In [28]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MixtralForCausalLM(
      (model): MixtralModel(
        (embed_tokens): Embedding(32006, 4096)
        (layers): ModuleList(
          (0-31): 32 x MixtralDecoderLayer(
            (self_attn): MixtralSdpaAttention(
              (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): MixtralRotaryEmbedding()
            )
            (block_sparse_moe): MixtralSparseMoeBlock(
              (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
              (experts): ModuleList(
                (0-7): 8 x MixtralBlockSparseTop2MLP(
                  (w1): lora.Linear4bit(
                    (base_layer): Linear4bit(in_feat

In [28]:
model_save_dir

PosixPath('/workspace/mlx-week7/mixtral-moe-lora-instruct-walking-tour-london')

In [35]:
trainer.train(resume_from_checkpoint=str(model_save_dir / "checkpoint-48"))



Epoch,Training Loss,Validation Loss
6,1.3278,2.30953
7,1.194,2.374886
8,1.186,2.352052
9,1.2756,2.360156




TrainOutput(global_step=80, training_loss=0.5246684074401855, metrics={'train_runtime': 2153.1202, 'train_samples_per_second': 0.163, 'train_steps_per_second': 0.037, 'total_flos': 1.8401815833870336e+17, 'train_loss': 0.5246684074401855, 'epoch': 9.657142857142857})

In [None]:
trainer.train(resume_from_checkpoint=str(model_save_dir / "checkpoint-48"))

In [89]:
trainer.save_model("mixtral-moe-lora-instruct-walking-tour-london")


Cannot access gated repo for url https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/resolve/main/config.json.
Repo model mistralai/Mixtral-8x7B-Instruct-v0.1 is gated. You must be authenticated to access it. - silently ignoring the lookup for the file config.json in mistralai/Mixtral-8x7B-Instruct-v0.1.


# Inference

In [53]:
import json

test_dataset_chunks = json.load(open("./test_dataset.json"))

In [54]:
from datasets import Dataset

test_dataset = Dataset.from_list(test_dataset_chunks)
test_dataset[0:2]

{'chunk': ['This walk starts at one of the most famous landmarks in Britain: Tower\nBridge. This bridge is a rare thing – a much-admired structure of great\ncharacter, even though it was forged through endless committee meetings,\ncompromise, contradiction, uncertainty and a fair degree of absurdity. It has a\nstory worth considering at some length as you ascend from Tower Hill tube and\nstroll south across the bridge. The tale starts in the mid 1870s, when it was\nargued that a river crossing to the east of London Bridge would both improve\ncommunications in the City and, by providing direct access to industry on the\nsouth bank, increase the commercial potential of the docks on the north. There\nwas, for practical reasons, only one site possible – a strip of land just east\nof the Tower of London – and a competition to find a suitable design was held.\nThis was supervised by the City of London Corporation, which grabbed the\ninitiative to build the new bridge. There was one key stipu

In [14]:
# test_prompt_gen = lambda text: "<s> [INST]" + sys_msg +"\n"+ user_query["chunk"] + "[/INST]" +  user_query["navs"] + "</s>

## The prompt

In [209]:
test_prompt = generate_test_prompt(test_dataset_chunks[0])
print(test_prompt)

<s> [INST] Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV> and </NAV> tags as well as the reason tags: <REASON> </REASON>:
This walk starts at one of the most famous landmarks in Britain: Tower
Bridge. This bridge is a rare thing – a much-admired structure of great
character, even though it was forged through endless committee meetings,
compromise, contradiction, uncertainty and a fair degree of absurdity. It has a
story worth considering at some length as you ascend from Tower Hill tube and
stroll south across the bridge. The tale starts in the mid 1870s, when it was
argued that a river crossing to the east of London Bridge would both improve
communications in the City and, by providing direct access to industry on the
south bank, increase the commercial potential of the docks on the north. There
was, for practical reasons, only one site possible 

## Different tokenization strategies for inference


###  Tokenize using encode method

In [132]:
encoded = tokenizer.encode(test_prompt, add_special_tokens=False)
print(encoded[0:100])
print(tokenizer.decode(encoded))

[1, 259, 32000, 28705, 12628, 272, 2245, 477, 264, 9863, 288, 3884, 1820, 18063, 264, 2948, 7103, 1059, 272, 2990, 302, 4222, 28725, 9131, 5099, 302, 272, 2245, 369, 6685, 2948, 18132, 11382, 297, 378, 1413, 259, 32002, 28705, 304, 28705, 32003, 28705, 12944, 390, 1162, 390, 272, 2611, 12944, 28747, 28705, 32004, 259, 32005, 28705, 390, 4894, 297, 272, 9254, 369, 1372, 13, 3260, 2338, 8383, 438, 624, 302, 272, 1080, 8376, 2533, 17181, 297, 10174, 28747, 19895, 13, 28760, 9163, 28723, 851, 9850, 349, 264, 9964, 1970, 764, 264, 1188, 28733, 316, 28719, 1360, 4693, 302, 1598, 13]
<s>  [INST]  Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV>  and </NAV>  tags as well as the reason tags: <REASON>  </REASON>  as shown in the examples that follow
This walk starts at one of the most famous landmarks in Britain: Tower
Bridge. This bridge is a rare thing – a 

In [92]:
test_prompt = """<s>  [INST]  Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV>  and </NAV>  tags as well as the reason tags: <REASON>  </REASON>  as shown in the examples that follow
This walk starts at one of the most famous landmarks in Britain: Tower
Bridge."""

In [146]:
test_prompt

'<s> [INST] Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV> and </NAV> tags as well as the reason tags: <REASON> </REASON> as shown in the examples that follow\nThis walk starts at one of the most famous landmarks in Britain: Tower\nBridge. This bridge is a rare thing – a much-admired structure of great\ncharacter, even though it was forged through endless committee meetings,\ncompromise, contradiction, uncertainty and a fair degree of absurdity. It has a\nstory worth considering at some length as you ascend from Tower Hill tube and\nstroll south across the bridge. The tale starts in the mid 1870s, when it was\nargued that a river crossing to the east of London Bridge would both improve\ncommunications in the City and, by providing direct access to industry on the\nsouth bank, increase the commercial potential of the docks on the north. There\nwas,

In [147]:
test_tokens = tokenize(test_prompt)
print(len(test_tokens["input_ids"]))
test_input_ids = torch.tensor(test_tokens["input_ids"])
test_attn_mask = torch.tensor(test_tokens["attention_mask"])
print(len(test_input_ids), len(test_attn_mask))
unmasked_text = torch.masked_select(test_input_ids, test_attn_mask.bool())
print(unmasked_text[0:100])
# Now lets decode the text
print(".")
print(tokenizer.decode(unmasked_text, add_special_tokens=False))

2048
2048 2048
tensor([    1,   259, 32000, 28705, 12628,   272,  2245,   477,   264,  9863,
          288,  3884,  1820, 18063,   264,  2948,  7103,  1059,   272,  2990,
          302,  4222, 28725,  9131,  5099,   302,   272,  2245,   369,  6685,
         2948, 18132, 11382,   297,   378,  1413,   259, 32002, 28705,   304,
        28705, 32003, 28705, 12944,   390,  1162,   390,   272,  2611, 12944,
        28747, 28705, 32004,   259, 32005, 28705,   390,  4894,   297,   272,
         9254,   369,  1372,    13,  3260,  2338,  8383,   438,   624,   302,
          272,  1080,  8376,  2533, 17181,   297, 10174, 28747, 19895,    13,
        28760,  9163, 28723,   851,  9850,   349,   264,  9964,  1970,   764,
          264,  1188, 28733,   316, 28719,  1360,  4693,   302,  1598,    13])
.
<s>  [INST]  Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV>  

In [149]:
test_tokens = tokenizer(
    [test_prompt],
    padding="max_length",
    max_length=200,
    truncation=True,
    return_tensors="pt",
    add_special_tokens=False,
).to("cuda")

In [161]:
from tqdm import tqdm


def generate_text(text, max_length=50):
    # Start with initial encoded text
    input_ids = tokenizer.encode(text, return_tensors="pt")
    # Generate text
    for _ in tqdm(range(max_length)):
        with torch.no_grad():
            outputs = model(input_ids=input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        # Append the newly generated token ID to the existing input_ids tensor
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        # Print the updated text at each step
        # print("Updated text:", tokenizer.decode(input_ids.squeeze()))
        # print(tokenizer.decode(next_token_id))
        # # Optional: Stop on specific conditions, e.g., end of sentence
        # 1f next_token_id in tokenizer.encode(l'.', '?'

    # Return the final generated text
    return tokenizer.decode(input_ids.squeeze(), skip_special_tokens=False)

In [151]:
test_prompt

'<s> [INST] Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV> and </NAV> tags as well as the reason tags: <REASON> </REASON> as shown in the examples that follow\nThis walk starts at one of the most famous landmarks in Britain: Tower\nBridge. This bridge is a rare thing – a much-admired structure of great\ncharacter, even though it was forged through endless committee meetings,\ncompromise, contradiction, uncertainty and a fair degree of absurdity. It has a\nstory worth considering at some length as you ascend from Tower Hill tube and\nstroll south across the bridge. The tale starts in the mid 1870s, when it was\nargued that a river crossing to the east of London Bridge would both improve\ncommunications in the City and, by providing direct access to industry on the\nsouth bank, increase the commercial potential of the docks on the north. There\nwas,

In [None]:
print(test_dataset_chunks[0]["chunk"])

no more than a strong and utilitarian steel-made structural frame – an honest
expression of the means of construction and of its function. The basic design
strategy is still apparent: the twin towers rise on piers that are connected to
the land by carriageways, which are supported from above by wrought-iron members
that curve down to lower towers on the banks on either side. Between the tall
main towers are the drawbridge sections of carriageway. Jones’s original
chain-drawn mechanism, ultimately rejected as too slow, was replaced by a
‘bascule’ or see-saw system, around which the two centre towers were built. In
the bascule system, each section of road is counterbalanced by a mighty weight,
which descends into a huge chamber when the bridge is opened. At first this was
all powered by a cutting-edge coal-fuelled hydraulic engine, which has long
since replaced by electricity. But the whole mechanism means that, like a ship,
Tower Bridge has always had a crew and a bridgemaster, an engin

: 

In [162]:
op = generate_text(test_dataset_chunks[0]["chunk"])

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [01:22<00:00,  1.64s/it]


In [164]:
print(op)

<s> This walk starts at one of the most famous landmarks in Britain: Tower
Bridge. This bridge is a rare thing – a much-admired structure of great
character, even though it was forged through endless committee meetings,
compromise, contradiction, uncertainty and a fair degree of absurdity. It has a
story worth considering at some length as you ascend from Tower Hill tube and
stroll south across the bridge. The tale starts in the mid 1870s, when it was
argued that a river crossing to the east of London Bridge would both improve
communications in the City and, by providing direct access to industry on the
south bank, increase the commercial potential of the docks on the north. There
was, for practical reasons, only one site possible – a strip of land just east
of the Tower of London – and a competition to find a suitable design was held.
This was supervised by the City of London Corporation, which grabbed the
initiative to build the new bridge. There was one key stipulation. It was not t

In [130]:
next_token_ids = output.logits[:, -1, :]
torch.argmax(next_token_ids, dim=-1, keepdim=True)

tensor([[415]])

In [216]:
test_prompt

'<s> [INST] Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV> and </NAV> tags as well as the reason tags: <REASON> </REASON>:\nThis walk starts at one of the most famous landmarks in Britain: Tower\nBridge. This bridge is a rare thing – a much-admired structure of great\ncharacter, even though it was forged through endless committee meetings,\ncompromise, contradiction, uncertainty and a fair degree of absurdity. It has a\nstory worth considering at some length as you ascend from Tower Hill tube and\nstroll south across the bridge. The tale starts in the mid 1870s, when it was\nargued that a river crossing to the east of London Bridge would both improve\ncommunications in the City and, by providing direct access to industry on the\nsouth bank, increase the commercial potential of the docks on the north. There\nwas, for practical reasons, only one sit

device(type='cuda', index=0)

In [219]:
input_with_no_padding = tokenizer(
    test_prompt + f"{NAV_START_TAG}",
    add_special_tokens=False,
)
input_with_no_padding["input_ids"]

[1,
 259,
 32000,
 28705,
 12628,
 272,
 2245,
 477,
 264,
 9863,
 288,
 3884,
 1820,
 18063,
 264,
 2948,
 7103,
 1059,
 272,
 2990,
 302,
 4222,
 28725,
 9131,
 5099,
 302,
 272,
 2245,
 369,
 6685,
 2948,
 18132,
 11382,
 297,
 378,
 1413,
 259,
 32002,
 28705,
 304,
 28705,
 32003,
 28705,
 12944,
 390,
 1162,
 390,
 272,
 2611,
 12944,
 28747,
 28705,
 32004,
 259,
 32005,
 714,
 13,
 3260,
 2338,
 8383,
 438,
 624,
 302,
 272,
 1080,
 8376,
 2533,
 17181,
 297,
 10174,
 28747,
 19895,
 13,
 28760,
 9163,
 28723,
 851,
 9850,
 349,
 264,
 9964,
 1970,
 764,
 264,
 1188,
 28733,
 316,
 28719,
 1360,
 4693,
 302,
 1598,
 13,
 19933,
 28725,
 1019,
 2070,
 378,
 403,
 354,
 2560,
 1059,
 18284,
 13414,
 13251,
 28725,
 13,
 3086,
 6187,
 28725,
 15235,
 3033,
 28725,
 18110,
 304,
 264,
 4968,
 6153,
 302,
 22976,
 472,
 28723,
 661,
 659,
 264,
 13,
 18387,
 4407,
 9868,
 438,
 741,
 3575,
 390,
 368,
 13294,
 416,
 477,
 19895,
 7442,
 17735,
 304,
 13,
 303,
 1584,
 6287,
 2673,
 

torch.Size([1, 17792])

In [247]:
input_ids = torch.tensor(input_with_no_padding["input_ids"]).unsqueeze(0)

In [248]:
# Generate text
for _ in tqdm(range(200)):
    with torch.no_grad():
        input_ids = model.generate(
            input_ids,
            max_new_tokens=1,
            do_sample=False,
            temperature=0.1,
            pad_token_id=tokenizer.pad_token_id,
            num_return_sequences=1,
        )
        token = tokenizer.decode(input_ids[:, -1].item())
        if token == NAV_END_TAG:
            input_ids = torch.cat(
                (
                    input_ids,
                    torch.tensor(
                        [
                            tokenizer.convert_tokens_to_ids(REASON_START_TAG),
                        ]
                    ).unsqueeze(0),
                ),
                dim=-1,
            )
        elif token == REASON_END_TAG:
            input_ids = torch.cat(
                (
                    input_ids,
                    torch.tensor(
                        [
                            tokenizer.convert_tokens_to_ids(NAV_START_TAG),
                        ]
                    ).unsqueeze(0),
                ),
                dim=-1,
            )
        elif token == EOS_TOKEN:
            break
        # next_token_logits = outputs.logits[:, -1, :]
        # next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
    # Append the newly generated token ID to the existing input_ids tensor

    # input_ids = torch.cat([input_ids, outputs], dim=-1)

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [06:08<00:00,  1.84s/it]


In [250]:
# Now lets do what daniel taught me
print(tokenizer.decode(input_ids.squeeze()))

<s>  [INST]  Given the text from a Walking tour book describing a specific route through the city of London, extract parts of the text that describe specific navigation instructions in it using  <NAV>  and </NAV>  tags as well as the reason tags: <REASON>  </REASON> :
This walk starts at one of the most famous landmarks in Britain: Tower
Bridge. This bridge is a rare thing – a much-admired structure of great
character, even though it was forged through endless committee meetings,
compromise, contradiction, uncertainty and a fair degree of absurdity. It has a
story worth considering at some length as you ascend from Tower Hill tube and
stroll south across the bridge. The tale starts in the mid 1870s, when it was
argued that a river crossing to the east of London Bridge would both improve
communications in the City and, by providing direct access to industry on the
south bank, increase the commercial potential of the docks on the north. There
was, for practical reasons, only one site pos

In [242]:
torch.tensor([tokenizer.convert_tokens_to_ids(NAV_END_TAG)]).unsqueeze(0)

tensor([[32003]])

In [244]:
torch.cat(
    (
        input_ids,
        torch.tensor(
            [
                tokenizer.convert_tokens_to_ids(NAV_END_TAG),
                tokenizer.convert_tokens_to_ids(REASON_START_TAG),
            ]
        ).unsqueeze(0),
    ),
    dim=-1,
)

tensor([[    1,   259, 32000, 28705, 12628,   272,  2245,   477,   264,  9863,
           288,  3884,  1820, 18063,   264,  2948,  7103,  1059,   272,  2990,
           302,  4222, 28725,  9131,  5099,   302,   272,  2245,   369,  6685,
          2948, 18132, 11382,   297,   378,  1413,   259, 32002, 28705,   304,
         28705, 32003, 28705, 12944,   390,  1162,   390,   272,  2611, 12944,
         28747, 28705, 32004,   259, 32005,   714,    13,  3260,  2338,  8383,
           438,   624,   302,   272,  1080,  8376,  2533, 17181,   297, 10174,
         28747, 19895,    13, 28760,  9163, 28723,   851,  9850,   349,   264,
          9964,  1970,   764,   264,  1188, 28733,   316, 28719,  1360,  4693,
           302,  1598,    13, 19933, 28725,  1019,  2070,   378,   403,   354,
          2560,  1059, 18284, 13414, 13251, 28725,    13,  3086,  6187, 28725,
         15235,  3033, 28725, 18110,   304,   264,  4968,  6153,   302, 22976,
           472, 28723,   661,   659,   264,    13, 1

In [232]:
for i in range(-1, -200, -1):
    if tokenizer.decode(input_ids[:, i].item()) == NAV_END_TAG:
        print("..")
    print(tokenizer.decode(input_ids[:, i].item()), end=" ")

by , and City the in ications commun 
 improve both would Bridge London of east the to crossing river a that ued arg 
 was it when , s 0 7 8 1  mid the in starts tale The . bridge the across south roll st 
 and tube Hill Tower from end asc you as length some at considering worth story 
 a has It . ity absurd of degree fair a and uncertainty , iction contrad , romise comp 
 , meetings committee endless through ged for was it though even , character 
 great of structure ired m ad - much a – thing rare a is bridge This . ridge B 
 Tower : Britain in marks land famous most the of one at starts walk This 
   ..
</NAV> 
 } % block end % { 
 > p </ 
 > section </ 
 > p </ 
 . bridge the across south roll st and Bridge Tower at Start 
 > p < 
 > 3 h :</ ruction Inst Navigation > 3 h < 
 > section < 
 > p < 
 } % action extr block % 

In [201]:
input_ids.shape

torch.Size([1, 565])