Skip to content

Add example to fine-tune StarCoder for chat-based applications #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

data/
wandb/
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
# What is this about?
💫 StarCoder is a language model (LM) trained on source code and natural language text. Its training data incorporates more that 80 different programming languages as well as text extracted from github issues and commits and from notebooks. This repository showcases how we get an overview of this LM's capabilities.

# News

* **May 9, 2023:** We've fine-tuned StarCoder to act as a helpful coding assistant 💬! Check out the `chat/` directory for the training code and play with the model [here](https://huggingface.co/spaces/HuggingFaceH4/starchat-playground).

# Disclaimer

Before you can use the model go to `hf.co/bigcode/starcoder` and accept the agreement.
Expand Down
111 changes: 111 additions & 0 deletions chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Fine-tuning StarCoder for chat-based applications

This is a fully-working example to fine-tune `StarCoder` on a corpus of multi-turn dialogues and thus create a coding assistant that is chatty and helpful. The resulting model is quite good at generating code for plots and other programming tasks. For example, given the prompt:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe some examples of what the model can do after tuning would be cool to show here.

```
Draw me a map of the world using geopandas. Make it so that only Germany and Spain are colored red.
```

the model will produce:

```python
import pandas as pd
from matplotlib import pyplot as plt
import geopandas as gpd
from shapely.geometry import Point, Polygon
%matplotlib inline
# Load data from URL or local file
url = 'https://raw.githubusercontent.com/datasets/geo-countries/master/data/countries.geojson'
df = gpd.read_file(url)

# Create plot
fig, ax = plt.subplots()
ax.set_aspect('equal')
world = gpd.GeoDataFrame({'name': ['World'],
'geometry': [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]})
germany = df[df['ADMIN'] == "Germany"]
spain = df[df['ADMIN'] == "Spain"]
other = df[(df['ADMIN']!= "Germany") & (df['ADMIN']!= "Spain")]
world.plot(color='lightgrey', edgecolor='white', ax=ax)
germany.plot(color="red", ax=ax)
spain.plot(color="red", ax=ax)
other.plot(color="skyblue", ax=ax)
plt.title("European Countries")
plt.show()
```

Check out our [blog post](https://huggingface.co/blog/starchat-alpha) for more details.

## Getting started

To run the `train.py` script, first create a Python virtual environment using e.g. Conda:

```shell
conda create -n chat python=3.10 && conda activate chat
```

Next, install PyTorch v1.13.1. Since this is hardware-dependent, we direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/previous-versions/#v1131) for this step. Next, install the rest of the project dependencies:

```shell
pip install -r requirements.txt
```

You'll also need to be logged into both your Hugging Face account. To do so, run:

```shell
huggingface-cli login
```

Finally, install Git LFS with:

```shell
sudo apt-get install git-lfs
```

## Prepare your dataset

For training and inference, we use _dialogue templates_ to format each message in a conversation. For example, a typical dialogue between a human user and AI assistant takes the form:

```json
{
"messages": [
{
"content": "Is it possible to imagine a society without law?",
"role": "user"},
{
"content": "It is difficult to imagine a society that is able to be maintained without any semblance of Law.",
"role": "assistant",
},
{
"content": "It seems like you consider the absence of law equal to the absence of anything that could guide the behaviour of the individual.",
"role": "user",
},
{
"content": "You are correct that there are other factors that can guide behavior in a society and play a role in shaping individuals' behavior and interactions with each other. However, even in societies where these factors are present, laws still serve an important role in maintaining social order and resolving conflicts.",
"role": "assistant",
}
]
}
```

Make sure you convert your dataset according to this schema, in particular you need to include a `messages` column like the above. You can adjust the model, dataset, and hyperparamters in the `config.yaml` file.

## Launch training

We use DeepSpeed ZeRO-3 to shard the model and optimizer across 8 x A100 (80GB) GPUs. To fine-tune run:

```
TRANSFORMERS_VERBOSITY=info torchrun --nproc_per_node=8 train.py config.yaml --deepspeed=deepspeed_z3_config_bf16.json
```

By default, this will save the model checkpoint in the `data/` directory and also push it to the Hugging Face Hub.


## Generate samples

To generate a few coding examples from your model, run:

```shell
python generate.py --model_id path/to/your/model
```

Loading