In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [8]:
from lib.full_flow_dataloader import load_full_flow_data

train, test = load_full_flow_data()
full_data = pd.concat([train, test], ignore_index=True)

In [None]:
# Extract relevant columns and rename for consistency
full_data_relevant = full_data[['ID', 'Sample Name', 'SiO2']].copy()
full_data_relevant.columns = ['id', 'group_by', 'target']

# Simplify the IDs for visualization purposes
full_data_relevant['id'] = range(len(full_data_relevant))
full_data_relevant['fold'] = np.nan

# Parameters
percentile = 0.05

# Step 2: Sort the Data by target
sorted_full_data = full_data_relevant.drop_duplicates(subset=['group_by']).sort_values(by='target').reset_index(drop=True)
sorted_full_data['id'] = range(len(sorted_full_data))

# Step 3: Assign fold numbers sequentially
sorted_full_data['fold'] = np.arange(len(sorted_full_data)) % 5

# Identify extremes
quantiles_full = sorted_full_data['target'].quantile([percentile, 1 - percentile])
extremes_full = sorted_full_data[
    (sorted_full_data['target'] <= quantiles_full[percentile]) | (sorted_full_data['target'] >= quantiles_full[1 - percentile])
]

# Step 5: Assign extreme values to training folds
sorted_full_data_extreme = sorted_full_data.copy()
sorted_full_data_extreme.loc[extremes_full.index, 'fold'] = np.random.randint(0, 4, len(extremes_full))

# Function to visualize each step with the 'fold' column added to the initial data and an additional step for adding extremes to training folds
def visualize_all_steps(data, sorted_data, extremes, quantiles, sorted_data_extreme):
    # Define a consistent style
    sns.set(style="whitegrid", context="talk")

    # Step 0: Initial Data
    fig, ax = plt.subplots(figsize=(12, 8))
    sns.scatterplot(data=data, x='id', y='target', ax=ax, color='dodgerblue', s=50)
    ax.set_title('Step 0: Initial Data', fontsize=18)
    ax.set_xlabel('ID', fontsize=14)
    ax.set_ylabel('SiO2', fontsize=14)
    plt.tight_layout()
    plt.show()

    # Step 1: Sort the Data by SiO2
    fig, ax = plt.subplots(figsize=(12, 8))
    sns.scatterplot(data=sorted_data.assign(fold=np.nan), x='id', y='target', ax=ax, color='dodgerblue', s=50)
    ax.set_title('Step 1: Sorted Data by SiO2', fontsize=18)
    ax.set_xlabel('ID', fontsize=14)
    ax.set_ylabel('SiO2', fontsize=14)
    plt.tight_layout()
    plt.show()

    # Step 2: Assign Fold Numbers Sequentially
    fig, ax = plt.subplots(figsize=(12, 8))
    sns.scatterplot(data=sorted_data, x='id', y='target', hue='fold', palette='Set1', ax=ax, s=50)
    ax.set_title('Step 2: Assign Fold Numbers Sequentially', fontsize=18)
    ax.set_xlabel('ID', fontsize=14)
    ax.set_ylabel('SiO2', fontsize=14)
    plt.legend(title='Fold', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
    plt.tight_layout()
    plt.show()

    # Step 3: Ensure Extremes in Training (Highlight Extremes and Percentiles)
    fig, ax = plt.subplots(figsize=(12, 8))
    sns.scatterplot(data=sorted_data, x='id', y='target', hue='fold', palette='Set1', ax=ax, s=50)
    sns.scatterplot(data=extremes, x='id', y='target', edgecolor='red', facecolor='none', linewidth=2, s=100, ax=ax, legend=False)
    ax.axhline(y=quantiles[percentile], color='green', linestyle='--', label=f'{percentile * 100}th Percentile')
    ax.axhline(y=quantiles[1 - percentile], color='blue', linestyle='--', label=f'{(1 - percentile) * 100}th Percentile')
    ax.set_title('Step 3: Identify Extremes (Highlight Extremes and Percentiles)', fontsize=18)
    ax.set_xlabel('ID', fontsize=14)
    ax.set_ylabel('SiO2', fontsize=14)
    plt.legend(title='Fold', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
    plt.tight_layout()
    plt.show()
    
    # Step 4: Assign Extremes to Training Folds
    palette = sns.color_palette('Set1', n_colors=5)  # Ensure the same palette is used
    fold_colors = {fold: color for fold, color in enumerate(palette)}

    g = sns.FacetGrid(sorted_data_extreme, col="fold", col_wrap=3, height=5, aspect=1.5)
    for ax, (fold, color) in zip(g.axes.flat, fold_colors.items()):
        subset = sorted_data_extreme[sorted_data_extreme['fold'] == fold]
        sns.scatterplot(data=subset, x='id', y='target', color=color, s=50, ax=ax)
        ax.axhline(y=quantiles[percentile], color='green', linestyle='--', label=f'{percentile * 100}th Percentile')
        ax.axhline(y=quantiles[1 - percentile], color='blue', linestyle='--', label=f'{(1 - percentile) * 100}th Percentile')
    g.set_axis_labels('ID', 'SiO2')
    g.set_titles(col_template='Fold {col_name}', size=16)
    plt.tight_layout()
    plt.show()

# Visualize all steps together with the additional step for assigning extremes to training folds
visualize_all_steps(full_data_relevant, sorted_full_data, extremes_full, quantiles_full, sorted_full_data_extreme)