# Fine-tuning Mistral

Thanks to [our first notebook](a_DatasetCreation.ipynb), we now have a training dataset containing *30 sentences* at each level ${A2, B1, B2, C1, C2}$ and their simplified versions at level ${A1, A2, B1, B2, C1}$. We will now fine-tune a version of **Mistral** specially designed for French on this dataset in order to have a model capable of simplifying French sentences.

In [1]:
# ---------------------------- PREPARING NOTEBOOK ---------------------------- #
# Autoreload
%load_ext autoreload
%autoreload 2

# Random seed
import numpy as np
np.random.seed(42)

# External modules
import os

# Set global log level
import logging
logging.basicConfig(level=logging.INFO)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Define PWD as the current git repository
import git
repo = git.Repo('.', search_parent_directories=True)
pwd = repo.working_dir
os.chdir(pwd)

# import

In [2]:
# -------------------------- LOAD PREVIOUS NOTEBOOKS ------------------------- #
import json
import __main__
import black

paths = [
    os.path.join(pwd, "notebooks", "text_simplification", "a_DatasetCreation.ipynb"),
]

# Read notebooks
code_dict = {}
for path in paths:
    code = ""
    with open(path, "r") as f:
        temp = json.load(f)

    cells = [
        cell
        for cell in temp["cells"]
        if cell["cell_type"] == "code"
        and len(cell["source"]) > 0
        and cell["source"][-1] == "# import"
    ]
    notebook_code = "\n".join(
        line
        for cell in cells
        for line in cell["source"]
        if line != "# import" and len(line) > 0 and line[0] != "%"
    )
    # Create something like a header
    code += f"# {'-'*76} #\n"
    code += f"# {os.path.basename(path).upper():^76} #\n"
    code += f"# {'-'*76} #\n"
    code += notebook_code

    # Add "Module Creation"
    notebook_name = (
        os.path.basename(path).replace("imported_", "").replace(".ipynb", "")
    )
    code += """
# --------------------------------- IMPORTER --------------------------------- #
import types


class MyNotebook:
    pass


NOTEBOOK_NAME = MyNotebook()
# Put every function defined in the notebook in the class
NOTEBOOK_NAME.__dict__.update(
    {
        name: obj
        for name, obj in locals().items()
        if isinstance(obj, (type, types.FunctionType))
        if not (name.startswith("_") or name == "MyNotebook")
    }
)
    """.replace(
        "NOTEBOOK_NAME", notebook_name
    )

    # Remove empty lines
    code = "\n".join([line for line in code.split("\n") if len(line) > 0])
    # Format code
    code = black.format_str(code, mode=black.FileMode())

    # Write scrach file
    path = os.path.join(
        pwd, "scratch", f"imported_{os.path.basename(path).replace('ipynb', 'py')}"
    )
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    with open(path, "w") as f:
        f.write(code)
    code_dict[path] = code


# Mainify code
for path, code in code_dict.items():
    compiled = compile(code, path, "exec")
    exec(compiled, __main__.__dict__)

# import

## Load the model

We create a function to load the **Mistral-7B** (*Vigostral*) model by applying a **LoRa** configuration to it.

In [3]:
import torch
from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

def load_model(model_name : str):
    # Load model
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto", use_cache=False
        )
    except:
        if torch.cuda.is_available():
            model = AutoModelForCausalLM.from_pretrained(
                model_name, device_map="cuda", use_cache=False
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name, device_map="cpu", use_cache=False
            )

    # Configure model
    config = LoraConfig(
        r=64,  # Plus r est grand, plus le modèle est précis mais plus il est lent
        lora_alpha=16,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
            "lm_head",
        ],
        bias="none",
        lora_dropout=0.05,
        task_type="CAUSAL_LM",
    )
    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, config)
    model.config.use_cache = False
    
    return model

# import

## Cluster Fine-Tuning function

First we're going to create the function that will run on the cluster.

In [4]:
# --------------------------- FINE-TUNING FUNCTION --------------------------- #
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
)
import os
import torch
from datetime import datetime

MODEL = "bofenghuang/vigostral-7b-chat"

