Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when training a hypernetwork #6

Open
jubgjf opened this issue Jun 5, 2024 · 1 comment
Open

Error when training a hypernetwork #6

jubgjf opened this issue Jun 5, 2024 · 1 comment

Comments

@jubgjf
Copy link

jubgjf commented Jun 5, 2024

I tried to train a hypernetwork with English and Chinese dataset, and transfer a bilingual tokenizer for TinyLlama.

My devices are 2 * A100 80G, with CUDA driver version 12.2

My config is:

{
    "output_dir": "output-debug",
    "train_directory": "data/train",
    "valid_directory": "data/valid",
    "langs": "data/langs.txt",
    "model_name_or_path": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
    "revision": "refs/pr/8",
    "loss": "clm",
    "n_embd": 2048,
    "n_token_subsample": null,
    "random_warmup_steps": 0,
    "identity_n_subsample": 16384,
    "identity_steps": 0,
    "warmup_steps": [
        10000
    ],
    "steps": 200000,
    "dtype": "bfloat16",
    "use_unigram_bias": true,
    "learning_rate": [
        6e-5
    ],
    "max_grad_norm": 0.1,
    "extra_valid_tokenizer_names": [
        "models/TinyLlama-1.1B-intermediate-step-1431k-3T-Ext"
    ],
    "extra_valid_files": [
        "data/valid/en.parquet",
        "data/valid/zh.parquet"
    ],
    "extra_lang_codes": [
        "en",
        "zh"
    ],
    "n_valid_subsample": 4000,
    "do_tokenizer_sampling": true,
    "hn_rescale_embeddings": true,
    "hn_surface_maxlen": 15,
    "tokenizer_sample_mean": 32768,
    "tokenizer_sample_max": 32768,
    "tokenizer_sample_std": 0,
    "tokenizer_batch_size": 32,
    "weight_decay": 0.01,
    "adam_beta2": 0.95,
    "hn_model_name_or_path": "roberta-base",
    "tokenizer_noise_mean": 1e-5,
    "tokenizer_noise_std": 4,
    "hn_embed_lang_id": true,
    "hn_add_inter_token_attention": false,
    "hn_embed_target_priors": false,
    "hn_inter_token_attention_bias_by_priors": true,
    "hn_embed_using_source_embeddings": true,
    "train_batch_size": 2,
    "eval_batch_size": 2,
    "hn_hidden_size": 2048,
    "hn_intermediate_size": 4096,
    "gradient_accumulation_steps": 1,
    "learnable_bias": false,
    "add_target_priors_to_bias": false,
    "lexical_loss_weight": 0.5,
    "debug": false,
    "dataloader_num_workers": 64,
    "mix_languages": false,
    "logging_steps": 10
}

data/langs.txt is

en,1
zh,3

Everything works well in the main training loop, but I meet errors when it goes into logging_steps:

Traceback (most recent call last):
  File "/home/jnguan/code/zett/train.py", line 1605, in <module>
    main()
  File "/home/jnguan/code/zett/train.py", line 1516, in main
    lambda x: x.flatten(), stack_forest(train_metrics)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 69, in stack_forest
    return jax.tree_util.tree_map(stack_args, *forest)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 68, in <lambda>
    stack_args = lambda *args: np.stack(args)
                               ^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in stack
    arrays = [asanyarray(arr) for arr in arrays]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
              ^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 390, in __array__
    return np.asarray(self._value, dtype=dtype)
                      ^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 588, in _value
    if self.is_fully_replicated:
       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 354, in is_fully_replicated
    return self.sharding.is_fully_replicated
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

AttributeError: 'UnspecifiedValue' object has no attribute 'is_fully_replicated'

Full log:
zett-142044.log

My environment:

Package                  Version
------------------------ -----------
absl-py                  2.1.0
accelerate               0.30.1
aiohttp                  3.9.5
aiosignal                1.3.1
appdirs                  1.4.4
attrs                    23.2.0
certifi                  2024.2.2
charset-normalizer       3.3.2
chex                     0.1.86
click                    8.1.7
cmake                    3.29.3
contourpy                1.2.1
cycler                   0.12.1
datasets                 2.19.1
dill                     0.3.8
docker-pycreds           0.4.0
etils                    1.8.0
filelock                 3.14.0
flax                     0.8.0
fonttools                4.52.4
frozenlist               1.4.1
fsspec                   2024.5.0
gitdb                    4.0.11
GitPython                3.1.43
h5py                     3.8.0
huggingface-hub          0.23.2
idna                     3.7
importlib_resources      6.4.0
jax                      0.4.23
jax-cuda12-pjrt          0.4.23
jax-cuda12-plugin        0.4.23
jaxlib                   0.4.23
Jinja2                   3.1.4
joblib                   1.4.2
kiwisolver               1.4.5
lit                      18.1.6
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.7.2
maturin                  1.3.0
mdurl                    0.1.2
ml-dtypes                0.4.0
mpmath                   1.3.0
msgpack                  1.0.8
multidict                6.0.5
multiprocess             0.70.16
nest-asyncio             1.6.0
networkx                 3.3
numpy                    1.26.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvcc-cu12    12.5.40
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11        8.5.0.96
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu11        10.9.0.58
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu11       10.2.10.91
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu11     11.7.4.91
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu11         2.14.3
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.5.40
nvidia-nvtx-cu11         11.7.91
nvidia-nvtx-cu12         12.1.105
opt-einsum               3.3.0
optax                    0.1.5
orbax-checkpoint         0.5.14
packaging                24.0
pandas                   2.0.3
pathtools                0.1.2
pillow                   10.3.0
pip                      24.0
protobuf                 4.25.3
psutil                   5.9.8
pyahocorasick            2.0.0
pyarrow                  16.1.0
pyarrow-hotfix           0.6
Pygments                 2.18.0
pyparsing                3.0.9
python-dateutil          2.9.0.post0
pytz                     2024.1
PyYAML                   6.0.1
regex                    2024.5.15
requests                 2.32.3
rich                     13.7.1
rust_utils               0.14.1.dev0
safetensors              0.4.3
scikit-learn             1.4.2
scipy                    1.10.1
sentry-sdk               2.3.1
setproctitle             1.3.3
setuptools               69.5.1
six                      1.16.0
smmap                    5.0.1
sympy                    1.12.1
tensorstore              0.1.60
threadpoolctl            3.5.0
tokenizers               0.19.1
toolz                    0.12.1
torch                    2.3.0
tqdm                     4.66.4
transformers             4.41.1
triton                   2.3.0
typing_extensions        4.12.0
tzdata                   2024.1
urllib3                  2.2.1
wandb                    0.15.4
wheel                    0.43.0
xxhash                   3.4.1
yarl                     1.9.4
zipp                     3.19.0
@kdcyberdude
Copy link

Hi @jubgjf, can you try branch mentioned in this - #8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants