In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from experiment_analysis.experiment_data_dl import get_full_runs_df
from lib.config import AppConfig

config = AppConfig()
runs = get_full_runs_df(config.optimization_experiment_results_path)

In [None]:
runs.shape

In [None]:
runs["params.model_type"].value_counts()

In [None]:
import mlflow
import pandas as pd
from pathlib import Path

from lib.config import AppConfig

mlflow.set_tracking_uri(AppConfig().mlflow_tracking_uri)

experiments_by_oxide = {
    'SiO2': '123',
    'TiO2': '125', 
    'Al2O3': '131',
    'FeO': '130', 
    'MgO': '126',
    'CaO': '128',
    'Na2O': '127',
    'K2O': '129',
}

n_splits = 4
runs_file_path = "/home/christian/projects/p9/baseline/optuna_runs.csv"

In [None]:
def get_runs_across_oxides():
    oxide_runs = {}
    for oxide, experiment_id in experiments_by_oxide.items():
        runs = mlflow.search_runs(experiment_ids=[experiment_id])
        oxide_runs[oxide] = runs
    return oxide_runs


def get_full_runs_df():
    if Path(runs_file_path).exists():
        return pd.read_csv(runs_file_path)

    oxide_runs = get_runs_across_oxides()
    runs = pd.concat(oxide_runs.values(), ignore_index=True)
    runs.to_csv(runs_file_path)

    return runs

runs = get_full_runs_df()

In [None]:
failed_runs = runs[runs['status'] == 'FAILED']
runs = runs[runs['status'] != 'FAILED']

print(f"Failed runs: {failed_runs.shape[0]}")
print(f"Successful runs: {runs.shape[0]}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Ensure the necessary columns are present
required_columns = ['params.oxide', 'metrics.rmse_cv', 'params.model_type', 'params.transformer_type', 'params.pca_type', 'params.scaler_type']
for col in required_columns:
    if (col not in runs.columns):
        raise ValueError(f"Missing required column: {col}")

# Filter out trials with rmse_cv > 50
runs_filtered = runs[runs['metrics.rmse_cv'] <= 50]

# Function to create visualizations based on a specified column
def create_visualizations(group_by_col):
    oxides = runs_filtered['params.oxide'].unique()
    for oxide in oxides:
        oxide_data = runs_filtered[runs_filtered['params.oxide'] == oxide]
        
        plt.figure(figsize=(16, 10))
        sns.scatterplot(data=oxide_data, x=group_by_col, y='metrics.rmse_cv', hue='params.transformer_type', style='params.pca_type', size='params.scaler_type', sizes=(20, 200))
        plt.title(f'RMSE CV Performance for {oxide} grouped by {group_by_col}')
        plt.xlabel(group_by_col.replace('params.', '').replace('_', ' ').title())
        plt.ylabel('RMSE CV')
        plt.legend(title='Transformer Type / PCA Type / Scaler Type', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()

# Example usage: create visualizations grouped by 'params.model_type'
create_visualizations('params.model_type')


In [None]:
# Function to create box plots for each params.oxide
def create_box_plots(df):
    grouped = df.groupby('params.oxide')
    
    for oxide, group in grouped:
        plt.figure(figsize=(16, 24))  # Increased width from 16 to 24
        plt.suptitle(f'Distribution of RMSE for {oxide}', fontsize=16)

        # Box plot for params.model_type
        plt.subplot(4, 2, 1)
        sns.boxplot(x='params.model_type', y='metrics.rmse_cv', data=group)
        plt.title('Model Type')
        plt.xticks(rotation=45)
        
        plt.subplot(4, 2, 2)
        plt.axis('off')
        model_type_stats = group.groupby('params.model_type')['metrics.rmse_cv'].describe().reset_index()[['params.model_type', 'mean', 'std', 'min', 'max']]
        model_type_stats = model_type_stats.round(2)
        plt.table(cellText=model_type_stats.values, colLabels=model_type_stats.columns, cellLoc='center', loc='center')
        
        # Box plot for params.transformer_type
        plt.subplot(4, 2, 3)
        sns.boxplot(x='params.transformer_type', y='metrics.rmse_cv', data=group)
        plt.title('Transformer Type')
        plt.xticks(rotation=45)
        
        plt.subplot(4, 2, 4)
        plt.axis('off')
        transformer_type_stats = group.groupby('params.transformer_type')['metrics.rmse_cv'].describe().reset_index()[['params.transformer_type', 'mean', 'std', 'min', 'max']]
        transformer_type_stats = transformer_type_stats.round(2)
        plt.table(cellText=transformer_type_stats.values, colLabels=transformer_type_stats.columns, cellLoc='center', loc='center')

        # Box plot for params.pca_type
        plt.subplot(4, 2, 5)
        sns.boxplot(x='params.pca_type', y='metrics.rmse_cv', data=group)
        plt.title('PCA Type')
        plt.xticks(rotation=45)
        
        plt.subplot(4, 2, 6)
        plt.axis('off')
        pca_type_stats = group.groupby('params.pca_type')['metrics.rmse_cv'].describe().reset_index()[['params.pca_type', 'mean', 'std', 'min', 'max']]
        pca_type_stats = pca_type_stats.round(2)
        plt.table(cellText=pca_type_stats.values, colLabels=pca_type_stats.columns, cellLoc='center', loc='center')

        # Box plot for params.scaler_type
        plt.subplot(4, 2, 7)
        sns.boxplot(x='params.scaler_type', y='metrics.rmse_cv', data=group)
        plt.title('Scaler Type')
        plt.xticks(rotation=45)
        
        plt.subplot(4, 2, 8)
        plt.axis('off')
        scaler_type_stats = group.groupby('params.scaler_type')['metrics.rmse_cv'].describe().reset_index()[['params.scaler_type', 'mean', 'std', 'min', 'max']]
        scaler_type_stats = scaler_type_stats.round(2)
        plt.table(cellText=scaler_type_stats.values, colLabels=scaler_type_stats.columns, cellLoc='center', loc='center')
        
        plt.tight_layout(rect=[0, 0.03, 1, 0.95], w_pad=0, h_pad=0)
        plt.show()

# Create box plots
create_box_plots(runs_filtered)

In [None]:
import pandas as pd
import scipy.stats as stats

# Function to perform Kruskal-Wallis H test
def analyze_correlations(df):
    results = []
    
    # Group by params.oxide
    grouped = df.groupby('params.oxide')
    
    for oxide, group in grouped:
        kw_results = {}
        for param in ['params.model_type', 'params.transformer_type', 'params.pca_type', 'params.scaler_type']:
            groups = [group['metrics.rmse_cv'][group[param] == level].values for level in group[param].unique()]
            if len(groups) > 1:  # Kruskal-Wallis test requires at least 2 groups
                stat, p_value = stats.kruskal(*groups)
                kw_results[param] = p_value
        
        # Identify the configuration with the lowest p-value
        best_param = min(kw_results, key=kw_results.get)
        results.append({
            'oxide': oxide,
            'best_param': best_param,
            'p_value': kw_results[best_param]
        })
    
    return pd.DataFrame(results)

# Load your data into df
# df = pd.read_csv('your_dataset.csv')

# Run the analysis
results_df = analyze_correlations(runs_filtered)


In [None]:
results_df

In [None]:
out = runs.groupby(['params.transformer_type', 'params.model_type', 'params.oxide']).size().reset_index(name='count')
pivot_table = out.pivot(index='params.oxide', columns=['params.transformer_type', 'params.model_type'], values='count')

# Add a total count for each transformer_type and oxide
out_total = runs.groupby(['params.transformer_type', 'params.oxide']).size().reset_index(name='total_count')
for transformer_type in out['params.transformer_type'].unique():
    pivot_table[(transformer_type, 'total_count')] = out_total[out_total['params.transformer_type'] == transformer_type].set_index('params.oxide')['total_count']

pivot_table

In [None]:
# Define a threshold to filter out the extremely high RMSE values as outliers
threshold = 15

# Filter out rows with NaN values in 'params.transformer_type' and 'params.model_type'
filtered_data = runs.dropna(subset=['params.transformer_type', 'params.model_type'])

# Filter the data to remove outliers
filtered_data = filtered_data[filtered_data['metrics.rmse_cv'] < threshold]

# Calculate the mean RMSE for each combination of transformer type, model type, and oxide using 'metrics.rmse_cv'
mean_rmse_cv = filtered_data.groupby(['params.transformer_type', 'params.model_type', 'params.oxide'])['metrics.rmse_cv'].mean().reset_index()

# Calculate the overall mean RMSE for each combination of transformer type and model type using 'metrics.rmse_cv'
overall_mean_rmse_cv = filtered_data.groupby(['params.transformer_type', 'params.model_type'])['metrics.rmse_cv'].mean().reset_index()

mean_rmse_cv


In [None]:
overall_mean_rmse_cv

In [None]:
overall_mean_rmse_cv.groupby(['params.transformer_type'])['metrics.rmse_cv'].mean().reset_index()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Group and pivot the data
teste = mean_rmse_cv.groupby(['params.transformer_type', 'params.oxide'])['metrics.rmse_cv'].mean().reset_index()
pivot_table = teste.pivot(index='params.oxide', columns='params.transformer_type', values='metrics.rmse_cv')
pivot_table

In [None]:
# Normalize the data by row
pivot_table_normalized = pivot_table.div(pivot_table.max(axis=1), axis=0)

# Plot the heatmap with a more visually comfortable color palette
plt.figure(figsize=(10, 8))
sns.heatmap(pivot_table_normalized, annot=pivot_table, cmap='coolwarm', cbar_kws={'label': 'Relative Mean RMSE CV'})
plt.title('Heatmap of Mean RMSE CV by Transformer Type and Oxide (Row Normalized)')
plt.xlabel('Transformer Type')
plt.ylabel('Oxide')
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Create a boxplot for overall RMSE CV by transformer type and model type for each oxide
oxides = filtered_data['params.oxide'].unique()

for oxide in oxides:
    plt.figure(figsize=(12, 8))
    sns.boxplot(x='params.transformer_type', y='metrics.rmse_cv', hue='params.model_type', data=filtered_data[filtered_data['params.oxide'] == oxide])
    plt.title(f'RMSE CV by Transformer Type and Model Type for {oxide}')
    plt.xlabel('Transformer Type')
    plt.ylabel('RMSE CV')
    plt.legend(title='Model Type')
    plt.show()


In [None]:

import matplotlib.pyplot as plt
import seaborn as sns

# Plot settings
plt.figure(figsize=(12, 8))

# Create a violin plot for overall RMSE CV by transformer type and model type
sns.violinplot(x='params.transformer_type', y='metrics.rmse_cv', hue='params.model_type', data=filtered_data, split=True)

# Title and labels
plt.title('RMSE CV by Transformer Type and Model Type')
plt.xlabel('Transformer Type')
plt.ylabel('RMSE CV')

# Show the plot
plt.legend(title='Model Type')
plt.show()

