# Plotting functions for the various BayesOpt Implementations

In [2]:
# imports
from matplotlib.cm import ScalarMappable
import torch
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # for 3D plotting
from botorch.utils.multi_objective.pareto import is_non_dominated

In [None]:
# Helper functions to be used across notebooks
%run HelperFunctions_MOBO.ipynb

# same but for .py version
#from HelperFunctions_MOBO import *

In [None]:
# to enable GPU processing
if torch.cuda.is_available():
    #print(f"CUDA is available. Number of devices: {torch.cuda.device_count()}")
    # If you have multiple GPUs, specify the desired device ordinal:
    device = torch.device(f"cuda:0")  # Use GPU 0
else:
    #print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")  

tkwargs = {'device': device, 'dtype': torch.double}
# output 'dtype': torch.float64 bc. in PyTorch double & float64 are equivalent
#print(tkwargs)

In [None]:
# TODO: Limit choice of options for xax and yax to "legal" ones
# TODO: Fix colour-scaling so that it accounts for batches
def plot_pareto(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    figname = "figure.png",
    MetModel = None,
    M9_medium = None,
    M9_costs = None
    ):
    """
    Plots growth rate (x-axis) against medium cost (y-axis) for each candidate medium
    Each dot is colour-coded according to the iteration it resulted from.
    Plots Pareto front deduced from data.
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * figname - string - name under which to save the figure
    
    RETURNS
    - 
    """
    
    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np])
    
    # Define batch numbers (iterations)
    N_ITER = len(x_np) # number of candidate mediums = length of growth_rate array
    iterations = np.arange(1, N_ITER + 1)  # Create iteration numbers for each sample
    
    # Create the plot with given size
    fig, axes = plt.subplots(1, 1, figsize = (10, 7))
    
    # get the colormap
    cm = plt.colormaps.get_cmap('viridis')
    
    # Scatter plot of all points, color-coded by iteration (c = iterations)
    # apply colour moa (cmap = cm) and transparence (alpha = 0.8)
    sc = axes.scatter(y[:, 0], y[:, 1], c = iterations, cmap = cm, alpha = 0.8)
    # Set y-axis to log scale if it contains costs
    if yax == "cost tensors":
        axes.set_yscale("log")
    
    """
    Pareto Front
    """
    # negate growth rate because pareto front assumes that minimisation is the goal
    if xax == "growth rate tensor":
        y[:, 0] = -y[:, 0]
    # Sort points by the first objective (growth rate) 
    # -> allows to plot front in order of increasing xax (growth rate)
    y_sorted = y[np.argsort(y[:, 0])]
    
    # Compute non-dominated (Pareto front) points; i.e. optimal trade.offs
    is_pareto = is_non_dominated(torch.tensor(-y_sorted).to(**tkwargs))

    # Plot the Pareto front
    axes.plot(
        [-y[0] for pareto, y in zip(is_pareto, y_sorted) if pareto], # negate again so it's back o orig. value
        [y[1] for pareto, y in zip(is_pareto, y_sorted) if pareto],
        label="Pareto Front",
        color="r",
        linewidth=2,
    )

    """
    M9 on Pareto front
    """
    # if model, M9 and M9 costs are given
    if MetModel and M9_medium and M9_costs:
        # Set the medium for M9 in the model
        MetModel.medium = M9_medium
        # Optimize model to get growth rate and potentially production rate
        MetModel.objective = results["biomass_objective"]
        solution = MetModel.optimize()
        M9_growth_rate = solution.fluxes[results["biomass_objective"]]

        if yax == "production tensors" and results["production_objective"] is not None:
            # Use production rate if specified
            M9_y_value = solution.fluxes[results["production_objective"]]
        else:
            # Otherwise, calculate cost
            M9_y_value = calc_cost_tot(M9_costs, M9_medium)
        
        # Plot M9 point in black
        axes.scatter(
            M9_growth_rate, M9_y_value, 
            color = "black", marker = "x", label = "M9 Medium", s = 100, zorder = 5
        )


    """
    Add labels and titles
    """
    #if xax == "growth rate tensors":
    axes.set_xlabel("Growth Rate [1/h]")
    
    if yax == "cost tensors":
        axes.set_ylabel("Medium Cost")
        axes.set_title("Growth Rate vs Medium Cost with Pareto Front")
    if yax == "production tensors":
        axes.set_ylabel("Production Rate [1/h]")
        axes.set_title("Growth Rate vs Production Rate with Pareto Front")
    
    # Normalize the color bar according to iteration
    norm = plt.Normalize(iterations.min(), iterations.max())
    sm = ScalarMappable(norm=norm, cmap=cm)
    sm.set_array([])
    
    # Add the color bar
    cbar = fig.colorbar(sm, ax=axes)
    cbar.ax.set_title("Iteration")
    
    # Display the legend
    axes.legend()
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.savefig(figname, dpi=fig.dpi)

In [None]:
# TODO: Limit choice of options for xax and yax to "legal" ones
def plot_pareto_3D(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    zax = "production tensors",
    figname = "figure.png",
    n_candidates = 1,
    MetModel = None,
    M9_medium = None,
    M9_costs = None
    ):
    """
    Plots growth rate (x-axis) against medium cost (y-axis) for each candidate medium
    Each dot is colour-coded according to the iteration it resulted from.
    Plots Pareto front deduced from data.
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * figname - string - name under which to save the figure
    
    RETURNS
    - 
    """
    
    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    z_np = results[zax].cpu().numpy() # are positive
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np, z_np])
    
    # Define batch numbers (iterations)
    N_ITER = len(x_np) # number of candidate mediums = length of growth_rate array
    iterations = np.arange(1, N_ITER + 1)  # Create iteration numbers for each sample

    # Create figure and add 3D subplot
    fig = plt.figure(figsize = (12, 7))
    axes = fig.add_subplot(111, projection='3d')
    
    # get the colormap
    cm = plt.colormaps.get_cmap('viridis')
    
    # Scatter plot of all points, color-coded by iteration (c = iterations)
    # apply colour moa (cmap = cm) and transparence (alpha = 0.8)
    sc = axes.scatter(y[:, 0], y[:, 1], y[:, 2], c = iterations, cmap = cm, alpha = 0.8)
    # Set x-axis to log scale if it contains costs
    
    #if yax == "cost tensors": axes.set_yscale("log")
    
    """
    Pareto Front
    """
    # negate growth rate because pareto front assumes that minimisation is the goal
    if xax == "growth rate tensor":
        y[:, 0] = -y[:, 0]
    # Sort points by the first objective (growth rate) 
    # -> allows to plot front in order of increasing xax (growth rate)
    y_sorted = y[np.argsort(y[:, 0])]
    
    # Compute non-dominated (Pareto front) points; i.e. optimal trade.offs
    is_pareto = is_non_dominated(torch.tensor(-y_sorted).to(**tkwargs))
    
    # Plotting the pareto surface requires at least three datapoints
    if (sum(is_pareto) > 3):
        # Extract Pareto points
        pareto_x = [-y[0] for pareto, y in zip(is_pareto, y_sorted) if pareto]
        pareto_y = [y[1] for pareto, y in zip(is_pareto, y_sorted) if pareto]
        pareto_z = [y[2] for pareto, y in zip(is_pareto, y_sorted) if pareto]

        # Plot Pareto front surface
        axes.plot_trisurf(pareto_x, pareto_y, pareto_z, color = 'red', alpha=0.5, linewidth=0.2)
    
    """
    M9 on Pareto front
    """
    # if model, M9 and M9 costs are given
    if MetModel and M9_medium and M9_costs:
        # Set the medium for M9 in the model
        MetModel.medium = M9_medium
        # Optimize model to get growth rate and potentially production rate
        MetModel.objective = results["biomass_objective"]
        solution = MetModel.optimize()
        M9_growth_rate = solution.fluxes[results["biomass_objective"]]

        M9_production = solution.fluxes[results["production_objective"]]
        M9_costs = calc_cost_tot(M9_costs, M9_medium)
        
        # Plot M9 point as a black cross in 3D
        axes.scatter(
            M9_growth_rate,  # x-axis: growth rate
            M9_costs,        # y-axis: medium cost
            M9_production,   # z-axis: production rate
            color="black", marker="x", label="M9 Medium", s=100, zorder=5
        )

    """
    Add labels and titles
    """
    #if xax == "growth rate tensors":
    axes.set_xlabel("Growth Rate [1/h]")
    axes.set_ylabel("Medium Cost")
    axes.set_zlabel("Production Rate [1/h]")
    axes.set_title("Growth Rate vs Production Rate vs Costs with Pareto Front")
    
    # Normalize the color bar according to iteration
    norm = plt.Normalize(iterations.min(), iterations.max())
    sm = ScalarMappable(norm=norm, cmap=cm)
    sm.set_array([])
    
    # Add the color bar
    cbar = fig.colorbar(sm, ax=axes, shrink=0.5, aspect=5)
    cbar.ax.set_title("Iteration")
    
    # Display the legend
    axes.legend()
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.savefig(figname, dpi=fig.dpi)

