Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question related to batch size and training speed #323

Open
paidaxinbao opened this issue May 27, 2024 · 2 comments
Open

A question related to batch size and training speed #323

paidaxinbao opened this issue May 27, 2024 · 2 comments

Comments

@paidaxinbao
Copy link

Hi!

The code in this repository has helped me a lot!

I found that as the batch size increases, the training time increases dramatically. When I set the batch size to 4 (the dataset has 25k images) the training time is about 2 days, but when the batch size is set to 128, the training time increases to 800 hours!

I don't know much about this.

My training configuration is as follows:
model = Unet(
dim=64,
out_dim=1,
dim_mults=(1, 2, 4, 8),
channels=2
)

diffusion = GaussianDiffusion(
model,
image_size=128,
timesteps=1000, # number of steps
sampling_timesteps=250,
# number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)

trainer = Trainer(
diffusion,
'/home/pxy/ML_work/train_picset/',
train_batch_size=4,
train_lr=8e-5,
train_num_steps=700000, # total training steps
gradient_accumulate_every=4, # gradient accumulation steps
ema_decay=0.995, # exponential moving average decay
amp=True, # turn on mixed precision
calculate_fid = False
)

trainer.train()

@gggah
Copy link

gggah commented May 28, 2024

Hi!

The code in this repository has helped me a lot!

I found that as the batch size increases, the training time increases dramatically. When I set the batch size to 4 (the dataset has 25k images) the training time is about 2 days, but when the batch size is set to 128, the training time increases to 800 hours!

I don't know much about this.

My training configuration is as follows: model = Unet( dim=64, out_dim=1, dim_mults=(1, 2, 4, 8), channels=2 )

diffusion = GaussianDiffusion( model, image_size=128, timesteps=1000, # number of steps sampling_timesteps=250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) )

trainer = Trainer( diffusion, '/home/pxy/ML_work/train_picset/', train_batch_size=4, train_lr=8e-5, train_num_steps=700000, # total training steps gradient_accumulate_every=4, # gradient accumulation steps ema_decay=0.995, # exponential moving average decay amp=True, # turn on mixed precision calculate_fid = False )

trainer.train()

May I ask what is your data format, why can't I recognize it, and the error should be greater than 100, while mine is 1200 pictures

@paidaxinbao
Copy link
Author

Hi!
The code in this repository has helped me a lot!
I found that as the batch size increases, the training time increases dramatically. When I set the batch size to 4 (the dataset has 25k images) the training time is about 2 days, but when the batch size is set to 128, the training time increases to 800 hours!
I don't know much about this.
My training configuration is as follows: model = Unet( dim=64, out_dim=1, dim_mults=(1, 2, 4, 8), channels=2 )
diffusion = GaussianDiffusion( model, image_size=128, timesteps=1000, # number of steps sampling_timesteps=250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) )
trainer = Trainer( diffusion, '/home/pxy/ML_work/train_picset/', train_batch_size=4, train_lr=8e-5, train_num_steps=700000, # total training steps gradient_accumulate_every=4, # gradient accumulation steps ema_decay=0.995, # exponential moving average decay amp=True, # turn on mixed precision calculate_fid = False )
trainer.train()

May I ask what is your data format, why can't I recognize it, and the error should be greater than 100, while mine is 1200 pictures

Hi, my data is a grayscale map and then I used the strategy in SR3 to use the condition and concatenate the original grayscale image as an input to Unet. I didn't understand what you mean by error, do you mean loss?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants