## PyMC3 Implementation of Chempy
First create a basic Chempy implementation in Chempy for a single end-time and single star

In [1]:
%pylab inline
import scipy,theano,os,inspect,platform
import pymc3 as pm
import time as ttime
import numpy.lib.recfunctions as rcfuncs
import numpy as np
import theano.tensor as tt

## TESTING PARAMETERS
full_checking = True # if True, check assertions in code (will be slower)
sample_yields = False # if False, yields are set to mean parameters

# Import useful Chempy functions
os.chdir('/home/oliverphilcox/ChempyMulti/')
from Chempy.solar_abundance import solar_abundances as SolarAbundances
from Chempy.infall import PRIMORDIAL_INFALL as PrimordialInfall

# Define localpath
if platform.system()=='Windows':
    string='\\'
else:
    string='/'

localpath = string.join(os.path.abspath(inspect.getfile(inspect.currentframe())).split(string)[:-1])
localpath += string+'Chempy'+string

if full_checking:
    # Define incomplete gamma function as theano operator
    from theano.compile.ops import as_op
    from scipy.special import gammainc

    @as_op(itypes=[tt.dscalar,tt.dscalar],
        otypes=[tt.dscalar])

    def gammainc_th(shape_par,x):
        return gammainc(shape_par,x)


Populating the interactive namespace from numpy and matplotlib


In [2]:

### MODELLING PARAMETERS

class Parameters(object):
    """This holds various parameters for the PyMC3 Chempy implementation"""
    
    class general(object):
        # General parameters
        elements_to_trace = ['C', 'Fe', 'H', 'He', 'Mg', 'N', 'Ne', 'O', 'Si']
        gas_reservoir_mass_factor = 1.  # how much more mass does the corona has compared to the integrated SFR
        gas_power = 1. # The Schmidt_exponent (default = 1, i.e. linear)
        sfr_factor_for_cosmic_accretion = 0.000001 # how much more gas should be infalling in the corona compared to the SFR

        # Derived parameters
        n_el = len(elements_to_trace) # number of elements
        # Get indices of elements which are metals and what is H:
        metal_list=[]
        for e,el in enumerate(elements_to_trace):
            if el not in ['H','He']:
                metal_list.append(e)
            if el=='H':
                H_index = e
            if el=='Fe':
                Fe_index = e
        metal_list = tt._shared(np.array(metal_list))

    class imf(object):
        # IMF related parameters
        min_mass = 0.1
        max_mass = 100.
        n_mass_steps = 20000 # number of mass steps used

        # Chabrier parameters
        chabrier_A = 0.852464
        chabrier_B = 0.237912
        chabrier_sigma = 0.69
        chabrier_m_c = 0.079 

    class sfr(object):
        # SFR related parameters
        a_parameter = 2 # gamma function shape parameter

    class times(object):
        # Chempy start and end time in Gyr
        start = 0
        end = 13.797617 # Planck 2015 age
        time_steps = 28 # number of time steps in simulation

    class yields(object):
        # Parameters controlling priors on nucleosynthetic yield parameters
        # Format is (intercept, mass coefficient, log_10(Z) coefficient) for each element and data-set
        
        # Use linear fits to TNG yield data here for testing
        data = np.load('TNG_Linear_Yield_Fits.npz')
        agb_yield_mean = data['agb'][:-1]
        sn2_yield_mean = data['sn2'][:-1]
        sn1a_yield_mean = data['sn1a'][:-1].reshape(-1,1) # for 1-parameter sn1a model
        
        agb_yield_std = np.abs(data['agb'][:-1])*0.01#np.asarray([0.1,0.001,0.001]) # prior standard deviation (in mass fraction)
        sn2_yield_std = np.abs(data['sn2'][:-1])*0.01#np.asarray([0.1,0.001,0.001]) 
        sn1a_yield_std = np.abs(data['sn1a'][:-1]+1.0e-3)*0.01#np.asarray([0.1]) 

        # Also define parameters for remnant mass fraction in same format
        agb_remnant_mean = data['agb'][-1]
        sn2_remnant_mean = data['sn2'][-1]
        sn1a_remnant_mean = data['sn1a'][-1]
        
        agb_remnant_std = np.abs(data['agb'][-1])*0.01#np.asarray([0.2,0.01,0.01]) # prior standard deviation
        sn2_remnant_std = np.abs(data['sn2'][-1])*0.01#np.asarray([0.2,0.01,0.01]) 
        sn1a_remnant_std = np.abs(data['sn1a'][-1])*0.01#np.asarray([0.2]) 

    class sn2(object):
        # Define SN2 mass ranges
        min_mass = 8.
        max_mass = 100.

    class agb(object):
        # Define AGB mass ranges
        min_mass = 1.
        max_mass = 8.

    class sn1a(object):
        # Define Maoz (2012) DTD parameters
        tau_8 = 0.04 # this is the SN1a time delay used in TNG
        s_exponent = 1.12
        
    class Observations(object):
        # This is accessible by self.observations initialized later.
        # Slightly clunky creation but not a time-limiting step
        def __init__(self,elements):
            
            ## Load observational data (here proto-solar)
            dat=np.load('Chempy/input/stars/Proto-sun_all.npy') 

            observational_abundances = []
            observational_errors = []
            for el in elements:
                if el=='H':
                    continue
                observational_abundances.append(dat[el][0])
                observational_errors.append(dat[el][1])
            assert len(observational_abundances)==len(elements)-1 # ignoring H

            self.star_age = dat['age'][0]

            self.abundances = tt._shared(np.asarray(observational_abundances))
            self.errors = tt._shared(np.asarray(observational_errors))

    def __init__(self):
        self.observations = self.Observations(self.general.elements_to_trace)

