-
Notifications
You must be signed in to change notification settings - Fork 530
Add example to fine-tune StarCoder for chat-based applications #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
0a0e065
Add StarChat files
lewtun 07a3418
Clean up
lewtun cc8785c
Fix readme
lewtun bd5e167
Tweak
lewtun 18b32a3
Clean up
lewtun bde49ab
Final polish
lewtun daf5ad3
Fix steps
lewtun 3c6cc84
Final tweaks
lewtun f660186
Delete dead code
lewtun 83b3b78
Fix typo
lewtun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# pdm | ||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
#pdm.lock | ||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
# in version control. | ||
# https://pdm.fming.dev/#use-with-ide | ||
.pdm.toml | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ | ||
|
||
data/ | ||
wandb/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Fine-tuning StarCoder for chat-based applications | ||
|
||
This is a fully-working example to fine-tune `StarCoder` on a corpus of multi-turn dialogues and thus create a coding assistant that is chatty and helpful. The resulting model is quite good at generating code for plots and other programming tasks. For example, given the prompt: | ||
|
||
``` | ||
Draw me a map of the world using geopandas. Make it so that only Germany and Spain are colored red. | ||
``` | ||
|
||
the model will produce: | ||
|
||
```python | ||
import pandas as pd | ||
from matplotlib import pyplot as plt | ||
import geopandas as gpd | ||
from shapely.geometry import Point, Polygon | ||
%matplotlib inline | ||
# Load data from URL or local file | ||
url = 'https://raw.githubusercontent.com/datasets/geo-countries/master/data/countries.geojson' | ||
df = gpd.read_file(url) | ||
|
||
# Create plot | ||
fig, ax = plt.subplots() | ||
ax.set_aspect('equal') | ||
world = gpd.GeoDataFrame({'name': ['World'], | ||
'geometry': [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]}) | ||
germany = df[df['ADMIN'] == "Germany"] | ||
spain = df[df['ADMIN'] == "Spain"] | ||
other = df[(df['ADMIN']!= "Germany") & (df['ADMIN']!= "Spain")] | ||
world.plot(color='lightgrey', edgecolor='white', ax=ax) | ||
germany.plot(color="red", ax=ax) | ||
spain.plot(color="red", ax=ax) | ||
other.plot(color="skyblue", ax=ax) | ||
plt.title("European Countries") | ||
plt.show() | ||
``` | ||
|
||
Check out our [blog post](https://huggingface.co/blog/starchat-alpha) for more details. | ||
|
||
## Getting started | ||
|
||
To run the `train.py` script, first create a Python virtual environment using e.g. Conda: | ||
|
||
```shell | ||
conda create -n chat python=3.10 && conda activate chat | ||
``` | ||
|
||
Next, install PyTorch v1.13.1. Since this is hardware-dependent, we direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/previous-versions/#v1131) for this step. Next, install the rest of the project dependencies: | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
You'll also need to be logged into both your Hugging Face account. To do so, run: | ||
|
||
```shell | ||
huggingface-cli login | ||
``` | ||
|
||
Finally, install Git LFS with: | ||
|
||
```shell | ||
sudo apt-get install git-lfs | ||
``` | ||
|
||
## Prepare your dataset | ||
|
||
For training and inference, we use _dialogue templates_ to format each message in a conversation. For example, a typical dialogue between a human user and AI assistant takes the form: | ||
|
||
```json | ||
{ | ||
"messages": [ | ||
{ | ||
"content": "Is it possible to imagine a society without law?", | ||
"role": "user"}, | ||
{ | ||
"content": "It is difficult to imagine a society that is able to be maintained without any semblance of Law.", | ||
"role": "assistant", | ||
}, | ||
{ | ||
"content": "It seems like you consider the absence of law equal to the absence of anything that could guide the behaviour of the individual.", | ||
"role": "user", | ||
}, | ||
{ | ||
"content": "You are correct that there are other factors that can guide behavior in a society and play a role in shaping individuals' behavior and interactions with each other. However, even in societies where these factors are present, laws still serve an important role in maintaining social order and resolving conflicts.", | ||
"role": "assistant", | ||
} | ||
] | ||
} | ||
``` | ||
|
||
Make sure you convert your dataset according to this schema, in particular you need to include a `messages` column like the above. You can adjust the model, dataset, and hyperparamters in the `config.yaml` file. | ||
|
||
## Launch training | ||
|
||
We use DeepSpeed ZeRO-3 to shard the model and optimizer across 8 x A100 (80GB) GPUs. To fine-tune run: | ||
|
||
``` | ||
TRANSFORMERS_VERBOSITY=info torchrun --nproc_per_node=8 train.py config.yaml --deepspeed=deepspeed_z3_config_bf16.json | ||
``` | ||
|
||
By default, this will save the model checkpoint in the `data/` directory and also push it to the Hugging Face Hub. | ||
|
||
|
||
## Generate samples | ||
|
||
To generate a few coding examples from your model, run: | ||
|
||
```shell | ||
python generate.py --model_id path/to/your/model | ||
``` | ||
|
||
lewtun marked this conversation as resolved.
Show resolved
Hide resolved
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe some examples of what the model can do after tuning would be cool to show here.