# Output Latex table with dataset statistics for the Appendix

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os, sys
project_dir = r'/home/er647/projects/feature-wise-active-learning/'
sys.path.insert(0, project_dir)
sys.path.insert(0, os.path.join(project_dir, 'src'))
# sys.path.insert(0, os.path.join(project_dir, 'notebooks'))
# sys.path.insert(0, os.path.join(project_dir, 'notebooks_TabEBM'))
# sys.path.insert(0, os.path.join(project_dir, 'evaluation_TabEBM'))

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import datetime
import time

from src.dataset import load_ASU_dataset, load_PBMC_small, load_finance, load_mice

In [7]:
def create_parser():
    # Create the parser
    parser = argparse.ArgumentParser(description="Process some integers.")

    # Add the data_dir argument
    parser.add_argument('--data_dir', type=str, help='Directory for data')

    # Parse the arguments with no command-line input (simulates hardcoding values)
    args = parser.parse_args(args=[])

    # Set data_dir manually
    args.data_dir = '/home/er647/data/fwal-data'

    return args

# Use the function and print the result
args = create_parser()

def load_dataset(dataset):
    if dataset in ["COIL20", "gisette", "Isolet", "madelon", "USPS"]:
        X, y = load_ASU_dataset(args, dataset)
    elif dataset == "PBMC":
        X, y = load_PBMC_small(args)
    elif dataset == "finance":
        X, y = load_finance(args)
    elif dataset=='mice_protein':
            X, y = load_mice(args)
    else:
        print(dataset, 'not found')
    return X,y
    

In [9]:
"""
==== Statistics to store for each dataset ====

- dataset name
- # samples (n)
- # features (d)
- # classes
- # samples per class
"""
# === initialise the vectors ===
dataset_names = ['COIL20','Isolet', 'PBMC','USPS', 'finance','madelon','mice_protein']
# samples = [1440, 1560, 1037,9298,2664,2600, 1080]
# features = [1024, 617,21932, 256, 155, 500,77]

# dataset_names = []
samples = []
features = []
ratio_d_n = []
ratio_n_d = []
classes = []
samples_per_class = []

for dataset in dataset_names:
    X, y = load_dataset(dataset)

    samples.append(X.shape[0])
    features.append(X.shape[1])
    ratio_d_n.append(float(features[-1]) / samples[-1])
    ratio_n_d.append(float(samples[-1]) / features[-1])
    classes.append(len(np.unique(y)))
    samples_per_class.append(sorted(pd.Series(y).value_counts().values))

# === Create dataframe ===
dataset_statistics = pd.DataFrame({
    'Dataset': map(str.lower, dataset_names),
    '# samples (N)': samples,
    '# features (D)': features,
    '# classes': classes,
    'N/D': [round(num) for num in ratio_n_d],
    # '# samples per class': [str(x)[1:-1] for x in samples_per_class],
    '# min samples per class': [min(x) for x in samples_per_class],
    '# max samples per class': [max(x) for x in samples_per_class]
})

dataset_statistics = dataset_statistics.sort_values(by='Dataset')
dataset_statistics

Unnamed: 0,Dataset,# samples (N),# features (D),# classes,N/D,# min samples per class,# max samples per class
0,coil20,1440,1024,20,1,72,72
4,finance,2664,154,2,17,1195,1469
1,isolet,1560,617,26,3,60,60
5,madelon,2600,500,2,5,1300,1300
6,mice_protein,1080,77,8,14,105,150
2,pbmc,1038,21932,2,0,514,524
3,usps,9298,256,10,36,708,1553


In [10]:
# === Print to latex ===
print(dataset_statistics.to_latex(index=False, header=True))

\begin{tabular}{lrrrrrr}
\toprule
     Dataset &  \# samples (N) &  \# features (D) &  \# classes &  N/D &  \# min samples per class &  \# max samples per class \\
\midrule
      coil20 &           1440 &            1024 &         20 &    1 &                       72 &                       72 \\
     finance &           2664 &             154 &          2 &   17 &                     1195 &                     1469 \\
      isolet &           1560 &             617 &         26 &    3 &                       60 &                       60 \\
     madelon &           2600 &             500 &          2 &    5 &                     1300 &                     1300 \\
mice\_protein &           1080 &              77 &          8 &   14 &                      105 &                      150 \\
        pbmc &           1038 &           21932 &          2 &    0 &                      514 &                      524 \\
        usps &           9298 &             256 &         10 &   36 &       

  print(dataset_statistics.to_latex(index=False, header=True))
