# Multi-objective Bayesian Optimisation

# Imports

In [1]:
# When running in google colab
#pip install cobra

In [2]:
# When running in google colab
#pip install botorch

In [None]:
import torch
import numpy as np

# BayesOpt
from botorch.fit import fit_gpytorch_mll
from botorch.utils.transforms import unnormalize, normalize # for normalising media components
# sampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.sampling import sample_simplex

# ACQUISITION FUNCTION
# for qPAREGO
from botorch.optim.optimize import optimize_acqf_list 
from botorch.acquisition.logei import qLogExpectedImprovement
from botorch.acquisition.objective import GenericMCObjective
from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization
from botorch.utils.multi_objective.pareto import is_non_dominated
# for q(Log)NEHVI
from botorch.optim import optimize_acqf 
from botorch.acquisition.multi_objective.logei import qLogNoisyExpectedHypervolumeImprovement # for qNEHVI

### Helper Functions & Plotting

In [None]:
# Plotting functions to be used across notebooks
%run HelperFunctions_MOBO_II.ipynb

In [None]:
# imports for .py version
#from Plotting_MOBO_II import *
#from BayesOpt_MOBO_II import *
#from HelperFunctions_MOBO_II import *

# BayesOpt

## Next Candidate

In [None]:
# ToDo: ref-point for qNEHVI -> Reference Points
def  find_next_candidates(
        medium_tensors_normalised_stacked,
        growth_tensors,
        cost_tensors, 
        production_tensors = None,
        opt_objective = "growth-cost",
        AF_type = "qPAREGO",
        n_candidates = 5
        ):
    """
    Finds the next medium composition for which to evaluate cost and optimal growth rate
    * initialises botorch model (list of SingleTaskGP) and mll
    * fits model using mll
    * sets up SobolQMCNormalSampler to sample from posterior
    * when qPAREGO is to be used
        * computes posterior mean
        * initialises list of acquisition functions (one per candidate of batch)
        * for each candidate (batch size)
            * uses chebyshev_scalarization to create a vector representation of the chosen objectives
            * defines qLogExpectedImprovement acquisition function and appends to list
        * finds all candidates depending on the acquisition functions
    * when qNEHVI is to be used
        * set re-point
        * initialise qLogNoisyExpectedHypervolumeImprovement acquisition function
        * find n candidates using optimize_acqf

    PARAMETERS
    * medium_tensors_normalised_stacked - tensor - all medium compositions previously evaluated 0-1 normalised, 
    stored as tensors (in order)
    * growth_tensors - tensor - corresponding growth rates
    * cost_tensors - tensor - corresponding medium costs
    * production_tensors - tensor - corresponding production rates
    * opt-objective - string - the (multi-)objective for which to find the optimal medium composition
    * AF_Type - string - which acquisition function to use (qPAREGO or qNEHVI)
    * n_candidates - integer - how many candidates to find at once

    RETURNS
    * candidates - tensor - a tensors with n_candidates 0-1 normalised medium compositions to be tested
    """
    

    '''parameters and conversion to tensors; normalisation of medium composition to (0,1)'''
    MC_SAMPLES = 256 #256 # Number of Monte Carlo samples in SobolQMCNormalSampler
    n_components = medium_tensors_normalised_stacked.size()[1] # of medium components
    standard_bounds = torch.tensor([[0.0] * n_components,
                                    [1.0] * n_components]).to(**tkwargs) # normalised bounds for medium composition
    # large values -> slower but possibly better accuracy
    NUM_RESTARTS =  5 #10 # Number of restarts for acquisition function optimisation
    RAW_SAMPLES = 512 # 1024 # Number of raw samples for initialisation of acquisition optimisation


    '''finding the new candidate'''
    # initialise GP model and marginal likelihood (mll)
    mll, model = initialise_model(
        medium_tensors_normalised_stacked,
        growth_tensors,
        opt_objective, 
        cost_tensors, 
        production_tensors)
    
    fit_gpytorch_mll(mll) # Fit the model using the maximum marginal likelihood
    # Set up a Sobol quasi-Monte Carlo sampler for sampling from the posterior
    # The sample_shape should correspond to the shape of the posterior samples needed
    # https://botorch.readthedocs.io/en/latest/sampling.html#botorch.sampling.normal.SobolQMCNormalSampler
    sampler = SobolQMCNormalSampler(sample_shape = torch.Size([MC_SAMPLES]), seed = MC_SAMPLES)


    if AF_type == "qPAREGO":
        # Compute the posterior mean for the given medium_tensors_stacked using the model
        with torch.no_grad():
            posterior = model.posterior(medium_tensors_normalised_stacked).mean

        acq_fun_list = [] # List to hold acquisition functions for each candidate
        # Loop to generate acquisition functions for each candidate
        for _ in range(n_candidates):
            # Sample weights from the simplex for Chebyshev scalarization
            weights = sample_simplex(2, **tkwargs).squeeze() # using 2 weights for scalarization (growth and cost or production)

            # Compute the scalarised objective values for all the training points
            # Sample weights from the simplex for Chebyshev scalarization
            if opt_objective == "growth-cost":
                scalarized_objective_values = (
                    weights[0] * growth_tensors + 
                    weights[1] * cost_tensors)
            elif opt_objective == "growth-production":
                scalarized_objective_values = (
                    weights[0] * growth_tensors + 
                    weights[1] * production_tensors)
            elif opt_objective == "production-cost":
                scalarized_objective_values = (
                    weights[0] * production_tensors + 
                    weights[1] * cost_tensors)
            elif opt_objective == "growth-production-cost":
                # using 3 weights for scalarization (growth, production, and cost)
                weights = sample_simplex(3, **tkwargs).squeeze()
                scalarized_objective_values = (
                    weights[0] * growth_tensors + 
                    weights[1] * production_tensors +
                    weights[2] * cost_tensors)

            # Find the best observed scalarized objective value
            best_f = scalarized_objective_values.max().item()

            # Define objective
            objective = GenericMCObjective(
                get_chebyshev_scalarization(weights, posterior)
            )

            # Define the acquisition function using quasi Monte Carlo EI
            acq_fun = qLogExpectedImprovement(
                model = model, # List of SingleTastk GP
                best_f = best_f, # best objective value observed so far - replaces X_baseline in Noisy version
                sampler = sampler, # SobolQMCNormalSampler
                objective = objective, # combination of objectives - Chebyshev scalarization
            )
            acq_fun_list.append(acq_fun)

        candidates, _ = optimize_acqf_list(
            acq_function_list = acq_fun_list,  # List of acquisition functions to optimise
            bounds = standard_bounds, # The normalised bounds for optimisation
            num_restarts = NUM_RESTARTS, # Number of restarts for optimisation
            raw_samples = RAW_SAMPLES, # Number of raw samples for initialisation (?)
            options = {"batch_limit": 10, "maxiter": 200,} # Options for acquisition function optimisation
        )
 
    elif AF_type == "qNEHVI":
        '''REFERENCE POINT'''
        # set based on domain knowledge
        # should be set slightly worse than the current Pareto Front estimate
        # it should be possible to find datapoints dominating both variables
        if opt_objective == "growth-cost":
            # print(growth_tensors.max(), cost_tensors.max(), sep = "\n")
            ref_point = torch.stack([
                torch.tensor(0.6), 
                torch.tensor(200)
                ])
        elif opt_objective == "growth-production":
            ref_point = torch.stack([
                torch.tensor(0.6), 
                torch.tensor(0.0002)
                ])
        elif opt_objective == "production-cost":
            ref_point = torch.stack([
                torch.tensor(0.0002), 
                torch.tensor(200)
                ])
        elif opt_objective == "growth-production-cost":
            ref_point = torch.stack([
                torch.tensor(0.6), 
                torch.tensor(0.0002),
                torch.tensor(200)
                ])

        # partition non-dominated space into disjoint rectangles
        # model has been initialised based on optimisation goal
        acq_func = qLogNoisyExpectedHypervolumeImprovement(
            model = model,
            ref_point = ref_point.tolist(),  # use known reference point
            X_baseline = medium_tensors_normalised_stacked, # normalize(train_x, problem.bounds)
            prune_baseline = True,  # prune baseline points that have estimated zero probability of being Pareto optimal
            sampler = sampler
            )
        
        # optimize
        candidates, _ = optimize_acqf(
            acq_function = acq_func, # qNEHVI
            bounds = standard_bounds, # The normalised bounds for optimisation
            q = n_candidates,
            num_restarts = NUM_RESTARTS, # Number of restarts for optimisation
            raw_samples = RAW_SAMPLES, # used for intialization heuristic
            options = {"batch_limit": 20, "maxiter": 200}, # Options for acquisition function optimisation
            sequential = True,
            )
        
    """     
    # SANITY CHECK: Get model predictions at selected candidate locations
    # does the model believe that it's doing well?
    with torch.no_grad():
        growth_pred = model.models[0].posterior(candidates).mean.squeeze()
        cost_pred_transformed = model.models[1].posterior(candidates).mean.squeeze()
    print("Candidate Prediction Sanity Check:")
    print("Predicted Growth:", growth_pred)
    print("Predicted Cost (Transformed):", cost_pred_transformed)
    """
    
    return candidates # candidate_tensor_normalised