In [3]:

### PRELOADING

# Load parameters
par = Parameters()

class IMF(object):
    # Preload IMF from Chabrier (2003)

    def __init__(self):
        masses = np.linspace(par.imf.min_mass,par.imf.max_mass,par.imf.n_mass_steps)
        masses_low = masses[masses<1.]
        masses_high = masses[masses>1.]

        # Define lower end of IMF (in Numpy first)
        dn_low = par.imf.chabrier_A/masses_low*np.exp(-(np.log10(masses_low/par.imf.chabrier_m_c)**2.)/(2.*par.imf.chabrier_sigma**2.))
        self.dn_low = tt._shared(dn_low) # convert to tensors
        self.masses = tt._shared(masses)
        self.masses_high = tt._shared(masses_high)
        self.masses_low = tt._shared(masses_low)

    def initialize_slope(self,alpha_imf):
        # Compute the IMF given the high-mass slope (needs to be done inside PyMC3 model)
        self.dn_high = par.imf.chabrier_B*self.masses_high**alpha_imf
        self.dn = tt.concatenate([self.dn_low,self.dn_high])*self.masses  # number of stars produced per unit mass
        self.dm = self.dn/self.dn.sum() # mass of stars produced per unit mass
        self.dn = self.dm/self.masses

    def imf_mass_fraction(self,mlow,mup):
        # This determines the mass fraction of the IMF between mlow and mup
        mup = tt.switch(tt.isnan(mup),par.imf.max_mass,mup) # remove any NaN values from upper limit
        cut=tt.and_(self.masses>=mlow,self.masses<mup)
        fraction = self.dm[cut].sum()
        return fraction

    def imf_number_fraction(self,mlow,mup):
        # This determines the number fraction of the IMF between mlow and mup
        mup = tt.switch(tt.isnan(mup),par.imf.max_mass,mup) # remove any NaN values from upper limit
        cut=tt.and_(self.masses>=mlow,self.masses<mup)
        fraction = self.dn[cut].sum()
        return fraction

class SFR(object):
    # Preload star formation rate (here a Gamma function with shape parameter a = 2 by default)
    # We assume SFR starts at t=0 here
    
    # First load in end-point of simulation (i.e. stellar birth-time)
    if full_checking:
        assert (par.observations.star_age <= 13.0), "Age of the star must be below 13Gyr"
    par.times.new_end = par.times.end - par.observations.star_age # time to stop the Chempy run

    # Define evolution time steps (first in Numpy)
    t = np.linspace(par.times.start,par.times.new_end,par.times.time_steps)
    dt = t[1]-t[0]
    
    # convenience functions
    tmp_t1 = np.power(t,par.sfr.a_parameter-1)
    tmp_t2 = np.exp(-1.*t)
    t = tt._shared(t) # convert to tensors
    tmp_t1 = tt._shared(tmp_t1)
    tmp_t2 = tt._shared(tmp_t2)
        
    def initialize_scale(self,sfr_scale):
        # Load the SFR in full given the SFR scale. This must be run in the PyMC3 model

        # First compute normalized gamma function SFR up to the correct end time
        sfr = self.tmp_t1*tt.power(self.tmp_t2,1./sfr_scale) # unnormalized gamma function
        self.sfr = sfr/(sfr.sum())

        if full_checking:
            # Check modeling assumptions
            tt.opt.Assert(self.sfr[0]==0.)

            # Now compute the total mass formed in this simulation (normalized to unit mass formed in universe age)
            total_mass = gammainc_th(theano.shared(1.*par.sfr.a_parameter),par.times.new_end/sfr_scale)
            total_mass_univ = gammainc_th(theano.shared(1.*par.sfr.a_parameter),par.times.end/sfr_scale) # total mass formed in universe age
            total_mass /= total_mass_univ

            # Check we form enough stars
            mean_sfr = total_mass_univ/par.times.end
            sfr_at_end = self.sfr[-1]
            tt.opt.Assert(sfr_at_end>0.05*mean_sfr)

def inverse_lifetime_function(lifetime,Z):
    """Inverse of the lifetime function from Portinari+ 1998, as used in IllustrisTNG.
    For this parametrization, the lifetime can be computed analytically.
    Inputs are age in Gyr and metallicity fraction.
    This returns the mass in Msun"""
    lZ = tt.log10(tt.max([Z,1e-5])) # to avoid zero errors
    lZ2 = lZ*lZ
    params = np.array([10.0153615 , -3.91089893,  0.99947209, -0.03601771, -0.31139679,0.09109059, -0.03218365, -0.01323112])
    #params = tt._shared(params)
    alpha = params[0]+lZ*params[3]+lZ2*params[6]
    beta = params[1]+lZ*params[4]+lZ2*params[7]
    gamma = params[2]+lZ*params[5]
    
    t = 9+tt.log10(lifetime)
    disc = beta**2.-4.*(alpha-t)*gamma
    
    def output(disc):
        # Compute logM from quadratic;
        lM = (-beta-tt.sqrt(disc))/(2.*gamma)
        return tt.power(10.,lM)
    
    # Ensure we return max IMF value for too-small-time errors
    return tt.switch(disc>0,output(disc),par.imf.max_mass)

def sn1a_time_delay(time_array,dt,sn1a_normalization,tau_8=par.sn1a.tau_8,s_exponent=par.sn1a.s_exponent):
    """ Compute the SN1a delay time distribution (DTD), assuming the form given in Maoz 2012.
    tau_8 is the delay time parameter.
    
    This is computed for an array of input times, separated by dt
    
    Unlike in Chempy we normalize by integrating the DTD instead of summing so it is independent of the time-step 
    """
    
    # Create DTD and zero any times with time<tau_8
    DTD = tt.switch(time_array>tau_8,tt.power(time_array/tau_8,-1.*s_exponent)*(s_exponent-1.)/tau_8,0)

    norm = (1.-tt.power((par.times.end/tau_8),1.-s_exponent))/dt
    
    # Normalize correctly
    return DTD/norm*sn1a_normalization
    

def mass_fractions_to_abundances(mass_fractions):
    """ Convert an array of mass fractions to [X/Fe] abundances normed to Asplund 2009 solar values."""
    
    output = mass_fractions/solar_abundances.cut_masses # convert to number fractions
    output/=output[par.general.H_index] # normalize by H 
    
    if full_checking:
        tt.opt.Assert(output>=0.) # check positive mass fractions
    
    X_H_abundances = tt.log10(output)+12.-solar_abundances.cut_abundances
    
    # Convert to X/Fe abundances
    X_Fe_abundances = X_H_abundances.copy() - X_H_abundances.copy()[par.general.Fe_index]
    # Also include [Fe/H] abundance
    X_Fe_abundances = tt.set_subtensor(X_Fe_abundances[par.general.Fe_index], X_H_abundances[par.general.Fe_index])
    
    return X_Fe_abundances

In [4]:

class AbundanceMatrix(object):
    """Preload important quantities for holding the chemical evolution properties of Chempy.
    
    We do not use the 'cube' and 'reservoir' rec-arrays used in Chempy to hold evolutionary properties. 
    The key properties can just be recovered from the output of the advance_simulation() wrapper below
    
    On initialization we create useful quantities including the initial states of the main evolutionary variables.
    These are analogous to self.cube['quantity'][0] in Chempy and are labelled as self.quantity_init here
    """       
    def __init__(self,SFR,PrimordialInfall):
        # Copy useful variables
        print("Do we need all cube variables?")
        self.t = SFR.t
        self.dt = SFR.dt
        self.elements = par.general.elements_to_trace
        self.infall_symbols = PrimordialInfall.symbols
        self.infall_fractions = tt._shared(PrimordialInfall.fractions)
        self.sfr_factor_for_cosmic_accretion = par.general.sfr_factor_for_cosmic_accretion
        
        # Initialize table to hold amounts of feedback created at timestep i from timestep j for each element
        self.all_feedback = tt.zeros((par.times.time_steps,par.times.time_steps,par.general.n_el))

        # Add initial gas composition
        # (since we have no initial gas and SFR=0 at t=0 this is independent of free parameters)
        starting_gas = par.general.gas_reservoir_mass_factor
        self.corona_gas_init = float64(starting_gas)
        self.corona_mass_fractions_init = starting_gas*self.infall_fractions
        
        # Initialize ISM gas to negligible mass with correct mass ratios
        self.ism_gas_init = float64(starting_gas*1e-20)
        self.ism_mass_fractions_init = self.corona_mass_fractions_init*1e-20
        self.ism_Z_init = self.ism_mass_fractions_init[par.general.metal_list].sum()/self.ism_gas_init
        self.element_fractions_init = self.ism_mass_fractions_init/self.ism_gas_init

In [5]:
class SSP(object):
    """Class to hold an SSP object giving its enrichment over time""" 
    def __init__(self,z,times,fractions_in_gas,imf):
        """z = metallicity in mass fraction (i.e. Z).
        times = times over which to compute the SSP enrichment
        fractions_in_gas = initial SSP mass fractions (for gross feedback)
        """
        self.fractions_in_gas = fractions_in_gas
        self.z = z
        self.t = times
        self.dt = times[1]-times[0]
        self.imf = imf
        
        self.logZ = tt.log10(tt.max([self.z,1e-5])) # to avoid zero errors
        self.time_steps_trunc = tt.shape(self.t)

        # Compute inverse IMF function
        self.inverse_imf = inverse_lifetime_function(self.t,self.z)

        # Create class to hold enrichment table
        class EnrichmentTable(object):
            # SSP enrichment table
            def __init__(self,n_steps):
                # n_steps is length of table
                
                print("Can we define this as one-step shorter?")
                self.yields = tt.zeros((n_steps[0],par.general.n_el))
                
        self.table=EnrichmentTable(self.time_steps_trunc)

    def sn2_feedback(self):
        """ Compute the SN2 feedback table from the SSP initialized above.
        """
        def feedback_per_time_step(inv_imf_low,inv_imf_high):
            # Compute the SN2 feedback for each time-step.
            # This wil be looped in a theano scan function

            # Need to sum over masses in correct region allowed by IMF
            lower_cut = tt.max([inv_imf_high,par.sn2.min_mass])
            upper_cut = tt.min([inv_imf_low,par.sn2.max_mass])

            # Here looking at all masses in correct range
            cut = tt.and_(imf.masses<upper_cut,imf.masses>=lower_cut)

            print("This assumes a linear weight function in mass")
            sum_weights = imf.dm[cut].sum()
            av_mass = (imf.masses[cut]*imf.dm[cut]).sum()/sum_weights
            sn2_yields = sn2_yield(av_mass,self.logZ)*sum_weights

            sn2_remnants = sn2_remnant(av_mass,self.logZ)*sum_weights
            
            # Now add unprocessed mass to gross yields
            sn2_yields += (sum_weights-sn2_remnants)*self.fractions_in_gas

            print("NB: all tables are gross tables here unlike Chempy")

            # Enforce that cut.sum()>=1 to avoid errors when there are no mass steps in correct range
            return tt.switch(cut.sum()<1,tt.zeros_like(sn2_yields),sn2_yields)
            
        scan_results,_ = theano.scan(feedback_per_time_step,sequences=[self.inverse_imf[:-1],self.inverse_imf[1:]])
        
        self.table.yields = tt.set_subtensor(self.table.yields[1:],self.table.yields[1:]+scan_results)
        
    def agb_feedback(self):
        """ Compute the AGB feedback table from the SSP initialized above.
        """
        
        def feedback_per_time_step(inv_imf_low,inv_imf_high):
            # Compute the AGB feedback for each time-step.
            # This wil be looped in a theano scan function

            # Need to sum over masses in correct region allowed by IMF
            lower_cut = tt.max([inv_imf_high,par.agb.min_mass])
            upper_cut = tt.min([inv_imf_low,par.agb.max_mass])

            # Here looking at all masses in correct range
            cut = tt.and_(imf.masses<upper_cut,imf.masses>=lower_cut)

            sum_weights = imf.dm[cut].sum()
            
            av_mass = (imf.masses[cut]*imf.dm[cut]).sum()/sum_weights
            agb_yields = agb_yield(av_mass,self.logZ)*sum_weights

            agb_remnants = agb_remnant(av_mass,self.logZ)*sum_weights
            
            # Now add unprocessed mass to gross yields
            agb_yields += (sum_weights-agb_remnants)*self.fractions_in_gas

            print("NB: agb tables are treated weirdly in Chempy - possibly missing first/last step, so may be different")

            return tt.switch(cut.sum()<1,tt.zeros_like(agb_yields),agb_yields)
            
        scan_results,_ = theano.scan(feedback_per_time_step,sequences=[self.inverse_imf[:-1],self.inverse_imf[1:]])

        # Now add to global yield tables
        self.table.yields = tt.set_subtensor(self.table.yields[1:],self.table.yields[1:]+scan_results)
        
    def sn1a_feedback(self,sn1a_normalization):
        """ Compute the SNIa feedback table from the SSP initialized above.
        
            sn1a_normalization is the DTD normalization (number of SN1a per 1Msun per 13.8 Gyr)
        """
        
        # Compute the number of supernovae in each time-step (from the DTD) 
        sn1a_feedback_number = sn1a_time_delay(self.t,self.dt,sn1a_normalization)
        
        # Set feedback mass parameters
        print("Should this be fit or set from data? (well constrained)")
        mean_mass_of_feedback = -sn1a_remnant() # This mass is turned into the explosion

        mean_mass = 2.156 #mass_of_stars_in_mass_range_for_remnant/number_of_stars_in_mass_range_for_remnant
        mean_remnant_mass = mean_mass*0.3
        mean_accretion_mass = mean_mass_of_feedback - mean_remnant_mass # from Hydrogen feedback
        print("Only SN1a mean_remnant_mass is unset here - set as free parameter?")
        
        # Add in H-accretion onto White Dwarfs
        self.table.yields = tt.set_subtensor(self.table.yields[:,par.general.H_index],
                                             self.table.yields[:,par.general.H_index]-sn1a_feedback_number * mean_accretion_mass)
        
        # Add element yields to each time-step
        tmp_yields = sn1a_yield()*mean_mass_of_feedback
        scan_result,_=theano.scan(lambda feedback_number: tmp_yields*feedback_number,sequences=sn1a_feedback_number)
        self.table.yields += scan_result
        

In [6]:
        
## MAIN CODE

# Pre-load modules
imf = IMF()
sfr = SFR()

# Load solar abundances from Chempy
solar_abundances = SolarAbundances()
solar_abundances.Asplund09()

# Select correct ranges of solar abundances for later
element_masses = [] # to hold mass of each element
solar_abundances_cut = [] # to hold photospheric abundances of each element
for item in par.general.elements_to_trace:
    element_masses.append(solar_abundances.table['Mass'][np.where(solar_abundances.table['Symbol']==item)][0])
    solar_abundances_cut.append(solar_abundances.table['photospheric'][np.where(solar_abundances.table['Symbol']==item)][0])
solar_abundances.cut_masses = tt._shared(np.asarray(element_masses))
solar_abundances.cut_abundances = tt._shared(np.asarray(solar_abundances_cut))

primordial_infall = PrimordialInfall(par.general.elements_to_trace,solar_abundances.table)
primordial_infall.primordial()

In [9]:
full_model = pm.Model()

with full_model:

    ## PARAMETERS TO INFER
    class InferenceParameters(object):
        # Main parameters
        alpha_imf = pm.Normal('imf-slope',mu=-2.3,sd=0.3,testval=-2.3) # high-mass IMF slope
        log10_N_1a = pm.Normal('log10-N-1a',mu=-2.89,sd=0.3,testval=-2.89) # SN-1a normalization

        log10_sfe = pm.Normal('log10-sfe',mu=-0.3,sd=0.1,testval=-0.3) # the SFE for a linear Kennicut-Schmidt law
        log10_sfr_scale = pm.Normal('log10-sfr-scale',mu=0.55,sd=0.1,testval=0.55) 
        x_out = pm.Normal('x-out',mu=0.5,sd=0.1,testval=0.5) # fractional enrichment goes into the corona

        # Derived parameters
        N_1a = tt.pow(10.,log10_N_1a)
        sfr_scale = tt.pow(10.,log10_sfr_scale)
        sfe = tt.pow(10.,log10_sfe)

        ## Yield parameters
        print("How do we fix yields to be less than unity? All elements should sum to 1...")
        # Net yield parameters
        
        if sample_yields:
            sn2_par = pm.Normal('sn2-yields',mu=par.yields.sn2_yield_mean,sd=par.yields.sn2_yield_std,shape=(par.general.n_el,3),
                               testval = par.yields.sn2_yield_mean)
            agb_par = pm.Normal('agb-yields',mu=par.yields.agb_yield_mean,sd=par.yields.agb_yield_std,shape=(par.general.n_el,3),
                               testval=par.yields.agb_yield_mean)
            sn1a_par = pm.Normal('sn1a-yields',mu=par.yields.sn1a_yield_mean,sd=par.yields.sn1a_yield_std,shape=(par.general.n_el,1),
                                testval=par.yields.sn1a_yield_mean)

            # Remnant mass fraction parameters
            sn2_remnant = pm.Normal('sn2-remnant',mu=par.yields.sn2_remnant_mean,sd=par.yields.sn2_remnant_std,shape=3,
                                   testval = par.yields.sn2_remnant_mean)
            agb_remnant = pm.Normal('agb-remnant',mu=par.yields.agb_remnant_mean,sd=par.yields.agb_remnant_std,shape=3,
                                   testval = par.yields.agb_remnant_mean)
            sn1a_remnant = pm.Normal('sn1a-remnant',mu=par.yields.sn1a_remnant_mean,sd=par.yields.sn1a_remnant_std,shape=1,
                                    testval = par.yields.sn1a_remnant_mean)
        else:
            sn2_par = par.yields.sn2_yield_mean
            agb_par = par.yields.agb_yield_mean
            sn1a_par = par.yields.sn1a_yield_mean
            sn2_remnant = par.yields.sn2_remnant_mean
            agb_remnant = par.yields.agb_remnant_mean
            sn1a_remnant = np.asarray([par.yields.sn1a_remnant_mean])
            
    inference = InferenceParameters()
        
    cube=AbundanceMatrix(sfr,primordial_infall)

    ## Load IMF (Chabrier 2003)
    imf.initialize_slope(inference.alpha_imf)    
    
    ## SFR (Gamma function)
    sfr.initialize_scale(inference.sfr_scale)
    
    ## AbundanceMatrix class
    cube.sfr = sfr.sfr
    cube.star_formation_efficiency = inference.sfe*cube.dt

    ## YIELDS
    # Define functions for mass fraction and remnant mass
    print("NB: We use logZ=-5 for Z=0")
    print("NB: No dependence on mass or logZ for SN1a yields")
    def sn1a_yield():
        return inference.sn1a_par[:,0]
    def sn2_yield(mass,logZ):
        return inference.sn2_par[:,0]+inference.sn2_par[:,1]*mass+inference.sn2_par[:,2]*logZ
    def agb_yield(mass,logZ):
        return inference.agb_par[:,0]+inference.agb_par[:,1]*mass+inference.agb_par[:,2]*logZ
    def sn1a_remnant():
        return inference.sn1a_remnant[0]
    def sn2_remnant(mass,logZ):
        return inference.sn2_remnant[0]+inference.sn2_remnant[1]*mass+inference.sn2_remnant[2]*logZ
    def agb_remnant(mass,logZ):
        return inference.agb_remnant[0]+inference.agb_remnant[1]*mass+inference.agb_remnant[2]*logZ

    def advance_simulation(time_index,
                           ism_mass_fractions,ism_gas,ism_Z,corona_mass_fractions,corona_gas,feedback_cube,previous_element_fractions,
                           sfr_array,sfr_times,infall_fractions,star_formation_efficiency):
        """Function to step through time in the simulation. 
        This must be iterated over by theano's scan function (i.e. a for loop) 

        *Sequence Variables*
        time_index is the current time-step
        
        *Tensor Variables*
        These contain all tensors that are updated in this time-step. 
        Each contains only the value for this current step e.g.
            ism_Z contains the metallicity at [time_index] initially, which is updated to [time_index+1] in this step.
        These are passed as output to the next time-step.
        This is much more memory efficient than passing the whole data-cube of outputs to each step.
        We can reconstruct the outputs at each time-step from the theano.scan() function if needed
        
        This comprises:
            - ism_mass_fractions = mass of each element in the ISM (relative to total simulation mass of 1 Msun)
            - ism_gas = mass of gas in the ISM
            - ism_Z = metallicity fraction of the ISM
            - corona_mass_fractions = mass of each element in the corona (relative to total simulation mass of 1 Msun)
            - corona_gas = mass of gas in the corona
            - feedback_cube = feedback of each element at time-step i from time-step j. This has n_times*n_times*n_el dimensions and is sourced by SN2, SN1a and AGB feedback
            - previous_element_fractions = ISM fraction of each element from the PREVIOUS time-step i.e. time_index-1 (used for gross yields)
            
        The 'previous_element_fractions' tensor is not updated during the function call, but we pass a different tensor (old_element_fractions) for the next iteration.

        *Non-Sequence Arguments*
        
        These are unchanged by the code.
            - sfr_array = vector with SFR of stars formed in each time-step i.e. cube.sfr
            - sfr_times = vector of time-steps used by the SFR
            - infall_fractions = element fractions for cosmic accretion (from PrimordialInfall class)
            - star_formation_efficiency = SFE*time-step
            
        The function outputs the variables in Tensor Variables for input to the next time-step
        """
        
        # Define SFR for next time-step which is used below
        next_sfr = sfr_array[time_index+1]
        
        # Also save previous mass fractions separately for later
        print("Do we need previous mass fractions or current mass fractions for gross feedback? Check with Jan")
        old_element_fractions = (ism_mass_fractions/ism_gas).copy()
        
        if full_checking:
            # Check for unphysical negative element fractions
            for i in range(par.general.n_el):
                tt.opt.Assert(previous_element_fractions[i]>0)
                
        # Compute time steps from now until the end of the simulation
        remaining_time_steps = sfr_times[:par.times.time_steps-time_index].copy()
        
        ## Compute SSP feedback from each process
        ssp = SSP(ism_Z.copy(),remaining_time_steps,previous_element_fractions,imf)
        ssp.sn2_feedback()
        ssp.agb_feedback()
        ssp.sn1a_feedback(inference.N_1a)

        ## Compute required quantities for the next time step
        
        # First fill up all_feedback table from SSP
        feedback_cube = tt.set_subtensor(feedback_cube[time_index,time_index:,:],ssp.table.yields)
        
        ### UPDATE CORONA AND ISM QUANTITIES
        ## Update via FEEDBACK
        # First update elements
        tmp = feedback_cube[:time_index+1,time_index+1,:]*sfr_array[:time_index+1].reshape((-1,1))
        feedback_mass = tmp.sum(axis=0)
        total_feedback_mass = feedback_mass.sum()
        ism_mass_fractions += (1.-inference.x_out)*feedback_mass
        
        # Also add COSMIC INFLOW to corona
        cosmic_inflow = next_sfr * par.general.sfr_factor_for_cosmic_accretion
        corona_gas += cosmic_inflow + inference.x_out*total_feedback_mass
        corona_mass_fractions += infall_fractions*cosmic_inflow+inference.x_out*feedback_mass
        
        ## Compute INFALL (SFR-related) via SCHMIDT LAW
        gas_needed = tt.power(next_sfr/star_formation_efficiency,1./par.general.gas_power)
        gas_there = ism_mass_fractions.sum()
        infall_needed = tt.max([0.,(gas_needed - gas_there)*1.00000001]) # to avoid less gas being requested than needed due to rounding errors
        infall_needed = tt.switch(infall_needed+gas_there<=next_sfr,
                                  (next_sfr-gas_there)*1.01,infall_needed)
        infall_needed = tt.min([corona_gas,infall_needed])
        
        # Move infall gas from corona to ISM
        ism_mass_fractions += infall_needed*corona_mass_fractions/corona_gas
        corona_mass_fractions -= infall_needed*corona_mass_fractions/corona_gas
        corona_gas -=  infall_needed

        ## Create STARS by subtracting SFR
        ism_gas = ism_mass_fractions.sum()
        if full_checking:
            tt.opt.Assert(ism_gas>=next_sfr)
        ism_mass_fractions -= next_sfr*ism_mass_fractions/ism_gas
        ism_gas -= next_sfr
        
        ## Determine METALLICITY fractions
        ism_Z = ism_mass_fractions[par.general.metal_list].sum()/ism_gas
        
        # Check we have positive values everywhere
        print("How do we deal with breaking of assumptions? Somehow need to return inf etc.")
        if full_checking:
            tt.opt.Assert(ism_mass_fractions[:]>=0.)
            tt.opt.Assert(ism_gas>=0.)
            tt.opt.Assert(corona_gas>=0.)
        
        return ism_mass_fractions,ism_gas,ism_Z,corona_mass_fractions,corona_gas,feedback_cube,old_element_fractions
    
    ### RUN SIMULATION
    
    # Define initial values:
    initials = [cube.ism_mass_fractions_init[:],cube.ism_gas_init,cube.ism_Z_init,
                cube.corona_mass_fractions_init[:],cube.corona_gas_init,cube.all_feedback,cube.element_fractions_init]
        
    # Run the iterator in theano using cube as initial value which is continually updated
    scan_results,scan_updates=theano.scan(advance_simulation,sequences=np.arange(0,par.times.time_steps-1),
                                          outputs_info=initials,
                                          non_sequences = [cube.sfr,sfr.t,cube.infall_fractions,cube.star_formation_efficiency])

    # Compute abundances from mass_fractions
    X_Fe_abundances = mass_fractions_to_abundances(scan_results[0][1])
    # Remove the H/Fe abundance
    output_abundances = tt.concatenate([X_Fe_abundances[:par.general.H_index],X_Fe_abundances[par.general.H_index+1:]])
    
    print("Add some infinities to outputs etc.?")
    
    print("Now compare to observational data!")

    ## COMPARE TO OBSERVATIONAL DATA
    # Create likelihood function
    likelihood=pm.Normal('likelihood', mu=output_abundances, sd=par.observations.errors, 
                             observed=par.observations.abundances)


How do we fix yields to be less than unity? All elements should sum to 1...
Do we need all cube variables?
NB: We use logZ=-5 for Z=0
NB: No dependence on mass or logZ for SN1a yields
Do we need previous mass fractions or current mass fractions for gross feedback? Check with Jan
Can we define this as one-step shorter?
This assumes a linear weight function in mass
NB: all tables are gross tables here unlike Chempy


  rval = inputs[0].__getitem__(inputs[1:])


NB: agb tables are treated weirdly in Chempy - possibly missing first/last step, so may be different
Should this be fit or set from data? (well constrained)
Only SN1a mean_remnant_mass is unset here - set as free parameter?
How do we deal with breaking of assumptions? Somehow need to return inf etc.


  rval = inputs[0].__getitem__(inputs[1:])


Add some infinities to outputs etc.?
Now compare to observational data!


In [None]:
init_time = ttime.time()
with full_model:
    samples = pm.sample(draws=1000,chains=8,cores=8,tune=1000,init='adapt_diag',step=pm.Metropolis())
#                      nuts_kwargs={'target_accept':0.9},init='advi+adapt_diag',n_init=1000)
end_time = ttime.time()-init_time
print("Sampling complete in %d seconds"%end_time)

Multiprocess sampling (8 chains in 8 jobs)
CompoundStep
>Metropolis: [x-out]
>Metropolis: [log10-sfr-scale]
>Metropolis: [log10-sfe]
>Metropolis: [log10-N-1a]
>Metropolis: [imf-slope]
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
Sampling 8 chains:  22%|██▏       | 3460/16000 [04:49<19:16, 10.84draws/s]

## Assertions are NOT implemented properly here. Check this

## Still getting NaN initial energy from sampling (unless initializing with MAP or adapt_diag?). This possibly suggests a problem or just that data is not good.

##### Worth trying to make Chempy data with fixed parameters and linear-yields to check this works


##### Also can try putting the model computation on a GPU

## NB: 
We can probably use shared values to pass multiple times into the model effectively e.g.
https://pymc3.readthedocs.io/en/latest/advanced_theano.html

i.e. using end-time as a shared value that can be updated without rerunning the model

In [None]:
for RV in full_model.basic_RVs:
    print(RV.name, RV.logp(full_model.test_point))

In [None]:
ism_mass_fractions,ism_gas,ism_Z,corona_mass_fractions,corona_gas,feedback_cube,old_element_fractions=scan_results

In [None]:
for i in range(8):
    plt.plot(ism_mass_fractions.tag.test_value[:,i])
#plt.colorbar();

In [None]:
plt.matshow(feedback_cube.tag.test_value[-1][:,:,5])
plt.colorbar();

## Once running:

- Profile code
- Use GPUs
- Check graphs

e.g. see http://www.marekrei.com/blog/theano-tutorial/

## NB:
- Unprocessed mass fraction in winds is 1 - Remnant Mass Fraction here (ignoring small correction terms)
- We assume Maoz time delay form here for SN1a with time-delay of 40 Myr as in TNG
- ISM is Chabrier 2003 as in TNG
- Even if we don't use major elements in the analysis they **should** be included in Chempy since they affect metallicites
- We assume SN1a yields to have **no** dependency on mass or logZ (as in the TNG yield set)
- Could parametrize mean-remnant-mass for SN1a masses better - it depends on some dodgy assumptions here.