In [None]:
import os.path

from src.classes.data.DataPreprocessor import DataPreprocessor
from src.classes.data.DataSplitter import DataSplitter
from src.classes.data.DatasetLoader import DatasetLoader
from src.classes.evaluation.periodicity.ModelFactory import ModelFactory
from src.classes.utils.Logger import Logger
from src.classes.utils.PeriodicityDetector import PeriodicityDetector
from src.config import DATASET_ID, CLASSIFICATION, MODEL

PATH_TO_MODEL = os.path.join("..", "results", "benchmark_num_reg")

logger = Logger()

In [None]:
x, y = DatasetLoader().load_dataset(DATASET_ID)

dataset_config = DatasetLoader().get_dataset_config(DATASET_ID)

cat_cols = dataset_config["cat_cols"]
numerical_columns = [col for col in x.columns if col not in cat_cols]
categorical_columns = [col for col in x.columns if col in cat_cols]

ordered_columns = numerical_columns + categorical_columns
x = x[ordered_columns]

In [None]:
x_num_cols = [col for col in numerical_columns if col != 'month']
idx_periodic, idx_non_periodic = [], []

logger.info(f"Analyzing periodicity for {len(x_num_cols)} numerical features.")
for column in x_num_cols:
    series = x[column].values
    if PeriodicityDetector().detect_periodicity_acf(series):
        logger.debug(f"Feature '{column}' detected as periodic.")
        idx_periodic.append(x.columns.get_loc(column))
    else:
        logger.debug(f"Feature '{column}' detected as non-periodic.")
        idx_non_periodic.append(x.columns.get_loc(column))

In [None]:
idx_num = [x.columns.get_loc(col) for col in numerical_columns]
idx_cat = [x.columns.get_loc(col) for col in categorical_columns]

In [None]:
preprocessor = DataPreprocessor(x, y, cat_cols)
x_original_shape = x.shape
x = preprocessor.make_preprocessor().fit_transform(x)
if CLASSIFICATION:
    y = preprocessor.encode_target()

In [None]:
import torch

data_splitter = DataSplitter(x, y, idx_num, idx_cat, idx_periodic, idx_non_periodic)
split_data = data_splitter.split()

input_sizes = split_data['input_sizes']
output_size = y.nunique() if CLASSIFICATION else 1

model_factory = ModelFactory(
    num_periodic_input_size=input_sizes['num_periodic_input_size'],
    num_non_periodic_input_size=input_sizes['num_non_periodic_input_size'],
    cat_input_size=input_sizes['cat_input_size'],
    output_size=output_size,
    dataset_config=dataset_config
)
model = model_factory.get_model(MODEL)
model.load_state_dict(torch.load(PATH_TO_MODEL))