In [None]:
import mlflow
from lib.config import AppConfig

mlflow.set_tracking_uri(AppConfig().mlflow_tracking_uri)

In [None]:
experiments_by_oxide = {
    'SiO2': '47',
    'TiO2': None, # No experiment yet
    'Al2O3': '51',
    'FeO': None, # No experiment yet
    # 'MgO': '54',
    'MgO': '81',
    'CaO': '56',
    'Na2O': '50',
    'K2O': '55',
}

oxide = 'MgO'
# experiment_id = '99' #experiments_by_oxide[oxide]
experiment_id = '113'
n_splits = 4

if experiment_id is None:
    print(f"No experiment found for {oxide}")
    raise ValueError(f"No experiment found for {oxide}")

In [None]:
runs = mlflow.search_runs(experiment_ids=[experiment_id])

runs.head()

In [None]:
len(runs)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Filter out runs with 'metrics.rmse_cv' greater than 50
filtered_runs = runs[runs['metrics.rmse_cv'] <= 50]
total_runs = len(runs)

sns.set_style('whitegrid')
plt.figure(figsize=(10, 6))
sns.boxplot(x='params.model_type', y='metrics.rmse', data=filtered_runs)
plt.title(f"{oxide}: RMSE for each model type - {len(filtered_runs)} runs out of {total_runs} total runs")
plt.xlabel("Model Type")
plt.ylabel("RMSEP")
plt.show()



In [None]:
# Find the runs that minimize rmse, rmse_cv, std_dev, and std_dev_cv
optimal_runs = filtered_runs.loc[filtered_runs[['metrics.rmse', 'metrics.rmse_cv', 'metrics.std_dev', 'metrics.std_dev_cv']].idxmin()]

# Display the optimal runs
optimal_runs[['metrics.rmse', 'metrics.rmse_cv', 'metrics.std_dev', 'metrics.std_dev_cv', 'params.model_type']]


In [None]:
# Setting up visualization style
sns.set(style="whitegrid")

# Plotting RMSE CV
plt.figure(figsize=(12, 7))
sns.boxplot(x='params.model_type', y='metrics.rmse_cv', data=filtered_runs)
plt.title(f'{oxide}: Average Cross-Validation RMSE by Model Type')
plt.ylabel('Average RMSE (Cross-Validation)')
plt.show()

# Plotting Standard Deviation of RMSE CV
plt.figure(figsize=(12, 7))
sns.boxplot(x='params.model_type', y='metrics.std_dev_cv', data=filtered_runs)
plt.title(f'{oxide}: Standard Deviation of Errors (Cross-Validation) by Model Type')
plt.ylabel('Standard Deviation of Errors (CV)')
plt.show()


In [None]:
# Prepare a melted DataFrame for seaborn plotting
melted_df = filtered_runs.melt(id_vars=['params.model_type'], value_vars=[f'metrics.rmse_cv_{i+1}' for i in range(n_splits)],
                               var_name='CV Fold', value_name='Fold RMSE')

# Plotting without outliers
plt.figure(figsize=(14, 8))
sns.boxplot(x='params.model_type', y='Fold RMSE', hue='CV Fold', data=melted_df, showfliers=False)
plt.title(f'{oxide}: Distribution of RMSE Across CV Folds by Model Type')
plt.show()


In [None]:
cv_columns = [
    'metrics.rmse_cv', 'params.model_type', 'params.scaler_type',
    'params.transformer_type', 'params.pca_type'
]
filtered_runs = runs[cv_columns]
filtered_runs = filtered_runs[filtered_runs['metrics.rmse_cv'] <= 50]


# Rename columns for clarity
rename_dict = {col: col.split('.')[-1] for col in cv_columns}
filtered_runs = filtered_runs.rename(columns=rename_dict)

In [None]:
sns.set(style="whitegrid")

# Individual Parameters
for parameter in ['model_type', 'scaler_type', 'transformer_type', 'pca_type']:
    plt.figure(figsize=(10, 6))
    chart = sns.barplot(x=parameter, y='rmse_cv', data=filtered_runs)
    chart.set_xticks(range(len(filtered_runs[parameter].unique())))
    chart.set_xticklabels(chart.get_xticklabels(), rotation=45, horizontalalignment='right')
    plt.title(f'{oxide}: Average RMSE (CV) by {parameter.capitalize()}')
    plt.ylabel('Average RMSE (CV)')
    plt.show()

# Combinations of Parameters
# Considering combinations might result in a lot of categories, focus on the top few based on average RMSE
combination_data = filtered_runs.groupby(['model_type', 'scaler_type', 'transformer_type', 'pca_type']).mean()['rmse_cv']
combination_data = combination_data.reset_index().sort_values(by='rmse_cv', ascending=True)

# Display top 10 combinations
print(combination_data.head(10))

# Optionally, visualize these top combinations
plt.figure(figsize=(14, 8))
combination_data_top10 = combination_data[:10]
combination_labels = combination_data_top10.apply(lambda row: ', '.join([str(row[param]) for param in ['model_type', 'scaler_type', 'transformer_type', 'pca_type'] if row[param] != 'none']), axis=1)
sns.barplot(x='rmse_cv', y=combination_labels, data=combination_data_top10, orient='h')
plt.title(f'{oxide}: Top 10 Combinations for RMSE Performance')
plt.xlabel('Average RMSE (Cross-Validation)')
plt.ylabel('Combinations')
plt.show()

In [None]:
# Aggregate the data to compute mean and standard deviation of RMSE for each configuration
# Lower RMSE (lower is better) and lower STD RMSE (lower is better for consistency)
aggregated_data = filtered_runs.groupby(['model_type', 'scaler_type', 'transformer_type', 'pca_type']).agg({
    'rmse_cv': ['mean', 'std']
}).reset_index()

# Flatten the columns (multi-level index after aggregation)
aggregated_data.columns = ['Model Type', 'Scaler Type', 'Transformer Type', 'PCA Type', 'Mean RMSE', 'STD RMSE']

# Sort configurations first by mean RMSE (ascending, lower is better) and then by STD RMSE (ascending, lower is better for consistency)
sorted_data = aggregated_data.sort_values(by=['Mean RMSE', 'STD RMSE'], ascending=[True, True])

# Display the top 10 consistently good configurations
print(sorted_data.head(10))

In [None]:
sns.set(style="whitegrid")

# Plotting the top configurations based on Mean RMSE
plt.figure(figsize=(12, 8))
top_n = 50
for parameter in ['Model Type', 'Scaler Type', 'Transformer Type', 'PCA Type']:
    top_configurations = sns.barplot(x='Mean RMSE', y=parameter, hue=parameter, data=sorted_data.head(top_n), dodge=False)
    plt.title(f'{oxide}: Top {top_n} Configurations by Mean RMSE and Their Consistency')
    plt.xlabel('Mean RMSE (Cross-Validation)')
    plt.ylabel(parameter)
    # Annotate each bar with the value of Mean RMSE
    for p in top_configurations.patches:
        width = p.get_width()
        plt.text(width + 0.01, p.get_y()+0.2 + p.get_height() / 2, f'{width:.2f}', ha='left', va='center')
    plt.show()
