In [1]:
import torch
import torch.nn as nn
from collections import Counter

In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

**Weighted CrossEntropy**

In [34]:
import torch
from collections import Counter

# Define probabilities for each class to introduce imbalance
probabilities = torch.tensor([0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.1, 0.1])

# Generate labels based on the defined probabilities
labels = torch.arange(10)
data = torch.multinomial(probabilities, num_samples=10000, replacement=True).tolist()

print(f"Generated {len(data)} labels with class imbalance.")
print(f"First 10 labels: {data[:10]}")


# Calculate class counts and weights
counts = Counter(data)
print(f"Class counts: {counts}")

num_classes = max(labels) + 1
freqs = torch.tensor([counts.get(i, 0) for i in range(num_classes)], dtype=torch.float32)
weights = 1 / (freqs + 1e-5)

weights = weights / weights.sum() * num_classes # normalize weights

criterion = nn.CrossEntropyLoss(weight=weights.to(device))

Generated 10000 labels with class imbalance.
First 10 labels: [9, 4, 4, 4, 7, 9, 7, 9, 6, 7]
Class counts: Counter({6: 1523, 7: 1492, 8: 1037, 3: 1029, 9: 1026, 2: 982, 5: 970, 4: 945, 0: 505, 1: 491})


In [35]:
display(weights), weights.mean()

tensor([1.7428, 1.7925, 0.8963, 0.8553, 0.9314, 0.9074, 0.5779, 0.5899, 0.8487,
        0.8578])

(None, tensor(1.))

**Oversampling with WeightedRandomSampler**

In [44]:
from torch.utils.data import DataLoader, WeightedRandomSampler

# Calculate sample weights for each data point based on class frequencies
sample_weights = [1.0 / counts[label] for label in data]


sampler = WeightedRandomSampler(sample_weights, num_samples=len(data), replacement=True)

# Assuming we have a train_dataset and collate_fn defined
# train_loader = DataLoader(train_dataset, batch_size=8, sampler=sampler, collate_fn=collate_fn)

print(f"Length of data: {len(data)}")
print(f"Length of sample_weights: {len(sample_weights)}")
print(f"First 10 sample weights: {sample_weights[:10]}")

Length of data: 10000
Length of sample_weights: 10000
First 10 sample weights: [0.0009746588693957114, 0.0010582010582010583, 0.0010582010582010583, 0.0010582010582010583, 0.0006702412868632708, 0.0009746588693957114, 0.0006702412868632708, 0.0009746588693957114, 0.0006565988181221273, 0.0006702412868632708]