In [None]:
# TODO: Limit choice of options for xax and yax to "legal" ones
def plot_3D(
    results, 
    xax = "growth rate tensors", 
    yax = "cost tensors", 
    zax = "production tensors",
    figname = "figure.png",
    n_candidates = 1,
    MetModel = None,
    M9_medium = None,
    M9_costs = None
    ):
    """
    Plots growth rate (x-axis) against medium cost (y-axis) for each candidate medium
    Each dot is colour-coded according to the iteration it resulted from.
    Plots Pareto front deduced from data.
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * figname - string - name under which to save the figure
    
    RETURNS
    - 
    """
    
    # extract data from results (growth rate and medium costs)
    x_np = results[xax].cpu().numpy() # are positive
    y_np = results[yax].cpu().numpy() # are positive
    z_np = results[zax].cpu().numpy() # are positive
    
    
    # Stack the two objectives (growth rate and medium cost) into a single 2Darray
    # rows: candidates
    # columns: grwoth rate, medium costs
    y = np.column_stack([x_np, y_np, z_np])
    
    # Define batch numbers (iterations)
    N_ITER = len(x_np) # number of candidate mediums = length of growth_rate array
    iterations = np.arange(1, N_ITER + 1)  # Create iteration numbers for each sample
    
    # Create figure and add 3D subplot
    fig = plt.figure(figsize=(12, 7))
    ax = fig.add_subplot(projection='3d')
    
    # get the colormap
    cm = plt.colormaps.get_cmap('viridis')
    
    # Scatter plot of all points, color-coded by iteration (c = iterations)
    # apply colour moa (cmap = cm) and transparence (alpha = 0.8)
    
    ax.scatter(y[:, 0], y[:, 1], y[:, 2], c = iterations, cmap = cm, alpha = 0.8)
    """    
    # Set x-axis to log scale if it contains costs
    if yax == "cost tensors":
        ax.set_yscale("log")
    """
    
    """
    M9 on Pareto front
    """
    # if model, M9 and M9 costs are given
    if MetModel and M9_medium and M9_costs:
        # Set the medium for M9 in the model
        MetModel.medium = M9_medium
        # Optimize model to get growth rate and potentially production rate
        MetModel.objective = results["biomass_objective"]
        solution = MetModel.optimize()
        M9_growth_rate = solution.fluxes[results["biomass_objective"]]

        M9_production = solution.fluxes[results["production_objective"]]
        M9_costs = calc_cost_tot(M9_costs, M9_medium)
        
        # Plot M9 point as a black cross in 3D
        ax.scatter(
            M9_growth_rate,  # x-axis: growth rate
            M9_costs,        # y-axis: medium cost
            M9_production,   # z-axis: production rate
            color="black", marker="x", label="M9 Medium", s=100, zorder=5
        )

    """
    Add labels and titles
    """
    #if xax == "growth rate tensors":
    ax.set_xlabel("Growth Rate [1/h]")
    ax.set_ylabel("Medium Cost")
    ax.set_zlabel("Production Rate [1/h]")
    ax.set_title("Growth Rate vs Production Rate vs Costs with Pareto Front")
    
    # Normalize the color bar according to iteration
    norm = plt.Normalize(iterations.min(), iterations.max())
    sm = ScalarMappable(norm=norm, cmap=cm)
    sm.set_array([])
    
    # Add the color bar
    cbar = fig.colorbar(sm, ax=ax)
    cbar.ax.set_title("Iteration")
    
    # Display the legend
    ax.legend()
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.savefig(figname, dpi=fig.dpi)

In [None]:
# TODO: Throw error when results doesn't have a cost tensor entry
def plot_growth_per_cost(results, figname = "figure.png", M9_medium = None):
    """
    Plots growth rate per cost (x-axis) against iteration for each candidate medium
    Each dot is colour-coded according to the iteration it resulted from.
    Saves the figure (as png file)

    PARAMETERS
    * results - dictionary - output of media_BayesOpt
    * figname - string - name under which to save the figure
    
    RETURNS
    - 
    """
    
    # extract data from results (growth rate and medium costs)
    growth_rates = results['growth rate tensors'].cpu().numpy() # are positive
    medium_costs = results['cost tensors'].cpu().numpy() # are positive
    
    #growth_costs = calc_growth_cost(growth_rates, medium_costs)
    growth_costs = np.divide(growth_rates, medium_costs, out=np.zeros_like(growth_rates), where=medium_costs!=0)
    
    # Define batch numbers (iterations)
    N_ITER = len(growth_costs) # number of candidate mediums = length of growth_rate array
    iterations = np.arange(1, N_ITER + 1)  # Create iteration numbers for each sample
    
    # Create the plot with given size
    fig, axes = plt.subplots(1, 1, figsize = (10, 7))
    
    # get the colormap
    cm = plt.colormaps.get_cmap('viridis')
    
    # Scatter plot of all points, color-coded by iteration (c = iterations)
    # apply colour moa (cmap = cm) and transparence (alpha = 0.8)
    sc = axes.scatter(x = iterations, y = growth_costs, c = iterations, cmap = cm, alpha = 0.8)
    
    # Add labels and titles
    axes.set_xlabel("Iteration")
    axes.set_ylabel("Growth [1/h] per Cost")
    axes.set_title("Growth [1/h] per Cost for each tested medium composition")
    
    # Normalize the color bar according to iteration
    norm = plt.Normalize(iterations.min(), iterations.max())
    sm = ScalarMappable(norm=norm, cmap=cm)
    sm.set_array([])
    
    # Add the color bar
    cbar = fig.colorbar(sm, ax=axes)
    cbar.ax.set_title("Iteration")
    
    # Show the plot
    plt.show()

    # Save the figure
    figname = figname
    fig.savefig(figname, dpi=fig.dpi)