# Plotting functions for the various BayesOpt Implementations

In [None]:
# imports
import torch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.cm import ScalarMappable
import matplotlib.colors as mcolors
import matplotlib.cm as cm

import colormaps as cmaps # for scientific colourmaps

from botorch.utils.multi_objective.pareto import is_non_dominated

In [None]:
# Helper functions to be used across notebooks
%run HelperFunctions_MOBO_II.ipynb

In [None]:
# imports for .py version
#from BayesOpt_MOBO_II import *
#from HelperFunctions_MOBO_II import *

In [None]:
# to enable GPU processing
if torch.cuda.is_available():
    #print(f"CUDA is available. Number of devices: {torch.cuda.device_count()}")
    # If you have multiple GPUs, specify the desired device ordinal:
    device = torch.device(f"cuda:0")  # Use GPU 0
else:
    #print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")  

tkwargs = {'device': device, 'dtype': torch.double}
# output 'dtype': torch.float64 bc. in PyTorch double & float64 are equivalent
#print(tkwargs)

In [None]:
# colour gradient is batch to batch
def plot_pareto_batch_colour(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    figname = "figure.png",
    MetModel = None,
    initial_medium = None,
    initial_costs = None,
    model_objective = None,
    ):
    """
    Plots the performance category indicated by "xax" (default: growth rate) on x-axis against 
    the performance category indicated by "yax" (default: medium cost) on the y-axis
    for each candidate medium.
    Each dot is colour-coded according to the batch it resulted from.
    Plots Pareto front deduced from data.
    When the metabolic model (MetModel), the baseline medium (initial_medium) and 
    the corresponding baseline costs (initial_costs) are given, the performance of the baseline
    medium is plotted onto the graph to allow for visual performance comparison
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * xax - string - variable to plot on x-axis
    * yax - string - variable to plot on y-axis
    * figname - string - name under which to save the figure
    * MetModel - cobra model - Metabolic model for simulation & used for optimisation
    * initial_medium - dictionary - Medium for initial simulation
    * initial_costs - dictionary - Costs associated with the initial medium
    * model_objective
    
    RETURNS
    - 
    """
    valid_values = {"growth rate tensors", "cost tensors", "production tensors"}
    if xax not in valid_values or yax not in valid_values:
        raise ValueError(f"xax and yax must be one of {valid_values}, but got xax='{xax}' and yax='{yax}'")


    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    
    # scale costs by a factor
    if xax == "cost tensors":
        x_np *= 1e-3
    if yax == "cost tensors":
        y_np *= 1e-3
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np])
    
    # Create the plot with given size
    fig, axes = plt.subplots(1, 1, figsize = (9, 7))
    

    # Define colour mapping - 1 for random initial points, 1 each per batch (n_iter)
    n_batch = results["n_iter"]
    n_start = results["n_start"]
    n_candidates = results["n_candidates"]
    
    # Generate distinct colours
    colours = cmaps.bamako.cut(0.05, "right")(np.linspace(0, 1, n_batch + 1))
    # Create a custom colourmap for the colour bar
    cmap = mcolors.ListedColormap(colours)

    # Create an array to store colours for each data point
    point_colours = np.zeros(len(results[xax]), dtype = object)

    # Assign first n_start points the same colour
    point_colours[:n_start] = [colours[0]] * n_start
    # Assign different colours to each batch
    for i in range(n_batch):
        start_idx = n_start + i * n_candidates
        end_idx = start_idx + n_candidates
        point_colours[start_idx:end_idx] = [colours[i + 1]] * (end_idx - start_idx) # Assign a new color per batch
    
    # Set boundaries between each batch, from -0.5 to n_batch + 0.5
    boundaries = np.arange(n_batch + 2) - 0.5
    norm = mcolors.BoundaryNorm(boundaries, cmap.N)
    
    # Scatter plot with custom colours, applying transparence (alpha = 0.8)
    sc = axes.scatter(
        x = y[:, 0], 
        y = y[:, 1], 
        c = point_colours, alpha = 0.8)

    # Set y-axis to log scale if it contains costs
    if yax == "cost tensors":
        axes.set_yscale("log")
    
    """
    Pareto Front
    is_non_dominated assumes maximisation
    """
    # negate costs because maximisation is assumed
    factor = 1
    if yax == "cost tensors":
        y[:, 1] = -y[:, 1]
        factor = -1

    # Sort points by the first objective (growth rate) 
    # -> allows to plot front in order of increasing xax (growth rate)
    y_sorted = y[np.argsort(y[:, 0])]
    
    # Compute non-dominated (Pareto front) points; i.e. optimal trade.offs
    is_pareto = is_non_dominated(torch.tensor(y_sorted).to(**tkwargs))

    # Plot the Pareto front
    axes.plot(
        [y[0] for pareto, y in zip(is_pareto, y_sorted) if pareto], # negate again so it's back o orig. value
        [factor * y[1] for pareto, y in zip(is_pareto, y_sorted) if pareto], # negate if production
        label = "Pareto Front",
        color = "r",
        linewidth = 2,
    )

    """
    M9 on Pareto front
    """
    # if model, initial medium and costs are given
    if MetModel and initial_medium and initial_costs:
        # Set the initial medium as medium in the model
        MetModel.medium = initial_medium
        # Set model objective to desired one
        if model_objective is None:
            MetModel.objective = results["model objective"]
        else:
            MetModel.objective = model_objective
        # Run optimisation
        solution = MetModel.optimize()
        initial_growth_rate = solution.fluxes[results["biomass objective"]]
        initial_cost = calc_cost_tot(initial_costs, initial_medium).cpu().numpy()
        
        initial_production = -1        
        if results["production objective"] is not None:
            initial_production = solution.fluxes[results["production objective"]]

        # map initial data to corresponding axes depending on call
        mapping = {
            "growth rate tensors": initial_growth_rate,
            "cost tensors": initial_cost,
            "production tensors": initial_production
        }
        ini_x = mapping.get(xax)
        ini_y = mapping.get(yax)
        
        # scale costs by a factor
        if xax == "cost tensors":
            ini_x *= 1e-3
        if yax == "cost tensors":
            ini_y *= 1e-3

        # Plot initial point as a red cross in 3D
        axes.scatter(
            ini_x, # x-axis
            ini_y, # y-axis
            color = "red", marker = "x", 
            label = "Original Medium", s = 150, zorder = 5
        )

        
    """
    Add labels and titles
    """

    mapping = {
        "growth rate tensors": "Growth Rate [1/h]",
        "cost tensors": "Medium Cost [£/gDW·h]",
        "production tensors": "Production Rate [mmol/gDW·h]"
    }
    axes.set_xlabel(mapping.get(xax), fontsize = 14)
    axes.set_ylabel(mapping.get(yax), fontsize = 14)
    #axes.set_title(f"{mapping.get(xax)} vs. {mapping.get(yax)} With Pareto Front")
    axes.xaxis.set_tick_params(width = 2, labelsize = 10)
    axes.yaxis.set_tick_params(width = 2, labelsize = 10)
    axes.spines["top"].set_visible(False)
    axes.spines["right"].set_visible(False)
    axes.spines["bottom"].set_linewidth(1.5)
    axes.spines["left"].set_linewidth(1.5)

    # Add the color bar
    tick_positions = np.arange(0, n_batch + 1, 5)
    sm = cm.ScalarMappable(cmap = cmap, norm = norm)
    cbar = fig.colorbar(sm, ax = axes, ticks = tick_positions, pad = 0.01)
    cbar.ax.set_title("Iteration", fontsize = 10)
    cbar.ax.tick_params(which = 'minor', size = 0) # turn off minor ticks at colour boundaries
    
    # Display the legend
    axes.legend(fontsize = 12)
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname

    fig.set_size_inches(9, 7)  # Consistent physical size in inches
    fig.savefig(figname, dpi=300, bbox_inches=None)

