In [1]:
# Reference: https://discuss.pytorch.org/t/some-problems-with-weightedrandomsampler/23242/34

In [2]:
import torch

In [3]:
# Creating a dataset with imbalance

class_counts = torch.tensor([700, 200, 100])
num_data_points = class_counts.sum()
data_dim = 5
bs = 100
data = torch.randn(num_data_points, data_dim)

In [4]:
target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),
                   torch.ones(class_counts[1], dtype=torch.long),
                   torch.ones(class_counts[2], dtype=torch.long)*2))

In [5]:
print('target train 0/1/2: {}/{}/{}'.format(
     (target==0).sum(), (target==1).sum(), (target==2).sum()))

target train 0/1/2: 700/200/100


In [6]:
# Compute samples weight (each sample should get its own weight)

class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])

In [7]:
# Torch dataset

train_dataset = torch.utils.data.TensorDataset(data, target)

In [8]:
# Dataloaders with different sampling techniques

dataloader_basic_sampler = torch.utils.data.DataLoader(train_dataset, batch_size=bs, num_workers=0)

dataloader_basic_randomsampler = torch.utils.data.DataLoader(train_dataset, batch_size=bs, num_workers=0, shuffle=True)

wrs = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))
dataloader_weighterrandomsampler = torch.utils.data.DataLoader(train_dataset, batch_size=bs, num_workers=0, sampler=wrs)


In [9]:
# Iterate DataLoader and check class balance for each batch

print('\n Basic sampler')
for i, (x, y) in enumerate(dataloader_basic_sampler):
    print("batch index {}, 0/1/2: {}/{}/{}".format(
        i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))
    
print('\n Basic random sampler')
for i, (x, y) in enumerate(dataloader_basic_randomsampler):
    print("batch index {}, 0/1/2: {}/{}/{}".format(
        i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))
    
print('\n Weighted random sampler')
for i, (x, y) in enumerate(dataloader_weighterrandomsampler):
    print("batch index {}, 0/1/2: {}/{}/{}".format(
        i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))


 Basic sampler
batch index 0, 0/1/2: 100/0/0
batch index 1, 0/1/2: 100/0/0
batch index 2, 0/1/2: 100/0/0
batch index 3, 0/1/2: 100/0/0
batch index 4, 0/1/2: 100/0/0
batch index 5, 0/1/2: 100/0/0
batch index 6, 0/1/2: 100/0/0
batch index 7, 0/1/2: 0/100/0
batch index 8, 0/1/2: 0/100/0
batch index 9, 0/1/2: 0/0/100

 Basic random sampler
batch index 0, 0/1/2: 70/15/15
batch index 1, 0/1/2: 67/25/8
batch index 2, 0/1/2: 68/22/10
batch index 3, 0/1/2: 70/22/8
batch index 4, 0/1/2: 83/13/4
batch index 5, 0/1/2: 77/14/9
batch index 6, 0/1/2: 71/17/12
batch index 7, 0/1/2: 64/27/9
batch index 8, 0/1/2: 64/25/11
batch index 9, 0/1/2: 66/20/14

 Weighted random sampler
batch index 0, 0/1/2: 26/39/35
batch index 1, 0/1/2: 34/31/35
batch index 2, 0/1/2: 34/31/35
batch index 3, 0/1/2: 29/39/32
batch index 4, 0/1/2: 33/35/32
batch index 5, 0/1/2: 28/33/39
batch index 6, 0/1/2: 37/32/31
batch index 7, 0/1/2: 30/39/31
batch index 8, 0/1/2: 39/32/29
batch index 9, 0/1/2: 39/34/27