## Main

In [None]:
# TODO: Use the dedicated function (for medium conversion)
# TODO: check that model.medium and costs have the same number of entries
def media_BayesOpt(
        MetModel, 
        medium = None, 
        bounds = None, 
        costs = None,
        opt_objective = "growth-cost",
        biomass_objective = None,
        production_objective = None,
        AF_type = "qPAREGO",
        n_start = 5,
        n_iter = 50,
        n_candidates = 5,
        model_objective = None
        ):
    """
    Performs medium optimisation for various objectives: trade-off between 
    * growth rate and medium cost
    * growth rate and production rate
    * production rate and medium cost
    * growth rate, production rate and medium cost

    1. Sets default values for medium, bounds and costs if not provided by the user
    2. Performs optimisation n_iter (default = 50) times
        1. calls generate_initial_data(args) to generate initial data points
        2. finds new candidate medium calling find_next_candidate(args)
        3. evaluates new medium for growth rate and costs
        4. keeps all values
    3. returns optimal composition alongside corresponding cost, growth, and cost-growth trade-off

    PARAMETERS:
    * MetModel - COBRApy model - the metabolic model to be evaluated
    * medium - dictionary - the medium composition of that model; if not provided defaults to default medium provided by CobraPy
    * bounds - dictionary - upper and lower bounds for the values the medium components are allowed to take,
    determines the search space; if not provided defaults to 0, and current medium value
    * costs - dictionary - the (monetary) cost of each component; if not provided defaults to unit costs
    * opt_objective - string - indicates what is to be optimised
    * biomass_objective - string - the name of the biomass reaction of the chosen model
    * production_objective - string - the name of the producing reaction to be maximised of the chosen model
    * AF_Type - string - which acquisition function to use (qPAREGO or qNEHVI)
    * n_start - integer - how many random media compositions are to be created to set up the BayesOpt
    * n_iter - integer - how many candidate medium compositions should be found and evaluated
    * n_candidates - integer - how many candidates to find at once
    * model_objective - CPBRApy objective - what is set as the objective of the modelled organisms, used in FBA

    RETURNS:
    A dictionary containing
    * "medium list" - a list of all evaluated medium compositions
    * "medium component bounds" - a dictionary with the upper and lower bounds of each medium components
    * "medium component costs" - a dictionary with the cost of each medium component
    * "growth rate tensors" - a tensor with corresponding growth rates
    * "cost tensors" - a tensor with corresponding total medium costs
    * "production tensors" - a tensor with corresponding production rates
    * "is pareto" -
    * "optimisation objective" - the objective with which the algorithm was run
    * "biomass objective" - the biomass function to be optimised
    * "production objective" - the production flux to be optimised
    * "model objective" - the COBRApy objective of the model that was used
    * "AF_type" - the acquisition function that was used
    * "n_start" - number of random start points
    * "n_iter" - number of iterations
    * "n_candidates" - batch size
    """

    '''TEST VALIDITY OF ARGUMENTS'''
    # AF_Type
    valid_AF_types = {"qPAREGO", "qNEHVI"}
    if AF_type not in valid_AF_types:
        raise ValueError(f"AF_type must be one of {valid_AF_types}, but got '{AF_type}'")
    
    # opt_objective
    valid_opt_objective = {"growth-cost", "growth-production", "production-cost", "growth-production-cost"}
    if opt_objective not in valid_opt_objective:
        raise ValueError(f"opt_objective must be one of {valid_opt_objective}, but got '{opt_objective}'")

    '''INITIALISE'''
    # Set default values for medium, boundaries and costs
    if medium is None:
        medium = MetModel.medium  # Default medium to model.medium if not provided
    if bounds is None:
        # if no bounds are provided, set the lower limit to 0 and upper to the value in medium
        bounds = {key: (0, medium[key]) for key in medium.keys()}
    if costs is None:
        # set unit costs if no costs are provided
        costs = {key: 1 for key in medium.keys()}
    # TODO: check that model.medium and costs have the same number of entries

    # if a model_objective is given, set it 
    if model_objective:
        MetModel.objective = model_objective

    '''GET RANDOM INITIAL DATA POINTS'''
    # generate n_start initial data points (parameters and corresponding cost + growth rate)
    initial_para, initial_growth, initial_production, initial_cost = generate_initial_data(
        MetModel, medium, bounds, costs,
        n_samples = n_start, opt_objective = opt_objective, 
        biomass_objective = biomass_objective, production_objective = production_objective)
    
    medium_list = initial_para # list of dictonaries
    medium_keys = medium_list[-1].keys() # extract keys from medium_list
    growth_tensors = initial_growth
    growth_tensors_normalised = normalise_1Dtensors(growth_tensors)
    production_tensors = initial_production
    production_tensors_normalised = normalise_1Dtensors(production_tensors)
    cost_tensors = initial_cost # tensor
    cost_tensors_normalised = normalise_1Dtensors(cost_tensors) # min-max normalised
    is_pareto = []
    
    '''CONVERT MEDIUM_LIST TO TENSOR'''
    # TODO: Use the dedicated function
    # convert bounds from dictionary to tensor
    bounds_tensor = torch.tensor(list(bounds.values()), dtype=torch.double).to(**tkwargs) # [x, 2]
    # Stack the lower and upper bounds to match the expected format
    bounds_tensors_stacked = torch.stack([bounds_tensor[:, 0], bounds_tensor[:, 1]], dim=0)

    # normalise medium composition
    medium_tensors_normalised = [] # initialise empty list
    for m in range(len(medium_list)):
        # transform current medium to tensor
        medium_m = medium_list[m]
        medium_m_tensor = torch.tensor(list(medium_m.values()), dtype=torch.double).to(**tkwargs) # [x]
        # normalise medium composition using the bounds
        normalised_medium_m = normalize(medium_m_tensor, bounds_tensors_stacked)
        # Append the normalized tensor to the list
        medium_tensors_normalised.append(normalised_medium_m)
    
    """
    print("Growth:", growth_tensors, 
          "Cost:", cost_tensors, 
          "Cost normalised:", cost_tensors_normalised,
          "Production:", production_tensors,
          "Production normalised:", production_tensors_normalised,
          sep = "\n")
    """
    '''MAIN LOOP'''
    for i in range(n_iter):
        # Stack the list of tensors along a new dimension (dim=0) -> single tensor
        medium_tensors_normalised_stacked = torch.stack(medium_tensors_normalised, dim = 0) # normalised
        # Use BayesOpt to change medium
        '''Need to pass normalised cost and production for qNEHVI, especially'''        
        if AF_type == "qPAREGO":
            candidates_tensor_normalised = find_next_candidates(
                medium_tensors_normalised_stacked,
                growth_tensors,
                (1 - cost_tensors_normalised), # because costs should be minimised but function maximises
                production_tensors = production_tensors_normalised,
                opt_objective = opt_objective,
                AF_type = AF_type,
                n_candidates = n_candidates
                )
        elif AF_type == "qNEHVI":
            """
            # all y's are 0-1 normalised
            candidates_tensor_normalised = find_next_candidates(
                medium_tensors_normalised_stacked,
                growth_tensors_normalised,
                (1 - cost_tensors_normalised), # because costs should be minimised but function maximises
                production_tensors = production_tensors_normalised,
                opt_objective = opt_objective,
                AF_type = AF_type,
                n_candidates = n_candidates
                )
            """
            max_cost = cost_tensors.max()
            candidates_tensor_normalised = find_next_candidates(
                medium_tensors_normalised_stacked,
                growth_tensors,
                (max_cost - cost_tensors), # because costs should be minimised but function maximises
                production_tensors = production_tensors,
                opt_objective = opt_objective,
                AF_type = AF_type,
                n_candidates = n_candidates
                )

        # if n_candidates > 1, each candidate needs to be evaluated individually
        # TODO: parallelise?
        for candidate_tensor_normalised in candidates_tensor_normalised:
                
            # unnormlise new candidate
            candidate_tensor_unnormalised = unnormalize(candidate_tensor_normalised, bounds_tensors_stacked)
            # convert back to dictionary            
            candidate_medium = convert_to_dict(candidate_tensor_unnormalised, medium_keys)
                
            # for new medium compute new values
            cost_tot = calc_cost_tot(costs, candidate_medium) # tensor
            MetModel.medium = candidate_medium # reassign medium
            # perform FBA
            solution = MetModel.optimize()
            # extract growth rate
            # TODO: find a solution for when biomass_objective = None
            FBA_growth = solution.fluxes[biomass_objective]
            # some model compositions lead to FBA returns NaN or negative numbers
            # to avoid them from breaking the algorithm, set growth to zero
            if (np.isnan(FBA_growth) or FBA_growth < 0):
                FBA_growth = 0

            if opt_objective == "growth-cost":
                FBA_production = -1
                
            elif (opt_objective == "growth-production" or 
                  opt_objective == "production-cost" or
                  opt_objective == "growth-production-cost"):
                FBA_production = solution.fluxes[production_objective]
                if (np.isnan(FBA_production) or FBA_production < 0):
                    FBA_production = 0

            '''APPEND RESULTS TO TENSORS AND NORMALISE'''
            # medium lists
            medium_list.append(candidate_medium)
            medium_tensors_normalised.append(candidate_tensor_normalised)
            # growth
            FBA_growth_tensor = torch.tensor([FBA_growth], dtype=torch.double).to(**tkwargs)
            growth_tensors = torch.cat((growth_tensors, FBA_growth_tensor), dim = 0)  # Concatenate along dimension 0
            growth_tensors_normalised = normalise_1Dtensors(growth_tensors)
            # production
            FBA_production_tensor = torch.tensor([FBA_production], dtype = torch.double).to(**tkwargs)
            production_tensors = torch.cat((production_tensors, FBA_production_tensor), dim = 0)
            production_tensors_normalised = normalise_1Dtensors(production_tensors)
            # cost
            cost_tensors = torch.cat((cost_tensors, cost_tot), dim = 0)  # Concatenate along dimension 0 (1D tensors)
            cost_tensors_normalised = normalise_1Dtensors(cost_tensors) # new min-max normalisation
                
        if ((i+1)%10 == 0):
            print("Iteration:\t", i+1)

    '''FIND POINTS ON PARETO FRONT'''
    # Find all points on pareto front and return them     
    # Stack all (two/three) objectives into a single 2D tensor
    # rows: candidates; columns: objectives
    # is_non_dominated assumes maximisation -> negate costs
    if opt_objective == "growth-cost":
        y = torch.stack((growth_tensors, cost_tensors*(-1)), dim = 1)
    elif opt_objective == "growth-production":
        y = torch.stack((growth_tensors, production_tensors), dim = 1)
    elif opt_objective == "production-cost":
        y = torch.stack((production_tensors, cost_tensors*(-1)), dim = 1)
    elif opt_objective == "growth-production-cost":
        y = torch.stack((growth_tensors, production_tensors, cost_tensors*(-1)), dim = 1)

    # Compute non-dominated (Pareto front) points; i.e. optimal trade.offs
    is_pareto = is_non_dominated((y).to(**tkwargs))
    
    """
    print("Growth:", growth_tensors, 
          "Cost:", cost_tensors, 
          "Cost normalised:", cost_tensors_normalised,
          "Production:", production_tensors,
          "Production normalised:", production_tensors_normalised,
          sep = "\n")
    """
    return {
        "medium list" : medium_list, 
        "medium component bounds" : bounds,
        "medium component costs" : costs, 
        "growth rate tensors" : growth_tensors,
        "production tensors" : production_tensors, 
        "cost tensors" : cost_tensors,
        "is pareto" : is_pareto,
        "optimisation objective" : opt_objective,
        "biomass objective" : biomass_objective,
        "production objective" : production_objective,
        "model objective" : model_objective,
        "AF_type" : AF_type,
        "n_start" : n_start,
        "n_iter" : n_iter,
        "n_candidates" : n_candidates
        }