In [None]:
# colour gradient is datapoint to datapoint
# obsolete compared to plot_pareto_batch_colour
def plot_pareto(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    figname = "figure.png",
    MetModel = None,
    initial_medium = None,
    initial_costs = None,
    model_objective = None,
    ):
    """
    Plots the performance category indicated by "xax" (default: growth rate) on x-axis against 
    the performance category indicated by "yax" (default: medium cost) on the y-axis
    for each candidate medium.
    Each dot is colour-coded according to the iteration it resulted from.
    Plots Pareto front deduced from data.
    When the metabolic model (MetModel), the baseline medium (initial_medium) and 
    the corresponding baseline costs (initial_costs) are given, the performance of the baseline
    medium is plotted onto the graph to allow for visual performance comparison
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * xax - string - variable to plot on x-axis
    * yax - string - variable to plot on y-axis
    * figname - string - name under which to save the figure
    * MetModel - cobra model - Metabolic model for simulation & used for optimisation
    * initial_medium - dictionary - Medium for initial simulation
    * initial_costs - dictionary - Costs associated with the initial medium
    * model_objective
    
    RETURNS
    - 
    """

    # assert that x-axis and y-axis is a legal choices
    valid_values = {"growth rate tensors", "cost tensors", "production tensors"}
    if xax not in valid_values or yax not in valid_values:
        raise ValueError(f"xax and yax must be one of {valid_values}, but got xax='{xax}' and yax='{yax}'")
    
    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np])
    
    # Define batch numbers (iterations)
    N_ITER = len(x_np) # number of candidate mediums = length of growth_rate array
    iterations = np.arange(1, N_ITER + 1)  # Create iteration numbers for each sample
    
    # Create the plot with given size
    fig, axes = plt.subplots(1, 1, figsize = (9, 7))
    
    # get the colormap
    cm = plt.colormaps.get_cmap('viridis')
    
    # Scatter plot of all points, color-coded by iteration (c = iterations)
    # apply colour moa (cmap = cm) and transparence (alpha = 0.8)
    sc = axes.scatter(y[:, 0], y[:, 1], c = iterations, cmap = cm, alpha = 0.8)
    # Set y-axis to log scale if it contains costs
    if yax == "cost tensors":
        axes.set_yscale("log")
    
    """
    Pareto Front
    is_non_dominated assumes maximisation
    """
    # negate costs because maximisation is assumed
    factor = 1
    if yax == "cost tensors":
        y[:, 1] = -y[:, 1]
        factor = -1

    # Sort points by the first objective (growth rate) 
    # -> allows to plot front in order of increasing xax (growth rate)
    y_sorted = y[np.argsort(y[:, 0])]
    
    # Compute non-dominated (Pareto front) points; i.e. optimal trade.offs
    is_pareto = is_non_dominated(torch.tensor(y_sorted).to(**tkwargs))

    # Plot the Pareto front
    axes.plot(
        [y[0] for pareto, y in zip(is_pareto, y_sorted) if pareto], # negate again so it's back o orig. value
        [factor * y[1] for pareto, y in zip(is_pareto, y_sorted) if pareto], # negate if production
        label="Pareto Front",
        color="r",
        linewidth=2,
    )

    """
    M9 on Pareto front
    """
    # if model, initial medium and costs are given
    if MetModel and initial_medium and initial_costs:
        # Set the initial medium as medium in the model
        MetModel.medium = initial_medium
        # Set model objective to desired one
        if model_objective is None:
            MetModel.objective = results["model objective"]
        else:
            MetModel.objective = model_objective
        # Run optimisation
        solution = MetModel.optimize()
        initial_growth_rate = solution.fluxes[results["biomass objective"]]
        initial_cost = calc_cost_tot(initial_costs, initial_medium).cpu().numpy()
        
        initial_production = -1        
        if results["production objective"] is not None:
            initial_production = solution.fluxes[results["production objective"]]

        # map initial data to corresponding axes depending on call
        mapping = {
            "growth rate tensors": initial_growth_rate,
            "cost tensors": initial_cost,
            "production tensors": initial_production
        }
        ini_x = mapping.get(xax)
        ini_y = mapping.get(yax)
        
        # Plot initial point as a red cross in 3D
        axes.scatter(
            ini_x, # x-axis
            ini_y, # y-axis
            color = "black", marker = "x", label = "Original Medium", s = 100, zorder = 5
        )

        
    """
    Add labels and titles
    """
    mapping = {
        "growth rate tensors": "Growth Rate [1/h]",
        "cost tensors": "Medium Cost [$10^{-3}$ £/gDW·h]",
        "production tensors": "Production Rate [mmol/gDW·h]"
    }
    axes.set_xlabel(mapping.get(xax))
    axes.set_ylabel(mapping.get(yax))
    axes.set_title(f"{mapping.get(xax)} vs. {mapping.get(yax)} with Pareto Front")

    
    # Normalize the color bar according to iteration
    norm = plt.Normalize(iterations.min(), iterations.max())
    sm = ScalarMappable(norm=norm, cmap=cm)
    sm.set_array([])
    
    # Add the color bar
    cbar = fig.colorbar(sm, ax=axes, pad=0.12)
    cbar.ax.set_title("Iteration")
    
    # Display the legend
    axes.legend()
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.set_size_inches(9, 7)  # Consistent physical size in inches
    fig.savefig(figname, dpi=300, bbox_inches=None)

