Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TpuEmbeddingEngine_WriteParameters not available in this library. #202

Closed
nikhilanayak opened this issue Feb 23, 2022 · 11 comments
Closed

Comments

@nikhilanayak
Copy link

I followed all of the instructions in the training guide but when I run the device_train script, I get this error:

2022-02-23 07:56:56.271731: F external/org_tensorflow/tensorflow/core/tpu/tpu_library_init_fns.inc:104] TpuEmbeddingEngine_WriteParameters not available in this library.

This is my exact command for the training process:

python3 device_train.py --config=configs/6B.json --tune-model-path=gs://nnrap/step_383500
@whoislimshady
Copy link

check jax version ig you are using 0.2.16 but the correct version in order to run the training is 0.2.12

@nikhilanayak
Copy link
Author

If I run pip3 list | grep jax, it returns:

jax                          0.2.12
jaxlib                       0.3.0

@nikhilanayak
Copy link
Author

Also @whoislimshady when I run it, it crashes and prints out

2022-02-23 17:35:25.749201: F external/org_tensorflow/tensorflow/core/tpu/tpu_library_init_fns.inc:104] TpuEmbeddingEngine_WriteParameters not available in this library.
Aborted (core dumped)

@mrseeker
Copy link

jax 0.2.12 won't run and crashes with this error. Using

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

works but introduces new issues.

@whoislimshady
Copy link

whoislimshady commented Feb 24, 2022

