-
Notifications
You must be signed in to change notification settings - Fork 892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add train script that runs on the TPU VM #16
Add train script that runs on the TPU VM #16
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for unwrapping a few layers of this onion and writing some simple code that trains properly haha. Would you be able to write up the config keys required either in the readme or with a commented fine tune config file?
Would also be good if you could write the rough tok/s you get on a v3-8 somewhere for reference
tfrecord_loader.py
Outdated
|
|
||
| unique = set(self.used) | ||
| last_pass_start_idx = len(self.used) % len(unique) | ||
| self.used = self.used[-last_pass_start_idx:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why this is nessasary, can you reset self.used around the except StopIteration: line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yeah, that's a better idea
Changed it to just do self.reset() in that except block
device_train.py
Outdated
|
|
||
| # set up datasets | ||
| print("setting up datasets") | ||
| tpu_size = jax.device_count() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should probs check that tpu_size <= cores per replica here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What situation would cause this check to fail?
(Trying to understand what I'm preventing so I can, e.g., write an error message about it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sometimes JAX doesn't detect the TPUs properly (and detects a single CPU device), or if there is a future model with more than 8 shards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't that mean we check that tpu_size >= cores per replica, not <=?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, my mistake
|
Thanks! |
Adds a script for training on the TPU VM, similar to the existing
device_sample.py, etc.My use case: I'm fine-tuning the model on a single TPU v3-8. I initially got this working with the existing
train.py, butmeta.jsonray, using an extra VM outside the TPU, theTPUClusterwrapper, etc.Note: I slapped this together hastily and it's not perfect. I haven't yet verified it works on this branch, just on a personal dev branch with other changes.
Comments
Step
train.py, I changed whenstepis incremented, to try to make the values in the checkpoint file names match the exact number of gradient steps that have been run.Data loader
usedlist that was loaded.usedentries from earlier epochsreset()method to try to make val data sampling deterministic. (I have a tiny val dataset)Script
tune-model-pathis provided, otherwise trains from scratch (just for completeness)train.py, adds some more wandb metrics, print statements, etc. You can take or leave these, they're just what I wanted for my projectsave()is mostly a copy-paste ofTPUCluster.savewhich has some useful saving logic. Ideally this logic would be defined only once, but I really wanted to avoid breakingTPUCluster