# How to fine-tune Mistral models on AWS Sagemaker

This sample notebook explains how to fine-tune and deploy a custom Mistral model on AWS Sagemaker.

In [None]:
import boto3
import sagemaker as sage
import os
import logging
import re
import json

from typing import (
    Optional,
    Dict,
    Any,
    List,
    Union
)
from botocore.exceptions import ClientError
from botocore.response import StreamingBody
from datetime import datetime
from pathlib import Path

# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger("notebook")

In [None]:
# DEFAULT VALUES (change only if/when needed)
DEFAULT_TRAINING_INSTANCE_VOLUME_SIZE_GB = 200
DEFAULT_TRAINING_MAX_RUNTIME_S = 3600


def build_s3_uri(s3_bucket: str, s3_data_dir: str, object_path: str) -> str:
    return "s3://" + str(Path(*[s3_bucket, s3_data_dir, object_path]))


def generate_ts_str() -> str:
    return datetime.now().strftime("%Y%m%d%H%M")


def get_model_name_from_arn(arn: str) -> str:
    pattern = r"(mistral|codestral|pixtral|ministral)-\d+b-\d+"
    match = re.search(pattern, arn)
    if match:
        model_name = match.group(0)
        return model_name
    else:
        raise ValueError(f"Could not extract model name from resource {arn}!")

        
