In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Literal
from pydantic import BaseModel

class Args(BaseModel):
    company: Literal["LMK", "AMK", "GL", "RT"]
    env: Literal["dev", "prod"]
    is_running_on_databricks: bool


In [3]:
args = Args(
    company="RT",
    env="dev",
    is_running_on_databricks=False
)

is_running_on_databricks = args.is_running_on_databricks

In [4]:
from databricks_env import auto_setup_env
if is_running_on_databricks:
    auto_setup_env()

In [5]:
from dishes_forecasting.train.configs.train_configs import get_training_configs
train_config = get_training_configs(company_code=args.company)

In [6]:
import logging

from constants.companies import get_company_by_code

company_code = args.company
company = get_company_by_code(company_code=company_code)
company_id = company.company_id

In [7]:
env = "dev"

In [None]:
from dishes_forecasting.train.training_set import create_training_set
from dishes_forecasting.train.configs.feature_lookup_config import feature_lookup_config_list

training_set, df_training_pk_target = create_training_set(
    company_id=company_id,
    train_config=train_config,
    feature_lookup_config_list=feature_lookup_config_list,
    is_drop_ignored_columns=True,
)

training_set

In [None]:
import mlflow

mlflow.set_registry_uri("databricks-uc")
mlflow.set_experiment("/Shared/ml_experiments/dishes-forecasting")

In [None]:
from dishes_forecasting.train.train_pipeline import train_model
from dishes_forecasting.train.configs.hyper_params import load_hyperparams
params_lgb, params_rf, params_xgb = load_hyperparams(company=args.company)

is_running_on_databricks=False

custom_pipeline, X_train, X_test, y_train, y_test, mape, mae, df_test_metrics, df_test_binned = train_model(
    training_set=training_set,
    params_lgb=params_lgb,
    params_rf=params_rf,
    params_xgb=params_xgb,
    env="dev",
    is_running_on_databricks=is_running_on_databricks,
    train_config=train_config,
    company=company,
    is_register_model=True,
    is_log_model=True
)

In [None]:
mae

In [None]:
mape

In [None]:
df_test_binned

In [56]:
import numpy as np
y_pred_transformed = np.exp(custom_pipeline.predict(X_test))

In [None]:
from dishes_forecasting.train.training_set import get_training_pk_target
df_training_target = get_training_pk_target(
    company_id=company_id,
    min_yyyyww=train_config["train_start_yyyyww"],
    max_yyyyww=train_config["train_end_yyyyww"],
    is_training_set=False,
)

In [9]:
from dishes_forecasting.schema import feature_schema
feature_schema.coerce=True

In [None]:
feature_schema.validate(training_set.dropna())

In [None]:
X_test

In [82]:
mlflow.set_tracking_uri("databricks")