In [None]:
import numpy as np
import scipy
import random
from datetime import datetime

import MLCE_CWBO2025.virtual_lab as virtual_lab

In [None]:
# ---------------- Objective wrapper ----------------
def objective_func(X):
    return np.array(virtual_lab.conduct_experiment(X), dtype=float)


# ---------------- Constants / bounds ----------------
CELLTYPES = ["celltype_1", "celltype_2", "celltype_3"]

BOUNDS_5 = np.array(
    [
        [30.0, 40.0],  # T
        [6.0, 8.0],    # pH
        [0.0, 50.0],   # F1
        [0.0, 50.0],   # F2
        [0.0, 50.0],   # F3
    ],
    dtype=float,
)


def norm5(X):
    return (X - BOUNDS_5[:, 0]) / (BOUNDS_5[:, 1] - BOUNDS_5[:, 0])


def denorm5(Xn):
    return Xn * (BOUNDS_5[:, 1] - BOUNDS_5[:, 0]) + BOUNDS_5[:, 0]

In [None]:

# ---------------- GP helpers ----------------
def rbf_kernel(X1, X2, lengthscales, signal_var):
    scaled = (X1[:, None, :] - X2[None, :, :]) / lengthscales
    sqdist = np.sum(scaled * scaled, axis=2)
    return signal_var * np.exp(-0.5 * sqdist)


def gp_posterior_std(X_train, y_train, X_test, lengthscales, signal_var, noise_var):
    """
    GP posterior on standardised y.
    Returns: mean_s, var_s, y_mean, y_std
    where y_s = (y - y_mean)/y_std
    """
    y_train = np.asarray(y_train, dtype=float).reshape(-1)
    y_mean = float(np.mean(y_train))
    y_std = float(np.std(y_train))
    if y_std < 1e-8:
        y_std = 1.0
    y_s = (y_train - y_mean) / y_std

    K = rbf_kernel(X_train, X_train, lengthscales, signal_var)
    n = X_train.shape[0]
    K[np.diag_indices(n)] += noise_var

    # Cholesky + jitter
    try:
        L = np.linalg.cholesky(K)
    except np.linalg.LinAlgError:
        K[np.diag_indices(n)] += 1e-8
        L = np.linalg.cholesky(K)

    alpha = np.linalg.solve(L.T, np.linalg.solve(L, y_s))
    K_s = rbf_kernel(X_train, X_test, lengthscales, signal_var)

    mean_s = K_s.T @ alpha

    v = np.linalg.solve(L, K_s)
    K_ss_diag = np.full(X_test.shape[0], signal_var, dtype=float)  # RBF: k(x,x)=signal_var
    var_s = K_ss_diag - np.sum(v * v, axis=0)
    var_s = np.maximum(var_s, 1e-12)

    return mean_s, var_s, y_mean, y_std


def expected_improvement_std(mean_s, var_s, best_s, xi):
    std_s = np.sqrt(var_s)
    with np.errstate(divide="ignore", invalid="ignore"):
        z = (mean_s - best_s - xi) / std_s
        Phi = scipy.stats.norm.cdf(z)
        phi = scipy.stats.norm.pdf(z)
        ei = (mean_s - best_s - xi) * Phi + std_s * phi
        ei[std_s < 1e-12] = 0.0
    return ei

