# Fine tuning of mistral model using metabolic pathway information

In [1]:
import os
import json
import random
from dotenv import load_dotenv, dotenv_values 
from mistralai.client import MistralClient
from mistralai.models.jobs import WandbIntegrationIn, TrainingParameters

In [2]:
load_dotenv()
api_key = os.getenv("MISTRAL_API_KEY")
wandb_api_key = os.getenv("WANDB_API_KEY")

## Data preparation

We have to split the dataset in train (95%) and val (5%).
If the data still have the 'target' entry (used for dataset verification), we need to remove it.

In [16]:
def process_split_jsonl(datafile, outf_train, outf_val):

    entries = []
    with open(datafile, 'r') as infile:
        for line in infile:
            try:
                entry = json.loads(line)
                processed_entry = {'messages': entry['messages']}
                entries.append(processed_entry)
            except json.JSONDecodeError:
                print(f"Ignoring invalid JSON: {line}")
                
    random.shuffle(entries)
    t_size = int(0.95*len(entries))
    
    with open(outf_train, 'w') as outfile:
        for entry in entries[:t_size]:
            json.dump(entry, outfile)
            outfile.write('\n')
    with open(outf_val, 'w') as outfile:
        for entry in entries[t_size:]:
            json.dump(entry, outfile)
            outfile.write('\n')

In [17]:
process_split_jsonl("dataset/train_dataset_lab.jsonl", "train/train.jsonl", "train/val.jsonl" )

## Training

First, dataset for train and validation are saved on the client.
Then, a job is created to fine-tune model open-mistral-7b with the metabolic pathway dataset.
Weight and biases (wandb) integration is added to record metrics from the training. This is done via Weight and biases website registration, to obtain an API key.

After a few minutes, we can check the job state to see if it is finished (SUCCESS). If so, the fine-tuned model can then be used for [evaluation](Evaluation.ipynb).

In [18]:
client = MistralClient(api_key=api_key)

In [19]:
with open("train/train.jsonl", "rb") as f:
    dbexamples_train = client.files.create(file=("dbex1000_train.jsonl", f))
with open("train/val.jsonl", "rb") as f:
    dbexamples_eval = client.files.create(file=("dbex1000_eval.jsonl", f))

id='2ac15edd-3e5f-428f-a9b6-6f82ca081d1f' object='file' bytes=1282888 created_at=1719571742 filename='dbex1000_train.jsonl' purpose='fine-tune'
id='94855080-ca1b-4161-82de-7cd00c1a842f' object='file' bytes=68452 created_at=1719571742 filename='dbex1000_eval.jsonl' purpose='fine-tune'


In [20]:
print(dbexamples_train)
print(dbexamples_eval)

id='2ac15edd-3e5f-428f-a9b6-6f82ca081d1f' object='file' bytes=1282888 created_at=1719571742 filename='dbex1000_train.jsonl' purpose='fine-tune'
id='94855080-ca1b-4161-82de-7cd00c1a842f' object='file' bytes=68452 created_at=1719571742 filename='dbex1000_eval.jsonl' purpose='fine-tune'


In [26]:
created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[dbexamples_train.id],
    validation_files=[dbexamples_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=11,
        learning_rate=0.0001,
    ),
    integrations=[
        WandbIntegrationIn(
            project="metabo_7B",
            run_name="dbex1000_20240628",
            api_key=wandb_api_key,
        ).dict()
    ]
)
created_jobs

Job(id='de14ae4a-beae-40ce-a6ca-ab97c9a65f06', hyperparameters=TrainingParameters(training_steps=11, learning_rate=0.0001), fine_tuned_model=None, model='open-mistral-7b', status='QUEUED', job_type='FT', created_at=1719573071, modified_at=1719573071, training_files=['2ac15edd-3e5f-428f-a9b6-6f82ca081d1f'], validation_files=['94855080-ca1b-4161-82de-7cd00c1a842f'], object='job', integrations=[WandbIntegration(type='wandb', project='metabo_7B', name=None, run_name='dbex1000_20240628')])

In [28]:
retrieved_jobs = client.jobs.retrieve('e63dc09f-04a9-4678-acb7-b3b38bf8e7d2')

print(retrieved_jobs.checkpoints)

[Checkpoint(metrics=Metric(train_loss=0.41102, valid_loss=0.439575, valid_mean_token_accuracy=1.356204), step_number=10, created_at=1719527313)]
