In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [3]:
from typing import Literal
from pydantic import BaseModel

class Args(BaseModel):
    company: Literal["LMK", "AMK", "GL", "RT"]
    env: Literal["dev", "prod"]
    is_running_on_databricks: bool

In [4]:
args = Args(
    company="LMK",
    env="dev",
    is_running_on_databricks=False
)

is_running_on_databricks = args.is_running_on_databricks

In [5]:
from databricks_env import auto_setup_env
if is_running_on_databricks:
    auto_setup_env()

In [6]:
from constants.companies import get_company_by_code

company_code = args.company
company = get_company_by_code(company_code=company_code)
company_id = company.company_id

In [7]:
from dishes_forecasting.spark_context import create_spark_context
spark = create_spark_context()

In [None]:
from dishes_forecasting.train.configs.feature_lookup_config import feature_lookup_config_list
from dishes_forecasting.train.training_set import create_training_set
from databricks.feature_store import FeatureStoreClient

from dishes_forecasting.train.configs.train_configs import get_training_configs
train_config = get_training_configs(company_code=args.company)

if is_running_on_databricks:
    fs = FeatureStoreClient()
else:
    fs = None
training_set, df_training_pk_target = create_training_set(
    is_use_feature_store=False,
    env=args.env,
    company_id=company_id,
    train_config=train_config,
    spark=spark,
    feature_lookup_config_list=feature_lookup_config_list,
    fs=fs,
)

In [None]:
from dishes_forecasting.train.tune import tune_pipeline
best_params, best_mae = tune_pipeline(
    company=company,
    env="dev",
    spark=spark,
    training_set=training_set,
    train_config=train_config,
    n_trials=30
)

In [None]:
from dishes_forecasting.train.tune import tune_pipeline
best_params, best_mae = tune_pipeline(
    company=company,
    env=args.env,
    spark=spark,
    train_config=train_config,
    training_set=training_set,
    n_trials=5
)