In [10]:
import os
import pandas as pd
import numpy as np
import setuptools
import openml
from sklearn.linear_model import LinearRegression 
import lightgbm as lgbm
import lightgbmlss
import optuna
from scipy.spatial.distance import mahalanobis
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestRegressor
from sklearn.gaussian_process.kernels import Matern
from engression import engression, engression_bagged
import torch
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import mahalanobis
from scipy.stats import norm
from sklearn.metrics import mean_squared_error
from rtdl_revisiting_models import MLP, ResNet, FTTransformer
from properscoring import crps_gaussian, crps_ensemble
import random
import gpytorch
import tqdm.auto as tqdm
from lightgbmlss.model import *
from lightgbmlss.distributions.Gaussian import *
from drf import drf
from pygam import LinearGAM, s, f
import gower
from utils import EarlyStopping, train, train_trans, train_no_early_stopping, train_trans_no_early_stopping
from torch.utils.data import TensorDataset, DataLoader

#openml.config.apikey = 'FILL_IN_OPENML_API_KEY'  # set the OpenML Api Key
#SUITE_ID = 336 # Regression on numerical features
#SUITE_ID = 337 # Classification on numerical features
SUITE_ID = 335 # Regression on numerical and categorical features
#SUITE_ID = 334 # Classification on numerical and categorical features
benchmark_suite = openml.study.get_suite(SUITE_ID)  # obtain the benchmark suite

task_id=361093

print(f"Task {task_id}")

# Create the checkpoint directory if it doesn't exist
os.makedirs('CHECKPOINTS/GOWER', exist_ok=True)
CHECKPOINT_PATH = f'CHECKPOINTS/GOWER/task_{task_id}.pt'

print(f"Task {task_id}")

task = openml.tasks.get_task(task_id)  # download the OpenML task
dataset = task.get_dataset()

X, y, categorical_indicator, attribute_names = dataset.get_data(
        dataset_format="dataframe", target=dataset.default_target_attribute)

# Find features with absolute correlation > 0.9
corr_matrix = X.corr().abs()
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
high_corr_features = [column for column in upper_tri.columns if any(upper_tri[column] > 0.9)]

# Drop one of the highly correlated features
X = X.drop(high_corr_features, axis=1)

Task 361093
Task 361093


Starting from Version 0.15.0 `download_splits` will default to ``False`` instead of ``True`` and be independent from `download_data`. To disable this message until version 0.15 explicitly set `download_splits` to a bool.
Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.


In [11]:
X

Unnamed: 0,Actions_taken,Liberal,Unconstitutional,Precedent_alteration,Unanimous,Year_of_decision,Lower_court_disagreement
0,0,1,0,0,1,1953.0,1
1,0,0,0,0,1,1953.0,0
2,0,0,0,0,0,1953.0,0
3,0,0,0,0,1,1953.0,1
4,0,1,0,0,0,1953.0,0
...,...,...,...,...,...,...,...
4047,0,0,0,0,0,1988.0,1
4048,0,1,1,0,0,1988.0,0
4049,0,1,0,0,0,1988.0,1
4050,0,0,0,0,0,1988.0,0


In [12]:
# Set the random seed for reproducibility
N_TRIALS=100
N_SAMPLES=100
PATIENCE=40
N_EPOCHS=1000
GP_ITERATIONS=1000
BATCH_SIZE=1024
seed=10
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)


# Compute Gower distance and define train and test set
# calculate the Gower distance matrix
X_gower = X.copy()

for col in X_gower.select_dtypes(['category']).columns:
    X_gower[col] = X_gower[col].astype('object')

gower_dist_matrix = gower.gower_matrix(X_gower)

# calculate the Gower distance for each data point
gower_dist = np.mean(gower_dist_matrix, axis=1)

gower_dist=pd.Series(gower_dist,index=X.index)
far_index=gower_dist.index[np.where(gower_dist>=np.quantile(gower_dist,0.8))[0]]
close_index=gower_dist.index[np.where(gower_dist<np.quantile(gower_dist,0.8))[0]]

X_train = X.loc[close_index,:]
X_gower_ = X_train.copy()

for col in X_gower_.select_dtypes(['category']).columns:
    X_gower_[col] = X_gower_[col].astype('object')

# calculate the Gower distance matrix for the training set
gower_dist_matrix_train = gower.gower_matrix(X_gower_)

# calculate the Gower distance for each data point in the training set
gower_dist_train = np.mean(gower_dist_matrix_train, axis=1)

gower_dist_train=pd.Series(gower_dist_train,index=X_train.index)
far_index_train=gower_dist_train.index[np.where(gower_dist_train>=np.quantile(gower_dist_train,0.8))[0]]
close_index_train=gower_dist_train.index[np.where(gower_dist_train<np.quantile(gower_dist_train,0.8))[0]]


