# Mistral Fine-tuning API

Check out the docs: https://docs.mistral.ai/capabilities/finetuning/

In [None]:
!pip install mistralai pandas

## Prepare the dataset

In this example, let’s use the ultrachat_200k dataset. We load a chunk of the data into Pandas Dataframes, split the data into training and validation, and save the data into the required jsonl format for fine-tuning.

In [None]:
import pandas as pd
df = pd.read_parquet('https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/data/test_gen-00000-of-00001-3d4cd8309148a71f.parquet')

df_train=df.sample(frac=0.995,random_state=200)
df_eval=df.drop(df_train.index)

df_train.to_json("ultrachat_chunk_train.jsonl", orient="records", lines=True)
df_eval.to_json("ultrachat_chunk_eval.jsonl", orient="records", lines=True)

In [None]:
!ls -lh

total 147M
drwxr-xr-x 1 root root 4.0K May 31 13:30 sample_data
-rw-r--r-- 1 root root 698K Jun  4 09:06 ultrachat_chunk_eval.jsonl
-rw-r--r-- 1 root root 146M Jun  4 09:06 ultrachat_chunk_train.jsonl


## Reformat dataset
If you upload this ultrachat_chunk_train.jsonl to Mistral API, you might encounter an error message “Invalid file format” due to data formatting issues. To reformat the data into the correct format, you can download the reformat_dataset.py script and use it to validate and reformat both the training and evaluation data:

In [None]:
# download the validation and reformat script
!wget https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py

--2024-06-04 09:07:07--  https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3378 (3.3K) [text/plain]
Saving to: ‘reformat_data.py’


2024-06-04 09:07:07 (37.1 MB/s) - ‘reformat_data.py’ saved [3378/3378]



In [None]:
# validate and reformat the training data
!python reformat_data.py ultrachat_chunk_train.jsonl

Skip 3674th sample
Skip 9176th sample
Skip 10559th sample
Skip 13293th sample
Skip 13973th sample
Skip 15219th sample


In [None]:
# validate the reformat the eval data
!python reformat_data.py ultrachat_chunk_eval.jsonl

In [None]:
df_train.iloc[3674]['messages']

