In [8]:
import mlflow
from lib.reproduction import major_oxides

# Set MLFlow tracking URI if your MLFlow server is not running locally
# mlflow.set_tracking_uri('your_mlflow_tracking_uri')

experiment_id = "935520882997737302"  # Replace with your MLFlow experiment ID
submodels = ['Low', 'Mid', 'High', 'Full']

# Initialize the dictionary to store the results
results = {oxide: {submodel: [] for submodel in submodels} for oxide in major_oxides}

def fetch_runs(experiment_id):
    # Fetch all runs for the given experiment ID
    runs = mlflow.search_runs(experiment_ids=[experiment_id])
    results = {}  # Initialize the dictionary to store the results
    
    for _, row in runs.iterrows():
        run_id = row['run_id']
        run_name = row['tags.mlflow.runName']
        params = mlflow.get_run(run_id).data.params
        metrics = mlflow.get_run(run_id).data.metrics
        
        # Parse the run name to get the oxide and submodel
        parts = run_name.split('_')
        if len(parts) == 2:
            major_oxide, submodel = parts
            
            # Initialize submodel dictionary if it doesn't exist
            if major_oxide not in results:
                results[major_oxide] = {}
            if submodel not in results[major_oxide]:
                results[major_oxide][submodel] = {}

            # Fetch the metric history for 'outliers_removed'
            metric_history = mlflow.tracking.MlflowClient().get_metric_history(run_id, 'outliers_removed')
            # Sum the values of the 'outliers_removed' metric across all steps
            total_outliers_removed = sum(m.value for m in metric_history)

            # Retrieve the other metrics and the parameter 'n_spectra'
            RMSECV = metrics.get('RMSECV', None)
            RMSECV_MIN = metrics.get('RMSECV_MIN', None)
            RMSEP = metrics.get('RMSEP', None)
            n_spectra = params.get('n_spectra', None)  # n_spectra is retrieved from parameters
            outlier_removal_iterations = metrics.get('outlier_removal_iterations', None)
            
            # Store the metrics, the parameter 'n_spectra', and the total outliers removed in the results dictionary
            results[major_oxide][submodel] = {
                'RMSECV': RMSECV,
                'RMSECV_MIN': RMSECV_MIN,
                'RMSEP': RMSEP,
                'n_spectra': n_spectra,
                'outlier_removal_iterations': outlier_removal_iterations,
                'outliers_removed': total_outliers_removed
            }

    return results

# Fetch runs and print the results
results = fetch_runs(experiment_id)
results


{'K2O': {'High': {'RMSECV': 0.9814821908117807,
   'RMSECV_MIN': 0.5830669073290804,
   'RMSEP': 0.7103206230713095,
   'n_spectra': '920',
   'outlier_removal_iterations': 2.0,
   'outliers_removed': 13.0},
  'Low': {'RMSECV': 0.38695108904609343,
   'RMSECV_MIN': 0.28119934159319815,
   'RMSEP': 0.34477671142578986,
   'n_spectra': '773',
   'outlier_removal_iterations': 3.0,
   'outliers_removed': 5.0},
  'Full': {'RMSECV': 1.0264282583240156,
   'RMSECV_MIN': 0.9223400777956894,
   'RMSEP': 0.8115648681258109,
   'n_spectra': '1538',
   'outlier_removal_iterations': 8.0,
   'outliers_removed': 48.0}},
 'Na2O': {'High': {'RMSECV': 1.2739483442129398,
   'RMSECV_MIN': 0.6499749247431731,
   'RMSEP': 0.5988507690588442,
   'n_spectra': '375',
   'outlier_removal_iterations': 1.0,
   'outliers_removed': 0.0},
  'Low': {'RMSECV': 0.5738871292977707,
   'RMSECV_MIN': 0.4814070065657079,
   'RMSEP': 0.5660600283295928,
   'n_spectra': '1278',
   'outlier_removal_iterations': 3.0,
   'outl

In [27]:
# format for latex
for oxide, submodel_data in list(results.items()).__reversed__():
    print(f"{oxide} &&&&&& \\\\")
    for submodel in submodels:
        data = submodel_data.get(submodel, None)
        if data is None:
            continue
        print(f"  {submodel} & {data['RMSECV']:.2f} & {data['RMSECV_MIN']:.2f} & {data['RMSEP']:.2f} & {data['n_spectra']} & {data['outliers_removed']:.0f} & {data['outlier_removal_iterations']:.0f} \\\\")
    print("\\\\")

SiO2 &&&&&& \\
  Low & 8.55 & 6.14 & 6.57 & 439 & 0 & 2 \\
  Mid & 4.71 & 3.64 & 4.20 & 1268 & 15 & 1 \\
  High & 3.59 & 3.01 & 4.26 & 605 & 1 & 2 \\
  Full & 5.98 & 4.75 & 7.18 & 1538 & 8 & 1 \\
\\
TiO2 &&&&&& \\
  Low & 0.29 & 0.28 & 0.39 & 1359 & 8 & 1 \\
  Mid & 0.62 & 0.57 & 0.44 & 418 & 6 & 4 \\
  High & 0.74 & 0.11 & 0.09 & 40 & 0 & 2 \\
  Full & 0.48 & 0.37 & 0.50 & 1538 & 16 & 4 \\
\\
Al2O3 &&&&&& \\
  Low & 2.40 & 1.68 & 1.99 & 324 & 0 & 2 \\
  Mid & 2.27 & 1.57 & 2.04 & 1198 & 12 & 2 \\
  High & 3.97 & 1.99 & 2.03 & 240 & 0 & 2 \\
  Full & 3.31 & 2.59 & 2.43 & 1538 & 9 & 1 \\
\\
FeOT &&&&&& \\
  Low & 1.81 & 1.72 & 1.55 & 1438 & 1 & 3 \\
  Mid & 2.64 & 2.15 & 1.69 & 978 & 28 & 9 \\
  High & 4.00 & 1.52 & 11.89 & 105 & 0 & 2 \\
  Full & 2.86 & 2.67 & 4.08 & 1538 & 23 & 5 \\
\\
MgO &&&&&& \\
  Low & 0.49 & 0.45 & 0.63 & 1000 & 7 & 4 \\
  Mid & 1.32 & 0.88 & 1.16 & 1488 & 10 & 1 \\
  High & 4.42 & 1.99 & 3.14 & 135 & 0 & 1 \\
  Full & 1.74 & 1.36 & 1.25 & 1538 & 41 & 6 \\
\\
Ca