In [125]:
import logging
from pathlib import Path
from typing import Dict

# import warnings filter
from warnings import simplefilter

import mlflow
import numpy as np
import pandas as pd
from dotenv import dotenv_values
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import mean_squared_error
from tqdm import tqdm

# from config import logger
from lib.data_handling import CustomSpectralPipeline, load_split_data  # type: ignore
from lib.outlier_removal import (
    calculate_leverage_residuals,
    identify_outliers,
    plot_leverage_residuals,
)
from lib.reproduction import (
    major_oxides,
    masks,
    optimized_blending_ranges,
    oxide_ranges,
    paper_individual_sm_rmses,
    spectrometer_wavelength_ranges,
    training_info,
)
from lib.utils import custom_kfold_cross_validation, filter_data_by_compositional_range
from PLS_SM.inference import predict_composition_with_blending

# ignore all future warnings
simplefilter(action="ignore", category=FutureWarning)

env = dotenv_values()
comp_data_loc = env.get("COMPOSITION_DATA_PATH")
dataset_loc = env.get("DATA_PATH")

if not comp_data_loc:
    print("Please set COMPOSITION_DATA_PATH in .env file")
    exit(1)

if not dataset_loc:
    print("Please set DATA_PATH in .env file")
    exit(1)

logger = logging.getLogger("train")

mlflow.set_tracking_uri("http://localhost:5000")

preformatted_data_path = Path("./data/_preformatted_sm/")
train_path = preformatted_data_path / "train.csv"
test_path = preformatted_data_path / "test.csv"

train_n1_path = preformatted_data_path / "train_n1.csv"
train_n3_path = preformatted_data_path / "train_n3.csv"
test_n1_path = preformatted_data_path / "test_n1.csv"
test_n3_path = preformatted_data_path / "test_n3.csv"

if (
    not preformatted_data_path.exists()
    or not train_path.exists()
    or not test_path.exists()
):
    take_samples = None

    logger.info("Loading data from location: %s", dataset_loc)
    # data = load_data(str(dataset_loc))
    train_data, test_data = load_split_data(
        str(dataset_loc), split_loc="./train_test_split.csv", average_shots=True
    )
    logger.info("Data loaded successfully.")

    logger.info("Initializing CustomSpectralPipeline.")
    pipeline = CustomSpectralPipeline(
        masks=masks,
        composition_data_loc=comp_data_loc,
        major_oxides=major_oxides,
    )
    logger.info("Pipeline initialized. Fitting and transforming data.")
    train_processed = pipeline.fit_transform(train_data)
    test_processed = pipeline.fit_transform(test_data)
    logger.info("Data processing complete.")

    preformatted_data_path.mkdir(parents=True, exist_ok=True)

    train_processed.to_csv(train_path, index=False)
    test_processed.to_csv(test_path, index=False)
else:
    logger.info("Loading preformatted data from location: %s", preformatted_data_path)
    train_processed = pd.read_csv(train_path)
    test_processed = pd.read_csv(test_path)

In [126]:
from sklearn.base import BaseEstimator, TransformerMixin
import enum
from typing import Dict, Tuple

class Norm3Scaler(BaseEstimator, TransformerMixin):
    def __init__(
        self, wavelength_ranges: Dict[str, Tuple[float, float]], reshaped=False
    ):
        self.scaler = (
            Norm3ScalerReshapedData(wavelength_ranges)
            if reshaped
            else Norm3ScalerOriginalData(wavelength_ranges)
        )

    def fit(self, df):
        return self.scaler.fit(df)

    def transform(self, df):
        return self.scaler.transform(df)


class Norm3ScalerOriginalData(BaseEstimator, TransformerMixin):
    def __init__(self, wavelength_ranges: Dict[str, Tuple[float, float]]):
        self.wavelength_ranges = wavelength_ranges
        self.totals = None

    def fit(self, df):
        """
        Compute the total intensity for each spectrometer range.
        """
        self.totals = {}
        shot_columns = df.columns[df.columns.str.startswith("shot")]
        for key, (start, end) in self.wavelength_ranges.items():
            mask = (df["wave"] >= start) & (df["wave"] <= end)
            self.totals[key] = df.loc[mask, shot_columns].sum().sum()
        return self

    def transform(self, df):
        """
        Apply norm3 normalization to the DataFrame.
        """
        if self.totals is None:
            raise ValueError("The fit method must be called before transform.")

        shot_columns = df.columns[df.columns.str.startswith("shot")]
        for key, (start, end) in self.wavelength_ranges.items():
            mask = (df["wave"] >= start) & (df["wave"] <= end)
            df.loc[mask, shot_columns] = df.loc[mask, shot_columns].div(
                self.totals[key], axis=1
            )
        return df


