# Initial Data Visualization for Black-Box Functions

This notebook loads the provided initial design points for functions 1-8 and visualises the relationship between inputs and outputs. The function descriptions come from the companion notebook `Summry_&_Initial_Analysis_of_Black_Box_Functions.ipynb`.

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", context="talk", palette="viridis")
%matplotlib inline

In [None]:
# Paths and metadata
DATA_DIR = Path("../Data/InitialData")
FUNCTION_METADATA = {
    1: {"description": "Sparse 2D contamination sources with mostly zero readings.", "input_dim": 2},
    2: {"description": "Noisy 2D landscape with local optima.", "input_dim": 2},
    3: {"description": "3D drug formulation experiments with noisy adverse reaction counts.", "input_dim": 3},
    4: {"description": "Dynamic 4D warehouse placement surrogate model with local optima.", "input_dim": 4},
    5: {"description": "Unimodal 4D chemical yield optimisation.", "input_dim": 4},
    6: {"description": "5D cake recipe scoring where outputs are negative by design.", "input_dim": 5},
    7: {"description": "6D hyperparameter tuning with expensive evaluations.", "input_dim": 6},
    8: {"description": "8D black-box optimisation aiming for strong local maxima.", "input_dim": 8},
}


In [None]:
def load_function_data(function_id: int) -> pd.DataFrame:
    """Load inputs/outputs for a specific function into a DataFrame."""
    func_dir = DATA_DIR / f"function_{function_id}"
    inputs = np.load(func_dir / "initial_inputs.npy")
    outputs = np.load(func_dir / "initial_outputs.npy")

    input_cols = [f"x{i+1}" for i in range(inputs.shape[1])]
    df = pd.DataFrame(inputs, columns=input_cols)
    df["output"] = outputs
    return df


In [None]:
# Summarise all functions in a single table
summary_rows = []
for fid, meta in FUNCTION_METADATA.items():
    df = load_function_data(fid)
    summary_rows.append(
        {
            "function": fid,
            "input_dim": meta["input_dim"],
            "num_points": len(df),
            "output_min": df["output"].min(),
            "output_max": df["output"].max(),
            "output_mean": df["output"].mean(),
        }
    )

summary_df = pd.DataFrame(summary_rows).set_index("function")
display(summary_df)


## Helper plotting utilities

The plots aim to highlight structure in each dataset:

- **2D functions (1-2):** scatter plots of inputs coloured by the output, revealing local maxima or sparse responses.
- **3D function (3):** pairwise scatter matrix with a hue based on output quantiles to show how outputs vary across the design space.
- **4D functions (4-5):** scatter plots of output versus each input plus an input correlation heatmap to reveal dependencies.
- **5D+ functions (6-8):** input correlation heatmap alongside scatter plots for the three most variable inputs against the output, surfacing influential dimensions.


In [None]:
def add_quantile_band(df: pd.DataFrame) -> pd.DataFrame:
    """Attach a categorical quantile label for colouring plots."""
    quantiles = pd.qcut(df["output"], q=4, labels=["Q1 (low)", "Q2", "Q3", "Q4 (high)"])
    return df.assign(output_band=quantiles)


def plot_function(df: pd.DataFrame, function_id: int, meta: dict):
    input_cols = [col for col in df.columns if col.startswith("x")]
    dim = len(input_cols)

    if dim == 2:
        fig, ax = plt.subplots(figsize=(8, 6))
        scatter = ax.scatter(
            df[input_cols[0]],
            df[input_cols[1]],
            c=df["output"],
            cmap="viridis",
            s=80,
            edgecolor="black",
        )
        ax.set_title(f"Function {function_id}: 2D response landscape")
        ax.set_xlabel(input_cols[0])
        ax.set_ylabel(input_cols[1])
        cbar = fig.colorbar(scatter, ax=ax)
        cbar.set_label("Output")
        plt.show()
        return

    if dim == 3:
        enriched = add_quantile_band(df)
        g = sns.pairplot(
            enriched,
            vars=input_cols,
            hue="output_band",
            diag_kind="hist",
            corner=True,
        )
        g.fig.suptitle(
            f"Function {function_id}: pairwise relationships with output quantiles",
            y=1.02,
        )
        plt.show()
        return

    # For 4D and above
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    corr = df[input_cols].corr()
    sns.heatmap(corr, ax=axes[0], vmin=-1, vmax=1, cmap="coolwarm", annot=True, fmt=".2f")
    axes[0].set_title(f"Function {function_id}: input correlation")

    # Pick up to three most variable inputs for output comparison
    variable_inputs = df[input_cols].var().sort_values(ascending=False).index[:3]
    melted = df.melt(
        id_vars="output",
        value_vars=variable_inputs,
        var_name="input",
        value_name="value",
    )
    sns.scatterplot(data=melted, x="value", y="output", hue="input", ax=axes[1])
    axes[1].set_title(f"Function {function_id}: output vs key inputs")
    axes[1].set_xlabel("Input value")
    axes[1].set_ylabel("Output")

    plt.tight_layout()
    plt.show()


## Visualisations per function

Run the cell below to generate visuals for each function in sequence. Adjust the plotting logic above if you want different insights per function.


In [None]:
for fid, meta in FUNCTION_METADATA.items():
    display({"function": fid, "description": meta["description"]})
    df = load_function_data(fid)
    plot_function(df, fid, meta)
