#### Methods for dealing with imbalanced datasets:
- Oversampling
- Class weighting

In [1]:
import torch 
import torchvision.datasets as datasets
import os 
from torch.utils.data import WeightedRandomSampler, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn

In [2]:
# Class weighting
root_dir = r'dataset\imbalanced_dataset'
# 第一个类别中有50张图像，第二个类别中只有一个图像
# 在计算损失时，对于不同类别的图像乘以不同的权重，以达到平衡
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1, 50]))

In [5]:
# Oversampling
def get_loader(root_dir, batch_size):
    my_transforms = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ]
    )

    dataset = datasets.ImageFolder(root=root_dir, transform=my_transforms)

    # class_weights = [1, 50] # or [1/50, 1] is same
    # Automatic
    class_weights = []
    for root, dirs, files in os.walk(root_dir):
        if len(files) > 0:
            class_weights.append(1/len(files))

    sample_weights = [0] * len(dataset)

    for idx, (data, label) in enumerate(dataset):
        class_weight = class_weights[label]
        sample_weights[idx] = class_weight

    # replacement 表示一个样本是否会在一个batch(row)中重复出现
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=
                                        len(sample_weights), replacement=True)

    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    return loader


In [7]:
loader = get_loader(root_dir=root_dir, batch_size=8)

num_retrievers = 0
num_elkhounds = 0
for epoch in range(10):
    for data, labels in loader:
        # print(labels)
        num_retrievers += torch.sum(labels==0)
        num_elkhounds += torch.sum(labels==1)

print(num_retrievers)
print(num_elkhounds)

tensor(258)
tensor(252)