class Norm3ScalerReshapedData(BaseEstimator, TransformerMixin):
    """
    This class is used to normalize the data in the same way as the
    Norm3Scaler class, but it is used for the reshaped data. This is
    necessary because the reshaped data has a different format than
    the original data.

    The reshaped data has the following format:
    - Each row represents a single shot
    - Each column represents a single wavelength
    - The column names are the wavelengths
    """

    def __init__(self, wavelength_ranges: Dict[str, Tuple[float, float]]):
        self.wavelength_ranges = wavelength_ranges
        self.totals = None

    def fit(self, df):
        """
        Compute the total intensity for each spectrometer range.
        """
        self.totals = {}

        # Convert column names to floats. If conversion fails, assign NaN
        float_cols = pd.to_numeric(df.columns, errors='coerce')

        for key, (start, end) in self.wavelength_ranges.items():
            # Use boolean indexing to select columns in the specified range
            selected_columns = df.columns[(float_cols >= start) & (float_cols <= end)]

            # Compute the sum of intensities in these columns
            self.totals[key] = df[selected_columns].sum().sum()
        
        assert len(self.totals) == 3, "Expected 3 spectrometer ranges"
        print(self.totals)
        sum_of_totals = sum(self.totals.values())
        print(sum_of_totals)
        return self

    def transform(self, df):
        """
        Apply norm3 normalization to the DataFrame.
        """
        if self.totals is None:
            raise ValueError("The fit method must be called before transform.")

        for key, (start, end) in self.wavelength_ranges.items():
            # Select columns in the specified range and ignore non-float columns
            selected_columns = []
            for col in df.columns:
                try:
                    if start <= float(col) <= end:
                        selected_columns.append(col)
                except ValueError:
                    # Ignore columns that cannot be converted to float
                    continue

            # Normalize intensities in these columns
            df[selected_columns] = df[selected_columns].div(self.totals[key], axis=0)
        return df


drop_cols = major_oxides + ["Sample Name", "ID"]

In [127]:
train_processed.drop(columns=drop_cols).T.describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1528,1529,1530,1531,1532,1533,1534,1535,1536,1537
count,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,...,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0
mean,91277520000.0,91844840000.0,101847100000.0,104285900000.0,103772600000.0,82290760000.0,84370580000.0,80619010000.0,79861650000.0,84488460000.0,...,72386010000.0,84968030000.0,79351800000.0,74604560000.0,73490040000.0,111477400000.0,119116500000.0,120181200000.0,117146500000.0,107845100000.0
std,224676100000.0,229504700000.0,256740100000.0,249957900000.0,253480700000.0,191430200000.0,205488600000.0,195422900000.0,192445700000.0,205145700000.0,...,156950000000.0,187041700000.0,170207300000.0,161900900000.0,161647200000.0,253393700000.0,245284400000.0,264735200000.0,253777400000.0,241301300000.0
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,7512402000.0,7645420000.0,8049826000.0,8411631000.0,8322361000.0,8571419000.0,8231139000.0,7810937000.0,7749104000.0,8166487000.0,...,8322698000.0,9212789000.0,8629386000.0,8210420000.0,8484050000.0,11251100000.0,11936160000.0,11004460000.0,11089860000.0,10910290000.0
50%,28409750000.0,29281020000.0,31391230000.0,33182970000.0,32606930000.0,29883500000.0,28664480000.0,27611840000.0,26618560000.0,28956920000.0,...,27561520000.0,30935310000.0,29478620000.0,27460750000.0,28120130000.0,38147360000.0,40728480000.0,37410680000.0,37999300000.0,35877060000.0
75%,85653430000.0,86693890000.0,93546470000.0,97938590000.0,98698730000.0,80802990000.0,80447000000.0,78202020000.0,78096040000.0,81286860000.0,...,73919030000.0,82366540000.0,78878920000.0,74654250000.0,74411980000.0,109849000000.0,113445900000.0,110004700000.0,109361500000.0,106624800000.0
max,4265688000000.0,4592867000000.0,5009349000000.0,4616491000000.0,4809581000000.0,2832522000000.0,3141137000000.0,2822526000000.0,2953743000000.0,2980790000000.0,...,2513769000000.0,2527600000000.0,2305864000000.0,2346277000000.0,2505947000000.0,4375193000000.0,3105788000000.0,3561074000000.0,3592492000000.0,4133483000000.0


In [128]:
train_processed.drop(columns=drop_cols).T.sum().sum()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1528,1529,1530,1531,1532,1533,1534,1535,1536,1537
246.688,181627800000.0,200483900000.0,210447400000.0,252546700000.0,217290700000.0,66119040000.0,56355740000.0,60970750000.0,81868120000.0,61177120000.0,...,79050390000.0,138022900000.0,134586100000.0,108071000000.0,69427780000.0,236602900000.0,533769500000.0,652755000000.0,474004000000.0,287651800000.0
246.741,162356700000.0,179342100000.0,188949800000.0,227186900000.0,196582800000.0,62236820000.0,50571590000.0,55470340000.0,73567380000.0,56290540000.0,...,71832320000.0,120862600000.0,118665600000.0,95284980000.0,63341940000.0,211775800000.0,489513800000.0,596030600000.0,440249200000.0,259989200000.0
246.79401,137014300000.0,152217400000.0,160787800000.0,193114900000.0,169036700000.0,57228660000.0,43815290000.0,48678990000.0,62416130000.0,50493290000.0,...,62810940000.0,98472780000.0,98737380000.0,79196170000.0,54842810000.0,174610100000.0,429474600000.0,518891400000.0,390208500000.0,220368800000.0
246.847,111209300000.0,125571200000.0,131232400000.0,156445400000.0,139365200000.0,52703960000.0,38334120000.0,43829650000.0,51545960000.0,44628750000.0,...,53704870000.0,76711480000.0,79103220000.0,64771340000.0,46508530000.0,134548500000.0,355757000000.0,422272800000.0,327585800000.0,174430600000.0
246.89999,99880300000.0,113342800000.0,115454300000.0,137305400000.0,125594800000.0,52131360000.0,37869520000.0,43072150000.0,47598420000.0,44289440000.0,...,50367120000.0,68368750000.0,71918560000.0,59958560000.0,43176720000.0,119445200000.0,324007800000.0,385175100000.0,299039300000.0,155413500000.0


In [129]:
scaler = Norm3Scaler(wavelength_ranges=spectrometer_wavelength_ranges, reshaped=True)
train_processed = scaler.fit_transform(train_processed)

{'UV': 6.176395627352914e+17, 'VIO': 2.6972361404241165e+17, 'VNIR': 5.452579200577558e+16}
9.418889687834787e+17


In [130]:
train_processed.head()

Unnamed: 0,246.688,246.741,246.79401,246.847,246.89999,246.953,247.007,247.06,247.11301,247.166,...,SiO2,TiO2,Al2O3,FeOT,MgO,CaO,Na2O,K2O,Sample Name,ID
0,2.940677e-07,2.628665e-07,2.218354e-07,1.800553e-07,1.617129e-07,1.721856e-07,1.985312e-07,2.240468e-07,2.363474e-07,2.364594e-07,...,79.35,0.3,9.95,2.18,1.0,1.2,2.75,1.84,201426,201426_2013_11_06_161336_ccs
1,3.245969e-07,2.903669e-07,2.464502e-07,2.033082e-07,1.835096e-07,1.933978e-07,2.212464e-07,2.489649e-07,2.61788e-07,2.600798e-07,...,79.35,0.3,9.95,2.18,1.0,1.2,2.75,1.84,201426,201426_2013_11_06_161134_ccs
2,3.407285e-07,3.059225e-07,2.603263e-07,2.124742e-07,1.869284e-07,1.941817e-07,2.236271e-07,2.5417e-07,2.677841e-07,2.649219e-07,...,79.35,0.3,9.95,2.18,1.0,1.2,2.75,1.84,201426,201426_2013_11_06_162544_ccs
3,4.0889e-07,3.678309e-07,3.126661e-07,2.532956e-07,2.223067e-07,2.316652e-07,2.680076e-07,3.05278e-07,3.215675e-07,3.166276e-07,...,79.35,0.3,9.95,2.18,1.0,1.2,2.75,1.84,201426,201426_2013_11_06_161514_ccs
4,3.518082e-07,3.182807e-07,2.736818e-07,2.256416e-07,2.033465e-07,2.133621e-07,2.437775e-07,2.755768e-07,2.89062e-07,2.842021e-07,...,79.35,0.3,9.95,2.18,1.0,1.2,2.75,1.84,201426,201426_2013_11_06_160941_ccs


In [131]:
train_processed = train_processed.drop(columns=drop_cols).T