def create_mistral_ft_job(job_params: Dict[str, Any]) -> str:
    logger.info("Starting fine-tuning training job creation...")
    
    # Unpack parameters
    logger.info("Unpacking job_params...")
    try:
        execution_role_arn = job_params["execution_role_arn"]
        algorithm_arn = job_params["algorithm_arn"]
        training_dataset = job_params["training_dataset"]
        validation_dataset = job_params["validation_dataset"]
        config_file = job_params["config_file"]
        lora_dir = job_params["lora_dir"]
        instance_type = job_params["instance_type"]
    except KeyError as e:
        raise KeyError(f"Missing/incorrect parameter: {e}")
    
    ft_job_name = f"{get_model_name_from_arn(algorithm_arn)}-ft-{generate_ts_str()}-training-job"
    logger.info(f"Starting preparation for training job {ft_job_name}...")
    
    # Input data configuration
    input_data_config: List[Dict[str, Any]] = []
    channels = [
        {
            "name": "training_data",
            "s3_uri": training_dataset
        },
        {
            "name": "validation_data",
            "s3_uri": validation_dataset
        },
        {
            "name": "configuration",
            "s3_uri": config_file
        }
    ]
    for channel in channels:
        input_data_config.append({
            "ChannelName": channel["name"],
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": channel["s3_uri"],
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "CompressionType": "None",
            "RecordWrapperType": "None"
        })
    
    # Output data configuration
    output_data_config = {
        "S3OutputPath": lora_dir,
        "CompressionType": "GZIP"
    }
    resource_config = {
        "InstanceCount": 1,
        "InstanceType": instance_type,
        "VolumeSizeInGB": DEFAULT_TRAINING_INSTANCE_VOLUME_SIZE_GB
    }
    stopping_condition = {"MaxRuntimeInSeconds": DEFAULT_TRAINING_MAX_RUNTIME_S}

    # API call for training job startup
    logger.info(f"Launching training job {ft_job_name}...")
    try:
        client = boto3.client("sagemaker")
        ft_job = client.create_training_job(
            TrainingJobName=ft_job_name,
            RoleArn=execution_role_arn,
            AlgorithmSpecification={"AlgorithmName": algorithm_arn, "TrainingInputMode": "File"},
            InputDataConfig=input_data_config,
            OutputDataConfig=output_data_config,
            ResourceConfig=resource_config,
            StoppingCondition=stopping_condition
        )
        ft_job_arn = ft_job["TrainingJobArn"]
        logger.info(f"Training job succesfully submitted (ARN: {ft_job_arn})")
        return ft_job_arn
    except (
        client.exceptions.ResourceInUse,
        client.exceptions.ResourceLimitExceeded,
        client.exceptions.ResourceNotFound
    ) as e:
        logger.error(f"Sagemaker error: {e}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        raise

    
def get_mistral_ft_job_info(ft_job_arn) -> Dict[str, Any]:
    logger.info(f"Fetching information for training job {ft_job_arn}...")
    client = boto3.client("sagemaker")
    ft_job_name = ft_job_arn.split('/')[-1]
    ft_job_info_raw = client.describe_training_job(TrainingJobName=ft_job_name)
    status = ft_job_info_raw["TrainingJobStatus"]
    events = [event for event in ft_job_info_raw["SecondaryStatusTransitions"]]
    if status == "Completed":
        model_artifact_uri = ft_job_info_raw["ModelArtifacts"]["S3ModelArtifacts"]
    else:
        model_artifact_uri = None
    algorithm_arn = ft_job_info_raw["AlgorithmSpecification"]["AlgorithmName"]
    ft_job_info = {
        "job_name": ft_job_name,
        "status": status,
        "events": events,
        "model_artifact_uri": model_artifact_uri,
        "algorithm_arn": algorithm_arn
    }
    return ft_job_info


def create_mistral_ft_model_package(ft_job_info: Dict[str, Any]) -> str:
    logger.info("Starting model package creation...")
    package_name = ft_job_info["job_name"].replace("training-job", "model-package")
    client = boto3.client("sagemaker")
    try:   
        response = client.create_model_package(
            ModelPackageName = package_name,
            SourceAlgorithmSpecification={
                "SourceAlgorithms": [
                    {
                        "AlgorithmName": ft_job_info["algorithm_arn"],
                        "ModelDataUrl": ft_job_info["model_artifact_uri"]

                    }
                ]
            }
        )
        model_package_arn = response["ModelPackageArn"]
        logger.info(f"Model package creation successfully submitted (ARN: {model_package_arn})")
        return model_package_arn
    except (
        client.exceptions.ConflictException,
        client.exceptions.ResourceLimitExceeded
    ) as e:
        logger.error(f"Sagemaker error: {e}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        raise
        
    
def create_mistral_ft_endpoint(model_package_arn: str, instance_type: str) -> str:
    logger.info(f"Starting model endpoint preparation from package {model_package_arn}...")
    client = boto3.client("sagemaker")
    model_package_status = client \
        .describe_model_package(ModelPackageName=model_package_arn)\
        .get("ModelPackageStatus")
    if model_package_status != "Completed":
        raise ValueError(f"Model package {model_package_arn} is not ready (status={model_package_status})")
    else:
        # Create model
        logger.info("Creating model...")
        try:
            model_name = model_package_arn.split('/')[-1].replace("model-package", "model")
            model_resp = client.create_model(
                ModelName=model_name,
                ExecutionRoleArn=sage.get_execution_role(),
                PrimaryContainer={"ModelPackageName": model_package_arn}
            )
            logger.info(f"Model successfully created (model_name: {model_name})")
        except client.exceptions.ResourceLimitExceeded as e:
            logger.error(f"Sagemaker error: {e}")
            raise
        except Exception as e:
            logger.error(f"Unknown error: {e}")
            raise
        # Create endpoint configuration
        logger.info("Creating endpoint configuration...")
        endpoint_config_name = f"{model_name}-config"
        endpoint_config = {
            "EndpointConfigName": endpoint_config_name,
            "ProductionVariants": [
                {
                    "VariantName": "model",
                    "ModelName": model_name,
                    "InstanceType": instance_type,
                    "InitialInstanceCount": 1,
                    "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
                    "ModelDataDownloadTimeoutInSeconds": 3600
                }
            ]
        }
        try:
            endpoint_config_resp = client.create_endpoint_config(**endpoint_config)
            endpoint_config_arn = endpoint_config_resp["EndpointConfigArn"]
            logger.info(f"Endpoint configuration successfully created (ARN: {endpoint_config_arn})")
        except client.exceptions.ResourceLimitExceeded as e:
            logger.error(f"Sagemaker error: {e}")
            raise
        except Exception as e:
            logger.error(f"Unknown error: {e}")
            raise
        # Create endpoint
        logger.info("Creating endpoint...")
        try:
            endpoint_resp = client.create_endpoint(
                EndpointName=model_name,
                EndpointConfigName=endpoint_config_name
            )
            endpoint_arn = endpoint_resp["EndpointArn"]
            logger.info(f"Endpoint successfully created (ARN: f{endpoint_arn})")
            return endpoint_arn
        except client.exceptions.ResourceLimitExceeded as e:
            logger.error(f"Sagemaker error: {e}")
        except Exception as e:
            logger.error(f"Unknown error: {e}")
            
            
def query_mistral_ft_endpoint(endpoint_arn: str, payload: Dict[str, Any]) -> Dict[str, Any]:
    runtime_client = boto3.client("sagemaker-runtime")
    client = boto3.client("sagemaker")

    logger.info(f"Checking endpoint status (ARN: {endpoint_arn})")
    endpoint_name = endpoint_arn.split('/')[-1]
    endpoint_status = client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
    if endpoint_status == "InService":
        logger.info("Endpoint is running. Sending inference request...")
        try:
            inference_out = runtime_client.invoke_endpoint(
                EndpointName=endpoint_name,
                ContentType="application/json",
                Body=payload
            )
            #resp_bytes = inference_out["Body"]
            #return resp_bytes
            inference_resp_str = inference_out["Body"].read().decode("utf-8")
            return json.loads(inference_resp_str)
        except (
            runtime_client.exceptions.InternalFailure,
            runtime_client.exceptions.ServiceUnavailable,
            runtime_client.exceptions.ValidationError,
            runtime_client.exceptions.ModelError,
            runtime_client.exceptions.InternalDependencyException,
            runtime_client.exceptions.ModelNotReadyException
        ) as e:
            logger.error(f"Sagemaker Runtime error: {e}")
            raise
        except Exception as e:
            logger.error(f"Unknown error: {e}")
            raise
    else:
        raise ValueError(f"Endpoint {endpoint_name} is not ready to be queried (status={endpoint_status})")

---
## User-defined values

Edit the values below to match your own development environment:


In [None]:
USER_S3_BUCKET = "mistral-development"
USER_S3_DATA_DIR = "harizo-tests/"
USER_SAGEMAKER_ALGORITHM_ARN = "arn:aws:sagemaker:us-west-2:777356365391:algorithm/mistral-finetune-for-ministral-3b-2410-1733402468"
USER_SAGEMAKER_TRAINING_INSTANCE_TYPE = "ml.p5.48xlarge"
USER_SAGEMAKER_INFERENCE_INSTANCE_TYPE = "ml.p4d.24xlarge"

## Pre-requisites

> **IMPORTANT**: This notebook must run on a SageMaker Notebook instance to directly leverage the underlying execution role for performing actions and reading data on your AWS tenant.

### Data

The SageMaker fine-tuning capabilities rely on the AWS S3 service to host the input data and output artifact files.

The rest of this notebook will assume that your data resides in a S3 bucket (`USER_S3_BUCKET`) within a specified directory (`USER_S3_DATA_DIR`) following this layout:
```
.
|_$USER_S3_DATA_DIR
    |_config.yml
    |_training.jsonl
    |_validation.jsonl
    |_lora_adapters/
```

A brief description of these items is provided below.

#### Training Data & Configuration Files

To fine-tune a Mistral model, the key ingredient you need is training data, split in two distinct files:

- `training.jsonl`: Contains the records based on which the model will update its weights.
- `validation.jsonl`: Contains a smaller set of records used to assess the model's performance at a given time during training.

In both files, each record follows the same structure as described in the [Mistral documentation](https://docs.mistral.ai/capabilities/finetuning/). Examples of training data files are available in the `data/` directory.

The `config.yml` configuration file contains additional settings specific to the Mistral fine-tuning codebase.

#### LoRA Adapters Directory

The result of a fine-tuning algorithm run is a small artifact called a _LoRA adapter_. When you deploy your fine-tuned model, you first create a _model package_. During this process, the LoRA adapter is merged into the base model to create a new variant that can then be deployed as a regular endpoint. When running a training job, you need to specify a target directory where adapters will be stored. In your case, this directory will be `lora_adapters/`. 


### IAM permissions

To run this example end-to-end, the AWS IAM role you will use requires the following permissions:

- All permissions from the `AmazonSageMakerFullAccess` role
- Authority to make AWS Marketplace subscriptions in the AWS account used:
  - `aws-marketplace:ViewSubscriptions`
  - `aws-marketplace:Unsubscribe`
  - `aws-marketplace:Subscribe`

## Create and start a fine-tuning job 

The following function creates a training job to fine-tune the base model from `USER_SAGEMAKER_ALGORITHM_ARN` using the `training.jsonl` and `validation.jsonl` datasets:

In [None]:
job_params = {
    "execution_role_arn": sage.get_execution_role(),
    "algorithm_arn": USER_SAGEMAKER_ALGORITHM_ARN,
    "training_dataset": build_s3_uri(USER_S3_BUCKET, USER_S3_DATA_DIR, "training.jsonl"),
    "validation_dataset": build_s3_uri(USER_S3_BUCKET, USER_S3_DATA_DIR, "validation.jsonl"),
    "config_file": build_s3_uri(USER_S3_BUCKET, USER_S3_DATA_DIR, "config.yaml"),
    "lora_dir": build_s3_uri(USER_S3_BUCKET, USER_S3_DATA_DIR, "lora_adapters/"),
    "instance_type": USER_SAGEMAKER_TRAINING_INSTANCE_TYPE
}

ft_job_arn = create_mistral_ft_job(job_params=job_params)

Once the training job has started, you can get its status and summary information with this function, which can directly reuse the output of `create_mistral_ft_job()`:

In [None]:
ft_job_info = get_mistral_ft_job_info(ft_job_arn)
print((ft_job_info))

## Deploy your fine-tuned model endpoint

Once your fine-tuning job has successfully completed, you will need to complete a few steps before using the resulting model.

### Create a model package

First, you need to make the model deployable by creating a _model package resource_. LoRA adapters cannot be directly deployed, instead they have to be recombined with the base model to create an operational model variant.
The following operation will execute these steps and trigger the creation of a model package resource:

In [None]:
model_package = create_mistral_ft_model_package(ft_job_info=ft_job_info)

### Deploy the endpoint

Next, you need to provision the resources required to host your endpoint, then deploy the endpoint. This is a 3-step process where you create distinct resources:

1. A _model_ which contains the model artifacts and the inference code
2. An _endpoint configuration_ where you specify the configuration details for deploying your model (e.g. the instance type
3. A _model endpoint_ which is the actual instance of the model that serves inference requests.

The `create_mistral_ft_endpoint()` function combines all these steps into a single entrypoint:

In [None]:
endpoint_arn = create_mistral_ft_endpoint(model_package_arn=model_package,
                           instance_type=USER_SAGEMAKER_INFERENCE_INSTANCE_TYPE)

### Perform inference

Once your endpoint is up and running, you can test it by running the following code:

In [None]:
endpoint_arn = 'arn:aws:sagemaker:us-west-2:777356365391:endpoint/ministral-3b-2410-ft-202412271035-model'
messages = [
    {
      "role": "user",
      "content": "Who is the best French painter? Answer in one short sentence."
    },
]

payload = json.dumps({
    "model": "ft:model",
    "messages": messages,
    "temperature": 0,
    "stream": False
})

inference_resp = query_mistral_ft_endpoint(endpoint_arn=endpoint_arn,
                                           payload=payload)

print(inference_resp["choices"][0]["message"]["content"])

Congratulations, you have completed an end-to-end training and deployment of a fine-tuned Mistral model on AWS Sagemaker!

## Clean-up

Once you have finished your inference tests, you should delete the deployed endpoint to avoid excessive infrastructure charges:

In [None]:
# TODO Delete endpoint

client = boto3.client("sagemaker")

endpoint_name = endpoint_arn.split('/')[-1]
client.delete_endpoint(EndpointName=endpoint_name)

You can also optionally delete the other resources created when executing this notebook:
- the endpoint configuration,
- the model,
- the model package,
- the model artifact in your S3 bucket.

## Going further

Here are some useful links to better understand some of the topics introduced in this notebook:

- [boto3 reference for Sagemaker](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html)
- [boto3 reference for Sagemaker Runtime (inference)](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime.html)
- [Mistral AI reference documentation for the chat completion API](https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post)