<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Mistral_Finetune_API_T2SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install colab-env -q
!pip install mistralai -q
!pip install datasets -q
import colab_env

In [None]:
from datasets import load_dataset

# Convert dataset to OAI messages
system_message = """You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA:
{schema}"""

def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message.format(schema=sample["context"])},
      {"role": "user", "content": sample["question"]},
      {"role": "assistant", "content": sample["answer"]}
    ]
  }

# Load dataset from the hub
dataset = load_dataset("b-mc2/sql-create-context", split="train")
dataset = dataset.shuffle().select(range(70000))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)

dataset = dataset.train_test_split(test_size=0.15)  # Use 0.15 for 15% test size

print(dataset["train"][345]["messages"])

# save datasets to disk
dataset["train"].to_json("train_dataset.jsonl", orient="records", lines=True)
dataset["test"].to_json("test_dataset.jsonl", orient="records", lines=True)

In [None]:
# download the validation and reformat script
!wget https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py

In [187]:
!python reformat_data.py train_dataset.jsonl
!python reformat_data.py test_dataset.jsonl

In [188]:
dataset

DatasetDict({
    train: Dataset({
        features: ['messages'],
        num_rows: 59500
    })
    test: Dataset({
        features: ['messages'],
        num_rows: 10500
    })
})

## Reformat dataset
If you upload this ultrachat_chunk_train.jsonl to Mistral API, you might encounter an error message “Invalid file format” due to data formatting issues. To reformat the data into the correct format, you can download the reformat_dataset.py script and use it to validate and reformat both the training and evaluation data:

In [5]:
# download the validation and reformat script
!wget https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py

--2025-03-19 05:21:44--  https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3381 (3.3K) [text/plain]
Saving to: ‘reformat_data.py’


2025-03-19 05:21:44 (31.3 MB/s) - ‘reformat_data.py’ saved [3381/3381]



In [189]:
# validate the reformat the  data
!python reformat_data.py /content/train_dataset.jsonl
!python reformat_data.py /content/test_dataset.jsonl

In [190]:
dataset['train'][45678]['messages']

[{'content': 'You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\nSCHEMA:\nCREATE TABLE table_name_89 (detriment VARCHAR, domicile VARCHAR, sign VARCHAR)',
  'role': 'system'},
 {'content': 'Which detriment has a domicile of mercury and Virgo as a sign?',
  'role': 'user'},
 {'content': 'SELECT detriment FROM table_name_89 WHERE domicile = "mercury" AND sign = "virgo"',
  'role': 'assistant'}]

## Upload dataset

In [191]:
from mistralai import Mistral
import os

api_key = os.environ["MISTRAL_API_KEY"]

client = Mistral(api_key=api_key)

chunk_train = client.files.upload(file={
    "file_name": "/content/train_dataset.jsonl",
    "content": open("/content/train_dataset.jsonl", "rb"),
})
chunk_eval = client.files.upload(file={
    "file_name": "/content/test_dataset.jsonl",
    "content": open("/content/test_dataset.jsonl", "rb"),
})

In [192]:
import json
def pprint(obj):
    print(json.dumps(obj.model_dump(), indent=4))

In [193]:
pprint(chunk_train)

{
    "id": "b0352eff-6dd1-45c2-852e-753c38dfae10",
    "object": "file",
    "bytes": null,
    "created_at": 1742368415,
    "filename": "train_dataset.jsonl",
    "purpose": "fine-tune",
    "sample_type": "instruct",
    "source": "upload",
    "num_lines": 59500
}


In [194]:
pprint(chunk_eval)

{
    "id": "e2948523-72c1-437a-8647-2bd2c868719e",
    "object": "file",
    "bytes": null,
    "created_at": 1742368417,
    "filename": "test_dataset.jsonl",
    "purpose": "fine-tune",
    "sample_type": "instruct",
    "source": "upload",
    "num_lines": 10500
}


## Create a fine-tuning job

In [199]:
created_jobs = client.fine_tuning.jobs.create(
    model="open-mistral-7b",
    training_files=[{"file_id": chunk_train.id, "weight": 1}],
    validation_files=[chunk_eval.id],
    hyperparameters={
    "training_steps": 10,
    "learning_rate":0.0001
    },
    auto_start=True
)
created_jobs