@nikhilanayak i also faced the same issue but somehow it got resolved by just changing version from 0.2.16 to 12
i am sharing the packages i have on my env and i am able to train model
hope this helps
absl-py==0.12.0 aiohttp==3.8.1 aiohttp-cors==0.7.0 aioredis==2.0.1 aiosignal==1.2.0 anyio==3.5.0 appdirs==1.4.4 asgiref==3.5.0 astunparse==1.6.3 async-timeout==4.0.2 attrs==19.3.0 Automat==0.8.0 bcrypt==3.2.0 best-download==0.0.9 black==22.1.0 blessings==1.7 BLEURT @ https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip blinker==1.4 cachetools==4.2.2 certifi==2020.12.5 cffi==1.15.0 chardet==4.0.0 charset-normalizer==2.0.11 chex==0.1.0 clang==5.0 click==7.1.2 cloud-init==21.1 cloud-tpu-client==0.10 cloudpickle==1.3.0 colorama==0.4.3 colorful==0.5.4 command-not-found==0.3 configobj==5.0.6 constantly==15.1.0 cryptography==2.8 Cython==0.29.23 DataProperty==0.54.2 datasets==1.15.1 dbus-python==1.2.16 Deprecated==1.2.13 dill==0.3.4 distlib==0.3.1 distro==1.4.0 distro-info===0.23ubuntu1 dm-haiku==0.0.5 dm-tree==0.1.6 docker-pycreds==0.4.0 dyNET38==2.1 einops==0.3.2 entrypoints==0.3 fabric==2.6.0 fastapi==0.73.0 filelock==3.0.12 Flask==1.1.4 flatbuffers==1.12 frozenlist==1.3.0 fsspec==2022.1.0 ftfy==6.1.1 func-timeout==4.3.5 future==0.18.2 gast==0.4.0 gitdb==4.0.9 GitPython==3.1.26 google-api-core==1.28.0 google-api-python-client==1.8.0 google-auth==1.30.1 google-auth-httplib2==0.1.0 google-auth-oauthlib==0.4.4 google-cloud-core==1.7.2 google-cloud-storage==1.36.2 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==1.3.3 googleapis-common-protos==1.53.0 gpustat==0.6.0 grpcio==1.38.0 h11==0.13.0 h5py==3.1.0 httplib2==0.19.1 huggingface-hub==0.4.0 hyperlink==19.0.0 idna==2.10 importlib-metadata==1.5.0 incremental==16.10.1 iniconfig==1.1.1 invoke==1.6.0 itsdangerous==1.1.0 jax==0.2.12 jaxlib==0.1.67 jieba==0.42.1 Jinja2==2.10.1 jmp==0.0.2 joblib==1.1.0 jsonlines==2.0.0 jsonpatch==1.22 jsonpointer==2.0 jsonschema==3.2.0 keras==2.6.0 Keras-Applications==1.0.8 keras-nightly==2.6.0.dev2021052400 Keras-Preprocessing==1.1.2 keyring==18.0.1 language-selector==0.1 launchpadlib==1.10.13 lazr.restfulclient==0.14.2 lazr.uri==1.0.3 libclang==13.0.0 lm-dataformat==0.0.20 lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness/@c74ca57c51c5fb0889b955878f00fe6d60ba393c Markdown==3.3.4 MarkupSafe==1.1.0 mbstrdecoder==1.1.0 mesh-transformer @ file:///home/harsh/gptj mock==4.0.3 more-itertools==4.2.0 msgfy==0.2.0 msgpack==1.0.3 multidict==6.0.2 multiprocess==0.70.12.2 mypy-extensions==0.4.3 nagisa==0.2.7 netifaces==0.10.4 nltk==3.7 numexpr==2.7.2 numpy==1.22.2 nvidia-ml-py3==7.352.0 oauth2client==4.1.3 oauthlib==3.1.0 openai==0.6.4 opencensus==0.8.0 opencensus-context==0.1.2 opt-einsum==3.3.0 optax==0.0.9 packaging==20.9 pandas==1.4.0 paramiko==2.9.2 pathlib2==2.3.7.post1 pathspec==0.9.0 pathtools==0.1.2 pathvalidate==2.5.0 pathy==0.6.1 pbr==5.8.1 pexpect==4.6.0 Pillow==8.2.0 platformdirs==2.5.0 pluggy==0.13.1 portalocker==2.3.2 prometheus-client==0.13.1 promise==2.3 protobuf==3.17.1 psutil==5.9.0 py==1.11.0 py-spy==0.3.11 pyarrow==7.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pybind11==2.6.2 pycountry==20.7.3 pycparser==2.21 pydantic==1.9.0 PyGObject==3.36.0 PyHamcrest==1.9.0 PyJWT==1.7.1 pymacaroons==0.13.0 PyNaCl==1.3.0 pyOpenSSL==19.0.0 pyparsing==2.4.7 pyrsistent==0.15.5 pyserial==3.4 pytablewriter==0.58.0 pytest==6.2.3 python-apt==2.0.0+ubuntu0.20.4.4 python-dateutil==2.8.2 python-debian===0.1.36ubuntu1 pytz==2021.1 PyYAML==5.4.1 ray==1.4.1 redis==4.1.3 regex==2022.1.18 rehash==1.0.0 requests==2.25.1 requests-oauthlib==1.3.0 requests-unixsocket==0.2.0 rouge-score==0.0.4 rsa==4.7.2 sacrebleu==1.5.0 sacremoses==0.0.47 scikit-learn==1.0.2 scipy==1.6.3 SecretStorage==2.3.1 sentencepiece==0.1.96 sentry-sdk==1.5.5 service-identity==18.1.0 shortuuid==1.0.8 simplejson==3.16.0 six==1.15.0 smart-open==5.2.1 smmap==5.0.0 sniffio==1.2.0 sos==4.1 sqlitedict==1.6.0 ssh-import-id==5.10 starlette==0.17.1 systemd-python==234 tabledata==1.3.0 tabulate==0.8.9 tb-nightly==2.6.0a20210524 tcolorpy==0.1.1 tensorboard==2.6.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorflow==2.8.0 tensorflow-cpu==2.6.3 tensorflow-estimator==2.6.0 tensorflow-io-gcs-filesystem==0.24.0 termcolor==1.1.0 testresources==2.0.1 tf-estimator-nightly==2.8.0.dev2021122109 tf-slim==1.1.0 threadpoolctl==3.1.0 tokenizers==0.11.4 toml==0.10.2 tomli==2.0.1 toolz==0.11.2 torch==1.8.1 torch-xla==1.8.1 torchvision==0.9.1 tqdm==4.62.3 tqdm-multiprocess==0.0.11 transformers==4.16.2 Twisted==18.9.0 typepy==1.3.0 typer==0.4.0 typing-extensions==3.10.0.2 ubuntu-advantage-tools==20.3 ufw==0.36 ujson==5.1.0 unattended-upgrades==0.1 uritemplate==3.0.1 urllib3==1.26.4 uvicorn==0.17.4 virtualenv==20.4.7 wadllib==1.3.3 wandb==0.12.10 wcwidth==0.2.5 Werkzeug==1.0.1 wrapt==1.12.1 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==1.0.0 zope.interface==4.7.1 zstandard==0.15.0