In [132]:
train_processed.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1528,1529,1530,1531,1532,1533,1534,1535,1536,1537
246.688,2.940677e-07,3.245969e-07,3.407285e-07,4.0889e-07,3.518082e-07,1.070512e-07,9.124374e-08,9.871574e-08,1.3255e-07,9.904988e-08,...,1.279879e-07,2.234684e-07,2.17904e-07,1.749742e-07,1.124082e-07,3.830759e-07,8.642088e-07,1.056854e-06,7.674444e-07,4.657276e-07
246.741,2.628665e-07,2.903669e-07,3.059225e-07,3.678309e-07,3.182807e-07,1.007656e-07,8.187881e-08,8.981021e-08,1.191105e-07,9.113817e-08,...,1.163014e-07,1.956847e-07,1.921276e-07,1.542728e-07,1.025549e-07,3.428793e-07,7.925558e-07,9.650137e-07,7.12793e-07,4.2094e-07
246.79401,2.218354e-07,2.464502e-07,2.603263e-07,3.126661e-07,2.736818e-07,9.265705e-08,7.09399e-08,7.881457e-08,1.010559e-07,8.175204e-08,...,1.016951e-07,1.594341e-07,1.598625e-07,1.282239e-07,8.87942e-08,2.827055e-07,6.953482e-07,8.4012e-07,6.317739e-07,3.56792e-07
246.847,1.800553e-07,2.033082e-07,2.124742e-07,2.532956e-07,2.256416e-07,8.533125e-08,6.206551e-08,7.096315e-08,8.345638e-08,7.225694e-08,...,8.695179e-08,1.242011e-07,1.280734e-07,1.048691e-07,7.530043e-08,2.178431e-07,5.759945e-07,6.836881e-07,5.303834e-07,2.824148e-07
246.89999,1.617129e-07,1.835096e-07,1.869284e-07,2.223067e-07,2.033465e-07,8.440418e-08,6.131329e-08,6.97367e-08,7.706504e-08,7.170758e-08,...,8.154775e-08,1.106936e-07,1.16441e-07,9.707695e-08,6.990602e-08,1.933899e-07,5.245905e-07,6.236243e-07,4.841648e-07,2.516249e-07


In [133]:
train_processed.describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1528,1529,1530,1531,1532,1533,1534,1535,1536,1537
count,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,...,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0,5485.0
mean,3972716000.0,3914676000.0,4317048000.0,4458777000.0,4566679000.0,4136896000.0,3847693000.0,3549741000.0,3472842000.0,3742581000.0,...,3420914000.0,3747758000.0,3586045000.0,3415412000.0,3483498000.0,6134766000.0,6680944000.0,5490599000.0,6750890000.0,5648024000.0
std,21852850000.0,21428860000.0,24850190000.0,25556270000.0,25261240000.0,21583750000.0,20722950000.0,18659960000.0,18168740000.0,19803330000.0,...,17978250000.0,19672170000.0,18774670000.0,17921340000.0,18464400000.0,34051160000.0,38033780000.0,28566130000.0,38581370000.0,30040210000.0
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,4.892421e-08,4.88878e-08,5.379712e-08,5.568947e-08,5.517769e-08,5.201454e-08,4.879736e-08,4.66318e-08,4.494762e-08,4.913957e-08,...,4.774892e-08,5.464539e-08,5.155393e-08,4.827826e-08,4.941489e-08,6.589037e-08,7.250331e-08,6.630114e-08,6.606661e-08,6.402867e-08
50%,1.321549e-07,1.34438e-07,1.435041e-07,1.531491e-07,1.493125e-07,1.409196e-07,1.322831e-07,1.265039e-07,1.266605e-07,1.31166e-07,...,1.242498e-07,1.389097e-07,1.316795e-07,1.251013e-07,1.278612e-07,1.848263e-07,1.904109e-07,1.829485e-07,1.820597e-07,1.75043e-07
75%,3.604905e-07,3.597304e-07,3.960229e-07,4.062325e-07,4.078397e-07,3.456523e-07,3.55101e-07,3.358881e-07,3.335302e-07,3.517755e-07,...,3.180134e-07,3.567809e-07,3.383934e-07,3.146635e-07,3.294619e-07,4.755008e-07,5.017162e-07,4.712282e-07,4.972831e-07,4.562919e-07
max,372907200000.0,363347700000.0,457248400000.0,471351600000.0,437117200000.0,248666000000.0,319740600000.0,206418100000.0,219423200000.0,254360900000.0,...,170103400000.0,248313000000.0,185367600000.0,189049100000.0,178647400000.0,593902800000.0,703179800000.0,364830400000.0,734397100000.0,449807200000.0


In [134]:
train_processed.sum().sum()

4.62967405368569e+16