In [None]:
# ---------------- V2.1 BO ----------------
class BO:
    """
    V2.1:
    - Separate GP per celltype (5D continuous only)
    - Standardise y per celltype GP
    - Greedy batch: choose highest EI across celltypes each pick
    - Kriging Believer fantasy update within the selected celltype
    - Scoring-aware schedule: exploration early, exploitation later
    - Proper time bookkeeping: dt_ms per objective call + zero padding
    """

    def __init__(self, X_initial, iterations, batch, objective_func):
        self.iterations = int(iterations)
        self.batch = int(batch)

        # Hyperparams in normalised [0,1] space
        self.lengthscales = np.array([0.22, 0.22, 0.22, 0.22, 0.22], dtype=float)
        self.signal_var = 1.0
        self.noise_var = 1e-6

        self.X = [row[:] for row in X_initial]
        self.Y = np.array([], dtype=float)
        self.time = []  # ms, one entry per evaluation

        # --- Evaluate initial points (1 objective call) ---
        t0 = datetime.timestamp(datetime.now())
        Y_init = objective_func(self.X)
        dt_ms = 1000.0 * (datetime.timestamp(datetime.now()) - t0)

        self.Y = np.concatenate([self.Y, Y_init])
        self.time += [dt_ms] + [0.0] * (len(Y_init) - 1)

        self.best_so_far = float(np.max(self.Y)) if len(self.Y) else -np.inf

        # --- Optimisation rounds: each round is 1 objective call returning <=batch points ---
        for r in range(self.iterations):
            X_batch = self._propose_batch(round_index=r)
            t0 = datetime.timestamp(datetime.now())
            Y_batch = objective_func(X_batch)
            dt_ms = 1000.0 * (datetime.timestamp(datetime.now()) - t0)

            self.X.extend(X_batch)
            self.Y = np.concatenate([self.Y, Y_batch])
            self.time += [dt_ms] + [0.0] * (len(Y_batch) - 1)

            yb = float(np.max(Y_batch)) if len(Y_batch) else -np.inf
            if yb > self.best_so_far:
                self.best_so_far = yb

    def _split_by_celltype(self):
        data = {c: {"Xn": [], "y": []} for c in CELLTYPES}
        for x, y in zip(self.X, self.Y):
            c = x[5]
            x5 = np.array([x[0], x[1], x[2], x[3], x[4]], dtype=float)
            data[c]["Xn"].append(norm5(x5))
            data[c]["y"].append(float(y))
        for c in CELLTYPES:
            data[c]["Xn"] = np.array(data[c]["Xn"], dtype=float).reshape(-1, 5)
            data[c]["y"] = np.array(data[c]["y"], dtype=float).reshape(-1)
        return data

    def _schedule(self, round_index):
        """
        Scoring-aware schedule:
        - rounds 0-2: explore more (higher xi, more global)
        - rounds 3-14: exploit more (lower xi, more local focus)
        """
        if round_index <= 2:
            return {
                "xi": 0.05,
                "n_global": 8000,
                "n_local": 2000,
                "local_sigma": 0.08,
                "force_explore_every": 2,  # force 1 random global candidate every 2 picks
            }
        else:
            return {
                "xi": 0.01,
                "n_global": 6000,
                "n_local": 5000,
                "local_sigma": 0.05,
                "force_explore_every": 4,  # still keep some exploration
            }

    def _propose_batch(self, round_index):
        data = self._split_by_celltype()

        # Fantasy copies for batch construction
        Xn_f = {c: data[c]["Xn"].copy() for c in CELLTYPES}
        y_f = {c: data[c]["y"].copy() for c in CELLTYPES}

        cfg = self._schedule(round_index)
        xi = cfg["xi"]
        n_global = cfg["n_global"]
        n_local = cfg["n_local"]
        local_sigma = cfg["local_sigma"]
        force_explore_every = cfg["force_explore_every"]

        chosen = []

        for k in range(self.batch):
            # Force an occasional pure exploration point (cheap anti-stuck)
            if force_explore_every > 0 and (k % force_explore_every == 0):
                c = CELLTYPES[k % 3]
                xn = np.random.rand(5)
                x5 = denorm5(xn)
                chosen.append([float(x5[0]), float(x5[1]), float(x5[2]), float(x5[3]), float(x5[4]), c])
                # Fantasy: if we have a model, fantasise mean; else append 0 in y-scale
                if y_f[c].shape[0] >= 2:
                    mu_s, _, y_mean, y_std = gp_posterior_std(
                        Xn_f[c], y_f[c], xn.reshape(1, 5),
                        self.lengthscales, self.signal_var, self.noise_var
                    )
                    y_fant = float(mu_s[0]) * float(y_std) + float(y_mean)
                else:
                    y_fant = float(np.mean(y_f[c])) if y_f[c].shape[0] else 0.0
                Xn_f[c] = np.vstack([Xn_f[c], xn.reshape(1, 5)])
                y_f[c] = np.concatenate([y_f[c], [y_fant]])
                continue

            best_ei = -1.0
            best_cell = None
            best_xn = None
            best_mu_s = None
            best_y_mean = None
            best_y_std = None

            for c in CELLTYPES:
                # If too little data in this celltype, treat as exploration
                if y_f[c].shape[0] < 2:
                    cand = np.random.rand(n_global, 5)
                    j = random.randrange(cand.shape[0])
                    # Give exploration a baseline EI so it competes early
                    score = 1e-3
                    if score > best_ei:
                        best_ei = score
                        best_cell = c
                        best_xn = cand[j]
                        best_mu_s = 0.0
                        best_y_mean = 0.0
                        best_y_std = 1.0
                    continue

                Xtr = Xn_f[c]
                ytr = y_f[c]

                best_idx_c = int(np.argmax(ytr))
                centre = Xtr[best_idx_c]

                global_cand = np.random.rand(n_global, 5)
                local_cand = centre + local_sigma * np.random.randn(n_local, 5)
                local_cand = np.clip(local_cand, 0.0, 1.0)
                cand = np.vstack([global_cand, local_cand])

                mu_s, var_s, y_mean, y_std = gp_posterior_std(
                    Xtr, ytr, cand,
                    self.lengthscales, self.signal_var, self.noise_var
                )

                best_s = (self.best_so_far - y_mean) / y_std
                ei = expected_improvement_std(mu_s, var_s, best_s, xi=xi)
                j = int(np.argmax(ei))

                if float(ei[j]) > best_ei:
                    best_ei = float(ei[j])
                    best_cell = c
                    best_xn = cand[j]
                    best_mu_s = float(mu_s[j])
                    best_y_mean = float(y_mean)
                    best_y_std = float(y_std)

            x5 = denorm5(best_xn)
            chosen.append([float(x5[0]), float(x5[1]), float(x5[2]), float(x5[3]), float(x5[4]), best_cell])

            # Kriging Believer fantasy update
            y_fant = float(best_mu_s) * float(best_y_std) + float(best_y_mean)
            Xn_f[best_cell] = np.vstack([Xn_f[best_cell], best_xn.reshape(1, 5)])
            y_f[best_cell] = np.concatenate([y_f[best_cell], [y_fant]])

        return chosen

In [None]:
# ---------------- Visualisation ----------------
def plot_results(bo, submission="group1_local_v2_1.py", out_png="group1_v2_1_individual.png"):
    try:
        import matplotlib.pyplot as plt
    except Exception:
        print("matplotlib not available; skipping plots.")
        return

    Y = np.asarray(bo.Y).reshape(-1)
    t_ms = np.asarray(bo.time).reshape(-1)
    cum_t = np.cumsum(t_ms)
    cum_Y = np.cumsum(Y)

    fig = plt.figure(figsize=(10, 12))
    fig.suptitle(
        f"ML4CE 2025/26 Data-driven Optimisation Coursework Results\nSubmission: {submission}",
        fontsize=14
    )

    ax1 = fig.add_subplot(5, 1, 1)
    ax2 = fig.add_subplot(5, 1, 2)
    ax3 = fig.add_subplot(5, 1, 3)
    ax4 = fig.add_subplot(5, 1, 4)
    ax5 = fig.add_subplot(5, 1, 5)

    ax1.plot(np.arange(len(t_ms)), t_ms, marker="o", linestyle="None", markersize=3)
    ax1.set_ylabel("Time [ms]")
    ax1.set_xlabel("Iterations")

    ax2.plot(np.arange(len(Y)), Y, linewidth=2)
    ax2.set_ylabel("Titre Conc. [g/L]")
    ax2.set_xlabel("Iterations")

    ax3.plot(np.arange(len(cum_t)), cum_t, linewidth=2)
    ax3.set_ylabel("Cumulative Time [ms]")
    ax3.set_xlabel("Iterations")

    ax4.plot(np.arange(len(cum_Y)), cum_Y, linewidth=2)
    ax4.set_ylabel("[g/L] Cumulative Titre Conc. [g/L]")
    ax4.set_xlabel("Iterations")

    ax5.plot(cum_t, cum_Y, linewidth=3, color="red")
    ax5.set_ylabel("Cumulative Titre Conc. [g/L]")
    ax5.set_xlabel("Cumulative Time [ms]")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.savefig(out_png, dpi=200)
    plt.show()


In [None]:

# ---------------- Runner ----------------
def make_initial_points(n_init=6):
    pts = []
    for i in range(n_init):
        T = 30.0 + 10.0 * random.random()
        pH = 6.0 + 2.0 * random.random()
        F1 = 50.0 * random.random()
        F2 = 50.0 * random.random()
        F3 = 50.0 * random.random()
        c = CELLTYPES[i % 3]
        pts.append([T, pH, F1, F2, F3, c])
    return pts


def main():
    X_initial = make_initial_points(6)

    bo = BO(X_initial=X_initial, iterations=15, batch=5, objective_func=objective_func)

    best = float(np.max(bo.Y))
    print("Best objective value found:", best)
    print("Total evaluations:", len(bo.Y))

    # best-so-far per round (init + 15 rounds)
    best_curve = []
    y = bo.Y
    best_curve.append(float(np.max(y[:6])))
    for r in range(15):
        end = 6 + (r + 1) * 5
        best_curve.append(float(np.max(y[:end])))
    print("Best-so-far per round (init + 15 rounds):")
    print(best_curve)
    print("Total runtime measured (ms):", float(np.sum(bo.time)))

    plot_results(bo)
    return bo


if __name__ == "__main__":
    main()