@mrseeker
Copy link

@whoislimshady It looks like you are using torch 1.8.1? I got 1.11.1...

@mrseeker
Copy link

Just to be sure that everyone is on the correct page: What version of TPU-VM are you actually running? It might be that the version that people are running is actually incorrect, and that a lower version actually performs better than the newer ones. I am running with --version tpu-vm-tf-2.8.0, but I think this version might need be actually a lower one (2.6.3)

@safeeazeem
Copy link

For fine tuning I have always used tpu version v2-alpha. Havent come across any errors so far.

@nikhilanayak
Copy link
Author

nikhilanayak commented Feb 25, 2022

Here's a full set of commands to reproduce the error:
(The GPT-J-6B/step_383500 weights are already uploaded to gcloud)

gcloud alpha compute tpus tpu-vm create my_tpu_vm --zone=us-central1-a --accelerator-type=v3-8 --version=v2-alpha
gcloud alpha compute tpus tpu-vm ssh my_tpu_vm --zone us-central1-a --project [MY PROJECT ID]
$ git clone https://github.com/kingoflolz/mesh-transformer-jax
$ cd mesh-transformer-jax
$ echo gs://[my_tfrecord] > data/main.train.index
$ cat configs/6B_roto_256.json (Change it so this is true)

{
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
  "gradient_accumulation_steps": 16,

  "warmup_steps": 3000,
  "anneal_steps": 300000,
  "lr": 5e-5,
  "end_lr": 1e-5,
  "weight_decay": 0.1,
  "total_steps": 350000,

  "tpu_size": 8,

  "bucket": "[my bucket (without the gs://)]",
  "model_dir": "mesh_jax_pile_6B_rotary",

  "train_set": "main.train.index",
  "val_set": {},

  "eval_harness_tasks": [],

  "val_batches": 100,
  "val_every": 350001,
  "ckpt_every": 500,
  "keep_every": 10000,

  "name": "GPT3_6B_pile_rotary",
  "wandb_project": "mesh-transformer-jax",
  "comment": ""
}

$ pip install -r requirements.txt --use-deprecated=legacy-resolver
$ pip install "jax[tpu]>=0.2.12" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$ pip install jaxlib==0.1.68
$ export LD_LIBRARY_PATH=/usr/local/lib
$ python3 device_train.py --config=configs/6B_roto_256.json --tune-model-path=gs://MY-BUCKET/step_383500/

@mrseeker
Copy link

mrseeker commented Feb 25, 2022

Okay, I fixed the issue this way:

  • Use version v2-alpha instead of the tpu-vm-tf-2.8.0
  • Install the package "python3-venv"
  • Build & start a virtualenv environment (python3 -m venv venv and source venv/bin/activate)
  • Install wheel (needed for some)
  • Install the requirements
  • Install Jax (as told by Google: pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html)
  • Reinstall tensorflow because of numpy (pip install -U tensorflow)

This seemed to fix a LOT of my issues I was having, and its now working :)

@mosmos6
Copy link

mosmos6 commented Feb 26, 2022

@mrseeker Your solution fixed my problem too. Thank you for sharing it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants