In [None]:
experiment_id = input("Enter the experiment id: ")
experiment_id = experiment_id or "10"

In [None]:
import mlflow
import pandas as pd

from lib.reproduction import major_oxides

runs = mlflow.search_runs(experiment_ids=[experiment_id])
client = mlflow.tracking.MlflowClient()

data = {}

for oxide in major_oxides:
    for _, run in runs.iterrows():
        if oxide in str(run):
            run_id = run['run_id']
            artifact_path = f'actual_vs_predicted_{oxide}.json'

            # Download the artifact
            artifact_uri = f"mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/{artifact_path}"
            actual_vs_predicted = mlflow.artifacts.load_dict(artifact_uri)

            # Get the rmse metric
            run_data = client.get_run(run_id)
            rmse = run_data.data.metrics['rmse']
            std_dev = run_data.data.metrics['std_dev']

            actual_vs_predicted_df = pd.DataFrame(actual_vs_predicted["data"], columns=actual_vs_predicted["columns"])
            actual_vs_predicted_df.columns = actual_vs_predicted_df.columns.str.lower()

            data[oxide] = {
                'actual_vs_predicted_df': actual_vs_predicted_df,
                'rmse': rmse,
                'std_dev': std_dev
            }

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

FULL_PAGE = True

plt.figure(figsize=(12, 24) if FULL_PAGE else (24, 12))

# Iterate over each oxide, with an index for subplotting
for i, oxide in enumerate(major_oxides):
    data = data[oxide]
    actual_vs_predicted_df = data['actual_vs_predicted_df']
    predicted_values = actual_vs_predicted_df['predicted']
    actual_values = actual_vs_predicted_df['actual']
    rmse = data['rmse']
    std_dev = data['std_dev']

    plt.subplot(4, 2, i+1) if FULL_PAGE else plt.subplot(2, 4, i+1)

    # Create a scatter plot in the subplot
    plt.scatter(actual_values, predicted_values, color='black', facecolors='none', edgecolors='black', s=20, label='Predictions')

    # Add a line of perfect predictions
    plt.plot([min(actual_values), max(actual_values)], [min(actual_values), max(actual_values)], 'k-', label='Perfect Predictions (1:1)')

    # Fit a regression model and plot the regression line
    model = LinearRegression()
    model.fit(actual_values.reshape(-1, 1), predicted_values)  # Fit model
    line_x = np.linspace(min(actual_values), max(actual_values), 100)
    line_y = model.predict(line_x.reshape(-1, 1))
    plt.plot(line_x, line_y, 'r--', label='Regression Line')

    # Add text box for RMSE and Std Dev
    textstr = f'RMSE: {rmse:.4f}\nStd Dev: {std_dev:.4f}'
    props = dict(boxstyle='round', facecolor='white', alpha=0.5)
    plt.gca().text(0.975, 0.1, textstr, transform=plt.gca().transAxes, fontsize=11, verticalalignment='top', horizontalalignment='right', bbox=props)

    # Enhancements for each subplot
    plt.xlabel('Actual Values')
    plt.ylabel('Predicted Values')
    plt.title(f'Test Set Predictions vs. Certificate Values for {oxide}')
    plt.legend()
    plt.axis('equal')
    plt.grid(True)

# Show all the subplots
plt.tight_layout()
plt.show()
