diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_13b.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_13b.json new file mode 100644 index 00000000..aa036d74 --- /dev/null +++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_13b.json @@ -0,0 +1,35 @@ +{ + "fp16": { + "enabled": "auto" + }, + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": 5e8, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true, + "round_robin_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_70b.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_70b.json new file mode 100644 index 00000000..23c70b4f --- /dev/null +++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_70b.json @@ -0,0 +1,28 @@ +{ + "fp16": { + "enabled": false + }, + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": false + }, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "gather_16bit_weights_on_model_save": true, + "round_robin_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_7b.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_7b.json new file mode 100644 index 00000000..f1ddac17 --- /dev/null +++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_llama_2_7b.json @@ -0,0 +1,35 @@ +{ + "fp16": { + "enabled": "auto" + }, + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": 5e8, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true, + "round_robin_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_offload_optim_param.json b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_offload_optim_param.json new file mode 100644 index 00000000..9130e09f --- /dev/null +++ b/llm-ai-systems/fine-tuning/configs/deepspeed_configs/zero_3_offload_optim_param.json @@ -0,0 +1,32 @@ +{ + "fp16": { + "enabled": "auto" + }, + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": false + }, + "offload_param": { + "device": "cpu", + "pin_memory": false + }, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "gather_16bit_weights_on_model_save": true, + "round_robin_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 10, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/llm-ai-systems/fine-tuning/configs/lora_configs/lora.json b/llm-ai-systems/fine-tuning/configs/lora_configs/lora.json new file mode 100644 index 00000000..b953a4c9 --- /dev/null +++ b/llm-ai-systems/fine-tuning/configs/lora_configs/lora.json @@ -0,0 +1,11 @@ +{ + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"], + "task_type": "CAUSAL_LM", + "modules_to_save": [], + "bias": "none", + "fan_in_fan_out": false, + "init_lora_weights": true +} \ No newline at end of file diff --git a/llm-ai-systems/fine-tuning/llama_fine_tune_runtime_env.yaml b/llm-ai-systems/fine-tuning/llama_fine_tune_runtime_env.yaml new file mode 100644 index 00000000..1c3a96a0 --- /dev/null +++ b/llm-ai-systems/fine-tuning/llama_fine_tune_runtime_env.yaml @@ -0,0 +1,12 @@ +pip: + - transformers==4.44.0 + - accelerate==0.31.0 + - peft==0.11.1 + - deepspeed==0.16.2 +env_vars: + LIBRARY_PATH: "$CUDA_HOME/lib64:$LIBRARY_PATH" + PROJECT_DIR: "/home/yarnapp/hopsfs" + TRAINED_MODEL_STORAGE_PATH: "${PROJECT_DIR}/Resources/llama_finetuning/fine-tuned-model" + TRAINING_DATA_DIR: "${PROJECT_DIR}/Resources/llama_finetuning/datasets" + TRAINING_CONFIGURATION_DIR: "${PROJECT_DIR}/Resources/llama_finetuning/configs" + \ No newline at end of file diff --git a/llm-ai-systems/fine-tuning/llama_fine_tuning_with_ray.ipynb b/llm-ai-systems/fine-tuning/llama_fine_tuning_with_ray.ipynb new file mode 100644 index 00000000..04be6584 --- /dev/null +++ b/llm-ai-systems/fine-tuning/llama_fine_tuning_with_ray.ipynb @@ -0,0 +1,576 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4392c184-cc3b-4ded-8d2e-31c12c49ecf6", + "metadata": {}, + "source": [ + "## Fine-tune Llama 3.1 (8B parameter) using Ray Framework on Hopsworks\n", + "This tutorial demonstrates how to perform fine-tuning (with LoRA and deepspeed) of a Llama 3.1 (8B) using the Ray framework on Hopsworks. Ray is an industry-leading distributed computing framework. This tutorial was run on OVH cluster but you can use any cloud provider of your choice." + ] + }, + { + "cell_type": "markdown", + "id": "bce9d0a6-3cda-4882-82e6-d3eb6b1f5a79", + "metadata": {}, + "source": [ + "### Pre-requisites\n", + "To perform the steps in this tutorial, you need to create a Hopsworks Kubernetes cluster with Ray enabled. For the fine-tuning task demonstrated in this example, these are the minimum resources required:\n", + "* 1 x B3-64 (16 CPU 64 GB RAM) for the Ray head\n", + "* 8 x T2-LE-90 (30 CPU, 90 GB RAM, 2x 32 GRAM Tesla V100S) for the workers\n", + "Let's get started!" + ] + }, + { + "cell_type": "markdown", + "id": "f400ddf9-5b27-4786-b165-0dea0e109618", + "metadata": {}, + "source": [ + "### Step 1: Dataset preparation\n", + "We are going to fine-tune the model for question answering. We need to prepare the dataset that will be used for supervised fine-tuning in a certain format. There is no specific prompt format required for the pre-trained Llama 3.1 so the dataset preprocessing can follow any prompt-completion style. The instruction-tuned models (Meta-Llama-3.1-{8,70,405}B-Instruct) use a multi-turn conversation prompt format that structures the conversation between the users and the models.\n", + "\n", + "The dataset for QA typically includes the following fields:\n", + "\n", + "* Question: The input question to the model.\n", + "* Context (optional): A passage or text providing information the model should use to answer.\n", + "* Answer: The correct response.\n", + "\n", + "This example is configured to fine-tune the Llama 3.1 8B pre-trained model on the GSM8K dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "adf6a503", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "import tempfile\n", + "import os\n", + "import json\n", + "import shutil" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "4de50ea7-9c67-4ebe-8fea-0956d50312bb", + "metadata": {}, + "outputs": [], + "source": [ + "llama_dir = \"Resources/llama_finetuning\"\n", + "HOPSFS_STORAGE_PATH = os.path.join(os.environ.get(\"PROJECT_PATH\"), llama_dir)\n", + "if not os.path.exists(HOPSFS_STORAGE_PATH):\n", + " os.mkdir(HOPSFS_STORAGE_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3baa764d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<_io.TextIOWrapper name='/hopsfs/Resources/llama_finetuning/datasets/tokens.json' mode='w' encoding='UTF-8'>\n" + ] + } + ], + "source": [ + "dataset = load_dataset(\"openai/gsm8k\", \"main\")\n", + "dataset_splits = {\"train\": dataset[\"train\"], \"test\": dataset[\"test\"]}\n", + "dataset_dir = os.path.join(HOPSFS_STORAGE_PATH, \"datasets\")\n", + "if not os.path.exists(dataset_dir):\n", + " os.mkdir(dataset_dir)\n", + " \n", + "with open(os.path.join(dataset_dir, \"tokens.json\"), \"w\") as f:\n", + " tokens = {}\n", + " print(f)\n", + " tokens[\"tokens\"] = [\"\", \"\", \"\", \"\"]\n", + " f.write(json.dumps(tokens))\n", + " for key, ds in dataset_splits.items():\n", + " with open(os.path.join(dataset_dir, f\"{key}.jsonl\"), \"w\") as f:\n", + " for item in ds:\n", + " newitem = {}\n", + " newitem[\"input\"] = (\n", + " f\"{item['question']}\"\n", + " f\"{item['answer']}\"\n", + " )\n", + " f.write(json.dumps(newitem) + \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "0686f075-a33e-4b1e-b20e-1b05dbede50a", + "metadata": {}, + "source": [ + "### Step 2: Download the pre-trained model\n", + "The next step is to download the pre-trained Llama model from hugging face. For this you will need the hugging face token." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "cb3b68c6", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers.utils.hub import TRANSFORMERS_CACHE\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b5effab1-04b8-49fc-88e1-a02df4e89456", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"HF_TOKEN\"] = \"\"\n", + "model_id = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b278bb45-c87d-440e-8649-dcf68621997d", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "16c9b79aa6d44e00ae339658b7193ee3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/55.4k [00:00 4\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m result \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m0\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mFailed to copy pre-trained model files to hopsfs\u001B[39m\u001B[38;5;124m\"\u001B[39m\n", + "\u001B[0;31mAssertionError\u001B[0m: Failed to copy pre-trained model files to hopsfs" + ] + } + ], + "source": [ + "# copy the downloaded model to hopsfs\n", + "cp_cmd = f\"cp -L -r {blobs_dir}/* {hopsfs_model_dir}\"\n", + "result = os.system(cp_cmd)\n", + "assert result != 0, \"Failed to copy pre-trained model files to hopsfs\"" + ] + }, + { + "cell_type": "markdown", + "id": "1492db64-88a7-431b-9bee-217aacb5916b", + "metadata": {}, + "source": [ + "### Step 3: Create the ray job for the fine-tuning task\n", + "We are going to use the hopsworks jobs api to create and run the job for the fine-tuning task" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "60e48c95-f236-4ebe-be27-f6f585813843", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Connection closed.\n", + "2025-01-09 07:01:22,956 INFO: Python Engine initialized.\n", + "\n", + "Logged in to project, explore it here https://hopsworks.ai.local/p/119\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bff06133e9884523a6cbdfbf1b46070e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Uploading /hopsfs/Jupyter/ray_llm_finetuning.py: 0.000%| | 0/28956 elapsed<00:00 remaining.original_module.weight" and another one with + # ".modules_to_save.default.weight" + # Therefore, each layer contributes 2 x the number of actual elements in + # that layer. + sum_params += 2 * target.weight.numel() + print( + "Found non-lora-layer to checkpoint: ", + layer_name, + " with num params ", + target.weight.numel(), + ) + else: + for module_name in lora_config.target_modules: + if layer_name == module_name: + loraified_modules += 1 + if isinstance(target, nn.Linear): + # Target is attention weight + sum_params += ( + target.in_features + target.out_features + ) * lora_config.r + elif isinstance(target, nn.Embedding): + # Target is linear weight + sum_params += ( + target.embedding_dim + target.num_embeddings + ) * lora_config.r + + print( + f"Detected {num_attention_layers} attention layers, containing" + f" {loraified_modules} modules to modify according to LoRA's `target_modules`." + f" This should yield {sum_params} trainable parameters." + ) + + return sum_params + + +def get_number_of_params(model: nn.Module): + sum = 0 + for name, param in model.named_parameters(): + if param.requires_grad: + sum += param.numel() + return sum + + +def collate_fn(batch, tokenizer, block_size, device): + out_batch = tokenizer( + list(batch["input"]), + padding="max_length", + max_length=block_size, + truncation=True, + return_tensors="pt", + ) + out_batch["labels"] = out_batch["input_ids"].clone() + + out_batch = tree.map_structure(lambda x: x.to(device), out_batch) + + return out_batch + + +def get_tokenizer(pretrained_path, special_tokens): + tokenizer = AutoTokenizer.from_pretrained(pretrained_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.add_tokens(special_tokens, special_tokens=True) + + return tokenizer + + +def evaluate( + *, model, eval_ds, accelerator, bsize, ds_kwargs, as_test: bool = False +) -> Tuple[float, float]: + model.eval() + losses = [] + + eval_dataloader = eval_ds.iter_torch_batches(batch_size=bsize, **ds_kwargs) + eval_ds_len = len(list(eval_ds.iter_batches(batch_size=1))) + for step, batch in tqdm.tqdm( + enumerate(eval_dataloader), total=eval_ds_len // (bsize + 1) + ): + with torch.no_grad(): + outputs = model(**batch) + + loss = outputs.loss + # The tensors are gathered by concatenating them on the first dimension, so we + # add a new dimension to the scalar loss to get a tensor of shape (K,) for K + # workers. + losses.append(accelerator.gather(loss[None])) + + if as_test: + break + + # We stack losses so that we have a tensor of shape (T, K) where T is the number of + # steps and K is the number of workers. + losses = torch.stack(losses) + try: + eval_loss = torch.mean(losses).item() + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + return perplexity, eval_loss + + +def copy_model_to_hopsfs(local_path, hopsfs_path): + if not os.path.exists(hopsfs_path): + os.makedirs(os.path.dirname(hopsfs_path), exist_ok=True) + shutil.copytree(local_path, hopsfs_path, dirs_exist_ok=True) + +def copy_model_to_local(pretrained_path, local_path): + os.makedirs(os.path.dirname(local_path), exist_ok=True) + shutil.copytree(pretrained_path, local_path) + + +def _test_tokenizer(pretrained_path): + # This function tests that adding special tokens does not + # result in un-expected tokenization + # Context: https://github.com/huggingface/transformers/issues/25176 + tokenizer = get_tokenizer(pretrained_path=pretrained_path, special_tokens=[""]) + testoutput = tokenizer("inform")["input_ids"] + expected = tokenizer("inform")["input_ids"] + assert testoutput[-1] == expected[-1], ( + "The tokenizer is not working as expected with special tokens, " + f"testoutput={testoutput}, expected={expected}" + ) + + +def checkpoint_model( + checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs +): + """Utility function for checkpointing model + optimizer dictionaries + The main purpose for this is to be able to resume training from that instant again. + """ + checkpoint_state_dict = { + "epoch": epoch, + "last_global_step": last_global_step, + } + # Add extra kwargs too + checkpoint_state_dict.update(kwargs) + + # In here model will be a DeepspeedEngine object + model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict) + status_msg = ( + f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}" + ) + print(status_msg) + + +def training_function(kwargs: dict): + print("training_function called") + + # Train has a bug somewhere that causes ACCELERATE_TORCH_DEVICE to not be set + # properly on multi-gpu nodes + cuda_visible_device = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + print("CUDA_VISIBLE_DEVICES", cuda_visible_device) + local_rank = int(os.environ["LOCAL_RANK"]) + device_id = cuda_visible_device[local_rank] + os.environ["ACCELERATE_TORCH_DEVICE"] = f"cuda:{device_id}" + print("ACCELERATE_TORCH_DEVICE", os.environ["ACCELERATE_TORCH_DEVICE"]) + + config = kwargs["config"] + args = argparse.Namespace(**kwargs["args"]) + chat_template = kwargs.get("chat_template", []) + special_tokens = kwargs.get("special_tokens", []) + model_id = config["model_name"] + + # Each worker should download its own model + pre_trained_path = "/tmp/" + model_id + "/" + generate_random_dir_name() + + # start by copying the model to local. Calculate the time it takes to copy the model + start_copy_time = time.time() + print("Copying model to local") + copy_model_to_local(args.pre_trained_path, pre_trained_path) + end_copy_time = time.time() + print("Done copying model to local in ", end_copy_time - start_copy_time, " seconds") + + # We need to download the model weights on this machine if they don't exit. + # We need to acquire a lock to ensure that only one process downloads the model + # bucket_uri = get_mirror_link(model_id) + # download_path = get_download_path(model_id) + # base_path = Path(download_path).parent + # base_path.mkdir(parents=True, exist_ok=True) + # lock_file = str(base_path / f'{model_id.replace("/", "--")}.lock') + # with FileLock(lock_file): + # download_model( + # model_id=model_id, bucket_uri=bucket_uri, s3_sync_args=["--no-sign-request"] + # ) + + # Sample hyperparameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + gradient_accumulation_steps = int(config["gradient_accumulation_steps"]) + + # Get deepspeed config to set up the batch size per device + ds_plugin = config["ds_plugin"] + ds_plugin.hf_ds_config.config["train_micro_batch_size_per_gpu"] = batch_size + + # Initialize accelerator + accelerator = Accelerator( + deepspeed_plugin=ds_plugin, + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=args.mx, + ) + + set_seed(seed) + + # train_ds is the local shard for this model + train_ds = train.get_dataset_shard("train") + valid_ds = train.get_dataset_shard("valid") + + train_ds_len = len(list(train_ds.iter_batches(batch_size=1))) + + _test_tokenizer(pre_trained_path) + tokenizer = get_tokenizer(pretrained_path=pre_trained_path, special_tokens=special_tokens) + collate_partial = functools.partial( + collate_fn, + tokenizer=tokenizer, + block_size=config["block_size"], + device=accelerator.device, + ) + + # pre_trained_path = config["pre-trained-path"] + print(f"Loading model from {pre_trained_path} ...") + s = time.time() + model = AutoModelForCausalLM.from_pretrained( + pre_trained_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + # low_cpu_mem_usage=True, + # `use_cache=True` is incompatible with gradient checkpointing. + use_cache=False, + # device_map={"": device_id}, + # attn_implementation="flash_attention_2", + ) + print(f"Done loading model in {time.time() - s} seconds.") + + model.resize_token_embeddings(len(tokenizer)) + + if config["lora"]: + # Apply LoRA + s = time.time() + lora_config = LoraConfig(**config["lora_config"]) + + expected_num_parameters = get_expected_lora_num_parameters( + lora_config=lora_config, model=model + ) + + print(f"Attempting to apply LoRA config: {lora_config}") + + model.enable_input_require_grads() + # model = PeftModel.from_pretrained(model, config=lora_config) + model = get_peft_model(model, lora_config) + + num_parameters = get_number_of_params(model) + + if num_parameters != expected_num_parameters: + raise ValueError( + f"Expected {expected_num_parameters} parameters, got {num_parameters} " + f"parameters. LoRA-ification failed." + ) + + print( + f"LoRA-ification done in {time.time() - s} seconds. Estimated checkpoint " + f"size (fp16): {num_parameters * 2 / 1e6} MB" + ) + + print(f"Number of check-pointed parameters: {get_number_of_params(model)}") + + print("Model initialized with pretrained weights. Training starting...") + if not args.no_grad_ckpt: + model.gradient_checkpointing_enable() + + optimizer_cls = ( + torch.optim.AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + + optimizer = optimizer_cls( + model.parameters(), + lr=lr, + betas=OPTIM_BETAS, + weight_decay=OPTIM_WEIGHT_DECAY, + eps=OPTIM_EPS, + ) + + # Instantiate scheduler + # Creates Dummy Scheduler if `scheduler` was specified in the config file or + # else, creates `args.lr_scheduler_type` Scheduler + # get train and valid dataset lengths + + num_steps_per_epoch = math.ceil(train_ds_len / args.batch_size_per_device) + total_training_steps = ( + num_steps_per_epoch * num_epochs // gradient_accumulation_steps + ) + + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=NUM_WARMUP_STEPS * args.num_devices, + num_training_steps=total_training_steps * args.num_devices, + ) + else: + lr_scheduler = DummyScheduler( + optimizer, + warmup_num_steps=NUM_WARMUP_STEPS * args.num_devices, + total_num_steps=total_training_steps * args.num_devices, + ) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the + # same order we gave them to the prepare method. + s = time.time() + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + print(f"Prepare done in {time.time() - s} seconds.") + + # Now we train the model + if accelerator.is_main_process: + print("Starting training ...") + print("Number of batches on main process", train_ds_len // batch_size) + + for epoch in range(num_epochs): + fwd_time_sum, bwd_time_sum, optim_step_time_sum = 0, 0, 0 + s_epoch = time.time() + model.train() + loss_sum = torch.tensor(0.0).to(accelerator.device) + + train_dataloader = train_ds.iter_torch_batches( + batch_size=batch_size, + collate_fn=collate_partial, + ) + + for step, batch in tqdm.tqdm( + enumerate(train_dataloader), total=train_ds_len // batch_size + 1 + ): + + # We could avoid this line since we set the accelerator with + # `device_placement=True`. + with accelerator.accumulate(model): + s_fwd = time.time() + outputs = model(**batch) + loss = outputs.loss + loss_sum += loss.item() + e_fwd = time.time() + fwd_time = e_fwd - s_fwd + fwd_time_sum += fwd_time + s_bwd = time.time() + accelerator.backward(loss) + e_bwd = time.time() + bwd_time = e_bwd - s_bwd + bwd_time_sum += bwd_time + + s_opt_step = time.time() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + e_opt_step = time.time() + optim_step_time_sum += e_opt_step - s_opt_step + + if accelerator.is_main_process: + accelerator.print( + f"[epoch {epoch} step {step}] " + f"loss: {loss.item()} step-time: {e_opt_step - s_fwd}" + ) + + aggregated_loss = torch.mean(accelerator.gather(loss[None])).item() + + if config["as_test"]: + break + + # as long as this is not the last step report here + if step != (train_ds_len // batch_size - 1): + train.report( + { + "epoch": epoch, + "iteration": step, + "train_loss_batch": aggregated_loss, + "avg_train_loss_epoch": None, + "eval_loss": None, + "perplexity": None, + "num_iterations": step + 1, + "train_time_per_epoch": None, + "eval_time_per_epoch": None, + "fwd_time": fwd_time, + "bwd_time": bwd_time, + "avg_fwd_time_per_epoch": None, + "avg_bwd_time_per_epoch": None, + "learning_rate": lr_scheduler.get_lr()[0], + } + ) + + e_epoch = time.time() + accelerator.print("Train time per epoch: ", e_epoch - s_epoch) + + eval_s_epoch = time.time() + print("Running evaluation ...") + perplex, eloss = evaluate( + model=model, + eval_ds=valid_ds, + accelerator=accelerator, + bsize=config["eval_batch_size"], + ds_kwargs={"collate_fn": collate_partial}, + as_test=config["as_test"], + ) + accelerator.print("Eval result loss", eloss) + accelerator.print("Eval perplex", perplex) + + eval_e_epoch = time.time() + accelerator.print("Eval time per epoch: ", eval_e_epoch - eval_s_epoch) + accelerator.print("avg fwd time: ", fwd_time_sum / (step + 1)) + accelerator.print("avg bwd time: ", bwd_time_sum / (step + 1)) + accelerator.print("avg opt step time: ", optim_step_time_sum / (step + 1)) + + metrics = { + "epoch": epoch, + "iteration": step, + "train_loss_batch": aggregated_loss, + "avg_train_loss_epoch": loss_sum.item() / (step + 1), + "eval_loss": eloss, + "perplexity": perplex, + "num_iterations": step + 1, + "train_time_per_epoch": e_epoch - s_epoch, + "eval_time_per_epoch": eval_e_epoch - eval_s_epoch, + "fwd_time": fwd_time, + "bwd_time": bwd_time, + "avg_fwd_time_per_epoch": fwd_time_sum / (step + 1), + "avg_bwd_time_per_epoch": bwd_time_sum / (step + 1), + "learning_rate": lr_scheduler.get_lr()[0], + } + + with tempfile.TemporaryDirectory(dir=args.output_dir) as temp_checkpoint_dir: + accelerator.print(f"Saving the model locally at {temp_checkpoint_dir}") + accelerator.wait_for_everyone() + + checkpoint_save_start = time.perf_counter() + + if accelerator.is_main_process: + print("Saving tokenizer and config.") + tokenizer.save_pretrained(temp_checkpoint_dir) + + accelerator.wait_for_everyone() + + # Checkpointing strategy 1: Distributed checkpointing + # This checkpointing method makes deepspeed checkpoints on each node + # and then Ray Train will aggregate them to a central s3 bucket. + # It should be done on all processes (not just the Rank 0) + # aggregate_on_rank_0 = False + # checkpoint_model( + # checkpoint_folder=tempdir, + # ckpt_id=epoch, + # model=model, + # epoch=epoch, + # last_global_step=step + # ) + + # Checkpointing strategy 2: Aggregate model on the rank 0 worker then upload + aggregate_on_rank_0 = True + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + temp_checkpoint_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + safe_serialization=True, + state_dict=accelerator.get_state_dict(model), + ) + accelerator.wait_for_everyone() + print("Checkpoint save time: ", time.perf_counter() - checkpoint_save_start) + + checkpoint_upload_start = time.perf_counter() + + # Create the checkpoint object to report to Ray Train and upload to storage. + # If we aggregated the model on rank 0, we only need to report + # the checkpoint from the rank 0 worker, since all other checkpoint + # directories are empty (`save_pretrained` was a noop for other workers). + if aggregate_on_rank_0: + checkpoint = ( + Checkpoint.from_directory(temp_checkpoint_dir) + if accelerator.is_main_process + else None + ) + else: + # Distributed checkpointing should upload shards from each worker. + checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) + + # Note: After `train.report`, in the case of remote storage, + # the checkpoint directory will be uploaded to the remote storage. + train.report(metrics, checkpoint=checkpoint) + + print( + "Checkpoint upload time: ", + time.perf_counter() - checkpoint_upload_start, + ) + print( + "Total checkpointing time: ", + time.perf_counter() - checkpoint_save_start, + ) + + hopsfs_upload_start = time.time() + copy_model_to_hopsfs(temp_checkpoint_dir, os.environ.get("TRAINED_MODEL_STORAGE_PATH")) + print("HopsFS upload time: ", time.time() - hopsfs_upload_start) + + if perplex < args.stop_perplexity: + print(f"Perplexity reached {perplex} < {args.stop_perplexity}. Stopping.") + break + + if config["as_test"]: + break + + +def parse_args(): + parser = argparse.ArgumentParser(description="LLM fine-tuning with DeepSpeed") + + parser.add_argument("--model-name", type=str, default="meta-llama/Meta-Llama-3.1-8B") + + parser.add_argument("--train-path", type=str, default=os.path.join(TRAINING_DATA_DIR, "train.jsonl"), + help="Path to training jsonl file") + + parser.add_argument("--test-path", type=str, default=os.path.join(TRAINING_DATA_DIR, "test.jsonl"), + help="Path to testing jsonl file") + + parser.add_argument("--dataset-config", type=str, default=os.path.join(TRAINING_DATA_DIR, "config.json"), + help="Path to the config file") + + parser.add_argument("--num-devices", "-nd", type=int, default=3, + help="Number of devices to use.") + + parser.add_argument("--mx", type=str, choices=["no", "fp16", "bf16", "fp8"], default="bf16", + help="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). " + "Bf16 requires PyTorch >= 1.10 and an Nvidia Ampere GPU.") + + parser.add_argument("--ds-config", type=str, default=os.path.join(TRAINING_CONFIGURATION_DIR, + "deepspeed_configs/zero_3_offload_optim_param.json"), + help="Deepspeed config json to use.") + + parser.add_argument("--lora", action="store_true", default=False, + help="If passed, will enable parameter efficient fine-tuning with LoRA.") + + parser.add_argument("--lora-config", type=str, default=os.path.join(TRAINING_CONFIGURATION_DIR, + "lora_configs/lora.json"), + help="Lora config json to use.") + + parser.add_argument("--num-epochs", type=int, default=1, + help="Number of epochs to train for.") + + parser.add_argument("--lr", type=float, default=1e-4, + help="Learning rate to use.") + + parser.add_argument("--ctx-len", type=int, default=512, + help="Maximum context length for the model input sequences.") + + parser.add_argument("--batch-size-per-device", "-bs", type=int, default=16, + help="Batch size to use per device.") + + parser.add_argument("--eval-batch-size-per-device", type=int, default=64, + help="Batch size to use per device (For evaluation).") + + parser.add_argument("--grad-accum", type=int, default=1, + help="Gradient accumulation steps.") + + parser.add_argument("--output-dir", type=str, default="/tmp", + help="Path to output directory.") + + parser.add_argument("--storage-path", type=str, + help="Path to results and checkpoints storage") + + parser.add_argument("--no-grad-ckpt", action="store_true", + help="If passed, will not use gradient checkpointing.") + + parser.add_argument("--num-checkpoints-to-keep", type=int, default=1, + help="Number of checkpoints to keep, if None, all checkpoints will be kept, " + "if set to n>=1, the top n checkpoint with min. evaluation perplexity " + "will be kept.") + + parser.add_argument("--stop-perplexity", type=float, default=0, + help="Target perplexity to reach after which to stop training. Default is 0. " + "If 0, training will not stop on perplexity.") + + parser.add_argument("--as-test", action="store_true", + help="If passed, will run the script in test mode.") + + parser.add_argument("--pre-trained-path", type=str, help="Path to pretrained model") + + args = parser.parse_args() + + return args + + +def main(): + # if TRAINING_DATA_DIR is None: + # TRAINING_DATA_DIR = os.environ.get("TRAINING_DATA_DIR") + + # if os.environ.get("TRAINING_CONFIGURATION_DIR") is not None: + # TRAINING_CONFIGURATION_DIR = os.environ.get("TRAINING_CONFIGURATION_DIR") + + args = parse_args() + + if not args.output_dir: + raise ValueError("--output-dir must be specified") + + # update the config with args so that we have access to them. + config = vars(args) + config.update( + **{ + "lr": args.lr, + "num_epochs": args.num_epochs, + "seed": 42, + "batch_size": args.batch_size_per_device, + "gradient_accumulation_steps": args.grad_accum, + "model_name": args.model_name, + "block_size": args.ctx_len, + "eval_batch_size": args.eval_batch_size_per_device, + } + ) + + # Add LoRA config if needed + if args.lora: + with open(args.lora_config, "r") as json_file: + lora_config = json.load(json_file) + config["lora_config"] = lora_config + + # Add deepspeed plugin to the config + ds_plugin = DeepSpeedPlugin(hf_ds_config=config.get("ds_config")) + config.update(ds_plugin=ds_plugin) + + ray.init() + + # Read data + train_ds = ray.data.read_json(args.train_path) + if args.test_path is not None: + valid_ds = ray.data.read_json(args.test_path) + else: + valid_ds = None + + special_tokens = ray.data.read_json(TRAINING_DATA_DIR + "/tokens.json").take_all()[0]["tokens"] + + # Config file + # chat_template = None + # special_tokens = None + # if os.path.isfile(args.dataset_config): + # with open(args.dataset_config, "r") as json_file: + # dataset_config = json.load(json_file) + # chat_template = dataset_config.get("chat_template", None) + # special_tokens = dataset_config.get("special_tokens", None) + + trial_name = f"{args.model_name}".split("/")[-1] + if args.lora: + trial_name += "-lora" + + storage_path = os.environ.get( + "PROJECT_DIR") + "/Resources/ft_llms_with_deepspeed/" + args.model_name + "/" + trial_name + trainer = TorchTrainer( + training_function, + train_loop_config={ + "config": config, + "args": vars(args), + # "chat_template": chat_template, + "special_tokens": special_tokens, + }, + run_config=train.RunConfig( + storage_path=storage_path, + checkpoint_config=train.CheckpointConfig( + num_to_keep=args.num_checkpoints_to_keep, + checkpoint_score_attribute="perplexity", + checkpoint_score_order="min", + # storage_path=os.environ["PROJECT_DIR"] + "/Resources/llm-checkpoint-dir", + ), + ), + scaling_config=train.ScalingConfig( + num_workers=args.num_devices, + use_gpu=True, + resources_per_worker={"GPU": 1, "CPU": 13}, + ), + datasets={"train": train_ds, "valid": valid_ds}, + dataset_config=ray.train.DataConfig(datasets_to_split=["train", "valid"]), + ) + + result: train.Result = trainer.fit() + # `best_checkpoints` are sorted in increasing score order. + # (Ex: in this case, negative perplexity, since we set `checkpoint_score_order=min`) + best_checkpoint, best_checkpoint_metrics = result.best_checkpoints[-1] + + print("Results are stored at:") + print(result.path) + print("Best checkpoint is stored at:") + print(best_checkpoint) + print(f"With perplexity: {best_checkpoint_metrics['perplexity']}") + + +if __name__ == "__main__": + main() \ No newline at end of file