Skip to content

Commit

Permalink
Merge pull request #69 from tharapalanivel/trainer_image
Browse files Browse the repository at this point in the history
Initial commit for trainer image
  • Loading branch information
anhuong committed Mar 5, 2024
2 parents 0f09dab + 4312222 commit 0e60ecd
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ durations/*
coverage*.xml
dist
htmlcov
build
test

# IDEs
Expand Down
118 changes: 118 additions & 0 deletions build/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
FROM registry.access.redhat.com/ubi9/ubi AS release

ARG CUDA_VERSION=11.8.0
ARG USER=tuning
ARG USER_UID=1000

USER root

RUN dnf remove -y --disableplugin=subscription-manager \
subscription-manager \
# we install newer version of requests via pip
python3.11-requests \
&& dnf install -y make \
# to help with debugging
procps \
&& dnf clean all

ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8

ENV CUDA_VERSION=$CUDA_VERSION \
NV_CUDA_LIB_VERSION=11.8.0-1 \
NVIDIA_VISIBLE_DEVICES=all \
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
NV_CUDA_CUDART_VERSION=11.8.89-1 \
NV_CUDA_COMPAT_VERSION=520.61.05-1

RUN dnf config-manager \
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
&& dnf install -y \
cuda-cudart-11-8-${NV_CUDA_CUDART_VERSION} \
cuda-compat-11-8-${NV_CUDA_COMPAT_VERSION} \
&& echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf \
&& echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf \
&& dnf clean all

ENV CUDA_HOME="/usr/local/cuda" \
PATH="/usr/local/nvidia/bin:${CUDA_HOME}/bin:${PATH}" \
LD_LIBRARY_PATH="/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$CUDA_HOME/lib64:$CUDA_HOME/extras/CUPTI/lib64:${LD_LIBRARY_PATH}"


ENV NV_NVTX_VERSION=11.8.86-1 \
NV_LIBNPP_VERSION=11.8.0.86-1 \
NV_LIBCUBLAS_VERSION=11.11.3.6-1 \
NV_LIBNCCL_PACKAGE_VERSION=2.15.5-1+cuda11.8

RUN dnf config-manager \
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
&& dnf install -y \
cuda-libraries-11-8-${NV_CUDA_LIB_VERSION} \
cuda-nvtx-11-8-${NV_NVTX_VERSION} \
libnpp-11-8-${NV_LIBNPP_VERSION} \
libcublas-11-8-${NV_LIBCUBLAS_VERSION} \
libnccl-${NV_LIBNCCL_PACKAGE_VERSION} \
&& dnf clean all

ENV NV_CUDA_CUDART_DEV_VERSION=11.8.89-1 \
NV_NVML_DEV_VERSION=11.8.86-1 \
NV_LIBCUBLAS_DEV_VERSION=11.11.3.6-1 \
NV_LIBNPP_DEV_VERSION=11.8.0.86-1 \
NV_LIBNCCL_DEV_PACKAGE_VERSION=2.15.5-1+cuda11.8

RUN dnf config-manager \
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
&& dnf install -y \
cuda-command-line-tools-11-8-${NV_CUDA_LIB_VERSION} \
cuda-libraries-devel-11-8-${NV_CUDA_LIB_VERSION} \
cuda-minimal-build-11-8-${NV_CUDA_LIB_VERSION} \
cuda-cudart-devel-11-8-${NV_CUDA_CUDART_DEV_VERSION} \
cuda-nvml-devel-11-8-${NV_NVML_DEV_VERSION} \
libcublas-devel-11-8-${NV_LIBCUBLAS_DEV_VERSION} \
libnpp-devel-11-8-${NV_LIBNPP_DEV_VERSION} \
libnccl-devel-${NV_LIBNCCL_DEV_PACKAGE_VERSION} \
&& dnf clean all

ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs"

RUN dnf install -y python3.11 git && \
ln -s /usr/bin/python3.11 /bin/python && \
python -m ensurepip --upgrade

RUN mkdir /app

WORKDIR /tmp
RUN python -m pip install packaging && \
python -m pip install --upgrade pip && \
python -m pip install torch && \
python -m pip install wheel

# TODO Move to installing wheel once we have proper releases setup instead of cloning the repo
RUN git clone https://github.com/foundation-model-stack/fms-hf-tuning.git && \
cd fms-hf-tuning && \
python -m pip install -r requirements.txt && \
python -m pip install -r flashattn_requirements.txt && \
python -m pip install -U datasets && \
python -m pip install /tmp/fms-hf-tuning

RUN mkdir -p /licenses
COPY LICENSE /licenses/

COPY launch_training.py /app
RUN chmod +x /app/launch_training.py

# Need a better way to address this hack
RUN touch /.aim_profile && \
chmod -R 777 /.aim_profile && \
mkdir /.cache && \
chmod -R 777 /.cache

# create tuning user and give ownership to dirs
RUN useradd -u $USER_UID tuning -m -g 0 --system && \
chown -R $USER:0 /app && \
chmod -R g+rwX /app

WORKDIR /app
USER ${USER}

CMD [ "tail", "-f", "/dev/null" ]
168 changes: 168 additions & 0 deletions build/launch_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright The SFT Trainer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script wraps SFT Trainer to run for Train Conductor.
Read SFTTrainer configuration via environment variable `SFT_TRAINER_CONFIG_JSON_PATH`
for the path to the JSON config file with parameters or `SFT_TRAINER_CONFIG_JSON_ENV_VAR`
for the encoded config string to parse.
"""

# Standard
import base64
import os
import pickle
import json
import tempfile
import shutil
import glob

# First Party
import logging
from tuning import sft_trainer
from tuning.config import configs, peft_config
from tuning.utils.merge_model_utils import create_merged_model

# Third Party
import transformers


def txt_to_obj(txt):
base64_bytes = txt.encode("ascii")
message_bytes = base64.b64decode(base64_bytes)
obj = pickle.loads(message_bytes)
return obj


def get_highest_checkpoint(dir_path):
checkpoint_dir = ""
for curr_dir in os.listdir(dir_path):
if curr_dir.startswith("checkpoint"):
if checkpoint_dir:
curr_dir_num = int(checkpoint_dir.split("-")[-1])
new_dir_num = int(curr_dir.split("-")[-1])
if new_dir_num > curr_dir_num:
checkpoint_dir = curr_dir
else:
checkpoint_dir = curr_dir

return checkpoint_dir


def main():
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)

logging.info("Attempting to launch training script")
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)
peft_method_parsed = "pt"
json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH")
json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")

# accepts either path to JSON file or encoded string config
if json_path:
(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_json_file(json_path, allow_extra_keys=True)

contents = ""
with open(json_path, "r") as f:
contents = json.load(f)
peft_method_parsed = contents.get("peft_method")
logging.debug(f"Input params parsed: {contents}")
elif json_env_var:
job_config_dict = txt_to_obj(json_env_var)
logging.debug(f"Input params parsed: {job_config_dict}")

(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(job_config_dict, allow_extra_keys=True)

peft_method_parsed = job_config_dict.get("peft_method")
else:
raise ValueError(
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)

tune_config = None
merge_model = False
if peft_method_parsed == "lora":
tune_config = lora_config
merge_model = True
elif peft_method_parsed == "pt":
tune_config = prompt_tuning_config

logging.debug(
f"Parameters used to launch training: model_args {model_args}, data_args {data_args}, training_args {training_args}, tune_config {tune_config}"
)

original_output_dir = training_args.output_dir
with tempfile.TemporaryDirectory() as tempdir:
training_args.output_dir = tempdir
sft_trainer.train(model_args, data_args, training_args, tune_config)

if merge_model:
export_path = os.getenv(
"LORA_MERGE_MODELS_EXPORT_PATH", original_output_dir
)

# get the highest checkpoint dir (last checkpoint)
lora_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
full_checkpoint_dir = os.path.join(
training_args.output_dir, lora_checkpoint_dir
)

logging.info(
f"Merging lora tuned checkpoint {lora_checkpoint_dir} with base model into output path: {export_path}"
)

create_merged_model(
checkpoint_models=full_checkpoint_dir,
export_path=export_path,
base_model=model_args.model_name_or_path,
save_tokenizer=True,
)
else:
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
logging.info(
f"Copying last checkpoint {pt_checkpoint_dir} into output dir {original_output_dir}"
)
shutil.copytree(
os.path.join(training_args.output_dir, pt_checkpoint_dir),
original_output_dir,
dirs_exist_ok=True,
)

# copy over any loss logs
for file in glob.glob(f"{training_args.output_dir}/*loss.jsonl"):
shutil.copy(file, original_output_dir)


if __name__ == "__main__":
main()

0 comments on commit 0e60ecd

Please sign in to comment.