# Fine-tune Mixtral 8x7B on Amazon SageMaker

In this sagemaker example, we are going to learn how to fine-tune [Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) using [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314). Compared to their previous model, [Mistral-7b](https://huggingface.co/mistralai/Mistral-7B-v0.1), Mixtral is based on a Mixture-of-Experts (MoE) transformers architecture, which uses multiple networks with a gating layer to allocate inputs to specialized experts — this basically allows for training different tasks separately.

The Mistral 8x7B model comprises 8 experts and 7 billion parameters each.

QLoRA is an efficient finetuning technique that quantizes a pretrained language model to 4 bits and attaches small “Low-Rank Adapters” which are fine-tuned. This enables fine-tuning of models with up to 65 billion parameters on a single GPU; despite its efficiency, QLoRA matches the performance of full-precision fine-tuning and achieves state-of-the-art results on language tasks.

In our example, we are going to leverage Hugging Face [Transformers](https://huggingface.co/docs/transformers/index), [Accelerate](https://huggingface.co/docs/accelerate/index), and [PEFT](https://github.com/huggingface/peft). 

In Detail you will learn how to:
1. Setup Development Environment
2. Load and prepare the dataset
3. Fine-Tune Mixtral-8x7B with QLoRA on Amazon SageMaker
4. Deploy Fine-tuned LLM on Amazon SageMaker

### Quick intro: PEFT or Parameter Efficient Fine-tuning

[PEFT](https://github.com/huggingface/peft), or Parameter Efficient Fine-tuning, is a new open-source library from Hugging Face to enable efficient adaptation of pre-trained language models (PLMs) to various downstream applications without fine-tuning all the model's parameters. PEFT currently includes techniques for:

- (Q)LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf)
- Prefix Tuning: [P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)
- P-Tuning: [GPT Understands, Too](https://arxiv.org/pdf/2103.10385.pdf)
- Prompt Tuning: [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/pdf/2104.08691.pdf)
- IA3: [Infused Adapter by Inhibiting and Amplifying Inner Activations](https://arxiv.org/abs/2205.05638)



## 1. Setup Development Environment

In [None]:
!pip install transformers sagemaker s3fs "fsspec==2023.9.0" "datasets[s3]==2.13.0"  --upgrade

If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find here more about it.

In [1]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

#gets role
role = sagemaker.get_execution_role()

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml
sagemaker role arn: arn:aws:iam::631450739534:role/service-role/AmazonSageMaker-ExecutionRole-20231204T140306
sagemaker bucket: sagemaker-us-east-1-631450739534
sagemaker session region: us-east-1


## 2. Load and prepare the dataset

we will use the [dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k) an open source dataset of instruction-following records generated by thousands of Databricks employees in several of the behavioral categories outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155), including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.


```python
{
  "instruction": "What is world of warcraft",
  "context": "",
  "response": "World of warcraft is a massive online multi player role playing game. It was released in 2004 by bizarre entertainment"
}
```

To load the `samsum` dataset, we use the `load_dataset()` method from the 🤗 Datasets library.

In [2]:
from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])

Downloading readme:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

Downloading and preparing dataset json/databricks--databricks-dolly-15k to /home/sagemaker-user/.cache/huggingface/datasets/databricks___json/databricks--databricks-dolly-15k-7427aa6e57c34282/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/sagemaker-user/.cache/huggingface/datasets/databricks___json/databricks--databricks-dolly-15k-7427aa6e57c34282/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.
dataset size: 15011
{'instruction': 'Suggest some sports I can do solo?', 'context': '', 'response': 'You can run, swim, cycle, dance - all by yourself.', 'category': 'brainstorming'}


To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a formatting_function that takes a sample and returns a string with our format instruction.

In [3]:
def create_prompt(sample):
    bos_token = "<s>"
    system_message = "[INST] Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
    instruction = sample["instruction"]
    context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
    question = "\n".join([i for i in [instruction, context] if i is not None])
    answer = sample['response']
    eos_token = "</s>"

    full_prompt = ""
    full_prompt += bos_token
    full_prompt += system_message
    full_prompt += question
    full_prompt += " [/INST]\n\n"
    full_prompt += answer
    full_prompt += eos_token

    return full_prompt

In [4]:
# lets test our formatting function on a random example.
create_prompt(dataset[randrange(len(dataset))])

'<s>[INST] Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\nWhat is a Pythagorean triple? [/INST]\n\nIn mathematics, a Pythagorean triple consists of three positive integers, a, b, c, such that a² + b² = c². These integers can form the sides of a right triangle, with c as the hypotenuse. For example, (3, 4, 5) is a Pythagorean triple because 3² + 4² = 5².</s>'

In [5]:
from transformers import AutoTokenizer

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
mixtral_tokenizer = AutoTokenizer.from_pretrained(model_id)
mixtral_tokenizer.pad_token = mixtral_tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

We define some helper functions to pack our samples into sequences of a given length and then tokenize them.

In [6]:
from random import randint
from itertools import chain
from functools import partial


# template dataset to add prompt to each sample
def template_dataset(sample):
    sample["text"] = f"{create_prompt(sample)}"
    return sample


# apply prompt template per sample
hf_df = dataset.map(template_dataset, remove_columns=list(dataset.features))
# print random sample
print(hf_df[randint(0, len(dataset))]["text"])

# empty list to save remainder from batches to use in next batch
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

def chunk(sample, chunk_length=2048):
    # define global remainder variable to save remainder from batches to use in next batch
    global remainder
    # Concatenate all texts and add remainder from previous batch
    concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
    concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
    # get total number of tokens for batch
    batch_total_length = len(concatenated_examples[list(sample.keys())[0]])

    # get max number of chunks for batch
    if batch_total_length >= chunk_length:
        batch_chunk_length = (batch_total_length // chunk_length) * chunk_length

    # Split by chunks of max_len.
    result = {
        k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
        for k, t in concatenated_examples.items()
    }
    # add remainder to global variable for next batch
    remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
    # prepare labels
    result["labels"] = result["input_ids"].copy()
    return result

# tokenize and chunk dataset
lm_dataset = hf_df.map(
    lambda sample: mixtral_tokenizer(sample["text"]), batched=True, remove_columns=list(hf_df.features)
).map(
    partial(chunk, chunk_length=2048),
    batched=True,
)



Map:   0%|          | 0/15011 [00:00<?, ? examples/s]

<s>[INST] Below is an instruction that describes a task. Write a response that appropriately completes the request.

Extract the names of all of the albums that Taylor Swift has released. Separate them with a comma.
### Context
Swift signed a record deal with Big Machine Records in 2005 and released her eponymous debut album the following year. With 157 weeks on the Billboard 200 by December 2009, the album was the longest-charting album of the 2000s decade. Swift's second studio album, Fearless (2008), topped the Billboard 200 for 11 weeks and was the only album from the 2000s decade to spend one year in the top 10. The album was certified Diamond by the RIAA. It also topped charts in Australia and Canada, and has sold 12 million copies worldwide. Her third studio album, the self-written Speak Now (2010), spent six weeks atop the Billboard 200 and topped charts in Australia, Canada, and New Zealand.

Her fourth studio album, Red (2012), was her first number-one album in the United Kin

Map:   0%|          | 0/15011 [00:00<?, ? examples/s]

Map:   0%|          | 0/15011 [00:00<?, ? examples/s]

In [9]:

# save train_dataset to s3
training_input_path = f's3://{sagemaker_session_bucket}/train'
local_path = './train_data'
lm_dataset.save_to_disk(local_path)


Saving the dataset (0/1 shards):   0%|          | 0/1695 [00:00<?, ? examples/s]

In [11]:
!aws s3 sync ./train_data $training_input_path
print("uploaded data to:")
print(f"training dataset to: {training_input_path}")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


upload: train_data/state.json to s3://sagemaker-us-east-1-631450739534/train/state.json
upload: train_data/dataset_info.json to s3://sagemaker-us-east-1-631450739534/train/dataset_info.json
upload: train_data/data-00000-of-00001.arrow to s3://sagemaker-us-east-1-631450739534/train/data-00000-of-00001.arrow
uploaded data to:
training dataset to: s3://sagemaker-us-east-1-631450739534/train


## 3. Fine-Tune Mixtral 8x7B with QLoRA on Amazon SageMaker

We are going to use the recently introduced method in the paper "[QLoRA: Quantization-aware Low-Rank Adapter Tuning for Language Generation](https://arxiv.org/abs/2106.09685)" by Tim Dettmers et al. QLoRA is a new technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance. The TL;DR; of how QLoRA works is: 

* Quantize the pretrained model to 4 bits and freezing it.
* Attach small, trainable adapter layers. (LoRA)
* Finetune only the adapter layers, while using the frozen quantized model for context.

We prepared a [run_clm.py](./scripts/run_clm.py), which implements QLora using PEFT to train our model. The script also merges the LoRA weights into the model weights after training. That way you can use the model as a normal model without any additional code. The model will be temporally offloaded to disk, if it is too large to fit into memory.

In order to create a sagemaker training job we need an `HuggingFace` Estimator. The Estimator handles end-to-end Amazon SageMaker training and deployment tasks. The Estimator manages the infrastructure use. 
SagMaker takes care of starting and managing all the required ec2 instances for us, provides the correct huggingface container, uploads the provided scripts and downloads the data from our S3 bucket into the container at `/opt/ml/input/data`. Then, it starts the training job by running.


In [49]:
import time
from sagemaker.huggingface import HuggingFace
from huggingface_hub import HfFolder

# define Training Job Name
job_name = f'huggingface-qlora-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'

# hyperparameters, which are passed into the training job
hyperparameters ={
  'model_id': model_id,                             # pre-trained model
  'dataset_path': '/opt/ml/input/data/training',    # path where sagemaker will save training dataset
  'epochs': 1,                                      # number of training epochs
  'per_device_train_batch_size': 1,                 # batch size for training
  'lr': 2e-4,                                       # learning rate used during training
  'hf_token': HfFolder.get_token(),                 # huggingface token to access llama 2
  'merge_weights': True,                            # wether to merge LoRA into the model (needs more memory)
}

# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point          = 'run_clm.py',      # train script
    source_dir           = 'scripts',         # directory which includes all the files needed for training
    instance_type        = 'ml.g5.12xlarge',   # instances type used for the training job
    instance_count       = 1,                 # the number of instances used for training
    base_job_name        = job_name,          # the name of the training job
    role                 = role,              # Iam role used in training job to access AWS ressources, e.g. S3
    volume_size          = 300,               # the size of the EBS volume in GB
    transformers_version = '4.28',            # the transformers version used in the training job
    pytorch_version      = '2.0',             # the pytorch_version version used in the training job
    py_version           = 'py310',           # the python version used in the training job
    hyperparameters      =  hyperparameters,  # the hyperparameters passed to the training job
    environment          = { "HUGGINGFACE_HUB_CACHE": "/tmp/.cache" }, # set env variable to cache models in /tmp
    disable_output_compression=True,
)

We can now start our training job, with the `.fit()` method passing our S3 path to the training script.

In [None]:
# define a data input dictonary with our uploaded s3 uris
data = {'training': training_input_path}

# starting the train job with our uploaded datasets as input
huggingface_estimator.fit(data, wait=True)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: huggingface-qlora-2024-02-07-03-35-57-2024-02-07-03-35-59-202


2024-02-07 03:36:00 Starting - Starting the training job...
2024-02-07 03:36:19 Pending - Training job waiting for capacity...
2024-02-07 03:36:43 Pending - Preparing the instances for training.........
2024-02-07 03:38:09 Downloading - Downloading input data...........................................................................................................................................................................................................

In [51]:
response = sess.describe_training_job("huggingface-qlora-2024-02-07-03-35-57-2024-02-07-03-35-59-202")
if response["TrainingJobStatus"] not in ["Completed", "Faile"]:
    print(response["TrainingJobStatus"])
    time.sleep(60)
else:
    print(f'training job finished with status {response["TrainingJobStatus"]}')
    

training job finished with status Completed


In our example for Mixtral 8x7B, the SageMaker training job took `18946` seconds, which is about `5.2` hours. The ml.g5.12xlarge instance we used costs `$7.09` per hour for on-demand usage. As a result, the total cost for training our fine-tuned Mistral 8x7B model was only ~ `$38`.

### Deploy fine-tuned model to SageMaker endpoints
You can deploy your fine-tuned Mixtral 8x7b model to a SageMaker endpoint and use it for inference.

In [52]:
model_url = response['ModelArtifacts']['S3ModelArtifacts']
model_url

's3://sagemaker-us-east-1-631450739534/huggingface-qlora-2024-02-07-03-35-57-2024-02-07-03-35-59-202/output/model'

In [60]:
import jinja2
from sagemaker import image_uris
import time
import json
from pathlib import Path
from sagemaker import Model, image_uris, serializers, deserializers

jinja_env = jinja2.Environment()
pretrained_model_location = model_url

bucket = sess.default_bucket()  # bucket to house artifacts
s3_code_prefix = "hf-large-model-djl/mixtral8-7b-finetuned"

In [54]:
!mkdir -p mymodel

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [64]:
%%writefile ./mymodel/serving.properties
engine=Python
option.model_id={{s3url}}
option.tensor_parallel_degree=8
option.max_rolling_batch_size=16
option.rolling_batch=vllm
# option.max_model_len=25456
option.dtype=fp16

Overwriting ./mymodel/serving.properties


In [65]:
# we plug in the appropriate model location into our `serving.properties` file based on the region in which this notebook is running
template = jinja_env.from_string(Path("mymodel/serving.properties").open().read())
Path("mymodel/serving.properties").open("w").write(
    template.render(s3url=pretrained_model_location)
)
!pygmentize mymodel/serving.properties | cat -n

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


     1	[36mengine[39;49;00m=[33mPython[39;49;00m[37m[39;49;00m
     2	[36moption.model_id[39;49;00m=[33ms3://sagemaker-us-east-1-631450739534/huggingface-qlora-2024-02-07-03-35-57-2024-02-07-03-35-59-202/output/model[39;49;00m[37m[39;49;00m
     3	[36moption.tensor_parallel_degree[39;49;00m=[33m8[39;49;00m[37m[39;49;00m
     4	[36moption.max_rolling_batch_size[39;49;00m=[33m16[39;49;00m[37m[39;49;00m
     5	[36moption.rolling_batch[39;49;00m=[33mvllm[39;49;00m[37m[39;49;00m
     6	[37m# option.max_model_len=25456[39;49;00m[37m[39;49;00m
     7	[36moption.dtype[39;49;00m=[33mfp16[39;49;00m[37m[39;49;00m


In [66]:
%%sh
rm -f mymodel.tar.gz
rm -rf mymodel/.ipynb_checkpoints
tar czvf mymodel.tar.gz -C mymodel .

./
./serving.properties


### Step 3: Start building SageMaker endpoint
#### Getting the container image URI

See available Large Model Inference DLC's [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)

In [58]:
image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.26.0"
    )

INFO:sagemaker.image_uris:Ignoring unnecessary instance type: None.


#### Upload artifact on S3 and create SageMaker model

In [67]:
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-east-1-631450739534/hf-large-model-djl/mixtral8-7b-finetuned/mymodel.tar.gz


#### Create SageMaker endpoint with a specified instance type

In [68]:
instance_type = "ml.g5.48xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model-mixtral-8x7b-finetuned-48x")
print(f"endpoint_name: {endpoint_name}")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             container_startup_health_check_timeout=1800
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
)

INFO:sagemaker:Creating model with name: djl-inference-2024-02-07-10-55-18-338


endpoint_name: lmi-model-mixtral-8x7b-finetuned-48x-2024-02-07-10-55-18-300


INFO:sagemaker:Creating endpoint-config with name lmi-model-mixtral-8x7b-finetuned-48x-2024-02-07-10-55-18-300
INFO:sagemaker:Creating endpoint with name lmi-model-mixtral-8x7b-finetuned-48x-2024-02-07-10-55-18-300


----------------!

### Step 4: Run inference
Comparing the results 
see below a few examples

In [69]:
predictor.predict(
    {"inputs": "The future of Artificial Intelligence is", "parameters": {"max_new_tokens":128, "do_sample":True}}
)

b'{"generated_text": " already here \xe2\x80\x94 automation is increasing and human productivity is growing at an exponential rate. There\'s never been a more important time to invest in business automation than now... and to invest in the skills that will allow you to take advantage of these opportunities. In this course, learn the fundamentals of Artificial Intelligence from Peter Norvig, Google\'s Director of Research, and Professor Sebastian Thrun. Dr. Norvig is well-known within the field of AI and has a fantastic gift for teaching. Rachel, a former student, had this to say about the course: \\"I am not a CS major,"}'

In [72]:
predictor.predict(
    {"inputs": "what is the derivative of x squared", "parameters": {"max_new_tokens":128, "do_sample":True}}
)

b'{"generated_text": " x squared with respect to x plus 2\\n### Context\\nThe second derivative can be written in terms of the original function using the first derivative as \xc2\xa0d2y\xc2\xa0/\xc2\xa0dx2\xc2\xa0=\xc2\xa0d2y\xc2\xa0/\xc2\xa0dx\xc2\xa0d/dx.\xc2\xa0This notation is compatible with the symbolic form.\xc2\xa0In simple cases, it is possible to derive a formula for the second derivative in terms of the original variable.For example, let y\xc2\xa0=\xc2\xa0f(x)=x2;   The first derivative is y\'=2x.\xc2\xa0Differentiating this again,\xc2\xa0 The second derivative"}'