In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import mlflow
import datetime

from lib.reproduction import major_oxides
from lib.full_flow_dataloader import load_full_flow_data
from lib.norms import Norm1Scaler, Norm3Scaler
from sklearn.cross_decomposition import PLSRegression
from lib.get_preprocess_fn import get_preprocess_fn
from lib.cross_validation import CustomKFoldCrossValidator, get_cross_validation_metrics, perform_cross_validation
from lib.metrics import rmse_metric, std_dev_metric

train_processed, test_processed = load_full_flow_data()

In [None]:
drop_cols = major_oxides + ["ID", "Sample Name"]
norm = 3

In [None]:
models = []

pls_params = {
    "n_components": 34,
    "scale": True,
    "max_iter": 500,
}

mlflow.set_experiment(f'PLS_Norm{norm}_{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')

for target in major_oxides:
    with mlflow.start_run(run_name=f"PLS_{target}"):
        # == Cross Validation ==
        kf = CustomKFoldCrossValidator(k=5, random_state=42, group_by="Sample Name")
        scaler = Norm1Scaler() if norm == 1 else Norm3Scaler()

        cv_metrics = perform_cross_validation(
            model=PLSRegression(**pls_params),
            preprocess_fn=get_preprocess_fn(target_col=target, drop_cols=drop_cols, preprocessor=scaler),
            kf=kf,
            data=train_processed,
            metric_fns=[rmse_metric, std_dev_metric],
        )

        mlflow.log_metrics(get_cross_validation_metrics(cv_metrics).as_dict())

        # == Training ==
        preprocess_fn = get_preprocess_fn(target_col=target, drop_cols=drop_cols, preprocessor=scaler)
        X_train, y_train, X_test, y_test = preprocess_fn(train_processed, test_processed)

        # Train the model
        model = PLSRegression(**pls_params)
        model.fit(X_train, y_train)
        models.append(model)

        pred = model.predict(X_test)
        rmse = rmse_metric(y_test, pred)
        std_dev = std_dev_metric(y_test, pred)

        # Logging
        mlflow.log_params({
            **pls_params,
            "target": target,
            "norm": norm
        })
        mlflow.log_metrics({"rmse": rmse, "std_dev": std_dev})
        mlflow.sklearn.log_model(model, f"model_{target}")