# Data exploration and preprocessing


[MNIST dataset](http://yann.lecun.com/exdb/mnist/)

In [1]:
import pandas as pd

In [2]:
MANUAL_SEED: int = 42

## Data loading

In [3]:
initial_df = pd.read_csv("../data/raw/mnist.csv")

initial_df.head(5)

Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## Data preprocessing

In [4]:
VALID_LABELS = [0, 1]
LABEL_SIZE_LIMIT = 1000


def concat_labels(
    df: pd.DataFrame,
    labels: list[int] = VALID_LABELS,
    label_size_limit: int = LABEL_SIZE_LIMIT,
) -> pd.DataFrame:
    return pd.concat(
        [df[df["label"] == label].iloc[:label_size_limit] for label in labels]
    )


def shuffle_dataset(df: pd.DataFrame, seed: int = MANUAL_SEED) -> pd.DataFrame:
    return df.sample(frac=1, random_state=seed).reset_index(drop=True)


def normalize_dataset(df: pd.DataFrame) -> pd.DataFrame:
    normalized_df = df.copy()
    normalized_df.iloc[:, 1:] = normalized_df.iloc[:, 1:] / 255
    return normalized_df


def transform_label(df: pd.DataFrame, labels: list[int] = VALID_LABELS) -> pd.DataFrame:
    transform_dict = {labels[0]: -1, labels[1]: 1}

    transformed_df = df.copy()
    transformed_df["label"] = transformed_df["label"].apply(lambda x: transform_dict[x])
    return transformed_df


def preprocess(df: pd.DataFrame) -> pd.DataFrame:
    return transform_label(normalize_dataset(shuffle_dataset(concat_labels(df))))

In [5]:
df = preprocess(initial_df)
print(f"{df.shape=}")
df.head(5)

df.shape=(2000, 785)


Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## Data saving

In [6]:
df.to_csv("../data/interim/mnist_binary.csv", index=False)