# **Imbalanced Dataset Handing**

* The code addresses the issue of imbalanced datasets, where some classes have significantly fewer samples than others.
Imbalanced datasets can lead to biased models and poor performance on the less common classes.
* The code offers two methods for handling imbalanced datasets: oversampling and class weighting.
Oversampling generates additional samples for underrepresented classes, while class weighting assigns higher weights to the loss of samples in underrepresented classes.
* The get_loader function takes a dataset root directory and batch size, and returns a PyTorch data loader.
* The get_loader function applies transformations to the images in the dataset and calculates class weights based on the number of samples in each class.

* The get_loader function creates a WeightedRandomSampler object, which randomly selects a batch of samples with a probability proportional to their weights.

* The main function uses the data loader to iterate over the dataset for 10 epochs and counts the number of samples in each class.
The final output of the code is the counts for each class.




In [18]:
import torch
import torchvision.datasets as datasets
import os
from torch.utils.data import WeightedRandomSampler, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn

In [19]:
def get_dataloader(root_dir, batch_size):
  my_transforms = transforms.Compose(
      [
        transforms.Resize((224,224)),
        transforms.ToTensor(),
      ]
  )

  dataset = datasets.ImageFolder(root=root_dir, transform = my_transforms)
  subdirectories = dataset.classes
  class_weights = []

  for subdir in subdirectories:
    files = os.listdir(os.path.join(root_dir, subdir))
    class_weights.append(1/len(files))
  
  sample_weights = [0]*len(dataset)

  for idx, (data, label) in enumerate(dataset):
    class_weight = class_weights[label]
    sample_weights[idx] = class_weight
  
  sampler = WeightedRandomSampler(
      sample_weights, num_samples = len(sample_weights), replacement = True
  )

  loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler)

  return loader

In [20]:
loader = get_dataloader(root_dir="/content/retrivers_elkhounds", batch_size=10)
num_retrivers = 0
num_elkhounds = 0

In [21]:
for epoch in range(10):
  for data, labels in loader:
    num_retrivers += torch.sum(labels == 0)
    num_elkhounds += torch.sum(labels == 1)

In [22]:
print(num_retrivers.item())
print(num_elkhounds.item())

254
256