JobOut(id='42325165-05e2-42e5-9759-2a7eaf7d9286', auto_start=True, hyperparameters=TrainingParameters(training_steps=10, learning_rate=0.0001, weight_decay=0.1, warmup_fraction=0.05, epochs=None, fim_ratio=None, seq_len=32768), model='open-mistral-7b', status='QUEUED', job_type='completion', created_at=1742368691, modified_at=1742368691, training_files=['b0352eff-6dd1-45c2-852e-753c38dfae10'], validation_files=['e2948523-72c1-437a-8647-2bd2c868719e'], OBJECT='job', fine_tuned_model=None, suffix=None, integrations=[], trained_tokens=None, repositories=[], metadata=JobMetadataOut(expected_duration_seconds=None, cost=0.0, cost_currency=None, train_tokens_per_step=None, train_tokens=None, data_tokens=None, estimated_start_time=None))

In [200]:
pprint(created_jobs)

{
    "id": "42325165-05e2-42e5-9759-2a7eaf7d9286",
    "auto_start": true,
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.0001,
        "weight_decay": 0.1,
        "warmup_fraction": 0.05,
        "epochs": null,
        "fim_ratio": null,
        "seq_len": 32768
    },
    "model": "open-mistral-7b",
    "status": "QUEUED",
    "job_type": "completion",
    "created_at": 1742368691,
    "modified_at": 1742368691,
    "training_files": [
        "b0352eff-6dd1-45c2-852e-753c38dfae10"
    ],
    "validation_files": [
        "e2948523-72c1-437a-8647-2bd2c868719e"
    ],
    "fine_tuned_model": null,
    "suffix": null,
    "integrations": [],
    "trained_tokens": null,
    "repositories": [],
    "metadata": {
        "expected_duration_seconds": null,
        "cost": 0.0,
        "cost_currency": null,
        "train_tokens_per_step": null,
        "train_tokens": null,
        "data_tokens": null,
        "estimated_start_time": null
    }
}


In [201]:
jobs = client.fine_tuning.jobs.list()
print(jobs)

