In [1]:
import torch
import numpy as np
from torch.utils.data import WeightedRandomSampler, TensorDataset, DataLoader

numDataPoints = 1000
data_dim = 5
bs = 100

positive_class_proportion = 0.1
negative_class_proportion = 0.9

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * negative_class_proportion), dtype=np.int32),
                    np.ones(int(numDataPoints * positive_class_proportion), dtype=np.int32)))

print(f'target train 0/1: {len(np.where(target == 0)[0])}/{len(np.where(target == 1)[0])}')

class_sample_count = np.unique(target, return_counts=True)[1]
# weight = 1. / class_sample_count
# samples_weight = weight[target]

new_majority_proportion = 3
class_sample_count[0] /= new_majority_proportion
weight = 1. / class_sample_count
samples_weight = weight[target]

samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = TensorDataset(data, target)

train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
    print(f"batch index {i}, 0/1: {len(np.where(target.numpy() == 0)[0])}/{len(np.where(target.numpy() == 1)[0])}")

target train 0/1: 900/100
batch index 0, 0/1: 75/25
batch index 1, 0/1: 73/27
batch index 2, 0/1: 72/28
batch index 3, 0/1: 82/18
batch index 4, 0/1: 84/16
batch index 5, 0/1: 78/22
batch index 6, 0/1: 79/21
batch index 7, 0/1: 75/25
batch index 8, 0/1: 73/27
batch index 9, 0/1: 67/33


In [2]:
weights = 1. / torch.tensor(class_sample_count, dtype=torch.float)
samples_weights = weights[target]

sampler = WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),
    replacement=True)

In [24]:
from sklearn.model_selection import RepeatedKFold, KFold
X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# y = np.array([0, 0, 1, 1])

# rkf = RepeatedKFold(n_splits=10, n_repeats=10, random_state=2652124)
# kf = KFold(n_splits=9)
# for train_index, test_index in rkf.split(X):
#     # print(train_index, test_index)
#     for subtrain_index, val_index in kf.split(train_index):
#         print(subtrain_index, val_index)

# kf10 = KFold(n_splits=10)
# kf9  = KFold(n_splits=9)
# for train_idx, test_idx in kf10.split(X):
#     print(train_idx[:-1], train_idx[-1], test_idx)
#     # for subtrain_idx, val_idx in kf9.split(train_idx):
#     #     print(subtrain_idx, val_idx, test_idx)
    
kFold = KFold(n_splits=10)
for i, (train_index, test_index) in enumerate(kFold.split(X)):
    val_index = [train_index[-1]]
    train_index = train_index[:-1]
    print(i, train_index, val_index, test_index)
    

0 [1 2 3 4 5 6 7 8] [9] [0]
1 [0 2 3 4 5 6 7 8] [9] [1]
2 [0 1 3 4 5 6 7 8] [9] [2]
3 [0 1 2 4 5 6 7 8] [9] [3]
4 [0 1 2 3 5 6 7 8] [9] [4]
5 [0 1 2 3 4 6 7 8] [9] [5]
6 [0 1 2 3 4 5 7 8] [9] [6]
7 [0 1 2 3 4 5 6 8] [9] [7]
8 [0 1 2 3 4 5 6 7] [9] [8]
9 [0 1 2 3 4 5 6 7] [8] [9]
