In [1]:
!pip install -U sagemaker



In [None]:
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()

In [3]:
region = 'us-east-1'

In [4]:
import boto3

boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.Session(boto_session=boto_session)

In [6]:
s3_bucket = sagemaker_session.default_bucket()
s3_folder = "Mistral-7B-v0.1_SFT"
output_dir = "/checkpoints/sft"
max_seq_len = 2048
train_file = "sft_train.jsonl"
path = "re_sft_alignment"
wandb_project = "Mistral-7B-v0.1_SFT"

In [7]:
secrets = boto3.client("secretsmanager")
hf_token = secrets.get_secret_value(SecretId="hf_token")["SecretString"]
wandb_api_key = secrets.get_secret_value(SecretId="wandb_api_key")["SecretString"]

In [None]:
from sagemaker.pytorch.estimator import PyTorch

estimator = PyTorch(
    entry_point='train_sft.py',
    source_dir=f'{path}',
    role=role,
    instance_count=2,
    instance_type='ml.g5.xlarge',
    output_path=f's3://{s3_bucket}/{s3_folder}/sagemaker_output/',
    checkpoint_s3_uri=f"s3://{s3_bucket}/{s3_folder}/checkpoints",
    checkpoint_local_path=output_dir,
    dependencies=[],
    distribution={
        "torch_distributed": {
            "enabled": True,
            "processes_per_host": 1,     # 1 GPU on each host
        }
    },
    environment={
        "HUGGINGFACE_HUB_TOKEN": hf_token,
        "HF_HUB_DISABLE_TELEMETRY": "1",
        "WANDB_API_KEY": wandb_api_key,
        "HF_HUB_ENABLE_HF_TRANSFER": "0",
        "HF_HOME": "/tmp/.cache/huggingface",
        "HYDRA_FULL_ERROR": "1"
    },
    hyperparameters={
    },
    py_version='py311',
    framework_version='2.3.0',
    # enable_sagemaker_metrics=True,
)

In [None]:
estimator.fit()