In [None]:
# the pareto surface isn't visually helpful
def plot_pareto_3D(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    zax = "production tensors",
    figname = "figure.png",
    n_candidates = 1,
    MetModel = None,
    initial_medium = None,
    initial_costs = None,
    model_objective = None,
    ):
    """
    Plots the performance category indicated by "xax" (default: growth rate) on x-axis against 
    the performance category indicated by "yax" (default: medium cost) on the y-axis against 
    the performance category indicated by "zax" (default: production rate) on the z-axis
    for each candidate medium.
    Each dot is colour-coded according to the iteration it resulted from.
    Plots Pareto front deduced from data.
    When the metabolic model (MetModel), the baseline medium (initial_medium) and 
    the corresponding baseline costs (initial_costs) are given, and the axes choices are the
    default choices, the performance of the baseline medium is plotted onto the graph to allow 
    for visual performance comparison
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * xax - string - variable to plot on x-axis
    * yax - string - variable to plot on y-axis
    * zax - string - variable to plot on z-axis
    * figname - string - name under which to save the figure
    * n_candidates - int - batch-size
    * MetModel - cobra model - Metabolic model for simulation & used for optimisation
    * initial_medium - dictionary - Medium for initial simulation
    * initial_costs - dictionary - Costs associated with the initial medium
    * model_objective
    
    RETURNS
    -
    """
    # assert that x-axis is a legal choice
    opt_xax = ["growth rate tensors", "cost tensors", "production tensors"]
    if xax not in opt_xax:
        raise ValueError(f"Invalid xax. Expected one of: {opt_xax}")
    # assert that y-axis is a legal choice
    opt_yax = ["growth rate tensors", "cost tensors", "production tensors"]
    if yax not in opt_yax:
        raise ValueError(f"Invalid yax. Expected one of: {opt_yax}")
    # assert that z-axis is a legal choice
    opt_zax = ["growth rate tensors", "cost tensors", "production tensors"]
    if zax not in opt_zax:
        raise ValueError(f"Invalid zax. Expected one of: {opt_zax}")
    
    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    z_np = results[zax].cpu().numpy() # are positive
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np, z_np])

    # Create figure and add 3D subplot
    fig = plt.figure(figsize = (9, 7))
    axes = fig.add_subplot(projection='3d')
    
    # Define colour mapping - 1 for random initial points, 1 each per batch (n_iter)
    n_batch = results["n_iter"]
    n_start = results["n_start"]
    n_candidates = results["n_candidates"]
     # Generate distinct colours
    colours = plt.cm.viridis(np.linspace(0, 1, n_batch + 1))

    # Create an array to store colours for each data point
    point_colours = np.zeros(len(results[xax]), dtype = object)

    # Assign first `n_start` points the same color
    point_colours[:n_start] = [colours[0]] * n_start
    # Assign different colours to each batch
    for i in range(n_batch):
        start_idx = n_start + i * n_candidates
        end_idx = start_idx + n_candidates
        point_colours[start_idx:end_idx] = [colours[i + 1]] * (end_idx - start_idx) # Assign a new color per batch
    
    # apply colour moa (cmap = cm) and transparence (alpha = 0.8)
    sc = axes.scatter(y[:, 0], y[:, 1], y[:, 2], c = point_colours, alpha = 0.8)
    # Create a custom colormap for the color bar
    cmap = mcolors.ListedColormap(colours)
    norm = mcolors.BoundaryNorm(np.arange(n_batch + 2) - 0.5, cmap.N)
    
    # Set x-axis to log scale if it contains costs
    #if yax == "cost tensors": axes.set_yscale("log")
    

    """
    Pareto Front
    is_non_dominated assumes maximisation
    whatever axis has cost needs to be minimised
    """
    xfactor = 1
    if xax == "cost tensors":
        y[:, 0] = -y[:, 0]
        xfactor = -1
    # negate costs because maximisation is assumed by is_non_dominated
    yfactor = 1
    if yax == "cost tensors":
        y[:, 1] = -y[:, 1]
        yfactor = -1
    zfactor = 1
    if zax == "cost tensors":
        y[:, 2] = -y[:, 2]
        zfactor = -1
    
    # Sort points by the first objective (growth rate) 
    # -> allows to plot front in order of increasing xax (growth rate)
    y_sorted = y[np.argsort(y[:, 0])]
    
    # Compute non-dominated (Pareto front) points; i.e. optimal trade.offs
    is_pareto = is_non_dominated(torch.tensor(y_sorted).to(**tkwargs))
    
    # Plotting the pareto surface requires at least three datapoints
    if (sum(is_pareto) > 3):
        # Extract Pareto points
        pareto_x = [(xfactor * y[0]) for pareto, y in zip(is_pareto, y_sorted) if pareto]
        pareto_y = [(yfactor * y[1]) for pareto, y in zip(is_pareto, y_sorted) if pareto]
        pareto_z = [(zfactor * y[2]) for pareto, y in zip(is_pareto, y_sorted) if pareto]

        # Plot Pareto front surface
        axes.plot_trisurf(pareto_x, pareto_y, pareto_z, color = 'red', alpha=0.5, linewidth=0.2)
    
    """
    original medium in comparison (black cross)
    """
    
    # if model, initial medium and costs are given
    if MetModel and initial_medium and initial_costs:
        # Set the initial medium as medium in the model
        MetModel.medium = initial_medium
        # Set model objective to desired one
        if model_objective is None:
            MetModel.objective = results["model objective"]
        else:
            MetModel.objective = model_objective

        solution = MetModel.optimize()
        initial_growth_rate = solution.fluxes[results["biomass objective"]]

        initial_production = solution.fluxes[results["production objective"]]
        initial_cost = calc_cost_tot(initial_costs, initial_medium).cpu().numpy()
        
        # map initial data to corresponding axes depending on call
        mapping = {
            "growth rate tensors": initial_growth_rate,
            "cost tensors": initial_cost,
            "production tensors": initial_production
        }
        ini_x = mapping.get(xax)
        ini_y = mapping.get(yax)
        ini_z = mapping.get(zax)
        
        # Plot initial point as a red cross in 3D
        axes.scatter(
            ini_x, # x-axis
            ini_y, # y-axis
            ini_z, # z-axis
            color = "red", marker = "x", label = "Original Medium", s = 100, zorder = 5
        )
    
    """
    Add labels and titles
    """
    mapping = {
        "growth rate tensors": "Growth Rate [1/h]",
        "cost tensors": "Medium Cost [$10^{-3}$ £/gDW·h]",
        "production tensors": "Production Rate [mmol/gDW·h]"
    }
    axes.set_xlabel(mapping.get(xax))
    axes.set_ylabel(mapping.get(yax))
    axes.set_zlabel(mapping.get(zax))
    axes.set_title("Growth Rate vs Production Rate vs Costs")

    # Add the color bar
    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    cbar = fig.colorbar(sm, ax=axes, pad=0.12)
    cbar.ax.set_title("Iteration")
    
    # Display the legend
    axes.legend()
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.set_size_inches(9, 7)  # Consistent physical size in inches
    fig.savefig(figname, dpi=300, bbox_inches=None)

