Skip to content

galatolofederico/pytorch-balanced-batch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-balanced-batch

A pytorch dataset sampler for always sampling balanced batches.

Be sure to use a batch_size that is an integer multiple of the number of classes.

For example, if your train_dataset has 10 classes and you use a batch_size=30 with the BalancedBatchSampler

train_loader = torch.utils.data.DataLoader(train_dataset, sampler=BalancedBatchSampler(train_dataset), batch_size=30)

You will obtain a train_loader in which each element has 3 samples for each of the 10 classes

About

A pytorch dataset sampler for always sampling balanced batches.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages