C. Toy model results.  Show a Pareto front of Growth vs. Cost, as in Fig 2b of your report.

goal: 2025-04-07_BayesOpt_textbook_growth-cost_qPAREGO_30it_round_1_pareto_batch_colour.png

In [12]:
# imports
import torch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.cm import ScalarMappable
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import colormaps as cmaps # for scientific colormaps

from botorch.utils.multi_objective.pareto import is_non_dominated

In [None]:
# function
def plot_pareto_batch_colour(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    figname = "figure.png",
    MetModel = None,
    initial_medium = None,
    initial_costs = None,
    model_objective = None,
    ):
    """
    Plots the performance category indicated by "xax" (default: growth rate) on x-axis against 
    the performance category indicated by "yax" (default: medium cost) on the y-axis
    for each candidate medium.
    Each dot is colour-coded according to the batch it resulted from.
    Plots Pareto front deduced from data.
    When the metabolic model (MetModel), the baseline medium (initial_medium) and 
    the corresponding baseline costs (initial_costs) are given, the performance of the baseline
    medium is plotted onto the graph to allow for visual performance comparison
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * xax - string - variable to plot on x-axis
    * yax - string - variable to plot on y-axis
    * figname - string - name under which to save the figure
    * MetModel - cobra model - Metabolic model for simulation & used for optimisation
    * initial_medium - dictionary - Medium for initial simulation
    * initial_costs - dictionary - Costs associated with the initial medium
    * model_objective
    
    RETURNS
    - 
    """
    valid_values = {"growth rate tensors", "cost tensors", "production tensors"}
    if xax not in valid_values or yax not in valid_values:
        raise ValueError(f"xax and yax must be one of {valid_values}, but got xax='{xax}' and yax='{yax}'")


    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np])
    
    # Create the plot with given size
    fig, axes = plt.subplots(1, 1, figsize = (9, 7))
    

    # Define colour mapping - 1 for random initial points, 1 each per batch (n_iter)
    n_batch = results["n_iter"]
    n_start = results["n_start"]
    n_candidates = results["n_candidates"]
    
    # Generate distinct colours
    colours = cmaps.batlow(np.linspace(0, 1, n_batch))
    # Create a custom colourmap for the colour bar
    cmap = mcolors.ListedColormap(colours)

    # Create an array to store colours for each data point
    point_colours = np.zeros(len(results[xax]), dtype = object)

    # Assign first n_start points the same colour
    point_colours[:n_start] = [colours[0]] * n_start
    # Assign different colours to each batch
    for i in range(n_batch):
        start_idx = n_start + i * n_candidates
        end_idx = start_idx + n_candidates
        point_colours[start_idx:end_idx] = [colours[i + 1]] * (end_idx - start_idx) # Assign a new color per batch
    
    # Set boundaries between each batch, from -0.5 to n_batch + 0.5
    boundaries = np.arange(n_batch + 2) - 0.5
    norm = mcolors.BoundaryNorm(boundaries, cmap.N)

    # Scatter plot with custom colours, applying transparence (alpha = 0.8)
    sc = axes.scatter(y[:, 0], y[:, 1], c = point_colours, alpha = 0.8)

    # Set y-axis to log scale if it contains costs
    if yax == "cost tensors":
        axes.set_yscale("log")
    
    """
    Pareto Front
    is_non_dominated assumes maximisation
    """
    # negate costs because maximisation is assumed
    factor = 1
    if yax == "cost tensors":
        y[:, 1] = -y[:, 1]
        factor = -1

    # Sort points by the first objective (growth rate) 
    # -> allows to plot front in order of increasing xax (growth rate)
    y_sorted = y[np.argsort(y[:, 0])]
    
    # Compute non-dominated (Pareto front) points; i.e. optimal trade.offs
    is_pareto = is_non_dominated(torch.tensor(y_sorted).to(**tkwargs))

    # Plot the Pareto front
    axes.plot(
        [y[0] for pareto, y in zip(is_pareto, y_sorted) if pareto], # negate again so it's back o orig. value
        [factor * y[1] for pareto, y in zip(is_pareto, y_sorted) if pareto], # negate if production
        label="Pareto Front",
        color="r",
        linewidth=2,
    )

    """
    M9 on Pareto front
    """
    # if model, initial medium and costs are given
    if MetModel and initial_medium and initial_costs:
        # Set the initial medium as medium in the model
        MetModel.medium = initial_medium
        # Set model objective to desired one
        if model_objective is None:
            MetModel.objective = results["model objective"]
        else:
            MetModel.objective = model_objective
        # Run optimisation
        solution = MetModel.optimize()
        initial_growth_rate = solution.fluxes[results["biomass objective"]]
        initial_cost = calc_cost_tot(initial_costs, initial_medium).cpu().numpy()
        
        initial_production = -1        
        if results["production objective"] is not None:
            initial_production = solution.fluxes[results["production objective"]]

        # map initial data to corresponding axes depending on call
        mapping = {
            "growth rate tensors": initial_growth_rate,
            "cost tensors": initial_cost,
            "production tensors": initial_production
        }
        ini_x = mapping.get(xax)
        ini_y = mapping.get(yax)
        
        # Plot initial point as a red cross in 3D
        axes.scatter(
            ini_x, # x-axis
            ini_y, # y-axis
            color = "red", marker = "x", label = "Original Medium", s = 100, zorder = 5
        )

        
    """
    Add labels and titles
    """

    mapping_with_units = {
        "growth rate tensors": "Growth Rate [1/h]",
        "cost tensors": "Medium Cost [$10^{-3}$ £/gDW·h]",
        "production tensors": "Production Rate [mmol/gDW·h]"
    }
    # axes
    axes.set_xlabel(mapping_with_units.get(xax), fontsize = 14)
    axes.set_ylabel(mapping_with_units.get(yax), fontsize = 14)
    axes.xaxis.set_tick_params(width = 2, labelsize = 10)
    axes.yaxis.set_tick_params(width = 2, labelsize = 10)
    # title
    mapping = {
        "growth rate tensors": "Growth Rate",
        "cost tensors": "Medium Cost",
        "production tensors": "Production Rate"
    }
    axes.set_title(f"{mapping.get(xax)} vs. {mapping.get(yax)} With Pareto Front", fontsize = 16)

    # Add the color bar
    tick_positions = np.arange(0, n_batch + 1, 5)
    sm = cm.ScalarMappable(cmap = cmap, norm = norm)
    cbar = fig.colorbar(sm, ax = axes, ticks = tick_positions, pad=0.12)
    cbar.ax.set_title("Iteration", fontsize = 10)
    cbar.ax.tick_params(which = 'minor', size = 0) # turn off minor ticks at colour boundaries
    
    # Display the legend
    axes.legend()
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname

    fig.set_size_inches(9, 7)  # Consistent physical size in inches
    fig.savefig(figname, dpi=300, bbox_inches=None)