# Jupyter Notebook example for MERRIN

In [None]:
# ==============================================================================
# Imports
# ==============================================================================
from typing import Literal
from pandas import DataFrame

from merrin import MerrinLearner, Observation, MetabolicNetwork

## Inputs

### Mandatory

In [None]:
# ~ Selection of the instance
instance: Literal['toy', 'core-regulated', 'large-scale'] = 'core-regulated'
objective: str = 'Growth'

In [None]:
# ~ Files describing the selected instance
sbml: str = f'instances/{instance}/metabolic_network.sbml'
pkn: str = f'instances/{instance}/pkn.txt'
observations_json: str = f'instances/{instance}/timeseries_kft.json'

### Optional

In [None]:
# ~ Solving modes
# Select the projection mode:
#   - `network`: enumerate all regulatory networks compatible with observations
#   - `node`: for each node, enumerate all rules compatible with observations
#   - `trace`: enumerate all classes of regulatory networks compatible
#              with observations
projection_mode: Literal['network', 'node', 'trace'] = 'network'
# Only enumerate subset minimal rules or regulatory networks
subset_minimal_optimisation: bool = False

# ~ Solving parameters
lpsolver: Literal['glpk', 'gurobi'] = 'glpk' # LP solver to use, default: `glpk`
nbsol: int = 0 # 0 to enumerate all solution, else the nb of solution to enum
timelimit: float = -1 # -1 if not timelimit, else the timelimit value in second
max_gap: int = 10 # maximum number of timestep than can be added
max_error: float = 0.1 # maximum error rate between observations and predictions
max_clause: int = 20 # maximum number of clauses per rules in DNF

# ~ Optional parameters
display: bool = False # Display the learned rules/BNs at runtime

## Preprocessing

### Parse the Prior Knowlege Network file

In [None]:
parsed_pkn: list[tuple[str, int, str]] = []
with open(pkn, 'r', encoding='utf-8') as file:
    for line in file.readlines():
        line = line.strip()
        u, s, v = line.split('\t')
        parsed_pkn.append((u, int(s), v))

### Parse the JSONs file describing the observations 

In [None]:
observations: list[Observation] = Observation.load_json(observations_json)

### Parse the SBML

In [None]:
mn: MetabolicNetwork = MetabolicNetwork.read_sbml(sbml)

## MERRIN

In [None]:
learner: MerrinLearner = MerrinLearner()
learner.load_instance(mn, objective, parsed_pkn, observations)

rules_df: DataFrame

In [None]:
# ~ Learn all the Boolean networks
if projection_mode == 'network':
    # ~ Learn the rule
    bns: list[list[tuple[str, str]]] = learner.learn(
        nbsol=nbsol, display=display, lp_solver=lpsolver, max_clause=max_clause,
        max_error=max_error, max_gap=max_gap, timelimit=timelimit,
        subsetmin=subset_minimal_optimisation
    )
    # ~ Post-processing: format the results into a pandas DataFrame
    rules_df = DataFrame([dict(bn) for bn in bns]).fillna('1')

In [None]:
# ~ Learn all the rules per nodes of the PKN
if projection_mode == 'node':
    # ~ Learn the rule
    rules: dict[str, list[str]] = learner.learn_per_node(
        nbsol=nbsol, display=display, lp_solver=lpsolver, max_clause=max_clause,
        max_error=max_error, max_gap=max_gap, timelimit=timelimit,
        subsetmin=subset_minimal_optimisation
    )
    # ~ Post-processing: format the results into a pandas DataFrame
    max_length = max(len(values) for _, values in rules.items())
    padded_rules = {
        col: values + [''] * (max_length - len(values))
        for col, values in rules.items()
    }
    rules_df = DataFrame(padded_rules)

In [None]:
# ~ Learn all the classes of BNs grouped per equivalent rFBA traces
if projection_mode == 'trace':
    # ~ Learn the rule
    rules: list[dict[str, list[str]]] = learner.learn_per_trace(
        nbsol=nbsol, display=display, lp_solver=lpsolver, max_clause=max_clause,
        max_error=max_error, max_gap=max_gap, timelimit=timelimit,
        subsetmin=subset_minimal_optimisation
    )
    # ~ Post-processing: format the results into a pandas DataFrame
    format_rules: list[dict[str, str]] = [
        { node: ';'.join(sorted(node_rules))
         for node, node_rules in compress_bns.items()
        }
        for compress_bns in rules
    ]
    rules_df = DataFrame(format_rules)

## Results

In [None]:
rules_df.sort_index(axis=1).sort_index(axis=0)