In [1]:
import numpy as np
import pandas as pd

from tensorflow.keras.datasets import mnist
from datasets import load_dataset

In [2]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

In [3]:
print(f"Train images shape : {train_images.shape}")
print(f"Shape of single image : {train_images[0].shape}")

Train images shape : (60000, 28, 28)
Shape of single image : (28, 28)


In [4]:
train_images = train_images.astype('float32') / 255.0

In [5]:
df = pd.DataFrame(columns=['imgs', 'labels'])

df['labels'] = train_labels
df['imgs'] = train_images.tolist()

In [6]:
print('Max number of images is' , max(df['labels'].value_counts()), 'for label : ', df['labels'].value_counts().idxmax())
print('Min number of images is' , min(df['labels'].value_counts()), 'for label : ', df['labels'].value_counts().idxmin())

Max number of images is 6742 for label :  1
Min number of images is 5421 for label :  5


| Inputs         | Output |
|:--------------:|:------:|
| 0 - 1 - 2 - 3  |   4    |
| 1 - 2 - 3 - 4  |   5    |
| 2 - 3 - 4 - 5  |   6    |
| 3 - 4 - 5 - 6  |   7    |
| 4 - 5 - 6 - 7  |   8    |
| 5 - 6 - 7 - 8  |   9    |


In [7]:
class DataGenerator:
    def __init__(self):
        self.new_df = pd.DataFrame(columns=['input1', 'input2', 'input3', 'input4', 'output'])

    @staticmethod
    def filter_by_label(df, label):
        return df[df['labels'] == label]

    def generate_data_for_labels(self, df, label_numbers):
        max_label_count = min(df[df['labels'].isin(label_numbers)]['labels'].value_counts())
        for i in range(max_label_count):
            self.new_df.loc[len(self.new_df)] = [
                self.filter_by_label(df, label_numbers[0]).iloc[i], 
                self.filter_by_label(df, label_numbers[1]).iloc[i], 
                self.filter_by_label(df, label_numbers[2]).iloc[i], 
                self.filter_by_label(df, label_numbers[3]).iloc[i], 
                self.filter_by_label(df, label_numbers[4]).iloc[i]
            ]
        return self.new_df

    @staticmethod
    def generate_formula_data(n):
        return n, (n + 1), (n + 2), (n + 3), (n + 4)

    def generate_data(self, df):
        data_formats = [list(self.generate_formula_data(i)) for i in range(10)]
        for data_format in data_formats:
            self.generate_data_for_labels(df, data_format)
        return self.new_df

In [10]:
generator = DataGenerator()
generated_df = generator.generate_data(df)