## Balancing the MNIST Dataset

### Purpose
The script aims to create a balanced version of the MNIST dataset where each digit class has an equal number of samples.

### Key Steps

1. **Data Loading**:
    - MNIST data is loaded from a CSV file.
    
2. **Data Organization**:
    - Data is organized by labels (0-9) into separate lists.
    
3. **Balancing**:
    - The label with the minimum number of occurrences is identified.
    - Samples from each label are randomly selected to match this minimum count, ensuring an equal number of samples for each digit class.
    
4. **Output**:
    - The balanced dataset is written to a new CSV file.

### Output
A message confirms the successful creation of the balanced dataset.


In [4]:
import random

# Path to the MNIST CSV file and the output file
input_path = "mnist_dataset/mnist_train.csv"
output_path = "mnist_dataset/mnist_balanced.csv"

# Load the MNIST data
with open(input_path, 'r') as file:
    data_list = file.readlines()

# Separate the data by label
data_by_label = {i: [] for i in range(10)}
for record in data_list:
    label = int(record.split(',')[0])
    data_by_label[label].append(record)

# Determine the number with the minimum occurrences
min_count = min(len(records) for records in data_by_label.values())

# Randomly sample records from each label category to match the minimum occurrences
balanced_data = []
for label, records in data_by_label.items():
    balanced_data.extend(random.sample(records, min_count))

# Write the balanced dataset to a new CSV file
with open(output_path, 'w') as file:
    for record in balanced_data:
        file.write(record)

print(f"Balanced dataset written to {output_path}")
80  # Change this for different imbalance percentage

Imbalanced dataset written to mnist_dataset/mnist_imbalanced_0_70.csv
