## Creating Imbalanced MNIST Dataset

### Purpose
The script generates imbalanced versions of the MNIST dataset by adjusting the occurrence of a target digit relative to the other digits.

### Key Steps

1. **Data Loading**:
    - MNIST data is loaded from a CSV file.

2. **Data Organization**:
    - Data is categorized by labels (0-9) into separate lists.

3. **Specifying Imbalance**:
    - A target digit and an imbalance percentage are specified.
    - The imbalance percentage determines the reduced occurrence of the other digits relative to the target digit.

4. **Creating Imbalance**:
    - The count of the target digit is found.
    - The total count for the remaining digits is calculated to maintain the specified imbalance percentage.
    - An equal number of samples for each of the other 9 digits are selected to form the imbalanced dataset.

5. **Output**:
    - The imbalanced dataset is saved to a new CSV file.

### Iterative Imbalance Creation
For each digit (0-9) and for a set of imbalance percentages (15%, 20%, 50%, 80%), an imbalanced dataset is created.

### Output
A confirmation message is displayed for each imbalanced dataset's successful creation.


In [4]:
import random


def create_imbalance(digit, balance):
    # Path to the MNIST CSV file and the output file
    input_path = "mnist_dataset/mnist_train.csv"
    
    # Specify the target digit and the imbalance percentage for it
    TARGET_DIGIT = digit  # Change this variable to choose the digit
    IMBALANCE_PERCENTAGE = balance  # Change this for different imbalance percentage
    
    output_path = "mnist_dataset/mnist_imbalanced_"+ str(TARGET_DIGIT) + "_"+ str(IMBALANCE_PERCENTAGE) + ".csv"
    
    # Load the MNIST data
    with open(input_path, 'r') as file:
        data_list = file.readlines()
    
    # Separate the data by label
    data_by_label = {i: [] for i in range(10)}
    for record in data_list:
        label = int(record.split(',')[0])
        data_by_label[label].append(record)
    
    # Count the occurrences of the target digit
    target_digit_count = len(data_by_label[TARGET_DIGIT])
    
    # Calculate the total count for the remaining digits to ensure the imbalance percentage
    other_digits_total_count = int((target_digit_count * (100 - IMBALANCE_PERCENTAGE)) / IMBALANCE_PERCENTAGE)
    
    # Calculate count for each of the other digits
    each_digit_count = other_digits_total_count // 9  # Divide equally among 9 other digits
    
    # Randomly sample records from each of the other 9 digit categories
    imbalance_data = data_by_label[TARGET_DIGIT]  # Add all of the target digit records
    for label, records in data_by_label.items():
        if label != TARGET_DIGIT:
            imbalance_data.extend(random.sample(records, each_digit_count))
    
    # Write the imbalanced dataset to a new CSV file
    with open(output_path, 'w') as file:
        for record in imbalance_data:
            file.write(record)
    
    print(f"Imbalanced dataset written to {output_path}")

for i in range (0, 10):
    for j in [15, 20, 50, 80]:
        create_imbalance(i, j)

Imbalanced dataset written to mnist_dataset/mnist_imbalanced_0_15.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_0_20.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_0_50.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_0_80.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_1_15.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_1_20.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_1_50.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_1_80.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_2_15.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_2_20.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_2_50.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_2_80.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_3_15.csv
Imbalanced dataset written to mnist_dataset/mnist_imbalanced_3_20.csv
Imbalanced dataset w