In [59]:
# Load libraries
import os
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
# Turn interactive plotting off, show plot only when plt.show() is called
plt.ioff()

INPUT = "datasets"
OUTPUT = os.path.join("datasets", "exploratory data analysis")

if not os.path.exists(OUTPUT):
    os.makedirs(OUTPUT)

from warnings import filterwarnings
filterwarnings(action='ignore', category=DeprecationWarning, message='`np.bool` is a deprecated alias')

# Import datasets

In [45]:
# preprocessing method
def cat_2_num(df:pd.DataFrame):
    cat_columns = df.select_dtypes(['object']).columns
    df[cat_columns] = df[cat_columns].astype('category')
    df[cat_columns] = df[cat_columns].apply(lambda x: x.cat.codes)
    return df

## UCI datasets

In [46]:
#German dataset
def German():
    df = pd.read_table(os.path.join(INPUT,"german.data-numeric"),delim_whitespace = True, header = None)
    df[24] = df[24]-1    # change label from 1,2 to 0,1
    return df.drop(columns = [24]), df[24].astype(bool)

In [47]:
def Australian():
    df = pd.read_table(os.path.join(INPUT,"australian.dat"),delim_whitespace = True, header = None)
    return df.drop(columns = [14]), df[14].astype(bool)

In [48]:
def Crx():
    df = pd.read_csv(os.path.join(INPUT,"crx.data"), header = None)
    # drop entries with ?
    df = df.replace("?", np.nan).dropna()
    # convert category data to numerical data
    df = cat_2_num(df)
    return df.drop(columns = [15]), df[15].astype(bool)

In [49]:
def Hepatitis():
    df = pd.read_csv(os.path.join(INPUT,"hepatitis.data"), header = None)
    df = cat_2_num(df)
    df[19] = df[19]-1 # change to 0 or 1
    return df.drop(columns = [19]), df[19].astype(bool)

In [50]:
def Ionosphere():
    df = pd.read_csv(os.path.join(INPUT, "ionosphere.data"), header=None)
    df = cat_2_num(df)
    
    return df.drop(columns = [34]), df[34].astype(bool)

## Additional Kaggle datasets

In [51]:
def Pumpkin():
    df = pd.read_excel(os.path.join("datasets",'Pumpkin_Seeds_Dataset.xlsx'), sheet_name='Pumpkin_Seeds_Dataset',engine='openpyxl')
    df = cat_2_num(df)
    return df.drop(columns = ['Class']), df['Class'].astype(bool)

In [52]:
# 5644 samples, relatively large dataset
def Mushroom():
    df = pd.read_csv(os.path.join(INPUT,'mushrooms.csv'))
    df = df.replace("?", np.nan).dropna()
    df = cat_2_num(df)
    return df.drop(columns = ['class']), df['class'].astype(bool)

In [53]:
def Diabetes():
    df = pd.read_csv(os.path.join(INPUT,'diabetes_data.csv'), sep=';')
    df = cat_2_num(df)
    return df.drop(columns = ['class']), df['class'].astype(bool)

# Exploratory Data Analysis

In [60]:
dataset_getters = [German, Australian, Crx, Hepatitis, Ionosphere, Pumpkin, Mushroom, Diabetes]
for getter in dataset_getters:
    X, y = getter()

    width = len(X.columns)

    # basic hist plot
    X.hist(figsize=(width,int(width*0.6)))
    plt.savefig(os.path.join(OUTPUT, f"{getter.__name__}_X_hist.png"))
    plt.close()

    corr_mat = X.corr().round(2)
    f, ax = plt.subplots(figsize=(width,width))
    mask = np.zeros_like(corr_mat,dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True
    sns.heatmap(corr_mat,mask=mask,vmin=-1,vmax=1,center=0, 
                cmap='plasma',square=False,lw=2,annot=True,cbar=False).set_title(f"{getter.__name__} Correlation Map")
    plt.savefig(os.path.join(OUTPUT, f"{getter.__name__}_X_corr.png"))
    plt.close()

    y.value_counts().plot(kind='bar', title = f"{getter.__name__} Y Distribution")
    plt.savefig(os.path.join(OUTPUT, f"{getter.__name__}_y.png"))
    plt.close()