In [None]:
def plot_3D(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    zax = "production tensors",
    figname = "figure.png",
    n_candidates = 1,
    MetModel = None,
    initial_medium = None,
    initial_costs = None,
    model_objective = None,
    ):
    """
    Plots the performance category indicated by "xax" (default: growth rate) on the x-axis against 
    the performance category indicated by "yax" (default: medium cost) on the y-axis against
    the performance category indicated by "zax" (default: production rate) on the z-axis
    for each candidate medium.
    Each dot is colour-coded according to the iteration it resulted from.
    If the metabolic model (MetModel), the baseline medium (initial_medium) and 
    the corresponding baseline costs (initial_costs) are given, the performance of the baseline
    medium is plotted onto the graph to allow for visual performance comparison
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * xax - string - variable to plot on x-axis
    * yax - string - variable to plot on y-axis
    * zax - string - variable to plot on z-axis
    * figname - string - name under which to save the figure
    * n_candidates - int - batch-size
    * MetModel - cobra model - Metabolic model for simulation & used for optimisation
    * initial_medium - dictionary - Medium for initial simulation
    * initial_costs - dictionary - Costs associated with the initial medium
    * model_objective
    
    RETURNS
    - 
    """

    # assert that x-axis is a legal choice
    opt_xax = ["growth rate tensors", "cost tensors", "production tensors"]
    if xax not in opt_xax:
        raise ValueError(f"Invalid xax. Expected one of: {opt_xax}")
    # assert that y-axis is a legal choice
    opt_yax = ["growth rate tensors", "cost tensors", "production tensors"]
    if yax not in opt_yax:
        raise ValueError(f"Invalid yax. Expected one of: {opt_yax}")
    # assert that z-axis is a legal choice
    opt_zax = ["growth rate tensors", "cost tensors", "production tensors"]
    if zax not in opt_zax:
        raise ValueError(f"Invalid zax. Expected one of: {opt_zax}")
    
    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    z_np = results[zax].cpu().numpy() # are positive

    # scale costs by a factor
    if xax == "cost tensors":
        x_np *= 1e-3
    if yax == "cost tensors":
        y_np *= 1e-3
    if zax == "cost tensors":
        z_np *= 1e-3
    
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np, z_np])
    
    # Create figure and add 3D subplot
    fig = plt.figure(figsize = (9, 7))
    ax = fig.add_subplot(projection='3d')
    
    # Define colour mapping - 1 for random initial points, 1 each per batch (n_iter)
    #n_batch = results["n_iter"]
    n_batch = 100 # for 3D plot for data from 2024-11-19
    #n_start = results["n_start"]
    n_start = 50 # for 3D plot for data from 2024-11-19
    #n_candidates = results["n_candidates"]
    n_candidates = 10 # for 3D plot for data from 2024-11-19

    # Generate distinct colours
    colours = cmaps.bamako.cut(0.05, "right")(np.linspace(0, 1, n_batch + 1))
    # Create a custom colourmap for the colour bar
    cmap = mcolors.ListedColormap(colours)

    # Create an array to store colours for each data point
    point_colours = np.zeros(len(results[xax]), dtype = object)

    # Assign first `n_start` points the same color
    point_colours[:n_start] = [colours[0]] * n_start
    # Assign different colours to each batch
    for i in range(n_batch):
        start_idx = n_start + i * n_candidates
        end_idx = start_idx + n_candidates
        point_colours[start_idx:end_idx] = [colours[i + 1]] * (end_idx - start_idx) # Assign a new color per batch

    # Set boundaries between each batch, from -0.5 to n_batch + 0.5
    boundaries = np.arange(n_batch + 2) - 0.5
    norm = mcolors.BoundaryNorm(boundaries, cmap.N)

    # Scatter plot with custom colours, applying transparence (alpha = 0.8)    
    ax.scatter(
        y[:, 0], 
        y[:, 1], 
        y[:, 2], 
        c = point_colours, alpha = 0.8)
    
    """
    original medium in comparison (red cross)
    """
    
    # if model, initial medium and costs are given
    if MetModel and initial_medium and initial_costs:
        # Set the initial medium as medium in the model
        MetModel.medium = initial_medium
        # Set model objective to desired one
        if model_objective is None:
            MetModel.objective = results["model objective"]
        else:
            MetModel.objective = model_objective

        solution = MetModel.optimize()
        initial_growth_rate = solution.fluxes[results["biomass objective"]]

        initial_production = solution.fluxes[results["production objective"]]
        initial_cost = calc_cost_tot(initial_costs, initial_medium).cpu().numpy()
        
        # map initial data to corresponding axes depending on call
        mapping = {
            "growth rate tensors": initial_growth_rate,
            "cost tensors": initial_cost,
            "production tensors": initial_production
        }
        ini_x = mapping.get(xax)
        ini_y = mapping.get(yax)
        ini_z = mapping.get(zax)

        # scale costs by a factor
        if xax == "cost tensors":
            ini_x *= 1e-3
        if yax == "cost tensors":
            ini_y *= 1e-3
        if zax == "cost tensors":
            ini_z *= 1e-3
        
        # Plot initial point as a red cross in 3D
        ax.scatter(
            ini_x, # x-axis
            ini_y, # y-axis
            ini_z, # z-axis
            color = "red", marker = "x", 
            label = "Original Medium", s = 150, zorder = 50
        )
    
    """
    Add labels and titles
    """
    mapping = {
        "growth rate tensors": "Growth Rate [1/h]",
        "cost tensors": "Medium Cost [£/gDW·h]",
        "production tensors": "Production Rate [mmol/gDW·h]"
    }
    ax.set_xlabel(mapping.get(xax), fontsize = 14)
    ax.set_ylabel(mapping.get(yax), fontsize = 14)
    ax.set_zlabel(mapping.get(zax), fontsize = 14)
    #ax.set_title("Growth Rate vs Production Rate vs Costs")

    # Add the color bar
    tick_positions = np.arange(0, n_batch + 1, 5)
    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    cbar = fig.colorbar(sm, ax=ax, ticks = tick_positions, pad = 0.075)
    cbar.ax.set_title("Iteration", fontsize = 12)
    cbar.ax.tick_params(which='minor', size=0) # turn off minor ticks at colour boundaries
    
    # Display the legend
    ax.legend(fontsize = 12)
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.set_size_inches(9, 7)  # Consistent physical size in inches
    fig.savefig(figname, dpi=300, bbox_inches=None)

