# Mistral Fine-tuning API

Check out the docs: https://docs.mistral.ai/capabilities/finetuning/

In [1]:
!pip install -q mistralai datasets nltk sacrebleu rouge-score

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.0/145.0 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━

## Prepare the dataset

Get the dataset

In [62]:
! wget https://raw.githubusercontent.com/habibadoum/lingua-franca/main/data/fr_sg_eval_data.jsonl
! wget https://raw.githubusercontent.com/habibadoum/lingua-franca/main/data/fr_sg_train_data.jsonl

--2024-06-30 22:25:49--  https://raw.githubusercontent.com/habibadoum/lingua-franca/main/data/fr_sg_eval_data.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 127232 (124K) [text/plain]
Saving to: ‘fr_sg_eval_data.jsonl’


2024-06-30 22:25:50 (9.14 MB/s) - ‘fr_sg_eval_data.jsonl’ saved [127232/127232]

--2024-06-30 22:25:50--  https://raw.githubusercontent.com/habibadoum/lingua-franca/main/data/fr_sg_train_data.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 25400615 (24M) [text/plain]
Saving to: ‘fr_sg_train_data.jsonl

## Reformat dataset
If you upload this ultrachat_chunk_train.jsonl to Mistral API, you might encounter an error message “Invalid file format” due to data formatting issues. To reformat the data into the correct format, you can download the reformat_dataset.py script and use it to validate and reformat both the training and evaluation data:

In [4]:
# download the validation and reformat script
!wget https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py

--2024-06-29 20:17:08--  https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3381 (3.3K) [text/plain]
Saving to: ‘reformat_data.py’


2024-06-29 20:17:08 (31.2 MB/s) - ‘reformat_data.py’ saved [3381/3381]



In [5]:
# validate and reformat the training data
!python reformat_data.py /content/fr_sg_train_data.jsonl

In [6]:
# validate the reformat the eval data
!python reformat_data.py /content/fr_sg_eval_data.jsonl

## Upload dataset

In [7]:
import os
from mistralai.client import MistralClient

api_key = os.environ.get("MISTRAL_API_KEY")
client = MistralClient(api_key=api_key)

train_file = "fr_sg_train_data.jsonl"
eval_file = "fr_sg_eval_data.jsonl"
train_data_splits = []

In [8]:
with open(eval_file, "rb") as f:
    ultrachat_chunk_eval = client.files.create(file=(eval_file, f))

In [9]:
with open(train_file, "rb") as f:
    ultrachat_chunk_train = client.files.create(file=(train_file, f))

In [10]:
import json
def pprint(obj):
    print(json.dumps(obj.dict(), indent=4))

In [11]:
pprint(ultrachat_chunk_train)

{
    "id": "ca63c194-cb59-4d2e-ae1a-f3b0ee6a538c",
    "object": "file",
    "bytes": 21193320,
    "created_at": 1719692240,
    "filename": "fr_sg_train_data.jsonl",
    "purpose": "fine-tune"
}


In [12]:
pprint(ultrachat_chunk_eval)

{
    "id": "4b01fdec-95f3-434c-925f-c01818daa08f",
    "object": "file",
    "bytes": 106298,
    "created_at": 1719692235,
    "filename": "fr_sg_eval_data.jsonl",
    "purpose": "fine-tune"
}


## Create a fine-tuning job

In [23]:
from mistralai.models.jobs import TrainingParameters

created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[ultrachat_chunk_train.id],
    validation_files=[ultrachat_chunk_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=200,
        learning_rate=0.0001,
        )
)

In [24]:
pprint(created_jobs)

{
    "id": "d82784af-b4fa-41fa-a70e-44692e7e3dd6",
    "hyperparameters": {
        "training_steps": 200,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "QUEUED",
    "job_type": "FT",
    "created_at": 1719692942,
    "modified_at": 1719692942,
    "training_files": [
        "ca63c194-cb59-4d2e-ae1a-f3b0ee6a538c"
    ],
    "validation_files": [
        "4b01fdec-95f3-434c-925f-c01818daa08f"
    ],
    "object": "job",
    "integrations": []
}


In [25]:
import time

retrieved_job = client.jobs.retrieve(created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.jobs.retrieve(created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
                "valid_loss": 2.069736,
                "valid_mean_token_accuracy": 4.198099
            },
            "step_number": 40,
            "created_at": 1719693170
        }
    ],
    "estimated_start_time": null
}
Job is RUNNING, waiting 10 seconds
{
    "id": "d82784af-b4fa-41fa-a70e-44692e7e3dd6",
    "hyperparameters": {
        "training_steps": 200,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "RUNNING",
    "job_type": "FT",
    "created_at": 1719692942,
    "modified_at": 1719692943,
    "training_files": [
        "ca63c194-cb59-4d2e-ae1a-f3b0ee6a538c"
    ],
    "validation_files": [
        "4b01fdec-95f3-434c-925f-c01818daa08f"
    ],
    "object": "job",
    "integrations": [],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "RUNNING"
            },
            "c

In [26]:
# List jobs
jobs = client.jobs.list()
pprint(jobs)

{
    "data": [
        {
            "id": "d82784af-b4fa-41fa-a70e-44692e7e3dd6",
            "hyperparameters": {
                "training_steps": 200,
                "learning_rate": 0.0001
            },
            "fine_tuned_model": "ft:open-mistral-7b:66e9a26f:20240629:d82784af",
            "model": "open-mistral-7b",
            "status": "SUCCESS",
            "job_type": "FT",
            "created_at": 1719692942,
            "modified_at": 1719694033,
            "training_files": [
                "ca63c194-cb59-4d2e-ae1a-f3b0ee6a538c"
            ],
            "validation_files": [
                "4b01fdec-95f3-434c-925f-c01818daa08f"
            ],
            "object": "job",
            "integrations": []
        },
        {
            "id": "14992c62-77d1-4db1-8bc6-0a5267fbf3da",
            "hyperparameters": {
                "training_steps": 100,
                "learning_rate": 0.0001
            },
            "fine_tuned_model": "ft:open-mistral-7b:66e9

In [27]:
# Retrieve a jobs
retrieved_jobs = client.jobs.retrieve(created_jobs.id)
pprint(retrieved_jobs)

{
    "id": "d82784af-b4fa-41fa-a70e-44692e7e3dd6",
    "hyperparameters": {
        "training_steps": 200,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": "ft:open-mistral-7b:66e9a26f:20240629:d82784af",
    "model": "open-mistral-7b",
    "status": "SUCCESS",
    "job_type": "FT",
    "created_at": 1719692942,
    "modified_at": 1719694033,
    "training_files": [
        "ca63c194-cb59-4d2e-ae1a-f3b0ee6a538c"
    ],
    "validation_files": [
        "4b01fdec-95f3-434c-925f-c01818daa08f"
    ],
    "object": "job",
    "integrations": [],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "SUCCESS"
            },
            "created_at": 1719694033
        },
        {
            "name": "status-updated",
            "data": {
                "status": "RUNNING"
            },
            "created_at": 1719692943
        },
        {
            "name": "status-updated",
            "data": {
             

In [29]:
retrieved_jobs.fine_tuned_model

'ft:open-mistral-7b:66e9a26f:20240629:d82784af'

## Use a fine-tuned model

In [2]:
from mistralai.client import MistralClient

api_key = os.environ.get("MISTRAL_API_KEY")
client = MistralClient(api_key=api_key)

In [60]:
from mistralai.models.chat_completion import ChatMessage

chat_response = client.chat(
    model=retrieved_jobs.fine_tuned_model,
    messages=[ChatMessage(role='user', content="Traduis en sango: S'il vous offre une bière, refusez-la.")]
)

In [61]:
message = chat_response.choices[0].message
res = dict(message)
res['content']

'sango : Mû biëre na mo, kîri na.'

In [5]:
def read_txt_file_into_list(filename):
  """
  Reads the contents of a text file into a list.

  Args:
    filename: The name of the text file to read.

  Returns:
    A list containing the lines of the text file.
  """

  with open(filename, 'r') as f:
    lines = f.readlines()
  return lines

french_file = 'dev.fra_Latn' # Corpus from FLORES-200
sango_file = 'dev.sag_Latn' # Corpus from FLORES-200
french_data = read_txt_file_into_list(french_file)
sango_data = read_txt_file_into_list(sango_file)

print(f"Number of lines in {french_file}: {len(french_data)}")
print(f"Number of lines in {sango_file}: {len(sango_data)}")

Number of lines in dev.fra_Latn: 997
Number of lines in dev.sag_Latn: 997


**Note :** For the `dev.fra_Latn` and `dev.sag_Latn` files, you'll need to download them manually from the [official repository](https://github.com/openlanguagedata/flores) of FLORES-200.

The latest version of the dataset can be downloaded in the Releases tab of this repository. It is available as a zip archive, with password multilingual machine translation. The data is only available in this format in order to avoid it being picked up by crawlers, which would lead to it being accidentally included in the sort of web corpora often used to train LLMs and large scale machine translation models, rendering it useless as a benchmark.

In [7]:
import time
from mistralai.models.chat_completion import ChatMessage

model_predictions = []

for text in french_data[:100]:
  chat_response = client.chat(
      model=retrieved_jobs.fine_tuned_model,
      messages=[ChatMessage(role='user', content=f"Traduis la phrase suivante en Sango: {text}")]
  )

  pred = chat_response.choices[0].message.content.split('sango :')[-1].strip()
  print(f"Prediction: {pred}")
  model_predictions.append(pred)

Prediction: Alîngbi tî fadësû tî âwandarä tî dalisoro tî Stanford akônda na pekô tî lâsô. Âla sû âgerê tî hînga wala âkêtê kêsi tî sâra tënë sô alöndö na lêgë nî tî hînga sëndëngö sêngê sêngê, sô ayeke tene ngâ tî sâra na gerê tî sû wala tî kutukutu tî mbëtïngû.
Prediction: Na lêgë tî âmokondö kûê, sô asâra sï, ngbanga tî kobêla tî kankûi, tî kötä-mbs, tî VIH na tî palüsïi na yâ tî âködörö tî nginza tî nî sô ayeke âkutukutu, sï kobêla tî kankûi tî bôbanga tî âwâlï so ague na ngangü, azîa na yâ tî âködörö tî nginza ahön ndönî.
Prediction: Sango : Masïni-lêgë tî laparaäo sô ayeke lô na yâ tî laparaäo ngâ na tângo tî kä na 9 h 30, lo tïgbïngö, ngâ lo tïngo ngâ lo bira, sï kua tî laparaäo tî wöngö-kôbe a yeke zîa.
Prediction: Sango : Pilôto atö ndâ tî lo na gbe tî komandâgbo Tilokrit Pattavee.
Prediction: Lêndo tî wërë sô atene : Mopepe tî lëkëngö-wâ afâ na hürüngö
Prediction: Sango : Vidal, balë-ôko na mbâgë tî o menë, sî bîakü asï na Barça.
Prediction: Sango : Lekere ngoi na ngoi sô Vida

In [10]:
clean_predictions = [pred.replace("Sango : ", "").strip() for pred in model_predictions]
clean_predictions

['Alîngbi tî fadësû tî âwandarä tî dalisoro tî Stanford akônda na pekô tî lâsô. Âla sû âgerê tî hînga wala âkêtê kêsi tî sâra tënë sô alöndö na lêgë nî tî hînga sëndëngö sêngê sêngê, sô ayeke tene ngâ tî sâra na gerê tî sû wala tî kutukutu tî mbëtïngû.',
 'Na lêgë tî âmokondö kûê, sô asâra sï, ngbanga tî kobêla tî kankûi, tî kötä-mbs, tî VIH na tî palüsïi na yâ tî âködörö tî nginza tî nî sô ayeke âkutukutu, sï kobêla tî kankûi tî bôbanga tî âwâlï so ague na ngangü, azîa na yâ tî âködörö tî nginza ahön ndönî.',
 'Masïni-lêgë tî laparaäo sô ayeke lô na yâ tî laparaäo ngâ na tângo tî kä na 9 h 30, lo tïgbïngö, ngâ lo tïngo ngâ lo bira, sï kua tî laparaäo tî wöngö-kôbe a yeke zîa.',
 'Pilôto atö ndâ tî lo na gbe tî komandâgbo Tilokrit Pattavee.',
 'Lêndo tî wërë sô atene : Mopepe tî lëkëngö-wâ afâ na hürüngö',
 'Vidal, balë-ôko na mbâgë tî o menë, sî bîakü asï na Barça.',
 'Lekere ngoi na ngoi sô Vidal alöndö na ndokua tî Katalâna, lo sâra kâmba tî bângâ tî bôsôngbi 49.',
 'Yê sô angbâ tî 

In [11]:
print(f"Len of model_predictions: {len(model_predictions)}")
print(f"First prediction: {model_predictions[0]}")
print(f"Last prediction: {model_predictions[-1]}")

Len of model_predictions: 100
First prediction: Alîngbi tî fadësû tî âwandarä tî dalisoro tî Stanford akônda na pekô tî lâsô. Âla sû âgerê tî hînga wala âkêtê kêsi tî sâra tënë sô alöndö na lêgë nî tî hînga sëndëngö sêngê sêngê, sô ayeke tene ngâ tî sâra na gerê tî sû wala tî kutukutu tî mbëtïngû.
Last prediction: Kürü bätängö mabôko tî wâkua tî kpöngö ndo na ngangü tî gerê-wâ sô, sï ngangü tî kängö kä ayeke sïgî na yâ tî mangbôkô tî kutukutu sô na yâ tî gbätä tî âzo tî Fort Greely.


In [41]:
from datasets import load_metric

references = sango_data[:100]

bleu_metric = load_metric('bleu')

predictions_list = [prediction.split() for prediction in clean_predictions]
references_list = [[reference.split()] for reference in references]

bleu_result = bleu_metric.compute(predictions=predictions_list, references=references_list)
print("BLEU Score:", bleu_result)

BLEU Score: {'bleu': 0.0049036996543233, 'precisions': [0.16145092460881935, 0.01696165191740413, 0.0015308075009567547, 0.000397456279809221], 'brevity_penalty': 0.7675279602865839, 'length_ratio': 0.7907761529808774, 'translation_length': 2812, 'reference_length': 3556}


In [None]:
import nltk

# Download necessary NLTK resources
nltk.download('wordnet')
nltk.download('omw-1.4')  # Optional, for additional languages

In [49]:
import nltk
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from sacrebleu import corpus_ter as sacre_corpus_ter
from nltk.translate.meteor_score import meteor_score


def compute_metrics(references, predictions):
    # Tokenize the sentences
    references_tokens = [[ref.split()] for ref in references]
    predictions_tokens = [pred.split() for pred in predictions]

    # BLEU score
    smoothie = SmoothingFunction().method4
    bleu_score = corpus_bleu(references_tokens, predictions_tokens, smoothing_function=smoothie)

    # ROUGE score
    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = [rouge.score(ref, pred) for ref, pred in zip(references, predictions)]
    avg_rouge1 = sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores)
    avg_rouge2 = sum([score['rouge2'].fmeasure for score in rouge_scores]) / len(rouge_scores)
    avg_rougeL = sum([score['rougeL'].fmeasure for score in rouge_scores]) / len(rouge_scores)

    # METEOR score
    meteor_scores = [meteor_score([ref.split()], pred.split()) for ref, pred in zip(references, predictions)]
    avg_meteor = sum(meteor_scores) / len(meteor_scores)

    # TER score
    ter_score = sacre_corpus_ter(predictions, [references])

    return {
        "BLEU": bleu_score,
        "ROUGE-1": avg_rouge1,
        "ROUGE-2": avg_rouge2,
        "ROUGE-L": avg_rougeL,
        "METEOR": avg_meteor,
        "TER": ter_score.score
    }


metrics = compute_metrics(references, clean_predictions)
metrics

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


{'BLEU': 0.004899409054457557,
 'ROUGE-1': 0.25038912565100213,
 'ROUGE-2': 0.03715914406924841,
 'ROUGE-L': 0.18183284738555094,
 'METEOR': 0.07671600339921887,
 'TER': 95.78177727784028}