total=10 data=[JobOut(id='42325165-05e2-42e5-9759-2a7eaf7d9286', auto_start=True, hyperparameters=TrainingParameters(training_steps=10, learning_rate=0.0001, weight_decay=0.1, warmup_fraction=0.05, epochs=0.17660251163285778, fim_ratio=None, seq_len=32768), model='open-mistral-7b', status='RUNNING', job_type='completion', created_at=1742368691, modified_at=1742368703, training_files=['b0352eff-6dd1-45c2-852e-753c38dfae10'], validation_files=['e2948523-72c1-437a-8647-2bd2c868719e'], OBJECT='job', fine_tuned_model=None, suffix=None, integrations=[], trained_tokens=None, repositories=[], metadata=JobMetadataOut(expected_duration_seconds=300, cost=2.63, cost_currency='USD', train_tokens_per_step=131072, train_tokens=1310720, data_tokens=7421865, estimated_start_time=None)), JobOut(id='11d93414-8a31-4669-9683-276ef92a79cc', auto_start=True, hyperparameters=TrainingParameters(training_steps=1000, learning_rate=0.0001, weight_decay=0.1, warmup_fraction=0.05, epochs=17.66025116328578, fim_rati

In [202]:
retrieved_jobs = client.fine_tuning.jobs.get(job_id = created_jobs.id)
retrieved_jobs

DetailedJobOut(id='42325165-05e2-42e5-9759-2a7eaf7d9286', auto_start=True, hyperparameters=TrainingParameters(training_steps=10, learning_rate=0.0001, weight_decay=0.1, warmup_fraction=0.05, epochs=0.17660251163285778, fim_ratio=None, seq_len=32768), model='open-mistral-7b', status='RUNNING', job_type='completion', created_at=1742368691, modified_at=1742368703, training_files=['b0352eff-6dd1-45c2-852e-753c38dfae10'], validation_files=['e2948523-72c1-437a-8647-2bd2c868719e'], OBJECT='job', fine_tuned_model=None, suffix=None, integrations=[], trained_tokens=None, repositories=[], metadata=JobMetadataOut(expected_duration_seconds=300, cost=2.63, cost_currency='USD', train_tokens_per_step=131072, train_tokens=1310720, data_tokens=7421865, estimated_start_time=None), events=[EventOut(name='status-updated', created_at=1742368691, data={'status': 'QUEUED'}), EventOut(name='status-updated', created_at=1742368698, data={'status': 'VALIDATED'}), EventOut(name='status-updated', created_at=1742368

In [203]:
import time

retrieved_job = client.fine_tuning.jobs.get(job_id = created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.fine_tuning.jobs.get(job_id = created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)



In [205]:
# List jobs
jobs = client.fine_tuning.jobs.list()
pprint(jobs)

{
    "total": 10,
    "data": [
        {
            "id": "42325165-05e2-42e5-9759-2a7eaf7d9286",
            "auto_start": true,
            "hyperparameters": {
                "training_steps": 10,
                "learning_rate": 0.0001,
                "weight_decay": 0.1,
                "warmup_fraction": 0.05,
                "epochs": 0.17660251163285778,
                "fim_ratio": null,
                "seq_len": 32768
            },
            "model": "open-mistral-7b",
            "status": "SUCCESS",
            "job_type": "completion",
            "created_at": 1742368691,
            "modified_at": 1742368986,
            "training_files": [
                "b0352eff-6dd1-45c2-852e-753c38dfae10"
            ],
            "validation_files": [
                "e2948523-72c1-437a-8647-2bd2c868719e"
            ],
            "fine_tuned_model": "ft:open-mistral-7b:9c9073bb:20250319:42325165",
            "suffix": null,
            "integrations": [],
            

In [206]:
# Retrieve a jobs
retrieved_jobs = client.fine_tuning.jobs.get(job_id = created_jobs.id)
pprint(retrieved_jobs)

{
    "id": "42325165-05e2-42e5-9759-2a7eaf7d9286",
    "auto_start": true,
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.0001,
        "weight_decay": 0.1,
        "warmup_fraction": 0.05,
        "epochs": 0.17660251163285778,
        "fim_ratio": null,
        "seq_len": 32768
    },
    "model": "open-mistral-7b",
    "status": "SUCCESS",
    "job_type": "completion",
    "created_at": 1742368691,
    "modified_at": 1742368986,
    "training_files": [
        "b0352eff-6dd1-45c2-852e-753c38dfae10"
    ],
    "validation_files": [
        "e2948523-72c1-437a-8647-2bd2c868719e"
    ],
    "fine_tuned_model": "ft:open-mistral-7b:9c9073bb:20250319:42325165",
    "suffix": null,
    "integrations": [],
    "trained_tokens": 1310720,
    "repositories": [],
    "metadata": {
        "expected_duration_seconds": 300,
        "cost": 2.63,
        "cost_currency": "USD",
        "train_tokens_per_step": 131072,
        "train_tokens": 1310720,
        "d

## Use a fine-tuned model

In [268]:
#'expedted answer': 'SELECT detriment FROM table_name_89 WHERE domicile = "mercury" AND sign = "virgo"'


chat_response = client.chat.complete(
    model = retrieved_jobs.fine_tuned_model,
    messages = [{"role":'system', "content": 'You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA. SCHEMA: CREATE TABLE table_name_89 (detriment VARCHAR, domicile VARCHAR, sign VARCHAR)'}] +
               [{"role":'user', "content":'Which detriment has a domicile of mercury and Virgo as a sign?'}]
)

# Extract the content from the response
content = chat_response.choices[0].message.content

# Now you can use replace() on the content string
modified_content = content.replace('\\"', "'")

print(f"Generated query: {modified_content}")


Generated query: SELECT DISTINCT detriment FROM table_name_89 WHERE domicile = "mercury" AND sign = "virgo"


In [259]:
client.fine_tuning.jobs.create(
    model="open-mistral-7b",
    training_files=[{"file_id": chunk_train.id, "weight": 1}],
    validation_files=[chunk_eval.id],
    hyperparameters={"training_steps": 10, "learning_rate": 0.0001},
    integrations=[
        {
            "project": "<value>",
            "api_key": "<value>",
        }
    ]
)

SDKError: API error occurred: Status 422
{"detail": [{"type": "string_too_short", "loc": ["body", "job_in", "integrations", 0, "wandb", "api_key"], "msg": "String should have at least 40 characters", "ctx": {"min_length": 40}}]}