Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add train script that runs on the TPU VM #16

Merged
merged 15 commits into from
Jun 14, 2021

Conversation

nostalgebraist
Copy link
Contributor

@nostalgebraist nostalgebraist commented Jun 11, 2021

Adds a script for training on the TPU VM, similar to the existing device_sample.py, etc.

My use case: I'm fine-tuning the model on a single TPU v3-8. I initially got this working with the existing train.py, but

  • I had to work around its expectations about file structure e.g. meta.json
  • this way has lots of complexity that is needless here: using ray, using an extra VM outside the TPU, the TPUCluster wrapper, etc.

Note: I slapped this together hastily and it's not perfect. I haven't yet verified it works on this branch, just on a personal dev branch with other changes.

Comments

Step

  • Compared to train.py, I changed when step is incremented, to try to make the values in the checkpoint file names match the exact number of gradient steps that have been run.
    • I increment the step after compiling model fn
    • Inside the loop, we do (update weights) --> (increment step) --> (save checkpoint if necessary), whereas train.py saves in between the update and the increment
  • As a result of this, the val steps are now incremented by one, i.e. we'll do the first val at step 1 (because we took step 1 compiling model_fn)

Data loader

  • The changes to the TFRecord reader class cover 2 issues:
    • On my small finetuning dataset, I'm doing multiple epochs. After save/load, the iterator would recurse infinitely because every file was in the used list that was loaded.
      • Fixed with logic to discard used entries from earlier epochs
    • I added reset() method to try to make val data sampling deterministic. (I have a tiny val dataset)

Script

  • Does fine-tuning if tune-model-path is provided, otherwise trains from scratch (just for completeness)
  • Compared to train.py, adds some more wandb metrics, print statements, etc. You can take or leave these, they're just what I wanted for my project
  • save() is mostly a copy-paste of TPUCluster.save which has some useful saving logic. Ideally this logic would be defined only once, but I really wanted to avoid breaking TPUCluster

Copy link
Owner

@kingoflolz kingoflolz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for unwrapping a few layers of this onion and writing some simple code that trains properly haha. Would you be able to write up the config keys required either in the readme or with a commented fine tune config file?

Would also be good if you could write the rough tok/s you get on a v3-8 somewhere for reference


unique = set(self.used)
last_pass_start_idx = len(self.used) % len(unique)
self.used = self.used[-last_pass_start_idx:]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is nessasary, can you reset self.used around the except StopIteration: line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah, that's a better idea

Changed it to just do self.reset() in that except block

device_train.py Outdated

# set up datasets
print("setting up datasets")
tpu_size = jax.device_count()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should probs check that tpu_size <= cores per replica here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What situation would cause this check to fail?

(Trying to understand what I'm preventing so I can, e.g., write an error message about it)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sometimes JAX doesn't detect the TPUs properly (and detects a single CPU device), or if there is a future model with more than 8 shards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't that mean we check that tpu_size >= cores per replica, not <=?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, my mistake

device_train.py Show resolved Hide resolved
@kingoflolz kingoflolz merged commit 4ea1a1a into kingoflolz:master Jun 14, 2021
@kingoflolz
Copy link
Owner

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants