In [1]:
# Reference: https://discuss.pytorch.org/t/some-problems-with-weightedrandomsampler/23242/34

In [22]:
import torch
import numpy as np

In [23]:
# Creating a dataset with imbalance

class_counts = torch.tensor([700, 200, 100])
num_data_points = class_counts.sum()
data_dim = 5
bs = 100
data = torch.randn(num_data_points, data_dim)

In [24]:
target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),
                   torch.ones(class_counts[1], dtype=torch.long),
                   torch.ones(class_counts[2], dtype=torch.long)*2))

In [25]:
print('target train 0/1/2: {}/{}/{}'.format(
     (target==0).sum(), (target==1).sum(), (target==2).sum()))

target train 0/1/2: 700/200/100


In [26]:
# Compute samples weight (each sample should get its own weight)

class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])

In [27]:
# Torch dataset

train_dataset = torch.utils.data.TensorDataset(data, target)

In [33]:
# Dataloaders with different sampling techniques

dataloader_basic_sampler = torch.utils.data.DataLoader(train_dataset, batch_size=bs, num_workers=0)

dataloader_basic_randomsampler = torch.utils.data.DataLoader(train_dataset, batch_size=bs, num_workers=0, shuffle=True)

wrs = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))
dataloader_weighterrandomsampler = torch.utils.data.DataLoader(train_dataset, batch_size=bs, num_workers=0, sampler=wrs)


In [34]:
# Iterate DataLoader and check class balance for each batch

acc = np.zeros(3)
print('\n Basic sampler')
for i, (x, y) in enumerate(dataloader_basic_sampler):
    print("batch index {}, 0/1/2: {}/{}/{}".format(i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))
    acc += [(y == 0).sum(), (y == 1).sum(), (y == 2).sum()]
print("Label distribution at the end \nof one epoch, 0/1/2: {}/{}/{}".format(int(acc[0]), int(acc[1]), int(acc[2])))
    
acc = np.zeros(3)    
print('\n Basic random sampler')
for i, (x, y) in enumerate(dataloader_basic_randomsampler):
    print("batch index {}, 0/1/2: {}/{}/{}".format(i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))
    acc += [(y == 0).sum(), (y == 1).sum(), (y == 2).sum()]
print("Label distribution at the end \nof one epoch, 0/1/2: {}/{}/{}".format(int(acc[0]), int(acc[1]), int(acc[2])))

acc = np.zeros(3)
print('\n Weighted random sampler')
for i, (x, y) in enumerate(dataloader_weighterrandomsampler):
    print("batch index {}, 0/1/2: {}/{}/{}".format(i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))
    acc += [(y == 0).sum(), (y == 1).sum(), (y == 2).sum()]
print("Label distribution at the end \nof one epoch, 0/1/2: {}/{}/{}".format(int(acc[0]), int(acc[1]), int(acc[2])))



 Basic sampler
batch index 0, 0/1/2: 100/0/0
batch index 1, 0/1/2: 100/0/0
batch index 2, 0/1/2: 100/0/0
batch index 3, 0/1/2: 100/0/0
batch index 4, 0/1/2: 100/0/0
batch index 5, 0/1/2: 100/0/0
batch index 6, 0/1/2: 100/0/0
batch index 7, 0/1/2: 0/100/0
batch index 8, 0/1/2: 0/100/0
batch index 9, 0/1/2: 0/0/100
Label distribution at the end 
of one epoch, 0/1/2: 700/200/100

 Basic random sampler
batch index 0, 0/1/2: 78/12/10
batch index 1, 0/1/2: 75/16/9
batch index 2, 0/1/2: 67/22/11
batch index 3, 0/1/2: 69/20/11
batch index 4, 0/1/2: 61/27/12
batch index 5, 0/1/2: 65/25/10
batch index 6, 0/1/2: 76/13/11
batch index 7, 0/1/2: 72/19/9
batch index 8, 0/1/2: 69/23/8
batch index 9, 0/1/2: 68/23/9
Label distribution at the end 
of one epoch, 0/1/2: 700/200/100

 Weighted random sampler
batch index 0, 0/1/2: 29/34/37
batch index 1, 0/1/2: 32/32/36
batch index 2, 0/1/2: 38/28/34
batch index 3, 0/1/2: 32/35/33
batch index 4, 0/1/2: 34/30/36
batch index 5, 0/1/2: 35/32/33
batch index 6, 

#### We can observe that-

####  - The basic sequential sampler simply loads the data sequentially and covers the entire dataset in one epoch.

####  - The basic random sampler loads the data randomly and also covers the entire dataset in one epoch.

####  - The weighted random sampler loads data randomly subject to the probabilities we assigned to the data labels. As it loads     data based on the imposed constraints, the loader cannot go through the entire dataset in one epoch. The minority class data points are seen multiple times in an epoch as the replacement option is set True.