# Base


## Setup

* Reload modules automatically
* Load environvent variables from .env
* Imports
* Load configs
* Set up logging

In [None]:
%load_ext autoreload
%autoreload

In [None]:
%load_ext dotenv
%dotenv

In [None]:
import logging

import mlflow
from omegaconf import OmegaConf
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline

In [None]:
cfg = OmegaConf.load("../config.yaml")

In [None]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    stream=sys.stdout,
    format=cfg.logging.format,
    datefmt=cfg.logging.date_format,
    level=cfg.logging.level
)

## Load data

In [None]:
input_data = load_wine(as_frame=True)

## Prepare data

In [None]:
X = input_data.copy()
y = X.pop(cfg.train.target_name)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=cfg.train.test_size,
    random_state=cfg.train.random_state,
    stratify=y
)

## Define model

In [None]:
pipeline = make_pipeline(
    StandardScaler(),
    RandomForestClassifier(random_state=cfg.train.random_state, n_jobs=-1)
)

## Set up experiment tracking

In [None]:
if cfg.train.experiment_name:
    experiment = mlflow.get_experiment_by_name(name=cfg.train.experiment_name)
    if not experiment:
        mlflow.create_experiment(name=cfg.train.experiment_name)
    mlflow.set_experiment(experiment_name=cfg.train.experiment_name)

mlflow.autolog()

## Train & evaluate

In [None]:
with mlflow.start_run() as run:
    run_info = run.info
    logger.info("Running experiment with id: %s", run_info.run_id)

    logger.info("Fitting model.")
    pipeline.fit(X_train, y_train)

    logger.info("Evaluating trained model.")
    model_path = mlflow.get_artifact_uri(artifact_path="model")
    test_data = X_test.copy()
    test_data["target"] = y_test
    mlflow.evaluate(
        model=model_path,
        data=test_data,
        targets="target",
        model_type=cfg.tran.model_type
    )
    logger.info("Finished experiment run %s", run_info)