In [None]:
import pydeb
import numpy

In [None]:
# Obtain Catalogue of Life identifier for the taxon of interest
# This function (get_ids) will return multiple matches if the provided name is not unique.
pydeb.infer.get_ids('Artemia urmiana')

In [None]:
# Obtain the prior distribution of DEB parameters through phylogenetic inference
taxon = pydeb.infer.Taxon.from_col_id('eeabc9e772823a64ab9af8110d03f4f6')
prior = taxon.infer_parameters(offline_db='C:/Users/jornb/OneDrive/Code/debweb/data/20200514', add_del_M=True)

In [None]:
import yaml
with open('obs/AghVSta2008b_stats.yaml') as f:
    AghVSta2008b_stats = yaml.safe_load(f)
with open('obs/AghVSta2008b.yaml') as f:
    AghVSta2008b = yaml.safe_load(f)
obs2target = {
    'total_offspring_per_female': 'N_R / (S_p + 1e-14)',
    'prereproductive_period': 'a_p - a_b',
    'lifespan': 'a_p - a_b + (a_m - state_at_time(a_p)[\'a\']) / (S_p + 1e-14)'
}


In [None]:
# Body temperature to use for prior information and simulations
T = AghVSta2008b_stats[0]['temperature'] #taxon.typical_temperature

sampler = pydeb.calibrate.MCMCSampler()
sampler.likelihood.add_parameters(prior.names, prior.mean, prior.cov, prior.inverse_transforms)

ip_M = prior.names.index('p_M')
sampler.likelihood.add_parameters('salt_threshold', 100, 100)
sampler.likelihood.add_parameters('p_T_at_s300', prior.mean[ip_M], prior.cov[ip_M, ip_M], numpy.exp)
sampler.likelihood.add_parameters('p_T_ref', prior.mean[ip_M], prior.cov[ip_M, ip_M], numpy.exp)

class SaltDependentModel(pydeb.Model):
    def __init__(self, salinity, **params):
        super().__init__()
        for name in prior.names:
            setattr(self, name, params[name])
        p_T_ref = params['p_T_ref']
        p_T_at_s300 = params['p_T_at_s300']
        self.p_T  = p_T_ref + max(salinity - params['salt_threshold'], 0.) * (p_T_at_s300 - p_T_ref) / (300. - params['salt_threshold'])

class SalinityDependent(pydeb.calibrate.likelihood.Component):
    def __init__(self, salinity):
        super().__init__()
        self.salt = salinity

    def evaluate(self, values) -> float:
        model = SaltDependentModel(self.salt, **values)
        model.initialize()
        values['model_%i' % self.salt] = model
        if model.valid:
            return super().evaluate(values | {'model': model})

salt2comp = {}
def get_component(salinity):
    if salinity not in salt2comp:
        salt2comp[salinity] = SalinityDependent(salinity=salinity)
        if salinity < 100:
            salt2comp[salinity].add_child(pydeb.calibrate.likelihood.ImpliedProperty('L_i/del_M', 1.5, 0.1, temperature=info['temperature']))
        if salinity < 400:
            sampler.likelihood.add_child(salt2comp[salinity])
    return salt2comp[salinity]

for info in AghVSta2008b_stats:
    comp = get_component(info['salinity'])

    # Add observed life history traits (ultimate structural length and reproduction rate)
    for obsname, target in obs2target.items():
        comp.add_child(pydeb.calibrate.likelihood.ImpliedProperty(target, info[obsname], info['%s_sd' % obsname], temperature=info['temperature']))

    comp.add_child(pydeb.calibrate.likelihood.ImpliedProperty('L_b/del_M', 0.05, 0.01, temperature=info['temperature']))

for info in AghVSta2008b:
    comp = get_component(info['salinity'])

    ts = pydeb.calibrate.likelihood.TimeSeries(info['time'][1:], temperature=info['temperature'], offset='a_b')
    if info['salinity'] < 1200:
        ts.add_series('S', info['survival'][1:], sd=info['survival_sd'][1:])
    ts.add_series('L/del_M', 0.1 * numpy.array(info['length'])[1:], sd=0.1 * numpy.array(info['length_sd'])[1:])
    comp.add_child(ts)

bar = pydeb.calibrate.ProgressBar()
display(bar.widget)
samples = sampler.get_samples(100000, progress_reporter=bar, nburn=100000)

In [None]:
for name in sampler.likelihood.get_combined_prior()[0]:
    values = [sample[name] for sample in samples]
    perc025, perc500, perc975 = numpy.percentile(values, (2.5, 50, 97.5))
    print('%s = %.4g (%.4g - %.4g)' % (name, perc500, perc025, perc975))

best_lnl = -numpy.inf
for sample in samples:
    if sample['lnl'] > best_lnl:
        best_sample, best_lnl = sample, sample['lnl']
print('Ln likelihood = %s' % best_lnl)

import scipy.optimize
parnames = sampler.likelihood.get_combined_prior()[0]
def f(x):
    lnl = sampler.likelihood.evaluate(dict(zip(parnames, x)))
    return numpy.inf if lnl is None else -lnl
opt = scipy.optimize.minimize(f, [best_sample['mu'][n] for n in parnames], method='Nelder-Mead')
opt_sample = dict(zip(parnames, opt['x']))
opt_lnl = sampler.likelihood.evaluate(opt_sample)
print(opt_sample, opt_lnl) 


In [None]:
bar = pydeb.calibrate.ProgressBar()
display(bar.widget)
outputs = ['L/del_M', 'S', 'R', 'cumR']
salt2result = {}
for info in AghVSta2008b:
    salt = int(info['salinity'])
    if salt not in salt2result:
        r = pydeb.simulate_ensemble([sample['model_%i' % salt] for sample in samples], outputs, T=T, t_end=90., progress_reporter=bar, t_offset='a_b')
        stats = {}
        for key, values in r.items():
            stats[key]  = values if key == 't' else numpy.percentile(values, [2.5, 25, 50, 75, 97.5], axis=0)
        salt2result[salt] = stats

In [None]:
from matplotlib import pyplot

def add_series(ax, result, name, title=None, scale_factor=1., color='C0', label=None, opt_res=None):
    if title is None:
        title = '%s (%s)' % (pydeb.long_names.get(name, name), pydeb.units.get(name, '?'))
    p025, p250, p500, p750, p975 = [v * scale_factor for v in result[name]]
    time = numpy.array(result['t'])
    ax.fill_between(time, p025, p975, alpha=0.3, color=color)
    ax.fill_between(time, p250, p750, alpha=0.3, color=color)
    ax.plot(time, p025, '-', color=color, lw=.2)
    ax.plot(time, p975, '-', color=color, lw=.2)
    ax.plot(time, p250, '-', color=color, lw=.2)
    ax.plot(time, p750, '-', color=color, lw=.2)
    ax.plot(time, p500, '-', color=color, label=label)
    if opt_res is not None:
        ax.plot(opt_res[0], opt_res[1], '-', color='k')
        ax.plot(opt_res[0], opt_res[1], '--', color=color)
    ax.set_ylabel(title)
    ax.grid(True)
    ax.set_xlim(0, 90)

def create_plot(ax, name, title=None, scale_factor=1., obs_name=None):
    for i, salt in enumerate(sorted(salt2result)):
        color = 'C%i' % i

        opt_res = {}
        m = opt_sample['model_%i' % salt]
        c_T = m.get_temperature_correction(T)
        a_b = m.evaluate('a_b', c_T=c_T)
        t = 90 + a_b
        dt = 0.1
        r = m.simulate(int(t / dt), dt, c_T=c_T)
        opt_res = (r['t'] - a_b, scale_factor * m.evaluate(name, c_T=c_T, locals=r))

        add_series(ax, salt2result[salt], name, title, scale_factor=scale_factor, color=color, label='S=%i' % salt, opt_res=opt_res)
        if obs_name is not None:
            for info in AghVSta2008b:
                if int(info['salinity']) == salt:
                    break
            ax.errorbar(info['time'], info[obs_name], info['%s_sd' % obs_name], fmt='ow', mec=color, ecolor=color)
    ax.legend()


#for name in sim.selected_outputs:
#    data, layout = create_plot(result, name)
#    fig = go.Figure(data=data, layout=layout)
#    py.iplot(fig)
fig, (axgrowth, axsurv, axrep, axcumr) = pyplot.subplots(nrows=4, figsize=(8, 15), sharex=True)
create_plot(axgrowth, 'L/del_M', 'length (mm)', scale_factor=10, obs_name='length')
create_plot(axsurv, 'S', 'survival (1)', obs_name='survival')
create_plot(axrep, 'R', 'reproduction rate (# d-1)')
create_plot(axcumr, 'cumR', 'lifetime reproductive output (#)')
axcumr.set_xlabel('time since hatching (d)');


In [None]:
targets = set(obs2target.values()) | {'N_R', 'a_m'}
bar = pydeb.calibrate.ProgressBar()
display(bar.widget)
salts = numpy.linspace(0, 300, 100)
target2default = {'a_p - a_b': numpy.inf}
target2values = dict([(target, numpy.zeros((len(samples), salts.size))) for target in targets])
for i, sample in enumerate(samples):
    bar(i / len(samples), '')
    for isalt, salt in enumerate(salts):
        model = SaltDependentModel(salt, **sample)
        model.initialize()
        for target in targets:
            target2values[target][i, isalt] = target2default.get(target, 0.) if not model.valid else model.evaluate(target, c_T=model.get_temperature_correction(T))

In [None]:
target2optvalues = dict([(target, numpy.zeros((salts.size,))) for target in targets])
for isalt, salt in enumerate(salts):
    model = SaltDependentModel(salt, **opt_sample)
    model.initialize()
    for target in targets:
        target2optvalues[target][isalt] = target2default.get(target, 0.) if not model.valid else model.evaluate(target, c_T=model.get_temperature_correction(T))

In [None]:
target2title = {
    'N_R': 'lifetime reproductive output (#)',
    'a_m': 'lifespan (d)',
    obs2target['total_offspring_per_female']: 'total offspring per mature female (#)',
    obs2target['prereproductive_period']: 'prereproductive period (d)',
    obs2target['lifespan']: 'expected lifespan of mature individual (d)'}
target2range = {obs2target['lifespan']: (0, 90), obs2target['prereproductive_period']: (0, 50)}
color = 'C0'
for target, salt2values in target2values.items():
    p025, p250, p500, p750, p975 = numpy.percentile(salt2values, (2.5, 25, 50, 75, 97.5), axis=0)
    fig, ax = pyplot.subplots()
    ax.fill_between(salts, p025, p975, alpha=0.3, color=color)
    ax.fill_between(salts, p250, p750, alpha=0.3, color=color)
    ax.plot(salts, p500, color=color)
    ax.grid(True)

    ax.plot(salts, target2optvalues[target], '-k')
    ax.plot(salts, target2optvalues[target], '--', color=color)

    ax.set_xlim(salts[0], salts[-1])
    ax.set_xlabel('salinity')
    ax.set_ylabel(target2title.get(target, target))
    ylim = target2range.get(target, (None, None))
    for obsname, obstarget in obs2target.items():
        if obstarget == target:
            for info in AghVSta2008b_stats:
                ax.errorbar(info['salinity'], info[obsname], info['%s_sd' % obsname], fmt='ow', mec='k', ecolor='k')
    ax.set_ylim(*ylim)
        

