Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: Cannot subclass <class 'typing._SpecialForm'> while fine tuning #222

Open
samyakai opened this issue Apr 23, 2022 · 9 comments
Open

Comments

@samyakai
Copy link

samyakai commented Apr 23, 2022

I am trying to fine tune gpt-j on custom data using TPU. When I try to run the "device_train.py" file using the mentioned command: "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/", I get this error:

Traceback (most recent call last):
File "device_train.py", line 13, in
from mesh_transformer import util
File "/home/shreyjain/mesh-transformer-jax/mesh_transformer/util.py", line 36, in
class ClipByGlobalNormState(OptState):
File "/usr/lib/python3.8/typing.py", line 317, in new
raise TypeError(f"Cannot subclass {cls!r}")
TypeError: Cannot subclass <class 'typing._SpecialForm'>

OS = Ubuntu 20.04
TPU V3-8
python version = 3.8 and 3.7 both give the error

I have no idea what this error means. Any help would be appreciated!
Thank you.

@jagruti-samyak
Copy link

Getting the same issue not able to solve
Please help us
Thank you

@mosmos6
Copy link

mosmos6 commented May 5, 2022

Maybe this works.

pip install dm-haiku==0.0.5
and put optax back to the default version.

@shrey10926
Copy link

shrey10926 commented May 8, 2022

@mosmos6 Nope it doesn't work

@anon-mouse-1
Copy link

#202 (comment)
Please follow this solution! It works.

@jagruti-samyak
Copy link

@anon-mouse-1 which v2 version of TPU should i use? There are 2 options for TPU namely TPU VM architecture and tpu node architecture.

@Tylersuard
Copy link

After the 5th error I just gave up on this notebook.

@samyakai
Copy link
Author

@Tylersuard Yes. Also no one is providing a solution to the errors which is a shame as I really want to train on TPU as opposed to a GPU

@dhruv2601
Copy link

Downgrading optax worked for me to get rid of this error.

pip install optax==0.0.9

@mosmos6
Copy link

mosmos6 commented Jul 18, 2022

#202 (comment) Please follow this solution! It works.

In addition to this,
pip install chex==0.1.2
pip install jaxlib==0.1.74
pip install dm-haiku==0.0.5

and it worked for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants