Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is training steps relevant with batch size? #4

Closed
fortunechen opened this issue Jul 28, 2021 · 2 comments
Closed

Is training steps relevant with batch size? #4

fortunechen opened this issue Jul 28, 2021 · 2 comments

Comments

@fortunechen
Copy link

fortunechen commented Jul 28, 2021

Hi,

Thx for your open-source code !

I find that training steps per epoch is not relevant with batch size. In line 343

steps_per_epoch = num_train_examples // (jax.local_device_count() * config.d_step_per_g_step)

maybe it shoud be

steps_per_epoch = num_train_examples // (jax.local_device_count() * config.d_step_per_g_step * config.batch_size)

@kohjingyu
Copy link
Contributor

Thanks, you are correct. However, we don't use steps_per_epoch in the training loop (we only use num_train_steps rather than counting epochs), so this does not affect training.

@jayati-naik
Copy link

Hi,

But num_train_steps are calculated based on steps_per_epoch.

num_train_steps = steps_per_epoch * config.num_epochs

In my case with 1000 training examples and 200 validation example, if I go for
batch_size = 8 , d_step_per_g_step = 2
num_train_steps = 2500 and it fails around step 1600 with OutOfRangeError at below step.

batch = jax.tree_map(np.asarray, next(train_iter))

Am I missing something here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants