# Dirichlet model fitting to nutrient productivity
Based on http://dirichletreg.r-forge.r-project.org/


In [1]:
# Import python packages
%matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pymc3 as pm
import theano as T
import ternary
import theano.tensor as tt
import seaborn as sns
import scipy as sp
import pdb
import os
import arviz as az
from matplotlib.patches import Ellipse, transforms
from itertools import combinations
#import ternary

# Helper functions
def indexall(L):
    poo = []
    for p in L:
        if not p in poo:
            poo.append(p)
    Ix = np.array([poo.index(p) for p in L])
    return poo,Ix

# Helper functions
def indexall_B(L,B):
    poo = []
    for p in L:
        if not p in poo:
            poo.append(p)
    Ix = np.array([poo.index(p) for p in L])
    a, b = poo.index(B), 0
    poo[b], poo[a] = poo[a], poo[b]
    
    Ix[Ix==b] = -1
    Ix[Ix==a] = 0
    Ix[Ix==-1] = a
    return poo,Ix

def subindexall(short,long):
    poo = []
    out = []
    for s,l in zip(short,long):
        if not l in poo:
            poo.append(l)
            out.append(s)
    return indexall(out)

match = lambda a, b: np.array([ b.index(x) if x in b else None for x in a ])
grep = lambda s, l: np.array([i for i in l if s in i])

# Function to standardize covariates
def stdize(x):
    return (x-np.mean(x))/(2*np.std(x))

# Coefficient of variation
cv =  lambda x: np.var(x) / np.mean(x)

# Posterior draws for each covariate, save as csv
def extract_con(alpha, varname):

    # covariate trace
    alphas = trace_dm[alpha].T

    # get posterior mean
    alpha_0 = alphas.sum(0).mean(1)
    Ex_alphas = alphas.mean(2)
    Ex = pd.DataFrame(Ex_alphas/alpha_0).T
    Ex.to_csv('fg/prod/prod_posterior_'+varname+'.csv', index=False)

    # get HPDI
    Ex_hpd = np.array([az.hdi(a) for a in alphas.transpose(0,2,1)])
    Ex_hpd_lo = pd.DataFrame([a[:,0]/alpha_0 for a in Ex_hpd]).T
    Ex_hpd_hi = pd.DataFrame([a[:,1]/alpha_0 for a in Ex_hpd]).T
    Ex_hpd_lo.to_csv('fg/prod/prod_posterior_'+varname+'_hpd_lo.csv', index=False)
    Ex_hpd_hi.to_csv('fg/prod/prod_posterior_'+varname+'_hpd_hi.csv', index=False)

Dataset are proportion of nutrient productivity by fish functional groups, for prod.mg

In [11]:
# Import data
nut = pd.read_csv("productivity_unscaled.csv")
nut.head()
# hnames

Unnamed: 0.1,Unnamed: 0,nutrient,nutrient_lab,country,site,year,hard_coral,macroalgae,turf_algae,bare_substrate,...,grav_nc,sediment,nutrient_load,pop_count,herbivore-detritivore,herbivore-macroalgae,invertivore-mobile,omnivore,piscivore,planktivore
0,1,calcium.mg,Calcium,Belize,CZFR1,2019,22.67,42.33,5.5,0.0,...,0.098334,0.0,0.0,1.309482,0.341334,0.100659,0.100964,0.0,0.144413,0.31263
1,2,calcium.mg,Calcium,Belize,CZFR2,2019,17.89,33.05,25.63,0.0,...,0.098334,0.0,0.0,1.309482,0.39319,0.428793,0.060765,0.0,0.075191,0.042062
2,3,calcium.mg,Calcium,Belize,CZFR3,2019,23.83,38.33,11.67,0.0,...,0.098334,0.0,0.0,1.309482,0.555626,0.059596,0.138386,0.0,0.121831,0.124561
3,4,calcium.mg,Calcium,Belize,CZPR1,2019,5.83,37.5,9.83,0.0,...,0.134825,0.0,2e-06,1.309482,0.299253,0.534477,0.051503,0.0,0.114767,0.0
4,5,calcium.mg,Calcium,Belize,CZPR4,2019,9.67,53.0,1.33,0.0,...,0.134825,0.0,2e-06,1.309482,0.129711,0.387577,0.080821,0.0,0.401891,0.0


In [12]:
# Import data
nut = pd.read_csv("productivity_unscaled.csv")
nut.head()
nut.management_rules = nut.country + nut.management_rules

# y = nut[["browser","cropper/grazer","invertivore-mobile","scraper-excavator","piscivore","planktivore", "mixed-diet feeder"]]
y = nut[["herbivore-detritivore","herbivore-macroalgae","invertivore-mobile","piscivore","planktivore", "omnivore"]]
# Grab fg names
hnames = list(y[:-1])
y = y.to_numpy()
y = y.round(2)+0.000001
# Make y's sum to 1
y = y/y.sum(axis=1,keepdims=1)
nfg = y.shape[1]

# export fitted df
y2=pd.DataFrame(y)
y2.to_csv('prod_focal.csv', index=False)

# # identify predictors
hc = stdize(nut.hard_coral).to_numpy()
ma = stdize(nut.macroalgae).to_numpy()
bs = stdize(nut.bare_substrate).to_numpy()
ta = stdize(nut.turf_algae).to_numpy()
rub = stdize(nut.rubble).to_numpy()
pop = stdize(nut.pop_count).to_numpy()
rt = pd.Categorical(nut.reef_type)
grav_nc = stdize(nut.grav_nc).to_numpy()
nl = stdize(nut.nutrient_load).to_numpy()
sed = stdize(nut.sediment).to_numpy()
dep = stdize(nut.depth).to_numpy()

## categorical levels
reef_type = list(np.sort(pd.unique(nut["reef_type"])))
rt = np.array([reef_type.index(x) for x in nut["reef_type"]])

reef_zone = list(np.sort(pd.unique(nut["reef_zone"])))
rz = np.array([reef_zone.index(x) for x in nut["reef_zone"]])

manage = list(np.sort(pd.unique(nut["management_rules"])))
mr = np.array([manage.index(x) for x in nut["management_rules"]])

# if manage nested in country, use subindexall
country,c = subindexall(nut["country"], nut["management_rules"])
# else
# country = list(np.sort(pd.unique(nut["country"])))
# c = np.array([country.index(x) for x in nut["country"]])

# site is almost n = 1 for all, so would not converge sensibly

Now build the model

In [4]:
country, c

(['Belize', 'Fiji', 'Madagascar', 'Solomon Islands'],
 array([0, 0, 1, 1, 2, 2, 2, 3, 3]))

In [5]:
coords={'reef_type':reef_type, 'country':country,'nfg':nfg}

with pm.Model(coords=coords) as BDM:
    intercept = pm.Normal('intercept', 0, 2, shape=nfg)
    
    # conts
    hard_coral = pm.Normal('hard_coral', 0, 0.5, shape=nfg)
    macroalgae = pm.Normal('macroalgae', 0, 0.5, shape=nfg)
    bare_sub = pm.Normal('bare_sub', 0, 0.5, shape=nfg)
    turf = pm.Normal('turf', 0, 0.5, shape=nfg)
    rubble = pm.Normal('rubble', 0, 0.5, shape=nfg)
    population = pm.Normal('population', 0, 0.5, shape=nfg)
    gravity = pm.Normal('gravity', 0, 0.5, shape=nfg)
    sediment = pm.Normal('sediment', 0, 0.5, shape=nfg)
    nut_load = pm.Normal('nut_load', 0, 0.5, shape=nfg)
    depth = pm.Normal('depth', 0, 0.5, shape=nfg)
    
    # cats
    reeftype_x = pm.Normal("reeftype_x", 0, 1, shape = (len(reef_type), nfg))
    reefzone_x = pm.Normal("reefzone_x", 0, 1, shape = (len(reef_zone), nfg))
#     manage_x = pm.Normal("manage_x", 0, 1, shape = (len(manage), nfg))
    
    # country nested in global intercept
    σ_c = pm.Exponential('Sigma_country', 1)
    β0_cnc = pm.Normal('β0_cnc', 0, 1, shape = (len(country), nfg))
    β0_c = pm.Normal('β0_c', intercept+β0_cnc*σ_c, shape = (len(country), nfg))
    
    # Site
#     σ_s = pm.Exponential('Sigma_site', 1)
#     β0_snc = pm.Normal('β0_snc', 0, 1, shape = (len(site), nfg))
#     β0_s = pm.Normal('β0_s', β0_c[c]+β0_snc*σ_c, shape = (len(site), nfg)) 

    # Management nested in country 
    σ_m = pm.Exponential('Sigma_manage', 1)
    β0_managenc = pm.Normal('β0_managenc', 0, 1, shape = (len(manage), nfg))
    β0_manage = pm.Normal('β0_manage', β0_c[c]+β0_managenc*σ_m, shape = (len(manage), nfg)) 

    α = pm.Deterministic('alpha', tt.exp(#intercept +
                                         β0_manage[mr, :None] +
#                                          β0_c[c, :None] +
                                         reeftype_x[rt,:None] + reefzone_x[rz, :None] +
#                                          manage_x[mr,:None] +
                                         hard_coral*hc[:,None]+macroalgae*ma[:,None]+
                                         bare_sub*bs[:,None]+turf*ta[:,None]+rubble*rub[:,None]+
                                         gravity*grav_nc[:,None] + population*pop[:,None]+
                                         sediment*sed[:,None] + nut_load*nl[:,None] +
                                         depth*dep[:,None]))
    Yi = pm.Dirichlet('Yi', α, observed=y)
    
    # Covariate predictions
    gravG = np.linspace(min(grav_nc),max(grav_nc), num=100)
    α2 = pm.Deterministic('alpha2', tt.exp(intercept.mean(0)+gravity*gravG[:,None]))
    
    coralG = np.linspace(min(hc),max(hc), num=100)
    α3 = pm.Deterministic('alpha3', tt.exp(intercept.mean(0)+hard_coral*coralG[:,None]))
    
    macroG = np.linspace(min(ma),max(ma), num=100)
    α4 = pm.Deterministic('alpha4', tt.exp(intercept.mean(0)+macroalgae*macroG[:,None]))
    
    bareG = np.linspace(min(bs),max(bs), num=100)
    α5 = pm.Deterministic('alpha5', tt.exp(intercept.mean(0)+bare_sub*bareG[:,None]))
    
    turfG = np.linspace(min(ta),max(ta), num=100)
    α6 = pm.Deterministic('alpha6', tt.exp(intercept.mean(0)+turf*turfG[:,None]))
    
    # manage is nested in country intercepts, so B0_manage covariate sample is the combined effect
    # B0_managenc covariate sample is the relative management effect
    manageG = np.linspace(0, 8, num = 9).astype(int)
    α7 = pm.Deterministic('alpha7', tt.exp(β0_manage[manageG,:None]))
    
    popG = np.linspace(min(pop),max(pop), num=100)
    α8 = pm.Deterministic('alpha8', tt.exp(intercept.mean(0)+population*popG[:,None]))
    
    sedG = np.linspace(min(sed),max(sed), num=100)
    α9 = pm.Deterministic('alpha9', tt.exp(intercept.mean(0)+sediment*sedG[:,None]))
    
    nlG = np.linspace(min(nl),max(nl), num=100)
    α10 = pm.Deterministic('alpha10', tt.exp(intercept.mean(0)+nut_load*nlG[:,None]))
    
    rubG = np.linspace(min(rub),max(rub), num=100)
    α11 = pm.Deterministic('alpha11', tt.exp(intercept.mean(0)+rubble*rubG[:,None]))
    
    # vectors of management type for each hard_coral gradient
    futureHC = np.repeat(coralG, 9)
    futureG = np.tile(manageG, 100)
    α12 = pm.Deterministic('alpha12', tt.exp(β0_manage[futureG,:None] + hard_coral*futureHC[:,None]))
    
    # nc manage and country effects
    manageGnc = np.linspace(0, 8, num = 9).astype(int)
    α13 = pm.Deterministic('alpha13', tt.exp(β0_managenc[manageG,:None]))
    
    countrync = np.linspace(0, 3, num = 4).astype(int)
    α14 = pm.Deterministic('alpha14', tt.exp(β0_cnc[c,:None]))



In [6]:
# manage, c, country, mr
country

['Belize', 'Fiji', 'Madagascar', 'Solomon Islands']

Note the `[:,None]` in the code is to broadcast the predictor measurements across all fish functional groups.

In [7]:
for RV in BDM.basic_RVs:
    print(RV.name, RV.logp(BDM.test_point))

intercept -9.672514282587708
hard_coral -1.3547481158683645
macroalgae -1.3547481158683645
bare_sub -1.3547481158683645
turf -1.3547481158683645
rubble -1.3547481158683645
population -1.3547481158683645
gravity -1.3547481158683645
sediment -1.3547481158683645
nut_load -1.3547481158683645
depth -1.3547481158683645
reeftype_x -27.56815599614019
reefzone_x -16.540893597684104
Sigma_country_log__ -1.0596601002984287
β0_cnc -22.054524796912148
β0_c -22.054524796912148
Sigma_manage_log__ -1.0596601002984287
β0_managenc -49.62268079305237
β0_manage -49.62268079305237
Yi_missing 0.0
Yi 60868.170017743054


In [8]:
with BDM:
    trace_dm = pm.sample()

  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...


SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'intercept': array([-0.97530238,  0.53393712, -0.90633195, -0.95718033, -0.03488871,
       -0.64781459]), 'hard_coral': array([ 0.78706868,  0.99552155,  0.95818798, -0.28658814,  0.88342019,
       -0.5819818 ]), 'macroalgae': array([-0.39649629,  0.0134038 ,  0.87049493, -0.68595199, -0.42438595,
        0.25431937]), 'bare_sub': array([ 0.4270202 , -0.00303172,  0.42935496,  0.78809352,  0.28548637,
        0.47551284]), 'turf': array([-0.89396864,  0.10164511,  0.55655767, -0.82087765,  0.76651729,
       -0.18971262]), 'rubble': array([-0.64938089,  0.39651056, -0.57114715, -0.38652568,  0.93577494,
        0.22970425]), 'population': array([0.97690705, 0.503683  , 0.66477365, 0.14223838, 0.25436439,
       0.26739764]), 'gravity': array([ 0.50962293,  0.1654385 , -0.01461738,  0.02052242,  0.65778682,
        0.86391285]), 'sediment': array([ 0.21599594, -0.1102446 , -0.14881663,  0.54724505,  0.92306615,
       -0.52197706]), 'nut_load': array([ 0.00738983, -0.67148202,  0.6673775 , -0.18766474,  0.5804526 ,
       -0.68939865]), 'depth': array([-0.59389603,  0.75861174,  0.96750689,  0.88628209,  0.47305258,
        0.03370344]), 'reeftype_x': array([[ 0.56488298, -0.4865105 , -0.66749513,  0.76107618,  0.30528202,
        -0.48240739],
       [ 0.62230694, -0.08842795,  0.07658347,  0.80094545,  0.50042824,
         0.54688607],
       [-0.61699776,  0.09874649, -0.7337129 ,  0.67354137, -0.55329366,
        -0.81262558],
       [ 0.96778888,  0.84281385, -0.90467781,  0.7219746 , -0.7521089 ,
        -0.65786811],
       [-0.86860934, -0.8065277 ,  0.77902835,  0.87408076, -0.4211348 ,
        -0.34991652]]), 'reefzone_x': array([[-0.23248806,  0.67582122, -0.15220647,  0.78308126, -0.78394979,
         0.88057424],
       [ 0.70965451, -0.03127182,  0.748333  ,  0.54493171,  0.6936497 ,
         0.06660313],
       [ 0.26513032,  0.40081166,  0.45991333,  0.43526768, -0.88264145,
         0.18455736]]), 'Sigma_country_log__': array(-1.34450756), 'β0_cnc': array([[-0.76950672, -0.06232219, -0.42021356, -0.57709725,  0.40562732,
         0.33780256],
       [-0.5021766 , -0.24213796, -0.36275682, -0.27575064, -0.22153816,
         0.17558156],
       [ 0.52020807, -0.03392551, -0.5084242 , -0.92391127,  0.77313131,
         0.54388559],
       [ 0.39477684, -0.11979494,  0.67193514, -0.31785168, -0.49148168,
         0.94109985]]), 'β0_c': array([[-0.16961466, -0.42482096, -0.00921238, -0.95788061,  0.18737795,
         0.87241897],
       [-0.50150597,  0.99520197, -0.74394633, -0.04033552, -0.72154176,
         0.88905743],
       [ 0.22468373, -0.47038904, -0.15565621, -0.7610944 , -0.89544441,
         0.94674896],
       [-0.13569253, -0.38875877, -0.67368443, -0.66412422, -0.02816509,
         0.38274433]]), 'Sigma_manage_log__': array(-1.07594023), 'β0_managenc': array([[-0.11871316,  0.32903416,  0.3462266 , -0.78299819,  0.3759836 ,
        -0.96737866],
       [ 0.26102716,  0.83773131,  0.43868699, -0.58892063, -0.76867176,
        -0.58192198],
       [ 0.3737681 , -0.37807928, -0.83096127,  0.00220233, -0.27057515,
        -0.46660563],
       [-0.44727287, -0.4654652 ,  0.17181193,  0.97771923, -0.78241612,
        -0.68010898],
       [-0.18780906, -0.09824308, -0.94162904,  0.02900902, -0.38598318,
        -0.48998714],
       [ 0.05198607, -0.69824975, -0.2244989 ,  0.33651667,  0.03231382,
        -0.61609953],
       [-0.79583397, -0.88506214, -0.37645853,  0.54672311,  0.35463887,
        -0.89025147],
       [-0.88196622,  0.72201488,  0.50929964, -0.74449922,  0.68164679,
        -0.61369938],
       [-0.07726745, -0.05494474, -0.45186817,  0.01176519,  0.27299274,
        -0.42361816]]), 'β0_manage': array([[ 0.69492957,  0.51312498, -0.03989004,  0.01207116,  0.49093305,
         0.76666915],
       [ 0.37625072, -0.61177985, -0.19374529,  0.76563169, -0.01409508,
        -0.36016922],
       [ 0.32433132, -0.61773456, -0.79779355,  0.18909128, -0.99799441,
        -0.74443808],
       [ 0.57174688, -0.2414321 , -0.99113469,  0.27056134,  0.26635132,
        -0.68872951],
       [ 0.51416782,  0.24302141,  0.68663831, -0.73672623,  0.98147117,
        -0.77804534],
       [ 0.12939172, -0.17750656, -0.47472773,  0.70762923,  0.62635342,
        -0.31910196],
       [-0.8662859 ,  0.19913624,  0.54586012,  0.91641494, -0.31912741,
         0.66806451],
       [-0.7190387 , -0.09614016,  0.99499422, -0.63107796,  0.79928687,
        -0.93047209],
       [ 0.02550764,  0.26303364, -0.24092932, -0.66868806, -0.09040428,
        -0.83958573]]), 'Yi_missing': array([-0.8773193 , -0.96031713,  0.01482709, ...,  0.87147426,
       -0.13208981, -0.02141747])}

Initial evaluation results:
intercept             -10.10
hard_coral             -8.81
macroalgae             -4.62
bare_sub               -3.95
turf                   -6.19
rubble                 -5.32
population             -4.97
gravity                -4.29
sediment               -4.36
nut_load               -4.84
depth                  -7.10
reeftype_x            -34.02
reefzone_x            -19.46
Sigma_country_log__    -1.61
β0_cnc                -25.11
β0_c                  -30.93
Sigma_manage_log__     -1.42
β0_managenc           -57.86
β0_manage             -71.82
Yi_missing              0.00
Yi                      -inf
Name: Log-probability of test_point, dtype: float64

In [None]:
pm.summary(trace_dm, var_names=['β0_manage'], filter_vars="regex")

In [None]:
pm.summary(trace_dm, var_names=['~^alpha'], filter_vars="regex")

In [None]:
# Export summary stats
tmp = pm.summary(trace_dm, var_names=['~^Sigma', '~^alpha'], hdi_prob=0.95, filter_vars="regex")
varnames = np.array(list(tmp.index), dtype=object)
varnames[match(grep('β0_c',list(varnames)),list(varnames))] = np.array(list(np.repeat(country, nfg))*2)
varnames[match(grep('β0_manage',list(varnames)),list(varnames))] = np.array(list(np.repeat(manage, nfg))*2)
varnames[match(grep('reeftype_x',list(varnames)),list(varnames))] = np.repeat(reef_type, nfg)
varnames[match(grep('reefzone_x',list(varnames)),list(varnames))] = np.repeat(reef_zone, nfg)
# varnames[match(grep('manage_x',list(varname§s)),list(varnames))] = np.repeat(manage, nfg)

tmp['varname'] = list(varnames)
tmp['fg']=int(len(tmp)/nfg)*hnames

tmp.to_csv('fg/prod/prod_posterior_summary.csv')

tmp = pm.summary(trace_dm, var_names=['~^Sigma', '~^alpha'], hdi_prob=0.5, filter_vars="regex")
varnames = np.array(list(tmp.index), dtype=object)
varnames[match(grep('β0_c',list(varnames)),list(varnames))] = np.array(list(np.repeat(country, nfg))*2)
varnames[match(grep('β0_manage',list(varnames)),list(varnames))] = np.array(list(np.repeat(manage, nfg))*2)
varnames[match(grep('reeftype_x',list(varnames)),list(varnames))] = np.repeat(reef_type, nfg)
varnames[match(grep('reefzone_x',list(varnames)),list(varnames))] = np.repeat(reef_zone, nfg)
# varnames[match(grep('manage_x',list(varname§s)),list(varnames))] = np.repeat(manage, nfg)

tmp['varname'] = list(varnames)
tmp['fg']=int(len(tmp)/nfg)*hnames

tmp.to_csv('fg/prod/prod_posterior_summary_50.csv')

In [None]:
# Grab expected alphas
alpha = trace_dm['alpha'].T
alpha_0 = alpha.sum(0).mean(1)
Ex_alphas = alpha.mean(2)
Ex = Ex_alphas/alpha_0

## Predicted vs. observed
[plt.scatter(ei,yi,label=l) for yi,l,ei in zip(y.T,hnames,Ex)]
plt.ylabel('Observed'),plt.xlabel('Predicted')
plt.legend(loc=(1.04,0));

In [None]:
# extract posterior dists for covariates
out = pm.trace_to_dataframe(trace_dm)
varnames = np.array(list(out), dtype=object)
varnames[match(grep('β0_c',list(varnames)),list(varnames))] = np.array(list(np.repeat(country, nfg))*2)
varnames[match(grep('reeftype_x',list(varnames)),list(varnames))] = np.repeat(reef_type, nfg)
varnames[match(grep('reefzone_x',list(varnames)),list(varnames))] = np.repeat(reef_zone, nfg)
# varnames[match(grep('manage_x',list(varnames)),list(varnames))] = np.repeat(manage, nfg)
varnames[match(grep('β0_manage',list(varnames)),list(varnames))] = np.array(list(np.repeat(manage, nfg))*2)

out.columns=varnames
out.to_csv('fg/prod/prod_posterior_trace.csv', index=False)

In [None]:
## posterior predictive distribution
with BDM:
    ppc = pm.sample_posterior_predictive(
        trace_dm, random_seed=43
    )

with BDM:
    az.plot_ppc(az.from_pymc3(posterior_predictive=ppc))

In [None]:
## extract posterior predicted 
extract_con(alpha='alpha2', varname = 'gravity')
extract_con(alpha='alpha3', varname = 'hard_coral')
extract_con(alpha='alpha4', varname = 'macroalgae')
extract_con(alpha='alpha5', varname = 'bare_substrate')
extract_con(alpha='alpha6', varname = 'turf')
extract_con(alpha='alpha7', varname = 'manage')
extract_con(alpha='alpha8', varname = 'pop')
extract_con(alpha='alpha9', varname = 'sediment')
extract_con(alpha='alpha10', varname = 'nut_load')
extract_con(alpha='alpha11', varname = 'rubble')
extract_con(alpha='alpha12', varname = 'future_hc')
extract_con(alpha='alpha13', varname = 'manage_nc')
extract_con(alpha='alpha14', varname = 'country')