array([{'content': 'What are the dimensions of the cavity, product, and shipping box of the Sharp SMC1662DS microwave?: With innovative features like preset controls, Sensor Cooking and the Carousel® turntable system, the Sharp® SMC1662DS 1.6 cu. Ft. Stainless Steel Carousel Countertop Microwave makes reheating your favorite foods, snacks and beverages easier than ever. Use popcorn and beverage settings for one-touch cooking. Express Cook allows one-touch cooking up to six minutes. The convenient and flexible "+30 Sec" key works as both instant start option and allows you to add more time during cooking.\nThe Sharp SMC1662DS microwave is a bold design statement in any kitchen. The elegant, grey interior and bright white, LED interior lighting complements the stainless steel finish of this premium appliance.\nCavity Dimensions (w x h x d): 15.5" x 10.2" x 17.1"\nProduct Dimensions (w x h x d): 21.8" x 12.8" x 17.7"\nShipping Dimensions (w x h x d) : 24.4" x 15.0" x 20.5"', 'role': 'user

## Upload dataset

In [None]:
import os
from mistralai.client import MistralClient

api_key = os.environ.get("MISTRAL_API_KEY")
client = MistralClient(api_key=api_key)

with open("ultrachat_chunk_train.jsonl", "rb") as f:
    ultrachat_chunk_train = client.files.create(file=("ultrachat_chunk_train.jsonl", f))
with open("ultrachat_chunk_eval.jsonl", "rb") as f:
    ultrachat_chunk_eval = client.files.create(file=("ultrachat_chunk_eval.jsonl", f))

In [None]:
import json
def pprint(obj):
    print(json.dumps(obj.dict(), indent=4))

In [None]:
pprint(ultrachat_chunk_train)

{
    "id": "872aa786-538f-4510-a7b0-2cffee00b71d",
    "object": "file",
    "bytes": 147449694,
    "created_at": 1717492198,
    "filename": "ultrachat_chunk_train.jsonl",
    "purpose": "fine-tune"
}


In [None]:
pprint(ultrachat_chunk_eval)

{
    "id": "d0ffa827-aca2-42ab-99a8-d449a2cd4da3",
    "object": "file",
    "bytes": 691885,
    "created_at": 1717492199,
    "filename": "ultrachat_chunk_eval.jsonl",
    "purpose": "fine-tune"
}


## Create a fine-tuning job

In [None]:
from mistralai.models.jobs import TrainingParameters

created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[ultrachat_chunk_train.id],
    validation_files=[ultrachat_chunk_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=10,
        learning_rate=0.0001,
        )
)

In [None]:
pprint(created_jobs)

{
    "id": "63be4026-cb98-4dc0-becd-83c424a42e0e",
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "QUEUED",
    "job_type": "FT",
    "created_at": 1717492375,
    "modified_at": 1717492376,
    "training_files": [
        "872aa786-538f-4510-a7b0-2cffee00b71d"
    ],
    "validation_files": [
        "d0ffa827-aca2-42ab-99a8-d449a2cd4da3"
    ],
    "object": "job",
    "integrations": []
}


In [None]:
import time

retrieved_job = client.jobs.retrieve(created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.jobs.retrieve(created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)



{
    "id": "63be4026-cb98-4dc0-becd-83c424a42e0e",
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "RUNNING",
    "job_type": "FT",
    "created_at": 1717492375,
    "modified_at": 1717492377,
    "training_files": [
        "872aa786-538f-4510-a7b0-2cffee00b71d"
    ],
    "validation_files": [
        "d0ffa827-aca2-42ab-99a8-d449a2cd4da3"
    ],
    "object": "job",
    "integrations": [],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "RUNNING"
            },
            "created_at": 1717492377
        },
        {
            "name": "status-updated",
            "data": {
                "status": "QUEUED"
            },
            "created_at": 1717492375
        }
    ],
    "checkpoints": []
}
Job is RUNNING, waiting 10 seconds
{
    "id": "63be4026-cb98-4dc0-becd-83c424a42e0e",
    "hyperp

In [None]:
# List jobs
jobs = client.jobs.list()
pprint(jobs)

{
    "data": [
        {
            "id": "63be4026-cb98-4dc0-becd-83c424a42e0e",
            "hyperparameters": {
                "training_steps": 10,
                "learning_rate": 0.0001
            },
            "fine_tuned_model": "ft:open-mistral-7b:8e2706f0:20240604:63be4026",
            "model": "open-mistral-7b",
            "status": "SUCCESS",
            "job_type": "FT",
            "created_at": 1717492375,
            "modified_at": 1717492452,
            "training_files": [
                "872aa786-538f-4510-a7b0-2cffee00b71d"
            ],
            "validation_files": [
                "d0ffa827-aca2-42ab-99a8-d449a2cd4da3"
            ],
            "object": "job",
            "integrations": []
        },
        {
            "id": "54e86210-7eee-4486-a47f-a00c84880b02",
            "hyperparameters": {
                "training_steps": 100,
                "learning_rate": 0.0001
            },
            "fine_tuned_model": null,
            "model"

In [None]:
# Retrieve a jobs
retrieved_jobs = client.jobs.retrieve(created_jobs.id)
pprint(retrieved_jobs)


{
    "id": "63be4026-cb98-4dc0-becd-83c424a42e0e",
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": "ft:open-mistral-7b:8e2706f0:20240604:63be4026",
    "model": "open-mistral-7b",
    "status": "SUCCESS",
    "job_type": "FT",
    "created_at": 1717492375,
    "modified_at": 1717492452,
    "training_files": [
        "872aa786-538f-4510-a7b0-2cffee00b71d"
    ],
    "validation_files": [
        "d0ffa827-aca2-42ab-99a8-d449a2cd4da3"
    ],
    "object": "job",
    "integrations": [],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "SUCCESS"
            },
            "created_at": 1717492452
        },
        {
            "name": "status-updated",
            "data": {
                "status": "RUNNING"
            },
            "created_at": 1717492377
        },
        {
            "name": "status-updated",
            "data": {
              

## Use a fine-tuned model

In [None]:
from mistralai.models.chat_completion import ChatMessage

chat_response = client.chat(
    model=retrieved_jobs.fine_tuned_model,
    messages=[ChatMessage(role='user', content='What is the best French cheese?')]
)

In [None]:
pprint(chat_response)

{
    "id": "a0b12a4ba1b9402ba8fe541526f26749",
    "object": "chat.completion",
    "created": 1717492478,
    "model": "ft:open-mistral-7b:8e2706f0:20240604:63be4026",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "It's impossible to definitively say which is the best French cheese, as it depends on one's personal taste and preferences. Some popular French cheeses include Camembert, Brie, Comt\u00e9, and Roquefort. These cheeses vary in flavor, texture, and intensity, so it is best to try a few to determine which you prefer.",
                "name": null,
                "tool_calls": null,
                "tool_call_id": null
            },
            "finish_reason": "stop"
        }
    ],
    "usage": {
        "prompt_tokens": 10,
        "total_tokens": 87,
        "completion_tokens": 77
    }
}


## Integration with Weights and Biases
We can also offer support for integration with Weights & Biases (W&B) to monitor and track various metrics and statistics associated with our fine-tuning jobs. To enable integration with W&B, you will need to create an account with W&B and add your W&B information in the “integrations” section in the job creation request:



In [None]:
from mistralai.models.jobs import WandbIntegrationIn

WANDB_API_KEY = "XXX"

created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[ultrachat_chunk_train.id],
    validation_files=[ultrachat_chunk_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=100,
        learning_rate=0.0001,
    ),
    integrations=[
        WandbIntegrationIn(
            project="test_ft_api",
            run_name="test",
            api_key=WANDB_API_KEY,
        ).dict()
    ],
)