# Convert data to PyTorch tensors
# Modify X_train_, X_val, X_train, and X_test to have dummy variables
X = pd.get_dummies(X, drop_first=True)

X_train = X.loc[close_index,:]
X_test = X.loc[far_index,:]
y_train = y.loc[close_index]
y_test = y.loc[far_index]

X_train_ = X_train.loc[close_index_train,:]
X_val = X_train.loc[far_index_train,:]
y_train_ = y_train.loc[close_index_train]
y_val = y_train.loc[far_index_train]

# Standardize the data for non-dummy variables
non_dummy_cols = X.select_dtypes(exclude=['bool']).columns
mean_X_train_ = np.mean(X_train_[non_dummy_cols], axis=0)
std_X_train_ = np.std(X_train_[non_dummy_cols], axis=0)
X_train__scaled = X_train_.copy()
X_train__scaled[non_dummy_cols] = (X_train_[non_dummy_cols] - mean_X_train_) / std_X_train_
X_val_scaled = X_val.copy()
X_val_scaled[non_dummy_cols] = (X_val[non_dummy_cols] - mean_X_train_) / std_X_train_

mean_X_train = np.mean(X_train[non_dummy_cols], axis=0)
std_X_train = np.std(X_train[non_dummy_cols], axis=0)
X_train_scaled = X_train.copy()
X_train_scaled[non_dummy_cols] = (X_train[non_dummy_cols] - mean_X_train) / std_X_train
X_test_scaled = X_test.copy()
X_test_scaled[non_dummy_cols] = (X_test[non_dummy_cols] - mean_X_train) / std_X_train

In [13]:
X.dtypes

Actions_taken_1               bool
Actions_taken_10              bool
Actions_taken_11              bool
Actions_taken_2               bool
Actions_taken_3               bool
Actions_taken_4               bool
Actions_taken_5               bool
Actions_taken_6               bool
Actions_taken_7               bool
Liberal_1                     bool
Unconstitutional_1            bool
Precedent_alteration_1        bool
Unanimous_1                   bool
Year_of_decision_1954.0       bool
Year_of_decision_1955.0       bool
Year_of_decision_1956.0       bool
Year_of_decision_1957.0       bool
Year_of_decision_1958.0       bool
Year_of_decision_1959.0       bool
Year_of_decision_1960.0       bool
Year_of_decision_1961.0       bool
Year_of_decision_1962.0       bool
Year_of_decision_1963.0       bool
Year_of_decision_1964.0       bool
Year_of_decision_1965.0       bool
Year_of_decision_1966.0       bool
Year_of_decision_1967.0       bool
Year_of_decision_1968.0       bool
Year_of_decision_196

In [14]:
non_dummy_cols

Index([], dtype='object')

In [15]:
X

Unnamed: 0,Actions_taken_1,Actions_taken_10,Actions_taken_11,Actions_taken_2,Actions_taken_3,Actions_taken_4,Actions_taken_5,Actions_taken_6,Actions_taken_7,Liberal_1,...,Year_of_decision_1980.0,Year_of_decision_1981.0,Year_of_decision_1982.0,Year_of_decision_1983.0,Year_of_decision_1984.0,Year_of_decision_1985.0,Year_of_decision_1986.0,Year_of_decision_1987.0,Year_of_decision_1988.0,Lower_court_disagreement_1
0,False,False,False,False,False,False,False,False,False,True,...,False,False,False,False,False,False,False,False,False,True
1,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
2,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,True
4,False,False,False,False,False,False,False,False,False,True,...,False,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4047,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,True,True
4048,False,False,False,False,False,False,False,False,False,True,...,False,False,False,False,False,False,False,False,True,False
4049,False,False,False,False,False,False,False,False,False,True,...,False,False,False,False,False,False,False,False,True,True
4050,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,True,False


In [16]:
X_train_scaled

Unnamed: 0,Actions_taken_1,Actions_taken_10,Actions_taken_11,Actions_taken_2,Actions_taken_3,Actions_taken_4,Actions_taken_5,Actions_taken_6,Actions_taken_7,Liberal_1,...,Year_of_decision_1980.0,Year_of_decision_1981.0,Year_of_decision_1982.0,Year_of_decision_1983.0,Year_of_decision_1984.0,Year_of_decision_1985.0,Year_of_decision_1986.0,Year_of_decision_1987.0,Year_of_decision_1988.0,Lower_court_disagreement_1
1,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
2,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,False,False,False,False,False,False,False,False,False,True,...,False,False,False,False,False,False,False,False,False,False
6,False,False,False,False,False,False,False,False,False,True,...,False,False,False,False,False,False,False,False,False,False
8,True,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4042,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,True,False
4044,False,False,False,False,False,False,False,False,False,True,...,False,False,False,False,False,False,False,False,True,False
4045,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,True,False
4046,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,True,False