In [None]:
# TODO: Throw error when results doesn't have a cost tensor entry
def plot_growth_per_cost(
        results, figname = "figure.png"):
    """
    Plots growth rate per cost (x-axis) against iteration for each candidate medium
    Each dot is colour-coded according to the iteration it resulted from.
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * figname - string - name under which to save the figure
    
    RETURNS
    - 
    """
    
    # extract data from results (growth rate and medium costs)
    growth_rates = results['growth rate tensors'].cpu().numpy() # are positive
    medium_costs = results['cost tensors'].cpu().numpy() # are positive
    
    #growth_costs = calc_growth_cost(growth_rates, medium_costs)
    growth_costs = np.divide(growth_rates, medium_costs, out = np.zeros_like(growth_rates), where = medium_costs != 0)
    
    # Define batch numbers (iterations)
    N_ITER = len(growth_rates) # number of candidate mediums = length of growth_rate array
    iterations = np.arange(1, N_ITER + 1)  # Create iteration numbers for each sample

    # Create the plot with given size
    fig, axes = plt.subplots(1, 1, figsize = (9, 7))
    
    # Define colour mapping - 1 for random initial points, 1 each per batch (n_iter)
    n_batch = results["n_iter"]
    n_start = results["n_start"]
    n_candidates = results["n_candidates"]
    
    # Generate distinct colours
    colours = cmaps.bamako.cut(0.05, "right")(np.linspace(0, 1, n_batch + 1))
    # Create a custom colourmap for the colour bar
    cmap = mcolors.ListedColormap(colours)

    # Create an array to store colours for each data point
    point_colours = np.zeros(N_ITER, dtype = object)

    # Assign first `n_start` points the same color
    point_colours[:n_start] = [colours[0]] * n_start
    # Assign different colours to each batch
    for i in range(n_batch):
        start_idx = n_start + i * n_candidates
        end_idx = start_idx + n_candidates
        point_colours[start_idx:end_idx] = [colours[i + 1]] * (end_idx - start_idx) # Assign a new color per batch
    
    
    # Set boundaries between each batch, from -0.5 to n_batch + 0.5
    boundaries = np.arange(n_batch + 2) - 0.5
    norm = mcolors.BoundaryNorm(boundaries, cmap.N)
    # Scatter plot with custom colours, applying transparence (alpha = 0.8)
    sc = axes.scatter(
        x = iterations, 
        y = growth_costs, 
        c = point_colours, alpha = 0.8)

    # draw a line along the best so far result (growth/cost)
    best_so_far = -np.inf
    best_values = []
    for val in growth_costs:
        if val > best_so_far:
            best_so_far = val
        best_values.append(best_so_far)
    axes.step(
        iterations, best_values, where = "post", color = "black", linestyle = "-",
        linewidth = 3, label = "Best Growth per Cost So Far")
    
    # Add labels and titles
    # axes
    axes.set_xlabel("Sample Number", fontsize = 14)
    axes.set_ylabel("Growth per Cost Cost [gDW/$10^{-3}£$]", fontsize = 14)
    axes.xaxis.set_tick_params(width = 2, labelsize = 10)
    axes.yaxis.set_tick_params(width = 2, labelsize = 10)
    axes.spines["top"].set_visible(False)
    axes.spines["right"].set_visible(False)
    axes.spines["bottom"].set_linewidth(1.5)
    axes.spines["left"].set_linewidth(1.5)

    # Add the color bar
    tick_positions = np.arange(0, n_batch + 1, 5)
    sm = cm.ScalarMappable(cmap = cmap, norm = norm)
    cbar = fig.colorbar(sm, ax = axes, ticks = tick_positions, pad = 0.01) # pad defines distance
    cbar.ax.set_title("Iteration", fontsize = 10)
    cbar.ax.tick_params(which = "minor", size = 0) # turn off minor ticks at colour boundaries
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.set_size_inches(9, 7)  # Consistent physical size in inches
    fig.savefig(figname, dpi=300, bbox_inches=None)

In [None]:
def plot_production_per_cost_coloured_by_growth(
    results,
    figname = "figure.png",
    growth_threshold = 0.5,
    production_threshold = 0.01,
    MetModel = None,
    initial_medium = None,
    initial_costs = None,
    model_objective = None
    ):
    """
    Subsets the results to the data points where both growth and production are above a given threshold.
    Plots production per cost coloured by growth, with an additional point for the initial medium.
    
    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * figname - string - name under which to save the figure
    * growth_threshold - double - Minimum growth rate threshold
    * production_threshold - double - Minimum production threshold
    * MetModel - cobra model - Metabolic model for simulation & used for optimisation
    * initial_medium - dictionary - Medium for initial simulation
    * initial_costs - dictionary - Costs associated with the initial medium

    RETURNS
    -
    """
    # Boolean masks for the conditions
    growth_mask = results["growth rate tensors"] >= growth_threshold
    production_mask = results["production tensors"] >= production_threshold

    # Combined mask for both conditions
    combined_mask = growth_mask & production_mask

    # Find indices where both conditions are satisfied
    indices = torch.nonzero(combined_mask, as_tuple=False)

    # Subset data based on indices
    growth_np = results["growth rate tensors"][indices[:, 0]].cpu().numpy()
    production_np = results["production tensors"][indices[:, 0]].cpu().numpy()
    cost_np = results["cost tensors"][indices[:, 0]].cpu().numpy()
    # scale costs by a factor
    cost_np = cost_np * 1e-3

    #cmap = plt.get_cmap("YlOrRd")
    cmap = plt.get_cmap("cividis")
    
    # Plotting
    plt.figure(figsize=(9, 7))
    scatter = plt.scatter(cost_np, production_np, c = growth_np, cmap = cmap, edgecolor = 'k', alpha = 0.9)
    cbar = plt.colorbar(scatter, pad = 0.01)
    cbar.set_label("Growth Rate [1/h]", fontsize = 12) 
    plt.xlabel("Cost [£/gDW·h]", fontsize = 14)
    plt.ylabel("Production [mmol/gDW·h]", fontsize = 14)
    #plt.title("Production Per Cost Coloured by Growth")
    plt.grid(True)

    # If MetModel, initial_medium, and initial_costs are provided
    if MetModel and initial_medium and initial_costs:
        # Set the medium for M9 in the model
        MetModel.medium = initial_medium
        
        # Set model objective to desired one
        if model_objective is None:
            MetModel.objective = results["model objective"]
        else:
            MetModel.objective = model_objective

        # Run optimization
        solution = MetModel.optimize()
        initial_growth_rate = solution.fluxes[results["biomass objective"]]
        initial_production_rate = solution.fluxes[results["production objective"]]
        costs = calc_cost_tot(initial_costs, initial_medium).cpu().numpy()
    
        # Normalize initial growth rate for color mapping
        #initial_color = cmap(norm(initial_growth_rate))
        initial_color = cmap(initial_growth_rate)
        if growth_np.min() == growth_np.max():
            initial_color = cmap(0.5)  # Assign a single middle color for all points
    
        # Plot initial point
        plt.scatter(
            costs * 1e-3, initial_production_rate, 
            color = initial_color, marker = "D", label = "Original Medium", 
            s = 150, edgecolor = "k", zorder = 3, alpha = 0.9
        )
    
    plt.legend(fontsize = 12)
    plt.savefig(figname, dpi = 300)
    plt.show()
    plt.close()