In [None]:
# -*- coding: utf-8 -*-

"""Contains function to perform iMAT"""

from __future__ import absolute_import

import logging
import six
from optlang.symbolics import Zero, Add
from cobra.flux_analysis import flux_variability_analysis as fva

from driven.data_sets import ExpressionProfile

logger = logging.getLogger(__name__)


def imat(model, expression_profile, cutoff, epsilon=0.1, condition=None,
         normalization=or2min_and2max, fraction_of_optimum):
    """
    Integrative Metabolic Analysis Tool[1]

    Parameters
    ----------
    model: cobra.Model
        A constraint-based model to perform iMAT on.
    expression_profile: ExpressionProfile
        The expression profile.
    cutoff: 2-tuple of floats (low, high)
        The cut-off value tuple for expression values.
    epsilon: float
        Positive flux threshold.
    
    Returns
    -------
    IMATResult
    
    References
    ----------
    .. [1] Shlomi, Tomer & N Cabili, Moran & Herrgård, Markus & Ø Palsson, Bernhard & Ruppin, Eytan. (2008).
           Network-based prediction of human tissue-specific metabolism.
           Nature biotechnology. 26. 1003-10.
           doi:10.1038/nbt.1487. 
    """

    assert isinstance(model, cobra.Model)
    assert isinstance(expression_profile, ExpressionProfile)
    assert isinstance(cutoff, tuple)
    low_cutoff, high_cutoff = cutoff
    assert isinstance(low_cutoff, float)
    assert isinstance(high_cutoff, float)
    try:
        low_cutoff < high_cutoff
    except ValueError:

    condition = expression_profile.conditions[0] if condition is None else condition
    not_measured_value = 0 if not_measured_value is None else not_measured_value

    reaction_profile = expression_profile.to_reaction_dict(condition, model, not_measured_value, normalization)

#     y_vars = []
#     x_vars = []
    obj_vars = []
    consts = []
    
    with model:
        prob = model.probelem
        if objective is not None:
            model.objective = objective
        fva_res = fva(model, reactions=list(reaction_profile.keys()),
                      fraction_of_optimum=fraction_of_optimum)

        for rxn_id, rxn_exp in six.iteritems(reaction_profile):
            rxn = model.reactions.get_by_id(rxn_id)
            if rxn_exp > high_cutoff:
                y_pos = prob.Variable("y_%s_pos" % rxn_id, type="binary")
                y_neg = prob.Variable("y_%s_neg" % rxn_id, type="binary")
#                 y_vars.append([y_pos + y_neg])
                obj_vars.append([y_pos + y_neg])

                pos_const = model.problem.Constraint(
                    rxn.flux_expression + y_pos * (fva_res.at[rxn_id, "minimum"] - epsilon),
                    lb=fva_res.at[rxn_id, "minimum"], name="pos_highly_%s" % rxn_id)

                neg_const = model.problem.Constraint(
                    rxn.flux_expression + y_neg * (fva_res.at[rxn_id, "maximum"] + epsilon),
                    ub=fva_res.at[rxn_id, "maximum"], name="neg_highly_%s" % rxn_id)

                consts.extend([y_pos, y_neg, pos_const, neg_const])

            elif rxn_exp < low_cutoff:
                x = model.problem.Variable("x_%s" % rxn_id, type="binary")
#                 x_vars.append(x)
                obj_vars.append(x)

                pos_const = model.problem.Constraint(
                    (1 - x) * fva_res.at[rxn_id, "maximum"] - rxn.flux_expression,
                    lb=0, name="x_%s_upper" % rxn_id)

                neg_const = model.problem.Constraint(
                    (1 - x) * fva_res.at[rxn_id, "minimum"] - rxn.flux_expression,
                    ub=0, name="x_%s_lower" % rxn_id)

                consts.extend([x, pos_const, neg_const])
        
        model.add_cons_vars(consts)
        model.objective = prob.Objective(Zero, sloppy=True, direction="max")
        model.objective.set_linear_coefficients({v: 1.0 for v in obj_vars})
#         model.objective = prob.Objective(Add(*y_vars) + Add(*x_vars),
#                                          sloppy=True, direction="max")

        solution = model.optimize()
        
        return IMATResult(solution.fluxes, solution.objective_value, reaction_profile, low_cutoff, high_cutoff, epsilon)