def train_mistral(
    pwd: str = ".",
    use_ray: bool = False,
    max_training_epochs: int = 20,
):
    # Fix partial import bug
    import ray.train.huggingface
    import ray.train.huggingface.transformers

    # Load data
    df = a_DatasetCreation.download_data(pwd=pwd)

    # Charger tokenizer
    tokenizer = a_DatasetCreation.download_tokenizer()

    # Create dataset
    dataset = a_DatasetCreation.format_data(df, tokenizer, training=True)

    # Encode dataset
    encoded_dataset = a_DatasetCreation.encode_dataset(dataset, tokenizer)

    # Create train and validation dataset
    split = encoded_dataset.train_test_split(test_size=0.25, shuffle=False)
    train_dataset = split["train"]
    validation_dataset = split["test"]

    # Load model
    model = load_model(MODEL)

    # Create model folder if it doesn't exist
    path = os.path.join(
        pwd,
        "models",
        "text_simplification",
        MODEL.replace("/", "_"),
    )
    if not os.path.exists(path):
        os.makedirs(path)

    # Configure WandB
    os.environ["WANDB_PROJECT"] = "mistral_sentence_simplification"

    # Early stopping
    early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

    # Créer le Trainer
    training_args = TrainingArguments(
        output_dir=path,
        warmup_steps=1,
        num_train_epochs=max_training_epochs,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=1,
        learning_rate=2.5e-5,
        bf16=True,
        optim="paged_adamw_8bit",
        logging_steps=25,
        logging_dir=os.path.join(path, "logs"),
        save_strategy="epoch",
        evaluation_strategy="epoch",
        do_eval=True,
        load_best_model_at_end=True,
        run_name=f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
        ddp_find_unused_parameters=False,  # Necessary for FSDP to work
        report_to="wandb",
    )

    trainer = Trainer(
        model=model,  # le modèle à entraîner
        args=training_args,  # les arguments d'entraînement
        train_dataset=train_dataset,  # le jeu de données d'entraînement
        eval_dataset=validation_dataset,  # le jeu de données d'évaluation
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
        callbacks=[early_stopping],
    )

    if use_ray:
        # Add ray tune callback
        callback = ray.train.huggingface.transformers.RayTrainReportCallback()
        trainer.add_callback(callback)
        trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)

    # Entraîner le modèle
    result = trainer.train()

    # Write checkpoint
    checkpoint_path = os.path.join(path, "mistral_simplification_trained")
    trainer.save_model(checkpoint_path)

    return result

## Cluster Execution function

We now need to create the function that Slurmray executes on the cluster. We need a way to start Ray correctly and retrieve the training results.

In [5]:
# ------------------------ CLUSTER EXECUTION FUNCTION ------------------------ #
from ray.train import ScalingConfig, CheckpointConfig, RunConfig
from ray.train.torch import TorchTrainer


def ray_launcher(config):
    # Get config
    kwargs = config["kwargs"]
    f = config["f"]

    # Run function
    return f(**kwargs)


def slurmray_function(
    pwd: str = "/scratch/hjamet",
    max_training_epochs: int = 20,
):
    # Create ray launcher
    ray_trainer = TorchTrainer(
        ray_launcher,
        scaling_config=ScalingConfig(num_workers=1, use_gpu=True),
        run_config=RunConfig(
            checkpoint_config=CheckpointConfig(num_to_keep=1),
            storage_path=pwd,
        ),
        train_loop_config={
            "f": train_mistral,
            "kwargs": {
                "pwd": pwd,
                "use_ray": True,
                "max_training_epochs": max_training_epochs,
            },
        },
    )

    # Start training
    result = ray_trainer.fit()

## Local Launcher

Now that we've perfectly defined the code to be run on the cluster, we can create a function that will allow us to launch the training. To do this, we're going to use the **Slurmray** module.

In [6]:
# ----------------------------- SLURMRAY LAUNCHER ---------------------------- #
from slurmray.RayLauncher import RayLauncher

launcher = RayLauncher(
    project_name="mistral_sentence_simplification",
    func=slurmray_function,
    args={
        "max_training_epochs": 20,
    },
    modules=["cuda/11.8.0"],
    node_nbr=1,
    use_gpu=True,
    memory=128,
    max_running_time=60 * 2,
    server_run=True,
    server_ssh="curnagl.dcsr.unil.ch",
    server_username="hjamet",
)

## Launch the training

We can now launch the training on the cluster using our previously defined function !

In [8]:
result = launcher()

Serializing function and arguments...
Connecting to the cluster...


INFO:paramiko.transport:Connected (version 2.0, client OpenSSH_8.0)
INFO:paramiko.transport:Authentication (password) successful!
INFO:paramiko.transport.sftp:[chan 0] Opened sftp connection (server version 3)


Writing slurmray server script...
Downloading server...
Running server...
Installing slurmray server
Writing python script...
Writing slurm script...
No serialization done.
Cluster detected, running on cluster...
Canceling old jobs...
Start to submit job!
Job submitted! Script file is at: </users/hjamet/slurmray-server/.slogs/server/sbatch.sh>. Log file is at: </users/hjamet/slurmray-server/.slogs/server/server_1103-13h45.log>
Start to monitor the queue... You can check the queue at: </users/hjamet/slurmray-server/.slogs/server/server_1103-13h45_queue.log>
Submitted batch job 40078576
IP Head: 10.203.101.86:6379
STARTING HEAD at dnagpu006
2024-03-11 13:45:54,216	INFO usage_lib.py:449 -- Usage stats collection is enabled by default without user confirmation because this terminal is detected to be non-interactive. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See ht