In [None]:
import numpy as np
import xsimlab as xs
import pandas as pd
import igraph as ig
from scipy.sparse import csgraph
import vmlab

In [None]:
from vmlab.models import vmango
from vmlab.processes.topology import Topology

In [None]:
tree = pd.DataFrame({
    'parent_id': [np.nan, 0, 1, 1, 1, 1],
    'id': [0, 1, 2, 3, 4, 5],
    'topology__is_apical': [1, 1, 1, 0, 0, 0]
})
in_g = ig.Graph.DataFrame(tree.dropna())
ig.plot(in_g, bbox=(0,0,150,150), layout=in_g.layout_reingold_tilford())

In [None]:
@xs.process
class TopologyPruning(Topology):
    
    pruned = xs.variable(dims='GU', intent='out')
    
    def all_descendants(self, parent, children=[]):
        for child in np.flatnonzero(self.adjacency[parent, :] == 1.):
            children = np.append(np.array(children, dtype=np.int64), self.all_descendants(child, [child]))
        return children
        
    @xs.runtime(args=('nsteps', 'step_start'))
    def initialize(self, nsteps, step_start):
        
        super().initialize(nsteps, step_start)
        self.pruned = np.zeros(self.GU.shape, dtype=np.float32)
        
    @xs.runtime(args=('step', 'step_start', 'nsteps'))
    def run_step(self, step, step_start, nsteps):
        
        super().run_step(step, step_start, nsteps)
        
        # just a simple rule: prune children of a lateral parent with 2 descendants (parent is not removed)
        prune = (self.is_apical == 0.) & (self.nb_descendants == 3) & (self.pruned == 0.)
        if np.any(prune):
            all_descendants = np.array([], dtype=np.int64)
            for parent in np.flatnonzero(prune):
                all_descendants = np.unique(np.append(all_descendants, self.all_descendants(parent)))
            self.pruned[prune] = 1.
            self.pruned[all_descendants] = 1.
            self.adjacency[self.pruned == 1, :] = 0.
            self.distance = csgraph.shortest_path(csgraph.csgraph_from_dense(self.adjacency)).astype(np.float32)
            self.bursted[:] = 0.
            # rebuild entire topology
            self.lstring = self.lsystem.derive(self.lsystem.axiom, 0, int(np.max(self.distance[np.isfinite(self.distance)])))
    

In [None]:
vmango_pruned = vmango.update_processes({'topology': TopologyPruning})

In [None]:
setup = vmlab.create_setup(
    model=vmango_pruned,
    tree=tree,
    start_date='2003-06-01',
    end_date='2005-06-01',
    setup_toml='vmango.toml',
    input_vars={
        'geometry__interpretation_freq': 1
    },
    output_vars={}
)

In [None]:
vmlab.run(setup, vmango_pruned, geometry=True)

In [None]:
# unpruned
vmlab.run(
    vmlab.create_setup(
        model=vmango,
        tree=tree,
        start_date='2003-06-01',
        end_date='2005-06-01',
        setup_toml='vmango.toml',
        input_vars={
            'geometry__interpretation_freq': 1
        },
        output_vars={}
    ),
    vmango,
    geometry=True
)