In [1]:
import logging
from pathlib import Path
from typing import Dict

import mlflow
import numpy as np
import pandas as pd
from dotenv import dotenv_values
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import mean_squared_error
from tqdm import tqdm

# from config import logger
from lib.data_handling import CustomSpectralPipeline, load_split_data  # type: ignore
from lib.norms import Norm1Scaler, Norm3Scaler
from lib.outlier_removal import (
    calculate_leverage_residuals,
    identify_outliers,
    plot_leverage_residuals,
)
from lib.reproduction import (
    major_oxides,
    masks,
    optimized_blending_ranges,
    oxide_ranges,
    paper_individual_sm_rmses,
    spectrometer_wavelength_ranges,
    training_info,
)
from lib.utils import custom_kfold_cross_validation, filter_data_by_compositional_range
from PLS_SM.inference import predict_composition_with_blending

env = dotenv_values()
comp_data_loc = env.get("COMPOSITION_DATA_PATH")
dataset_loc = env.get("DATA_PATH")

if not comp_data_loc:
    print("Please set COMPOSITION_DATA_PATH in .env file")
    exit(1)

if not dataset_loc:
    print("Please set DATA_PATH in .env file")
    exit(1)

logger = logging.getLogger("train")

mlflow.set_tracking_uri("http://localhost:5000")

preformatted_data_path = Path("./data/_preformatted_sm/")
train_path = preformatted_data_path / "train.csv"
test_path = preformatted_data_path / "test.csv"

train_n1_path = preformatted_data_path / "train_n1.csv"
train_n3_path = preformatted_data_path / "train_n3.csv"
test_n1_path = preformatted_data_path / "test_n1.csv"
test_n3_path = preformatted_data_path / "test_n3.csv"

if (
    not preformatted_data_path.exists()
    or not train_path.exists()
    or not test_path.exists()
):
    take_samples = None

    logger.info("Loading data from location: %s", dataset_loc)
    # data = load_data(str(dataset_loc))
    train_data, test_data = load_split_data(
        str(dataset_loc), split_loc="./train_test_split.csv", average_shots=True
    )
    logger.info("Data loaded successfully.")

    logger.info("Initializing CustomSpectralPipeline.")
    pipeline = CustomSpectralPipeline(
        masks=masks,
        composition_data_loc=comp_data_loc,
        major_oxides=major_oxides,
    )
    logger.info("Pipeline initialized. Fitting and transforming data.")
    train_processed = pipeline.fit_transform(train_data)
    test_processed = pipeline.fit_transform(test_data)
    logger.info("Data processing complete.")

    preformatted_data_path.mkdir(parents=True, exist_ok=True)

    train_processed.to_csv(train_path, index=False)
    test_processed.to_csv(test_path, index=False)

    n1_scaler = Norm1Scaler(reshaped=True)
    n3_scaler = Norm3Scaler(spectrometer_wavelength_ranges, reshaped=True)

    train_cols = train_processed.columns
    test_cols = test_processed.columns

    train_processed_n1 = n1_scaler.fit_transform(train_processed)
    train_processed_n3 = n3_scaler.fit_transform(train_processed)
    test_processed_n1 = n1_scaler.fit_transform(test_processed)
    test_processed_n3 = n3_scaler.fit_transform(test_processed)

    # turn back into dataframe
    train_processed_n1 = pd.DataFrame(train_processed_n1, columns=train_cols)
    train_processed_n3 = pd.DataFrame(train_processed_n3, columns=train_cols)
    test_processed_n1 = pd.DataFrame(test_processed_n1, columns=test_cols)
    test_processed_n3 = pd.DataFrame(test_processed_n3, columns=test_cols)

    train_processed_n1.to_csv(preformatted_data_path / "train_n1.csv", index=False)
    train_processed_n3.to_csv(preformatted_data_path / "train_n3.csv", index=False)

    test_processed_n1.to_csv(preformatted_data_path / "test_n1.csv", index=False)
    test_processed_n3.to_csv(preformatted_data_path / "test_n3.csv", index=False)

    logger.info("Preformatted data saved to %s", preformatted_data_path)
else:
    logger.info("Loading preformatted data from location: %s", preformatted_data_path)
    train_processed = pd.read_csv(train_path)
    test_processed = pd.read_csv(test_path)

    # train_processed_n1 = pd.read_csv(train_n1_path)
    # train_processed_n3 = pd.read_csv(train_n3_path)
    # test_processed_n1 = pd.read_csv(test_n1_path)
    # test_processed_n3 = pd.read_csv(test_n3_path)

SHOULD_TRAIN = True
SHOULD_PREDICT = False

DO_OUTLIER_REMOVAL = True

In [3]:
for oxide in major_oxides:
    _oxrange = oxide_ranges[oxide]
    
    for comp_range in _oxrange.keys():
        filtered = filter_data_by_compositional_range(
            train_processed, comp_range, oxide, oxide_ranges
        )

        print(f"{oxide} {comp_range}: {len(filtered)}")

SiO2 Full: 1538
SiO2 Low: 449
SiO2 Mid: 1258
SiO2 High: 615
TiO2 Full: 1538
TiO2 Low: 1349
TiO2 Mid: 403
TiO2 High: 45
Al2O3 Full: 1538
Al2O3 Low: 354
Al2O3 Mid: 1193
Al2O3 High: 255
FeOT Full: 1538
FeOT Low: 1433
FeOT Mid: 963
FeOT High: 110
MgO Full: 1538
MgO Low: 1000
MgO Mid: 1493
MgO High: 130
CaO Full: 1538
CaO Low: 1055
CaO Mid: 1418
CaO High: 40
Na2O Full: 1538
Na2O Low: 1303
Na2O High: 350
K2O Full: 1538
K2O Low: 763
K2O High: 935
