Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight conversion testing and other features #27

Merged
merged 10 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions docs/guide/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ First we need to install the dependencies.
cd ../../
```

(download_weights)=
## Downloading LLaMa2 weights

1. Request access to the weights directly to meta: https://ai.meta.com/resources/models-and-libraries/llama-downloads/
1. Follow the instructions given by meta to download `llama-2-7b/` weights.
1. Make sure you have also downloaded the `tokenizer.model` file.
1. Request access to the weights directly to meta: https://ai.meta.com/resources/models-and-libraries/llama-downloads/.
1. Request access to the LLaMa2 huggingface model: https://huggingface.co/meta-llama/Llama-2-7b-hf.
1. Create a new huggingface token (or use an existing one): https://huggingface.co/settings/tokens.
1. Run the huggingface login CLI, and enter the token created on the previous step when asked:
```
huggingface-cli login
```

## Preparing the raw data

Expand All @@ -65,11 +70,7 @@ In this case, skip to the [data preprocessing](#data-preprocessing) section.
:::

1. Accept starcoder's terms of use via the huggingface portal: https://huggingface.co/datasets/bigcode/starcoderdata
1. Create a new huggingface token (or use an existing one): https://huggingface.co/settings/tokens.
1. Run the huggingface login CLI, and enter the token created on the previous step when asked:
```
huggingface-cli login
```
1. Create a huggingface token (or use an existing one) and login using `huggingface-cli` (see [Downloading LLaMa2 weights](#download_weights) for more information).
1. Download and save the starcoder dataset.
In this tutorial we will use the `julia` data, but feel free to use any other subset.
This data contains around 500M tokens.
Expand Down Expand Up @@ -141,7 +142,7 @@ torchrun $DISTRIBUTED_ARGS verify_correctness.py \
--load=/path/to/megatron/weights/ \
--data_path=/path/to/tokenized/starcoder \
--tokenizer_type=SentencePieceTokenizer \
--vocab_file=/path/to/tokenizer.model \
--vocab_file=/path/to/megatron/weights/tokenizer.model \
--huggingface_cache=/path/to/meta/llama-2-7b/ \
--huggingface_device=cuda:1 \
$COMMON_ARGS $LLAMA_ARGS # dont include LLAMA_ARGS if using Falcon
Expand Down Expand Up @@ -185,7 +186,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--data_path /path/to/tokenized/starcoder \
--model_name llama2 \
--tokenizer_type SentencePieceTokenizer \
--vocab_file /path/to/tokenizer.model \
--vocab_file=/path/to/megatron/weights/tokenizer.model \
--bf16 \
--use_flash_attn \
--micro_batch_size 5 \
Expand Down Expand Up @@ -245,7 +246,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
pipeline = transformers.pipeline(
"text-generation",
model=LlamaForCausalLM.from_pretrained("/path/to/hf/weights/"),
tokenizer=LlamaTokenizer("/path/to/tokenizer.model"),
tokenizer=LlamaTokenizer.from_pretrained("/path/to/hf/weights/"),
torch_dtype=torch.bfloat16,
device_map="auto"
)
Expand Down
3 changes: 1 addition & 2 deletions megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def build_tokenizer(args):
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)

if args.tokenizer_type not in {'SentencePieceTokenizer', 'FalconTokenizer'}:
if args.tokenizer_type != 'FalconTokenizer':
assert args.vocab_file is not None

# Select and instantiate the tokenizer.
Expand All @@ -31,7 +31,6 @@ def build_tokenizer(args):
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'SentencePieceTokenizer':
assert args.vocab_file is not None
tokenizer = _SentencePieceTokenizer(args.vocab_file, vocab_extra_ids=args.vocab_extra_ids,
vocab_extra_ids_list=args.vocab_extra_ids_list, new_tokens=args.new_tokens)
elif args.tokenizer_type == 'FalconTokenizer':
Expand Down
58 changes: 58 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from pathlib import Path


_test_failed_incremental: dict[str, dict[tuple[int, ...], str]] = {}


def pytest_addoption(parser):
parser.addoption("--cache_path", type=Path,
help="Huggingface cache path (optional)")
parser.addoption("--llama2_path", type=Path, required=True,
help="Path where the raw llama-2-7b weights are located")
parser.addoption("--tmp_dir", type=Path,
help="Prefix of the tempdir to create (optional)")
parser.addoption("--data_path", type=Path, required=True,
help="Path where the megatron dataset is located")
parser.addoption("--vocab_path", type=Path, required=True,
help="Meta's vocabfile")


def pytest_runtest_makereport(item, call):
if "incremental" in item.keywords:
# incremental marker is used
if call.excinfo is not None:
# the test has failed
# retrieve the class name of the test
cls_name = str(item.cls)
# retrieve the index of the test (if parametrize is used in combination with incremental)
parametrize_index = (
tuple(item.callspec.indices.values())
if hasattr(item, "callspec")
else ()
)
# retrieve the name of the test function
test_name = item.originalname or item.name
# store in _test_failed_incremental the original name of the failed test
_test_failed_incremental.setdefault(cls_name, {}).setdefault(
parametrize_index, test_name
)


def pytest_runtest_setup(item):
if "incremental" in item.keywords:
# retrieve the class name of the test
cls_name = str(item.cls)
# check if a previous test has failed for this class
if cls_name in _test_failed_incremental:
# retrieve the index of the test (if parametrize is used in combination with incremental)
parametrize_index = (
tuple(item.callspec.indices.values())
if hasattr(item, "callspec")
else ()
)
# retrieve the name of the first test function to fail for this class name and index
test_name = _test_failed_incremental[cls_name].get(parametrize_index, None)
# if name found, test has failed for the combination of class name & test name
if test_name is not None:
pytest.xfail("previous test failed ({})".format(test_name))
3 changes: 3 additions & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
incremental: mark a test as incremental
180 changes: 180 additions & 0 deletions tests/test_llama_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import re
import pytest
import shutil
from pathlib import Path
from typing import Optional, Iterator
from tempfile import TemporaryDirectory
from subprocess import PIPE, Popen


# ===
# = Arguments
# ===
@pytest.fixture(scope="session")
def llama_meta(pytestconfig) -> Path:
return pytestconfig.getoption("llama2_path")


@pytest.fixture(scope="session")
def cache_dir(pytestconfig) -> Optional[Path]:
return pytestconfig.getoption("cache_path")


@pytest.fixture(scope="session")
def data(pytestconfig) -> Path:
return pytestconfig.getoption("data_path")


@pytest.fixture(scope="session")
def vocab(pytestconfig) -> Path:
return pytestconfig.getoption("vocab_path")

@pytest.fixture(scope="session")
def root_dir(pytestconfig) -> TemporaryDirectory:
prefix = pytestconfig.getoption("tmp_dir")
prefix = None if prefix is None else str(prefix)
return TemporaryDirectory(prefix=prefix)


# ===
# = Paths
# ===
@pytest.fixture(scope="session")
def root(root_dir) -> Path:
return Path(root_dir.name)

@pytest.fixture(scope="session")
def llama_meta2mega(root: Path) -> Path:
return root/"llama-meta2mega"

@pytest.fixture(scope="session")
def llama_hf2mega(root: Path) -> Path:
return root/"llama-hf2mega"

@pytest.fixture(scope="session")
def vocab_hf2mega(llama_hf2mega: Path) -> Path:
return llama_hf2mega/"tokenizer.model"

@pytest.fixture(scope="session")
def llama_sharded(root: Path) -> Path:
return root/"llama-sharded"

@pytest.fixture(scope="session")
def llama_unsharded(root: Path) -> Path:
return root/"llama-unsharded"

@pytest.fixture(scope="session")
def llama_mega2hf(root: Path) -> Path:
return root/"llama-mega2hf"

@pytest.fixture(scope="session")
def llama_unsharded2hf(root: Path) -> Path:
return root/"llama-unsharded2hf"


# ===
# = Utils
# ===
def execute(cmd: list[str]) -> Iterator[str]:
with Popen(cmd, stdout=PIPE, text=True) as proc:
yield from map(lambda line: line.strip(), iter(proc.stdout.readline, ""))
assert proc.wait() == 0


def verify_correctness(our_path: Path, cache_dir: Optional[Path], data: Path,
vocab: Path) -> list[float]:
distributed_args = ["--nproc_per_node=1", "--nnodes=1",
"--node_rank=0", "--master_addr=localhost",
"--master_port=8000"]
main_args = ["--model_name=llama2", f"--load={our_path}",
f"--data_path={data}", "--no_new_tokens",
"--tokenizer_type=SentencePieceTokenizer",
"--model_size=7", f"--vocab_file={vocab}"]
extra_args = ["--hidden_dropout=0.0", "--attention_dropout=0.0",
"--no_bias_dropout_fusion", "--no_bias_gelu_fusion"]
cmd = ["torchrun"] + distributed_args + ["verify_correctness.py"] \
+ main_args + extra_args
if cache_dir is not None:
cmd.append(f"--huggingface_cache={cache_dir}")

max_errors = []
for line in execute(cmd):
if any(key in line for key in ["Iteration", "Max abs", "Abs loss"]):
print(line)
if rmatch := re.match(fr"^.*max=([0-9]+\.[0-9]+).*$", line):
max_errors.append(float(rmatch.group(1)))
assert sum(max_errors)/len(max_errors) <= 0.001, "Avg max error exceeds tolerance (0.001)"
return max_errors


def shard(load_dir: Path, save_dir: Path, tp: int = 1, pp: int = 1):
cmd = ["python", "tools/checkpoint_util.py", f"--load_dir={load_dir}",
f"--save_dir={save_dir}", "--model_type=llama2", "--true_vocab_size=32000",
f"--target_tensor_parallel_size={tp}", f"--target_pipeline_parallel_size={pp}"]
ignores = {"---", "...", "Setting"}
for line in execute(cmd):
if all(avoid not in line for avoid in ignores):
print(line)


def mega2hf(load_dir: Path, out_dir: Path):
with Popen(["python", "weights2megatron/megatron2hf.py",
f"--input_dir={load_dir}", f"--output_dir={out_dir}"]) as proc:
assert proc.wait() == 0


# ===
# = Tests
# ===
@pytest.mark.incremental
class TestLlamaWeights:
def test_path_exists(self, llama_meta: Path):
assert llama_meta.exists() and llama_meta.is_dir()

def test_meta2mega(self, llama_meta2mega: Path, llama_meta: Path,
cache_dir: Optional[Path], data: Path, vocab: Path):
assert not llama_meta2mega.exists()
with Popen(["python", Path("weights2megatron")/"weights2megatron.py",
"llama2", "--size=7", f"--out={llama_meta2mega}",
f"--cache-dir={llama_meta}"]) as proc:
assert proc.wait() == 0
assert llama_meta2mega.exists()
verify_correctness(llama_meta2mega, cache_dir, data, vocab)
shutil.rmtree(llama_meta2mega) # all future tests will only use llama_hf2mega

def test_hf2mega(self, llama_hf2mega: Path, cache_dir: Optional[Path],
data: Path, vocab_hf2mega: Path):
assert not llama_hf2mega.exists()
cmd = ["python", Path("weights2megatron")/"weights2megatron.py",
"llama2", "--size=7", f"--out={llama_hf2mega}"]
if cache_dir is not None:
cmd.append(f"--cache-dir={cache_dir}")
with Popen(cmd) as proc:
assert proc.wait() == 0
assert llama_hf2mega.exists()
verify_correctness(llama_hf2mega, cache_dir, data, vocab_hf2mega)

def test_metallama_verification(self, llama_hf2mega: Path, llama_meta: Path,
data: Path, vocab: Path):
verify_correctness(llama_hf2mega, llama_meta, data, vocab)

def test_shard_unshard(self, llama_hf2mega: Path, llama_sharded: Path,
llama_unsharded: Path, cache_dir: Optional[Path],
data: Path, vocab_hf2mega: Path):
print("sharding to tp=2, pp=2")
shard(llama_hf2mega, llama_sharded, tp=2, pp=2)
assert llama_sharded.exists()
print("merging back to tp=1, pp=1")
shard(llama_sharded, llama_unsharded, tp=1, pp=1)
assert llama_unsharded.exists()
verify_correctness(llama_unsharded, cache_dir, data, vocab_hf2mega)

def test_mega2hf(self, llama_hf2mega: Path, llama_mega2hf: Path,
cache_dir: Optional[Path], data: Path, vocab_hf2mega: Path):
mega2hf(llama_hf2mega, llama_mega2hf)
verify_correctness(llama_mega2hf, cache_dir, data, vocab_hf2mega)

def test_unsharded2hf(self, llama_unsharded: Path, llama_unsharded2hf: Path,
cache_dir: Optional[Path], data: Path, vocab_hf2mega: Path):
mega2hf(llama_unsharded, llama_unsharded2hf)
verify_correctness(llama_unsharded2hf, cache_dir, data, vocab_hf2mega)
Loading