In [3]:
from model_definitions import *
import os
import pandas as pd
import pmlb
from pmlb import fetch_data

def initialize_all_models(input_dimension: int, 
                          seed_val: int, 
                          output_dim: int = 1,
                          hidden_units_wide: int = 1000,
                          hidden_units_deep: int = 16,
                          hidden_layers: int = 8,
                          num_exps: int = 6) -> list:
    """Initialize models with given configurations."""
    common_args = {
        'input_dim': input_dimension, 
        'output_dim': output_dim, 
        'seed': seed_val
    }

    models = [
        create_linear_model(**common_args),
        create_wide_relu_ann(hidden_units=hidden_units_wide, **common_args),
        create_deep_relu_ann(hidden_units=hidden_units_deep, hidden_layers=hidden_layers, **common_args),
        LookupTableModel(partition_num=1, default_val=-1., **common_args)  # constant model
    ]

    #for partition_number in range(1, 11):
    for partition_num in [1,2,4,8,10]:
        models.append(SplineANN(partition_num=partition_num, **common_args))
        models.append(LookupTableModel(partition_num=partition_num, default_val=-1., **common_args))
        models.append(ANNEXSpline(partition_num=partition_num, num_exps=num_exps, **common_args))

    return models

# Path to the TSV file
metadata_path = os.path.join(os.path.dirname(pmlb.__file__), 'all_summary_stats.tsv')

# Read the TSV into a DataFrame
metadata = pd.read_csv(metadata_path, sep='\t')

# Filter datasets based on number of records (rows) and other criteria
filtered_datasets = metadata[
    (metadata['n_features'] < 5) &
    (metadata['n_binary_features'] == 0) &
    (metadata['n_categorical_features'] == 0) &
    (metadata['n_continuous_features'] == metadata['n_features'] ) &
    (metadata['n_instances'] >= 500) &  # Assuming 'n_instances' is the column for the number of records
    (metadata['n_instances'] <= 1000) &
    (metadata['endpoint_type'] == 'continuous') &
    (metadata['task'] == 'regression')
]

#print(filtered_datasets)
filtered_datasets


Unnamed: 0,dataset,n_instances,n_features,n_binary_features,n_categorical_features,n_continuous_features,endpoint_type,n_classes,imbalance,task
2,1029_LEV,1000,4,0,0,4,continuous,5.0,0.111245,regression
3,1030_ERA,1000,4,0,0,4,continuous,9.0,0.031251,regression
