diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_13b.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_13b.json
new file mode 100644
index 00000000..aa036d74
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_13b.json
@@ -0,0 +1,35 @@
+{
+ "fp16": {
+ "enabled": "auto"
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": 5e8,
+ "stage3_prefetch_bucket_size": 5e8,
+ "stage3_param_persistence_threshold": 1e6,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true,
+ "round_robin_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 10,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_70b.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_70b.json
new file mode 100644
index 00000000..23c70b4f
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_70b.json
@@ -0,0 +1,28 @@
+{
+ "fp16": {
+ "enabled": false
+ },
+ "bf16": {
+ "enabled": true
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": false
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "gather_16bit_weights_on_model_save": true,
+ "round_robin_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 10,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_7b.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_7b.json
new file mode 100644
index 00000000..f1ddac17
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_7b.json
@@ -0,0 +1,35 @@
+{
+ "fp16": {
+ "enabled": "auto"
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": 5e8,
+ "stage3_prefetch_bucket_size": 5e8,
+ "stage3_param_persistence_threshold": 1e6,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true,
+ "round_robin_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 10,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_offload_optim_param.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_offload_optim_param.json
new file mode 100644
index 00000000..9130e09f
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_offload_optim_param.json
@@ -0,0 +1,32 @@
+{
+ "fp16": {
+ "enabled": "auto"
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": false
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": false
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "gather_16bit_weights_on_model_save": true,
+ "round_robin_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 10,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
diff --git a/llm-ai-systems/fine-tuning/configs/lora_configs/lora.json b/llm-ai-systems/fine-tuning/configs/lora_configs/lora.json
new file mode 100644
index 00000000..b953a4c9
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/configs/lora_configs/lora.json
@@ -0,0 +1,11 @@
+{
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": [],
+ "bias": "none",
+ "fan_in_fan_out": false,
+ "init_lora_weights": true
+}
\ No newline at end of file
diff --git a/llm-ai-systems/fine-tuning/llama_fine_tune_runtime_env.yaml b/llm-ai-systems/fine-tuning/llama_fine_tune_runtime_env.yaml
new file mode 100644
index 00000000..1c3a96a0
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/llama_fine_tune_runtime_env.yaml
@@ -0,0 +1,12 @@
+pip:
+ - transformers==4.44.0
+ - accelerate==0.31.0
+ - peft==0.11.1
+ - deepspeed==0.16.2
+env_vars:
+ LIBRARY_PATH: "$CUDA_HOME/lib64:$LIBRARY_PATH"
+ PROJECT_DIR: "/home/yarnapp/hopsfs"
+ TRAINED_MODEL_STORAGE_PATH: "${PROJECT_DIR}/Resources/llama_finetuning/fine-tuned-model"
+ TRAINING_DATA_DIR: "${PROJECT_DIR}/Resources/llama_finetuning/datasets"
+ TRAINING_CONFIGURATION_DIR: "${PROJECT_DIR}/Resources/llama_finetuning/configs"
+
\ No newline at end of file
diff --git a/llm-ai-systems/fine-tuning/llama_fine_tuning_with_ray.ipynb b/llm-ai-systems/fine-tuning/llama_fine_tuning_with_ray.ipynb
new file mode 100644
index 00000000..04be6584
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/llama_fine_tuning_with_ray.ipynb
@@ -0,0 +1,576 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "4392c184-cc3b-4ded-8d2e-31c12c49ecf6",
+ "metadata": {},
+ "source": [
+ "## Fine-tune Llama 3.1 (8B parameter) using Ray Framework on Hopsworks\n",
+ "This tutorial demonstrates how to perform fine-tuning (with LoRA and deepspeed) of a Llama 3.1 (8B) using the Ray framework on Hopsworks. Ray is an industry-leading distributed computing framework. This tutorial was run on OVH cluster but you can use any cloud provider of your choice."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bce9d0a6-3cda-4882-82e6-d3eb6b1f5a79",
+ "metadata": {},
+ "source": [
+ "### Pre-requisites\n",
+ "To perform the steps in this tutorial, you need to create a Hopsworks Kubernetes cluster with Ray enabled. For the fine-tuning task demonstrated in this example, these are the minimum resources required:\n",
+ "* 1 x B3-64 (16 CPU 64 GB RAM) for the Ray head\n",
+ "* 8 x T2-LE-90 (30 CPU, 90 GB RAM, 2x 32 GRAM Tesla V100S) for the workers\n",
+ "Let's get started!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f400ddf9-5b27-4786-b165-0dea0e109618",
+ "metadata": {},
+ "source": [
+ "### Step 1: Dataset preparation\n",
+ "We are going to fine-tune the model for question answering. We need to prepare the dataset that will be used for supervised fine-tuning in a certain format. There is no specific prompt format required for the pre-trained Llama 3.1 so the dataset preprocessing can follow any prompt-completion style. The instruction-tuned models (Meta-Llama-3.1-{8,70,405}B-Instruct) use a multi-turn conversation prompt format that structures the conversation between the users and the models.\n",
+ "\n",
+ "The dataset for QA typically includes the following fields:\n",
+ "\n",
+ "* Question: The input question to the model.\n",
+ "* Context (optional): A passage or text providing information the model should use to answer.\n",
+ "* Answer: The correct response.\n",
+ "\n",
+ "This example is configured to fine-tune the Llama 3.1 8B pre-trained model on the GSM8K dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "adf6a503",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "import tempfile\n",
+ "import os\n",
+ "import json\n",
+ "import shutil"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "4de50ea7-9c67-4ebe-8fea-0956d50312bb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "llama_dir = \"Resources/llama_finetuning\"\n",
+ "HOPSFS_STORAGE_PATH = os.path.join(os.environ.get(\"PROJECT_PATH\"), llama_dir)\n",
+ "if not os.path.exists(HOPSFS_STORAGE_PATH):\n",
+ " os.mkdir(HOPSFS_STORAGE_PATH)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "3baa764d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "<_io.TextIOWrapper name='/hopsfs/Resources/llama_finetuning/datasets/tokens.json' mode='w' encoding='UTF-8'>\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = load_dataset(\"openai/gsm8k\", \"main\")\n",
+ "dataset_splits = {\"train\": dataset[\"train\"], \"test\": dataset[\"test\"]}\n",
+ "dataset_dir = os.path.join(HOPSFS_STORAGE_PATH, \"datasets\")\n",
+ "if not os.path.exists(dataset_dir):\n",
+ " os.mkdir(dataset_dir)\n",
+ " \n",
+ "with open(os.path.join(dataset_dir, \"tokens.json\"), \"w\") as f:\n",
+ " tokens = {}\n",
+ " print(f)\n",
+ " tokens[\"tokens\"] = [\"\", \"\", \"\", \"\"]\n",
+ " f.write(json.dumps(tokens))\n",
+ " for key, ds in dataset_splits.items():\n",
+ " with open(os.path.join(dataset_dir, f\"{key}.jsonl\"), \"w\") as f:\n",
+ " for item in ds:\n",
+ " newitem = {}\n",
+ " newitem[\"input\"] = (\n",
+ " f\"{item['question']}\"\n",
+ " f\"{item['answer']}\"\n",
+ " )\n",
+ " f.write(json.dumps(newitem) + \"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0686f075-a33e-4b1e-b20e-1b05dbede50a",
+ "metadata": {},
+ "source": [
+ "### Step 2: Download the pre-trained model\n",
+ "The next step is to download the pre-trained Llama model from hugging face. For this you will need the hugging face token."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "cb3b68c6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers.utils.hub import TRANSFORMERS_CACHE\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "b5effab1-04b8-49fc-88e1-a02df4e89456",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "os.environ[\"HF_TOKEN\"] = \"\"\n",
+ "model_id = \"meta-llama/Meta-Llama-3.1-8B-Instruct\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "b278bb45-c87d-440e-8649-dcf68621997d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "16c9b79aa6d44e00ae339658b7193ee3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer_config.json: 0%| | 0.00/55.4k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bb85d1b4eace4654b81fb61cc327bec7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer.json: 0%| | 0.00/9.09M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "133f3ba47fad49dab555132d2c44e54b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "special_tokens_map.json: 0%| | 0.00/296 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b7b6538b78f34431a7df398f66344199",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/855 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f60803ab38f0406eb863c3d2019d1bf9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors.index.json: 0%| | 0.00/23.9k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6c97c887d00e4668b4436691637dcfef",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "823151aa537041e1a62a00abe61290f6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a342280ddd724ddca4083b5649122ca7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a312753604f14dc6ac06fd34df389474",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7ac77dfd4b4b4b78867d8db30c579155",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "05249250193b47ebad345baf71529e2b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9f4dd545de0f412a95e6205efef0e7d7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/184 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# download the pre-trained model from Hugging face\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_id)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "4ec0badb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['special_tokens_map.json',\n",
+ " 'model-00002-of-00004.safetensors',\n",
+ " 'model-00001-of-00004.safetensors',\n",
+ " 'model-00003-of-00004.safetensors',\n",
+ " 'config.json',\n",
+ " 'tokenizer.json',\n",
+ " 'model-00004-of-00004.safetensors',\n",
+ " 'tokenizer_config.json',\n",
+ " 'generation_config.json',\n",
+ " 'model.safetensors.index.json']"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "local_model_dir = os.path.join(TRANSFORMERS_CACHE, f\"models--{model_id.replace('/', '--')}\")\n",
+ "snapshots_dir = os.path.join(local_model_dir, \"snapshots\")\n",
+ "blobs_dir = os.path.join(snapshots_dir, next(d for d in os.listdir(snapshots_dir) if os.path.isdir(os.path.join(snapshots_dir, d))))\n",
+ "os.listdir(blobs_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "c036501d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hopsfs_model_dir = os.path.join(HOPSFS_STORAGE_PATH, f\"models--{model_id.replace('/', '--')}\")\n",
+ "if not os.path.exists(hopsfs_model_dir):\n",
+ " os.mkdir(hopsfs_model_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "b4bf6ab1",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "AssertionError",
+ "evalue": "Failed to copy pre-trained model files to hopsfs",
+ "output_type": "error",
+ "traceback": [
+ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
+ "\u001B[0;31mAssertionError\u001B[0m Traceback (most recent call last)",
+ "Cell \u001B[0;32mIn[26], line 4\u001B[0m\n\u001B[1;32m 2\u001B[0m cp_cmd \u001B[38;5;241m=\u001B[39m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcp -L -r \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mblobs_dir\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m/* \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mhopsfs_model_dir\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 3\u001B[0m result \u001B[38;5;241m=\u001B[39m os\u001B[38;5;241m.\u001B[39msystem(cp_cmd)\n\u001B[0;32m----> 4\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m result \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m0\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mFailed to copy pre-trained model files to hopsfs\u001B[39m\u001B[38;5;124m\"\u001B[39m\n",
+ "\u001B[0;31mAssertionError\u001B[0m: Failed to copy pre-trained model files to hopsfs"
+ ]
+ }
+ ],
+ "source": [
+ "# copy the downloaded model to hopsfs\n",
+ "cp_cmd = f\"cp -L -r {blobs_dir}/* {hopsfs_model_dir}\"\n",
+ "result = os.system(cp_cmd)\n",
+ "assert result != 0, \"Failed to copy pre-trained model files to hopsfs\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1492db64-88a7-431b-9bee-217aacb5916b",
+ "metadata": {},
+ "source": [
+ "### Step 3: Create the ray job for the fine-tuning task\n",
+ "We are going to use the hopsworks jobs api to create and run the job for the fine-tuning task"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "60e48c95-f236-4ebe-be27-f6f585813843",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Connection closed.\n",
+ "2025-01-09 07:01:22,956 INFO: Python Engine initialized.\n",
+ "\n",
+ "Logged in to project, explore it here https://hopsworks.ai.local/p/119\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bff06133e9884523a6cbdfbf1b46070e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Uploading /hopsfs/Jupyter/ray_llm_finetuning.py: 0.000%| | 0/28956 elapsed<00:00 remaining"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "10493ca0e0b44d38bfa2dcfabf056880",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Uploading /hopsfs/Jupyter/llama_fine_tune_runtime_env.yaml: 0.000%| | 0/341 elapsed<00:00 remaining"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import hopsworks\n",
+ "\n",
+ "project = hopsworks.login()\n",
+ "\n",
+ "dataset_api = project.get_dataset_api()\n",
+ "\n",
+ "app_file_path = dataset_api.upload(\"ray_llm_finetuning.py\", llama_dir, overwrite=True)\n",
+ "environment_config_yaml_path = dataset_api.upload(\"llama_fine_tune_runtime_env.yaml\", llama_dir, overwrite=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "865ad7ec-0822-4893-ab85-e78aab0ee372",
+ "metadata": {},
+ "source": [
+ "### About the runtime environment file\n",
+ "The runtime environment file contains the dependencies required for the Ray job including files, packages, environment variables, and more. This is useful when you need to install specific packages and set environment variables for this particular Ray job. It should be provided as a YAML file. In this example, the runtime environment file has the following configuration.\n",
+ "```\n",
+ "pip:\n",
+ " - transformers==4.44.0\n",
+ " - accelerate==0.31.0\n",
+ " - peft==0.11.1\n",
+ " - deepspeed==0.16.2\n",
+ "env_vars:\n",
+ " LIBRARY_PATH: \"$CUDA_HOME/lib64:$LIBRARY_PATH\"\n",
+ " PROJECT_DIR: \"/home/yarnapp/hopsfs\"\n",
+ " TRAINED_MODEL_STORAGE_PATH: \"${PROJECT_DIR}/Resources/llama_finetuning/fine-tuned-model\" # Where the fine-tuned model will be saved\n",
+ " TRAINING_DATA_DIR: \"${PROJECT_DIR}/Resources/llama_finetuning/datasets\" # dataset location\n",
+ " TRAINING_CONFIGURATION_DIR: \"${PROJECT_DIR}/Resources/llama_finetuning/configs\" # location for deepspeed and lora configuration files\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "30caad87-ba34-430b-8400-3cfdb6c21c70",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Job created successfully, explore it at https://hopsworks.ai.local/p/119/jobs/named/ray_llama_finetuning\n"
+ ]
+ }
+ ],
+ "source": [
+ "jobs_api = project.get_jobs_api()\n",
+ "\n",
+ "ray_config = jobs_api.get_configuration(\"RAY\")\n",
+ "pretrained_path = \"/home/yarnapp\" + hopsfs_model_dir\n",
+ "ray_config['appPath'] = os.path.join('/Projects/'+project.name, app_file_path)\n",
+ "ray_config['environmentName'] = \"ray-torch-training-pipeline\"\n",
+ "ray_config['driverCores'] = 8\n",
+ "ray_config['driverMemory'] = 34816\n",
+ "ray_config['workerCores'] = 28\n",
+ "ray_config['workerMemory'] = 34816\n",
+ "ray_config['minWorkers'] = 8\n",
+ "ray_config['maxWorkers'] = 8\n",
+ "ray_config['workerGpus'] = 2\n",
+ "ray_config['runtimeEnvironment'] = os.path.join('/Projects/'+project.name, environment_config_yaml_path)\n",
+ "ray_config['defaultArgs'] = f\"--model-name models-meta-llama-Meta-Llama-3.1-8B-Instruct --mx fp16 --lora --num-devices=16 --num-epochs=1 --lr=5e-4 --batch-size-per-device=16 --eval-batch-size-per-device=16 --pre-trained-path {pretrained_path}\"\n",
+ "\n",
+ "job = jobs_api.create_job(job_name, ray_config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf7f58da-c240-45d4-ab23-2ecb89152fcc",
+ "metadata": {},
+ "source": [
+ "## Step 4: Run the job"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "e4d0cfa5-afe5-4682-a2b5-dad97a7ef280",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "finetuning_job = jobs_api.get_job(job_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5840d9f2-32ae-41b0-a829-ecd72247ddc0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "finetuning_job.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9fc8c427-2f2c-4e8e-b36e-83c213a77669",
+ "metadata": {},
+ "source": [
+ "After the job is run you can go to the hopsworks UI to monitor the job execution. From executions page, you can open the Ray dashboard. In the Ray Dashboard, you can monitor the resources used by the job, the number of workers, logs, and the tasks that are running. \n",
+ "\n",
+ "After the job finishes running successfully, the fine-tuned model will be saved in the directory specified in the TRAINED_MODEL_STORAGE_PATH variable defined in the "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cd5c7a6d-8fdd-4a3f-bc45-d4175d54a895",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/llm-ai-systems/fine-tuning/ray_llm_finetuning.py b/llm-ai-systems/fine-tuning/ray_llm_finetuning.py
new file mode 100644
index 00000000..d977ee11
--- /dev/null
+++ b/llm-ai-systems/fine-tuning/ray_llm_finetuning.py
@@ -0,0 +1,775 @@
+import argparse
+import functools
+import json
+import math
+import os
+import shutil
+import string
+import tempfile
+import time
+import random
+
+import tree
+from typing import Tuple
+import urllib
+from urllib.parse import urljoin
+
+try:
+ import deepspeed # noqa: F401
+except ImportError as e:
+ raise RuntimeError(
+ "Please install deepspeed with `pip install --user deepspeed`."
+ ) from e
+
+from accelerate import Accelerator, DeepSpeedPlugin
+from accelerate.utils import DummyOptim, DummyScheduler, set_seed
+import torch
+import torch.nn as nn
+import tqdm
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ get_linear_schedule_with_warmup,
+)
+
+from peft import LoraConfig, get_peft_model, PeftModel
+import ray
+from ray import train
+import ray.util.scheduling_strategies
+from ray.train.torch import TorchTrainer
+from ray.train import Checkpoint
+
+urllib.parse.uses_relative.append("s3")
+urllib.parse.uses_netloc.append("s3")
+
+OPTIM_BETAS = (0.9, 0.999)
+OPTIM_EPS = 1e-8
+NUM_WARMUP_STEPS = 10
+OPTIM_WEIGHT_DECAY = 0.0
+ATTENTION_LAYER_NAME = "self_attn"
+
+TRAINING_DATA_DIR = os.environ.get("TRAINING_DATA_DIR")
+if TRAINING_DATA_DIR is None:
+ TRAINING_DATA_DIR = ""
+TRAINING_CONFIGURATION_DIR = os.environ.get("TRAINING_CONFIGURATION_DIR")
+if TRAINING_CONFIGURATION_DIR is None:
+ TRAINING_CONFIGURATION_DIR = ""
+
+
+def generate_random_dir_name(length=16):
+ letters = string.ascii_letters + string.digits
+ return ''.join(random.choice(letters) for i in range(length))
+
+
+def get_expected_lora_num_parameters(
+ model, lora_config: LoraConfig, attn_layer_name: str = ATTENTION_LAYER_NAME
+):
+ """Calculate the expected number of parameters for lora fine-tuning."""
+ sum_params = 0
+ num_attention_layers = 0
+ modules = model.named_modules()
+ loraified_modules = 0
+ # We calculate the number of parameters we need for lora fine-tuning by calculating
+ # the sizes of the decomposed weight matrices according to the paper.
+ for full_name, target in modules:
+ layer_name = full_name.split(".")[-1]
+
+ if layer_name == attn_layer_name:
+ # Detected another attention layer (for example, llama 2 70b should have 80
+ # of these)
+ num_attention_layers += 1
+ elif layer_name in lora_config.modules_to_save:
+ # Detect another non-lora module to save, which will also contribute to the
+ # number of check-pointed parameters. This will result in one set of
+ # trainable parameters ".original_module.weight" and another one with
+ # ".modules_to_save.default.weight"
+ # Therefore, each layer contributes 2 x the number of actual elements in
+ # that layer.
+ sum_params += 2 * target.weight.numel()
+ print(
+ "Found non-lora-layer to checkpoint: ",
+ layer_name,
+ " with num params ",
+ target.weight.numel(),
+ )
+ else:
+ for module_name in lora_config.target_modules:
+ if layer_name == module_name:
+ loraified_modules += 1
+ if isinstance(target, nn.Linear):
+ # Target is attention weight
+ sum_params += (
+ target.in_features + target.out_features
+ ) * lora_config.r
+ elif isinstance(target, nn.Embedding):
+ # Target is linear weight
+ sum_params += (
+ target.embedding_dim + target.num_embeddings
+ ) * lora_config.r
+
+ print(
+ f"Detected {num_attention_layers} attention layers, containing"
+ f" {loraified_modules} modules to modify according to LoRA's `target_modules`."
+ f" This should yield {sum_params} trainable parameters."
+ )
+
+ return sum_params
+
+
+def get_number_of_params(model: nn.Module):
+ sum = 0
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ sum += param.numel()
+ return sum
+
+
+def collate_fn(batch, tokenizer, block_size, device):
+ out_batch = tokenizer(
+ list(batch["input"]),
+ padding="max_length",
+ max_length=block_size,
+ truncation=True,
+ return_tensors="pt",
+ )
+ out_batch["labels"] = out_batch["input_ids"].clone()
+
+ out_batch = tree.map_structure(lambda x: x.to(device), out_batch)
+
+ return out_batch
+
+
+def get_tokenizer(pretrained_path, special_tokens):
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.add_tokens(special_tokens, special_tokens=True)
+
+ return tokenizer
+
+
+def evaluate(
+ *, model, eval_ds, accelerator, bsize, ds_kwargs, as_test: bool = False
+) -> Tuple[float, float]:
+ model.eval()
+ losses = []
+
+ eval_dataloader = eval_ds.iter_torch_batches(batch_size=bsize, **ds_kwargs)
+ eval_ds_len = len(list(eval_ds.iter_batches(batch_size=1)))
+ for step, batch in tqdm.tqdm(
+ enumerate(eval_dataloader), total=eval_ds_len // (bsize + 1)
+ ):
+ with torch.no_grad():
+ outputs = model(**batch)
+
+ loss = outputs.loss
+ # The tensors are gathered by concatenating them on the first dimension, so we
+ # add a new dimension to the scalar loss to get a tensor of shape (K,) for K
+ # workers.
+ losses.append(accelerator.gather(loss[None]))
+
+ if as_test:
+ break
+
+ # We stack losses so that we have a tensor of shape (T, K) where T is the number of
+ # steps and K is the number of workers.
+ losses = torch.stack(losses)
+ try:
+ eval_loss = torch.mean(losses).item()
+ perplexity = math.exp(eval_loss)
+ except OverflowError:
+ perplexity = float("inf")
+ return perplexity, eval_loss
+
+
+def copy_model_to_hopsfs(local_path, hopsfs_path):
+ if not os.path.exists(hopsfs_path):
+ os.makedirs(os.path.dirname(hopsfs_path), exist_ok=True)
+ shutil.copytree(local_path, hopsfs_path, dirs_exist_ok=True)
+
+def copy_model_to_local(pretrained_path, local_path):
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
+ shutil.copytree(pretrained_path, local_path)
+
+
+def _test_tokenizer(pretrained_path):
+ # This function tests that adding special tokens does not
+ # result in un-expected tokenization
+ # Context: https://github.com/huggingface/transformers/issues/25176
+ tokenizer = get_tokenizer(pretrained_path=pretrained_path, special_tokens=[""])
+ testoutput = tokenizer("inform")["input_ids"]
+ expected = tokenizer("inform")["input_ids"]
+ assert testoutput[-1] == expected[-1], (
+ "The tokenizer is not working as expected with special tokens, "
+ f"testoutput={testoutput}, expected={expected}"
+ )
+
+
+def checkpoint_model(
+ checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs
+):
+ """Utility function for checkpointing model + optimizer dictionaries
+ The main purpose for this is to be able to resume training from that instant again.
+ """
+ checkpoint_state_dict = {
+ "epoch": epoch,
+ "last_global_step": last_global_step,
+ }
+ # Add extra kwargs too
+ checkpoint_state_dict.update(kwargs)
+
+ # In here model will be a DeepspeedEngine object
+ model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict)
+ status_msg = (
+ f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}"
+ )
+ print(status_msg)
+
+
+def training_function(kwargs: dict):
+ print("training_function called")
+
+ # Train has a bug somewhere that causes ACCELERATE_TORCH_DEVICE to not be set
+ # properly on multi-gpu nodes
+ cuda_visible_device = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
+ print("CUDA_VISIBLE_DEVICES", cuda_visible_device)
+ local_rank = int(os.environ["LOCAL_RANK"])
+ device_id = cuda_visible_device[local_rank]
+ os.environ["ACCELERATE_TORCH_DEVICE"] = f"cuda:{device_id}"
+ print("ACCELERATE_TORCH_DEVICE", os.environ["ACCELERATE_TORCH_DEVICE"])
+
+ config = kwargs["config"]
+ args = argparse.Namespace(**kwargs["args"])
+ chat_template = kwargs.get("chat_template", [])
+ special_tokens = kwargs.get("special_tokens", [])
+ model_id = config["model_name"]
+
+ # Each worker should download its own model
+ pre_trained_path = "/tmp/" + model_id + "/" + generate_random_dir_name()
+
+ # start by copying the model to local. Calculate the time it takes to copy the model
+ start_copy_time = time.time()
+ print("Copying model to local")
+ copy_model_to_local(args.pre_trained_path, pre_trained_path)
+ end_copy_time = time.time()
+ print("Done copying model to local in ", end_copy_time - start_copy_time, " seconds")
+
+ # We need to download the model weights on this machine if they don't exit.
+ # We need to acquire a lock to ensure that only one process downloads the model
+ # bucket_uri = get_mirror_link(model_id)
+ # download_path = get_download_path(model_id)
+ # base_path = Path(download_path).parent
+ # base_path.mkdir(parents=True, exist_ok=True)
+ # lock_file = str(base_path / f'{model_id.replace("/", "--")}.lock')
+ # with FileLock(lock_file):
+ # download_model(
+ # model_id=model_id, bucket_uri=bucket_uri, s3_sync_args=["--no-sign-request"]
+ # )
+
+ # Sample hyperparameters for learning rate, batch size, seed and a few other HPs
+ lr = config["lr"]
+ num_epochs = int(config["num_epochs"])
+ seed = int(config["seed"])
+ batch_size = int(config["batch_size"])
+ gradient_accumulation_steps = int(config["gradient_accumulation_steps"])
+
+ # Get deepspeed config to set up the batch size per device
+ ds_plugin = config["ds_plugin"]
+ ds_plugin.hf_ds_config.config["train_micro_batch_size_per_gpu"] = batch_size
+
+ # Initialize accelerator
+ accelerator = Accelerator(
+ deepspeed_plugin=ds_plugin,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ mixed_precision=args.mx,
+ )
+
+ set_seed(seed)
+
+ # train_ds is the local shard for this model
+ train_ds = train.get_dataset_shard("train")
+ valid_ds = train.get_dataset_shard("valid")
+
+ train_ds_len = len(list(train_ds.iter_batches(batch_size=1)))
+
+ _test_tokenizer(pre_trained_path)
+ tokenizer = get_tokenizer(pretrained_path=pre_trained_path, special_tokens=special_tokens)
+ collate_partial = functools.partial(
+ collate_fn,
+ tokenizer=tokenizer,
+ block_size=config["block_size"],
+ device=accelerator.device,
+ )
+
+ # pre_trained_path = config["pre-trained-path"]
+ print(f"Loading model from {pre_trained_path} ...")
+ s = time.time()
+ model = AutoModelForCausalLM.from_pretrained(
+ pre_trained_path,
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+ # low_cpu_mem_usage=True,
+ # `use_cache=True` is incompatible with gradient checkpointing.
+ use_cache=False,
+ # device_map={"": device_id},
+ # attn_implementation="flash_attention_2",
+ )
+ print(f"Done loading model in {time.time() - s} seconds.")
+
+ model.resize_token_embeddings(len(tokenizer))
+
+ if config["lora"]:
+ # Apply LoRA
+ s = time.time()
+ lora_config = LoraConfig(**config["lora_config"])
+
+ expected_num_parameters = get_expected_lora_num_parameters(
+ lora_config=lora_config, model=model
+ )
+
+ print(f"Attempting to apply LoRA config: {lora_config}")
+
+ model.enable_input_require_grads()
+ # model = PeftModel.from_pretrained(model, config=lora_config)
+ model = get_peft_model(model, lora_config)
+
+ num_parameters = get_number_of_params(model)
+
+ if num_parameters != expected_num_parameters:
+ raise ValueError(
+ f"Expected {expected_num_parameters} parameters, got {num_parameters} "
+ f"parameters. LoRA-ification failed."
+ )
+
+ print(
+ f"LoRA-ification done in {time.time() - s} seconds. Estimated checkpoint "
+ f"size (fp16): {num_parameters * 2 / 1e6} MB"
+ )
+
+ print(f"Number of check-pointed parameters: {get_number_of_params(model)}")
+
+ print("Model initialized with pretrained weights. Training starting...")
+ if not args.no_grad_ckpt:
+ model.gradient_checkpointing_enable()
+
+ optimizer_cls = (
+ torch.optim.AdamW
+ if accelerator.state.deepspeed_plugin is None
+ or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
+ else DummyOptim
+ )
+
+ optimizer = optimizer_cls(
+ model.parameters(),
+ lr=lr,
+ betas=OPTIM_BETAS,
+ weight_decay=OPTIM_WEIGHT_DECAY,
+ eps=OPTIM_EPS,
+ )
+
+ # Instantiate scheduler
+ # Creates Dummy Scheduler if `scheduler` was specified in the config file or
+ # else, creates `args.lr_scheduler_type` Scheduler
+ # get train and valid dataset lengths
+
+ num_steps_per_epoch = math.ceil(train_ds_len / args.batch_size_per_device)
+ total_training_steps = (
+ num_steps_per_epoch * num_epochs // gradient_accumulation_steps
+ )
+
+ if (
+ accelerator.state.deepspeed_plugin is None
+ or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
+ ):
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer=optimizer,
+ num_warmup_steps=NUM_WARMUP_STEPS * args.num_devices,
+ num_training_steps=total_training_steps * args.num_devices,
+ )
+ else:
+ lr_scheduler = DummyScheduler(
+ optimizer,
+ warmup_num_steps=NUM_WARMUP_STEPS * args.num_devices,
+ total_num_steps=total_training_steps * args.num_devices,
+ )
+
+ # Prepare everything
+ # There is no specific order to remember, we just need to unpack the objects in the
+ # same order we gave them to the prepare method.
+ s = time.time()
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
+ print(f"Prepare done in {time.time() - s} seconds.")
+
+ # Now we train the model
+ if accelerator.is_main_process:
+ print("Starting training ...")
+ print("Number of batches on main process", train_ds_len // batch_size)
+
+ for epoch in range(num_epochs):
+ fwd_time_sum, bwd_time_sum, optim_step_time_sum = 0, 0, 0
+ s_epoch = time.time()
+ model.train()
+ loss_sum = torch.tensor(0.0).to(accelerator.device)
+
+ train_dataloader = train_ds.iter_torch_batches(
+ batch_size=batch_size,
+ collate_fn=collate_partial,
+ )
+
+ for step, batch in tqdm.tqdm(
+ enumerate(train_dataloader), total=train_ds_len // batch_size + 1
+ ):
+
+ # We could avoid this line since we set the accelerator with
+ # `device_placement=True`.
+ with accelerator.accumulate(model):
+ s_fwd = time.time()
+ outputs = model(**batch)
+ loss = outputs.loss
+ loss_sum += loss.item()
+ e_fwd = time.time()
+ fwd_time = e_fwd - s_fwd
+ fwd_time_sum += fwd_time
+ s_bwd = time.time()
+ accelerator.backward(loss)
+ e_bwd = time.time()
+ bwd_time = e_bwd - s_bwd
+ bwd_time_sum += bwd_time
+
+ s_opt_step = time.time()
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ e_opt_step = time.time()
+ optim_step_time_sum += e_opt_step - s_opt_step
+
+ if accelerator.is_main_process:
+ accelerator.print(
+ f"[epoch {epoch} step {step}] "
+ f"loss: {loss.item()} step-time: {e_opt_step - s_fwd}"
+ )
+
+ aggregated_loss = torch.mean(accelerator.gather(loss[None])).item()
+
+ if config["as_test"]:
+ break
+
+ # as long as this is not the last step report here
+ if step != (train_ds_len // batch_size - 1):
+ train.report(
+ {
+ "epoch": epoch,
+ "iteration": step,
+ "train_loss_batch": aggregated_loss,
+ "avg_train_loss_epoch": None,
+ "eval_loss": None,
+ "perplexity": None,
+ "num_iterations": step + 1,
+ "train_time_per_epoch": None,
+ "eval_time_per_epoch": None,
+ "fwd_time": fwd_time,
+ "bwd_time": bwd_time,
+ "avg_fwd_time_per_epoch": None,
+ "avg_bwd_time_per_epoch": None,
+ "learning_rate": lr_scheduler.get_lr()[0],
+ }
+ )
+
+ e_epoch = time.time()
+ accelerator.print("Train time per epoch: ", e_epoch - s_epoch)
+
+ eval_s_epoch = time.time()
+ print("Running evaluation ...")
+ perplex, eloss = evaluate(
+ model=model,
+ eval_ds=valid_ds,
+ accelerator=accelerator,
+ bsize=config["eval_batch_size"],
+ ds_kwargs={"collate_fn": collate_partial},
+ as_test=config["as_test"],
+ )
+ accelerator.print("Eval result loss", eloss)
+ accelerator.print("Eval perplex", perplex)
+
+ eval_e_epoch = time.time()
+ accelerator.print("Eval time per epoch: ", eval_e_epoch - eval_s_epoch)
+ accelerator.print("avg fwd time: ", fwd_time_sum / (step + 1))
+ accelerator.print("avg bwd time: ", bwd_time_sum / (step + 1))
+ accelerator.print("avg opt step time: ", optim_step_time_sum / (step + 1))
+
+ metrics = {
+ "epoch": epoch,
+ "iteration": step,
+ "train_loss_batch": aggregated_loss,
+ "avg_train_loss_epoch": loss_sum.item() / (step + 1),
+ "eval_loss": eloss,
+ "perplexity": perplex,
+ "num_iterations": step + 1,
+ "train_time_per_epoch": e_epoch - s_epoch,
+ "eval_time_per_epoch": eval_e_epoch - eval_s_epoch,
+ "fwd_time": fwd_time,
+ "bwd_time": bwd_time,
+ "avg_fwd_time_per_epoch": fwd_time_sum / (step + 1),
+ "avg_bwd_time_per_epoch": bwd_time_sum / (step + 1),
+ "learning_rate": lr_scheduler.get_lr()[0],
+ }
+
+ with tempfile.TemporaryDirectory(dir=args.output_dir) as temp_checkpoint_dir:
+ accelerator.print(f"Saving the model locally at {temp_checkpoint_dir}")
+ accelerator.wait_for_everyone()
+
+ checkpoint_save_start = time.perf_counter()
+
+ if accelerator.is_main_process:
+ print("Saving tokenizer and config.")
+ tokenizer.save_pretrained(temp_checkpoint_dir)
+
+ accelerator.wait_for_everyone()
+
+ # Checkpointing strategy 1: Distributed checkpointing
+ # This checkpointing method makes deepspeed checkpoints on each node
+ # and then Ray Train will aggregate them to a central s3 bucket.
+ # It should be done on all processes (not just the Rank 0)
+ # aggregate_on_rank_0 = False
+ # checkpoint_model(
+ # checkpoint_folder=tempdir,
+ # ckpt_id=epoch,
+ # model=model,
+ # epoch=epoch,
+ # last_global_step=step
+ # )
+
+ # Checkpointing strategy 2: Aggregate model on the rank 0 worker then upload
+ aggregate_on_rank_0 = True
+ unwrapped_model = accelerator.unwrap_model(model)
+ unwrapped_model.save_pretrained(
+ temp_checkpoint_dir,
+ is_main_process=accelerator.is_main_process,
+ save_function=accelerator.save,
+ safe_serialization=True,
+ state_dict=accelerator.get_state_dict(model),
+ )
+ accelerator.wait_for_everyone()
+ print("Checkpoint save time: ", time.perf_counter() - checkpoint_save_start)
+
+ checkpoint_upload_start = time.perf_counter()
+
+ # Create the checkpoint object to report to Ray Train and upload to storage.
+ # If we aggregated the model on rank 0, we only need to report
+ # the checkpoint from the rank 0 worker, since all other checkpoint
+ # directories are empty (`save_pretrained` was a noop for other workers).
+ if aggregate_on_rank_0:
+ checkpoint = (
+ Checkpoint.from_directory(temp_checkpoint_dir)
+ if accelerator.is_main_process
+ else None
+ )
+ else:
+ # Distributed checkpointing should upload shards from each worker.
+ checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
+
+ # Note: After `train.report`, in the case of remote storage,
+ # the checkpoint directory will be uploaded to the remote storage.
+ train.report(metrics, checkpoint=checkpoint)
+
+ print(
+ "Checkpoint upload time: ",
+ time.perf_counter() - checkpoint_upload_start,
+ )
+ print(
+ "Total checkpointing time: ",
+ time.perf_counter() - checkpoint_save_start,
+ )
+
+ hopsfs_upload_start = time.time()
+ copy_model_to_hopsfs(temp_checkpoint_dir, os.environ.get("TRAINED_MODEL_STORAGE_PATH"))
+ print("HopsFS upload time: ", time.time() - hopsfs_upload_start)
+
+ if perplex < args.stop_perplexity:
+ print(f"Perplexity reached {perplex} < {args.stop_perplexity}. Stopping.")
+ break
+
+ if config["as_test"]:
+ break
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="LLM fine-tuning with DeepSpeed")
+
+ parser.add_argument("--model-name", type=str, default="meta-llama/Meta-Llama-3.1-8B")
+
+ parser.add_argument("--train-path", type=str, default=os.path.join(TRAINING_DATA_DIR, "train.jsonl"),
+ help="Path to training jsonl file")
+
+ parser.add_argument("--test-path", type=str, default=os.path.join(TRAINING_DATA_DIR, "test.jsonl"),
+ help="Path to testing jsonl file")
+
+ parser.add_argument("--dataset-config", type=str, default=os.path.join(TRAINING_DATA_DIR, "config.json"),
+ help="Path to the config file")
+
+ parser.add_argument("--num-devices", "-nd", type=int, default=3,
+ help="Number of devices to use.")
+
+ parser.add_argument("--mx", type=str, choices=["no", "fp16", "bf16", "fp8"], default="bf16",
+ help="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). "
+ "Bf16 requires PyTorch >= 1.10 and an Nvidia Ampere GPU.")
+
+ parser.add_argument("--ds-config", type=str, default=os.path.join(TRAINING_CONFIGURATION_DIR,
+ "deepspeed_configs/zero_3_offload_optim_param.json"),
+ help="Deepspeed config json to use.")
+
+ parser.add_argument("--lora", action="store_true", default=False,
+ help="If passed, will enable parameter efficient fine-tuning with LoRA.")
+
+ parser.add_argument("--lora-config", type=str, default=os.path.join(TRAINING_CONFIGURATION_DIR,
+ "lora_configs/lora.json"),
+ help="Lora config json to use.")
+
+ parser.add_argument("--num-epochs", type=int, default=1,
+ help="Number of epochs to train for.")
+
+ parser.add_argument("--lr", type=float, default=1e-4,
+ help="Learning rate to use.")
+
+ parser.add_argument("--ctx-len", type=int, default=512,
+ help="Maximum context length for the model input sequences.")
+
+ parser.add_argument("--batch-size-per-device", "-bs", type=int, default=16,
+ help="Batch size to use per device.")
+
+ parser.add_argument("--eval-batch-size-per-device", type=int, default=64,
+ help="Batch size to use per device (For evaluation).")
+
+ parser.add_argument("--grad-accum", type=int, default=1,
+ help="Gradient accumulation steps.")
+
+ parser.add_argument("--output-dir", type=str, default="/tmp",
+ help="Path to output directory.")
+
+ parser.add_argument("--storage-path", type=str,
+ help="Path to results and checkpoints storage")
+
+ parser.add_argument("--no-grad-ckpt", action="store_true",
+ help="If passed, will not use gradient checkpointing.")
+
+ parser.add_argument("--num-checkpoints-to-keep", type=int, default=1,
+ help="Number of checkpoints to keep, if None, all checkpoints will be kept, "
+ "if set to n>=1, the top n checkpoint with min. evaluation perplexity "
+ "will be kept.")
+
+ parser.add_argument("--stop-perplexity", type=float, default=0,
+ help="Target perplexity to reach after which to stop training. Default is 0. "
+ "If 0, training will not stop on perplexity.")
+
+ parser.add_argument("--as-test", action="store_true",
+ help="If passed, will run the script in test mode.")
+
+ parser.add_argument("--pre-trained-path", type=str, help="Path to pretrained model")
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ # if TRAINING_DATA_DIR is None:
+ # TRAINING_DATA_DIR = os.environ.get("TRAINING_DATA_DIR")
+
+ # if os.environ.get("TRAINING_CONFIGURATION_DIR") is not None:
+ # TRAINING_CONFIGURATION_DIR = os.environ.get("TRAINING_CONFIGURATION_DIR")
+
+ args = parse_args()
+
+ if not args.output_dir:
+ raise ValueError("--output-dir must be specified")
+
+ # update the config with args so that we have access to them.
+ config = vars(args)
+ config.update(
+ **{
+ "lr": args.lr,
+ "num_epochs": args.num_epochs,
+ "seed": 42,
+ "batch_size": args.batch_size_per_device,
+ "gradient_accumulation_steps": args.grad_accum,
+ "model_name": args.model_name,
+ "block_size": args.ctx_len,
+ "eval_batch_size": args.eval_batch_size_per_device,
+ }
+ )
+
+ # Add LoRA config if needed
+ if args.lora:
+ with open(args.lora_config, "r") as json_file:
+ lora_config = json.load(json_file)
+ config["lora_config"] = lora_config
+
+ # Add deepspeed plugin to the config
+ ds_plugin = DeepSpeedPlugin(hf_ds_config=config.get("ds_config"))
+ config.update(ds_plugin=ds_plugin)
+
+ ray.init()
+
+ # Read data
+ train_ds = ray.data.read_json(args.train_path)
+ if args.test_path is not None:
+ valid_ds = ray.data.read_json(args.test_path)
+ else:
+ valid_ds = None
+
+ special_tokens = ray.data.read_json(TRAINING_DATA_DIR + "/tokens.json").take_all()[0]["tokens"]
+
+ # Config file
+ # chat_template = None
+ # special_tokens = None
+ # if os.path.isfile(args.dataset_config):
+ # with open(args.dataset_config, "r") as json_file:
+ # dataset_config = json.load(json_file)
+ # chat_template = dataset_config.get("chat_template", None)
+ # special_tokens = dataset_config.get("special_tokens", None)
+
+ trial_name = f"{args.model_name}".split("/")[-1]
+ if args.lora:
+ trial_name += "-lora"
+
+ storage_path = os.environ.get(
+ "PROJECT_DIR") + "/Resources/ft_llms_with_deepspeed/" + args.model_name + "/" + trial_name
+ trainer = TorchTrainer(
+ training_function,
+ train_loop_config={
+ "config": config,
+ "args": vars(args),
+ # "chat_template": chat_template,
+ "special_tokens": special_tokens,
+ },
+ run_config=train.RunConfig(
+ storage_path=storage_path,
+ checkpoint_config=train.CheckpointConfig(
+ num_to_keep=args.num_checkpoints_to_keep,
+ checkpoint_score_attribute="perplexity",
+ checkpoint_score_order="min",
+ # storage_path=os.environ["PROJECT_DIR"] + "/Resources/llm-checkpoint-dir",
+ ),
+ ),
+ scaling_config=train.ScalingConfig(
+ num_workers=args.num_devices,
+ use_gpu=True,
+ resources_per_worker={"GPU": 1, "CPU": 13},
+ ),
+ datasets={"train": train_ds, "valid": valid_ds},
+ dataset_config=ray.train.DataConfig(datasets_to_split=["train", "valid"]),
+ )
+
+ result: train.Result = trainer.fit()
+ # `best_checkpoints` are sorted in increasing score order.
+ # (Ex: in this case, negative perplexity, since we set `checkpoint_score_order=min`)
+ best_checkpoint, best_checkpoint_metrics = result.best_checkpoints[-1]
+
+ print("Results are stored at:")
+ print(result.path)
+ print("Best checkpoint is stored at:")
+ print(best_checkpoint)
+ print(f"With perplexity: {best_checkpoint_metrics['perplexity']}")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file