<a href="https://colab.research.google.com/github/davidwhogg/FewProcessModel/blob/main/notebooks/Hogg_for_Griffith.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A two-process (or `K`-process) model for stellar abundances

## Authors:
- **Emily J. Griffith** (Colorado)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Things to discuss:
- Note that the model becomes ambiguous at high metallicity (maybe because there are no longer two sequences!). Fix this somehow? Maybe by putting all of high metallicity into a wider bin? Or perhaps by penalizing derivatives dq/dZ? Observation is: When we use the fine `EJG` binning we get craziness at high Z; when we use the coarser `DWH` binning we don't? Maybe? Or maybe it's sufficient to extrapolate outside the range with the linear dependence inside the range? A lot of this can be fixed with the knot formalism proposed in the Major bugs list.
- Right now the code de-emphasizes element measurements that are outside a certain, hard box. That's wrong; and what should we do instead of that to deal with outliers?
- How to display results and/or what plots are most diagnostic? How to best show dependences on `q_CC_Fe`?

## Major bugs / to-do items:
- When we want to go to `K=4` I now think that we have to optimize `K=2` and then add a process, optimize, and then add another, and then optimize again. That requires some refactoring.
- Right now, the `bins` and `metallicities` inputs are used only for plotting? Maybe we should drop them entirely?
- Set up the notebook so it doesn't go from scratch by default; it can save pickle files at checkpoints and restart there instead. Move to operating on local hardware rather than the Colab.
- Consider de-emphasizing some elements in the fit when we run the A-step or looking at which elements we want to use / believe for the A-step.
- The violence we do to the uncertainties / inverse variances right after reading the data is terrible; fix this.
- We ought to produce some kind of error estimates on everything for which we claim results.
- We should go to a wider metallicity range, and non-uniform metallicity bins.
- Maybe regularize the CCSN process to be more dominate at low metallicity when it doesn't know what process to assign abundances to

## Minor bugs / stretch goals:
- Should we make versions of the element-element plots in which we add in fake noise so that the model looks more like the data?
- `Aq-step()` contains a lot of optimization hacks of which we are not proud.
- We could add a "jitter" term to the observational errors on the abundances?
- Shouldn't be taking a `sqrt` in the residual (chi) code.

## Comments
- `jax vmap()` completely changed what was possible in this project. We must acknowledge & cite them in the paper.

In [None]:
!pip install jaxopt

In [None]:
!pip install wget

In [None]:
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jaxopt
import wget
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
import os.path
from tqdm import tqdm
from jax import vmap, grad, jit
from functools import partial
import time

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
rng = np.random.default_rng(42) # for all important random numbers
rng2 = np.random.default_rng(17) # for random numbers used just in plotting

In [None]:
# Revise default plotting style
style_revisions = {
            'axes.linewidth': 1.5,
            'xtick.top' : True, 
            'ytick.right' : True, 
            'xtick.direction' : 'in',
            'ytick.direction' : 'in', 
            'xtick.major.size' : 11, 
            'ytick.major.size' : 11, 
            'xtick.minor.size' : 5.5, 
            'ytick.minor.size' : 5.5,
            'font.size' : 16,
            'figure.figsize' : [6, 6],
            'lines.linewidth' : 2.5,
        }
plt.rcParams.update(style_revisions)

In [None]:
# Set constants
ln10 = np.log(10.)

In [None]:
# Set hyper-parameters
K = 2              # number of processes
processes = np.array(['CC', 'Ia', 'AGB', 'fourth']) # process names
processes = processes[:K]
Lambda_a = 1.e6    # regularization strength on Mg for CC and on Fe for Ia.
Lambda_b = 1.e6    # regularization strength on Mg for Ia and Fe for CC.
q_CC_Fe = 0.35     # q_CC,Fe at Z=0
dq_CC_Fe_dZ = 0.0  # slope wrt Z
Lambda_c = 1.0     # regularization strength on everything else in q
Lambda_d = 1.e3    # regularization strength on the A values for processes 3+
sqrt_Lambda_A = jnp.ones(K) * jnp.sqrt(Lambda_d) # see way above
sqrt_Lambda_A = jnp.where(processes == "CC", 0., sqrt_Lambda_A)
sqrt_Lambda_A = jnp.where(processes == "Ia", 0., sqrt_Lambda_A)
def get_A_regularization():
    return sqrt_Lambda_A

In [None]:
# Download files from Emily's website
url = 'https://www.emilyjgriffith.com/s/'
if(os.path.isfile('lnqs.npy')==False): wget.download(url+'lnqs.npy')
if(os.path.isfile('lnAs.npy')==False): wget.download(url+'lnAs.npy')
if(os.path.isfile('bins.npy')==False): wget.download(url+'bins.npy')
if(os.path.isfile('alldata.npy')==False): wget.download(url+'alldata.npy')
if(os.path.isfile('allivars.npy')==False): wget.download(url+'allivars.npy')

# These are the new files that extend to lower Z
# Emily commented out bins files so that we rememeber to recreate

#if(os.path.isfile('bins_2.npy')==False): wget.download(url+'bins_2.npy')
if(os.path.isfile('alldata_2.npy')==False): wget.download(url+'alldata_2.npy')
if(os.path.isfile('allivars_2.npy')==False): wget.download(url+'allivars_2.npy')

In [None]:
# Load numpy files
# N - number of stars = 34410
# M - number of elements = 16

# warning: MAGIC; BRITTLE
elements  = np.array(['Mg','O','Si','S','Ca','CN','Na','Al','K','Cr','Fe','Ni','V','Mn','Co','Ce'])
M = len(elements)

# lnqs: shape(2, 12, 16), 0 is qcc, 1 is qIa, replaced negative values with 0.05
w22_lnqs = np.load('lnqs.npy')
w22_metallicity_labels = np.array(['-0.7', '-0.6', '-0.5', '-0.4', '-0.3',
                                   '-0.2', '-0.1', '0.0', '0.1', '0.2', '0.3',
                                   '0.4', ])
w22_metallicities = np.array([float(m) for m in w22_metallicity_labels])
a, b, c = w22_lnqs.shape
assert b == len(w22_metallicities)
assert c == M

# artificially raise the zeros above zero
w22_lnqs = np.clip(w22_lnqs, -7., None)

# alldata: shape(34410, 16), bad data = 0
alldata = np.load('alldata_2.npy')
N, b = alldata.shape
assert b == len(elements)

# allivars: shape(34410,16), bad data = 0
allivars = np.load('allivars_2.npy')
assert allivars.shape == alldata.shape

In [None]:
# define the binning of the data MAGIC
# THIS IS NOW DEPRECATED -- IT IS ONLY USED IN PLOTTING.
EJG_bin_edges = np.array([-2.2, -1.7, -1.5, -1.3, -1.1, -0.95, -0.80, -0.65, -0.55, 
             -0.45, -0.35, -0.25, -0.15, -0.05,  0.05,  0.15,  0.25,  0.35,
             0.45, 0.6])
DWH_bin_edges = np.array([-2.2, -1.4, -1.1, -0.80, -0.55, -0.45, -0.35, -0.25, -0.15,
                          -0.05,  0.05,  0.15,  0.25, 0.35, 0.6])
bin_edges = DWH_bin_edges
Nbin = len(bin_edges) - 1
bins = (np.digitize(alldata[:, elements == "Mg"], bin_edges) - 1).flatten()
metallicities = np.zeros(Nbin)
for bin in range(Nbin):
    metallicities[bin] = np.median(alldata[:, elements == "Mg"][bins == bin])
    print(bin, bin_edges[bin], metallicities[bin], bin_edges[bin + 1], np.sum(bins == bin))

In [None]:
# Set up knots for q values
knot_xs = np.array([-2.5, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.55])
ii = 0
assert elements[ii] == "Mg"
xs = alldata[:, ii] # BRITTLE MAGIC
Nknot = len(knot_xs)

In [None]:
# mess with the uncertainties -- THIS IS TOTAL HACKING!
# THIS IS NOT PERMITTED.
# NOTE brittle 0 (should be "Mg")
for el, name in enumerate(elements):
    allivars[np.where((alldata - alldata[:, 0][:, None]) < -0.5)] = 1. # UGH HACK MAGIC
    allivars[np.where((alldata - alldata[:, 0][:, None]) > 0.2)] = 1. # UGH HACK MAGIC
sqrt_allivars = jnp.sqrt(allivars)

In [None]:
def get_lnqs(lnqs, knot_xs, xs):
    """
    linear interpolation on vmap?
    """
    return jnp.concatenate([vmap(jnp.interp, in_axes=(None, None, 1),
                                 out_axes=(1))(xs, knot_xs, lnqs[k])[None, :, :]
                            for k in range(K)], axis=0)

In [None]:
def all_stars_K_process_model(lnAs, lnqs, knot_xs, xs):
    """
    ## inputs
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes
    - `lnqs`: shape `(K, Nknot, M)` natural-logarithmic processes
    - `knot_xs`: shape `(Nknot, )` metallicity bin centers
    - `xs`: shape `(N, )` abundance data (used to interpolate the `lnqs`)

    ## outputs
    shape `(M, )` log_10 abundances

    ## comments
    - Note the `ln10`.
    """
    return logsumexp(lnAs[:, :, None]
                     + get_lnqs(lnqs, knot_xs, xs), axis=0) / ln10

In [None]:
def one_star_K_process_model(lnAs, lnqs):
    """
    ## inputs
    - `lnAs`: shape `(K,)` natural-logarithmic amplitudes
    - `lnqs`: shape `(K, M)` natural-logarithmic processes

    ## outputs
    shape `(M, )` log_10 abundances

    ## comments
    - Note the `ln10`.
    """
    return logsumexp(lnAs[:, None] + lnqs, axis=0) / ln10

def one_star_chi(lnAs, lnqs, data, sqrt_ivars, sqrt_Lambda):
    """
    ## inputs
    - `lnAs`: shape `(K, )` natural-logarithmic amplitudes
    - `lnqs`: shape `(K, M)` natural-logarithmic processes
    - `data`: shape `(M, )` log_10 abundance measurements
    - `sqrt_ivars`: shape `(M, )` inverse errors on the data
    - `sqrt_Lambda`: shape `(K, )` regularization strength on As

    ## outputs
    chi for this one star
    """
    return jnp.concatenate([sqrt_ivars * (data - one_star_K_process_model(lnAs, lnqs)),
                            sqrt_Lambda * jnp.exp(lnAs)])

def one_star_A_step(lnqs, data, sqrt_ivars, sqrt_Lambda, init):
    """
    ## inputs
    - `lnqs`: shape `(K, M)` natural-logarithmic processes
    - `data`: shape `(M, )` log_10 abundance measurements
    - `sqrt_ivars`: shape `(M, )` inverse errors on the data
    - `sqrt_Lambda`: shape `(K, )` regularization
    - `init`: shape `(K,)` initial guess for the A vector

    ## outputs
    shape `(K,)` best-fit natural-logarithmic amplitudes

    ## bugs
    - Doesn't check the output of the optimizer AT ALL.
    - Check out the crazy `maxiter` input!
    """
    solver = jaxopt.GaussNewton(residual_fun=one_star_chi, maxiter=4)
    lnAs_init = init.copy()
    chi2_init = np.sum(one_star_chi(lnAs_init, lnqs, data, sqrt_ivars, sqrt_Lambda) ** 2)
    res = solver.run(lnAs_init, lnqs=lnqs, data=data, sqrt_ivars=sqrt_ivars,
                     sqrt_Lambda=sqrt_Lambda)
    chi2_res = np.sum(one_star_chi(res.params, lnqs, data, sqrt_ivars, sqrt_Lambda) ** 2)
    return res.params, chi2_init - chi2_res

def A_step(lnqs, data, sqrt_ivars, knot_xs, xs, old_lnAs):
    """
    ## inputs
    - `lnqs`: shape `(K, Nknot, M)` natural-logarithmic processes
    - `data`: shape `(N, M)` log_10 abundance measurements
    - `sqrt_ivars`: shape `(N, M)` inverse variances on alldata
    - `knot_xs`: shape `(Nknot, )` metallicity knot locations
    - `xs`: shape `(N, )` metallicities (to use with the knots)
    - `old_lnAs`: previous `lnAs`; used for initialization of the optimizer

    ## outputs
    shape `(K, N)` best-fit natural-logarithmic amplitudes

    ## bugs
    - Ridiculous post-processing of outputs, with MAGIC numbers.
    """
    N, M = data.shape
    K, Nknot, em = lnqs.shape
    assert em == M
    assert sqrt_ivars.shape == (N, M)
    assert knot_xs.shape == (Nknot, )
    assert old_lnAs.shape == (K, N)
    sqrt_Lambda = get_A_regularization()
    new_lnAs, dc2 = vmap(one_star_A_step, in_axes=(1, 0, 0, None, 1),
                    out_axes=(1, 0))(get_lnqs(lnqs, knot_xs, xs), data,
                                     sqrt_ivars, sqrt_Lambda, old_lnAs)
    if not jnp.all(jnp.isfinite(new_lnAs)):
        print("A-step(): fixing bad elements:", jnp.sum(jnp.logical_not(jnp.isfinite(new_lnAs))))
        new_lnAs = jnp.where(jnp.isfinite(new_lnAs), new_lnAs, old_lnAs)
    if np.any(new_lnAs > 2.0): # MAGIC HACK
        print("A-step(): fixing large elements:", np.sum(new_lnAs > 2.0), np.max(new_lnAs))
        new_lnAs = jnp.where(new_lnAs > 2.0, 2.0, new_lnAs)
    if np.any(new_lnAs < -9.0): # MAGIC HACK
        print("A-step(): fixing small elements:", np.sum(new_lnAs < -9.0), np.min(new_lnAs))
        new_lnAs = jnp.where(new_lnAs < -9.0, -9.0, new_lnAs)
    return new_lnAs, dc2

In [None]:
def one_element_K_process_model(lnqs, lnAs):
    """
    ## inputs
    - `lnqs`: shape `(K, N)` natural-logarithmic process elements
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes

    ## outputs
    shape `(N, )` log_10 abundances

    ## comments
    - Note the `ln10`.
    """
    return logsumexp(lnqs + lnAs, axis=0) / ln10

def q_step_regularization(lnqs):
    """
    Build arrays that are used for the regularization of the q step.

    ## outputs
    `Lambdas, q0s` regularization amplitudes and mean values; same shape as `lnqs`.

    ## bugs:
    - Depends on many global variables and choices.
    """
    Lambdas = np.zeros_like(lnqs) + Lambda_c # default value
    q0s = np.zeros_like(lnqs) + 0.5 # default value
    fixed = np.zeros_like(lnqs).astype(bool)

    # First point: Strongly require that q_Mg = 1 for CC
    elem = elements == "Mg"
    proc = processes == "CC"
    Lambdas[proc, :, elem] = Lambda_a
    q0s[    proc, :, elem] = 1.0
    fixed[  proc, :, elem] = True

    # Second point: Strongly require that q_Fe = 0.5 for Ia
    elem = elements == "Fe"
    proc = processes == "Ia"
    Lambdas[proc, :, elem] = Lambda_a
    q0s[    proc, :, elem] = 0.5
    fixed[  proc, :, elem] = True

    # Third point: Require that q_Mg = 0 for all but CC
    elem = elements == "Mg"
    proc = processes != "CC"
    Lambdas[proc, :, elem] = Lambda_b
    q0s[    proc, :, elem] = 0.0
    fixed[  proc, :, elem] = True

    # Fourth point: Require that q_Fe has some particular value / form for CC
    elem = elements == "Fe"
    proc = processes == "CC"
    Lambdas[proc, :, elem] = Lambda_b
    q0s[    proc, :, elem] = q_CC_Fe + dq_CC_Fe_dZ * knot_xs
    fixed[  proc, :, elem] = True

    # Now set the form for any AGB process
    elem = elements == "CN"
    proc = processes == "AGB"
    Lambdas[proc, :, elem] = Lambda_b
    q0s[    proc, :, elem] = 0.5
    fixed[  proc, :, elem] = True

    # Now set the form for a fourth process
    elem = elements == "Co"
    proc = processes == "fourth"
    Lambdas[proc, :, elem] = Lambda_b
    q0s[    proc, :, elem] = 0.5
    fixed[  proc, :, elem] = True

    return Lambdas, q0s, fixed

def one_element_chi(lnqs, lnAs, data, sqrt_ivars, knot_xs, xs, sqrt_Lambdas, q0s):
    """
    ## inputs
    - `lnqs`: shape `(K, Nknot)` natural-logarithmic process vectors
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes
    - `data`: shape `(N, )` log_10 abundance measurements
    - `sqrt_ivars`: shape `(N, )` inverse variances on the data
    - `knot_xs`: shape `(Nknot, )` metallicity bin "centers"
    - `xs` : shape `(N, )` metallicities to use with `metallicities`
    - `sqrt_Lambdas`: shape `(K, Nbin)` list of regularization amplitudes
    - `q0s`: shape `(K, Nknot)` 

    ## outputs
    chi for this one star (weighted residual)
    """
    interp_lnqs = get_lnqs(lnqs[:, :, None], knot_xs, xs)[:, :, 0]
    return jnp.concatenate([sqrt_ivars * (data - one_element_K_process_model(interp_lnqs, lnAs)),
                            jnp.ravel(sqrt_Lambdas * (jnp.exp(lnqs) - q0s))])

def one_element_q_step(lnAs, data, sqrt_ivars, knot_xs, xs, sqrt_Lambdas, q0s,
                       fixed, init):
    """
    ## inputs
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes
    - `data`: shape `(N, )` log_10 abundance measurements
    - `sqrt_ivars`: shape `(N, )` inverse errors on the data
    - `knot_xs`: shape `(Nknot, )` metallicity bin centers
    - `xs` : shape `(N, )` metallicities to use with `metallicities`
    - ... 

    ## outputs
    shape `(K, Nknot)` best-fit natural-logarithmic process elements

    ## bugs
    - Uses the `fixed` input incredibly stupidly, because Hogg SUX.
    - Doesn't check the output of the optimizer AT ALL.
    - Check out the crazy `maxiter` input!
    """
    solver = jaxopt.GaussNewton(residual_fun=one_element_chi, maxiter=4)
    lnqs_init = init.copy()
    chi2_init = np.sum(one_element_chi(lnqs_init, lnAs, data, sqrt_ivars, 
                       knot_xs, xs, sqrt_Lambdas, q0s) ** 2)
    res = solver.run(lnqs_init, lnAs=lnAs, data=data, sqrt_ivars=sqrt_ivars,
                     knot_xs=knot_xs, xs=xs,
                     sqrt_Lambdas=sqrt_Lambdas, q0s=q0s)
    chi2_res = np.sum(one_element_chi(res.params, lnAs, data, sqrt_ivars, 
                      knot_xs, xs, sqrt_Lambdas, q0s) ** 2)
    return jnp.where(fixed, lnqs_init, res.params), chi2_init - chi2_res

def q_step(lnAs, data, sqrt_ivars, knot_xs, xs, old_lnqs):
    """
    ## inputs
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes
    - `alldata`: shape `(N, M)` log_10 abundance measurements
    - `sqrt_allivars`: shape `(N, M)` inverse errors on alldata
    - `knot_xs`: shape `(Nknot, )` metallicity bin centers
    - `xs` : shape `(N, )` metallicities to use with `metallicities`
    - `old_lnqs`: shape `(K, Nbin, M)` initialization for optimizations

    ## outputs
    shape `(K, Nbin, M)` best-fit natural-logarithmic processes

    ## bugs
    - Ridiculous post-processing of outputs.
    """
    N, M = data.shape
    assert lnAs.shape == (K, N)
    assert sqrt_ivars.shape == (N, M)
    Nknot = len(knot_xs)
    assert len(xs) == N
    assert old_lnqs.shape == (K, Nknot, M)
    lnqs1 = np.zeros((K, Nknot, M))
    Lambdas, q0s, fixed = q_step_regularization(old_lnqs)
    new_lnqs, dc2 = vmap(one_element_q_step, in_axes=(None, 1, 1, None, None, 2, 2, 2, 2),
                    out_axes=(2, 0))(lnAs, data, sqrt_ivars, knot_xs, xs,
                                jnp.sqrt(Lambdas), jnp.array(q0s),
                                jnp.array(fixed), old_lnqs)
    if not np.all(jnp.isfinite(new_lnqs)):
        print("q-step(): fixing bad elements:", np.sum(jnp.logical_not(jnp.isfinite(new_lnqs))))
        new_lnqs = jnp.where(jnp.isfinite(new_lnqs), new_lnqs, old_lnqs)
    if np.any(new_lnqs > 1.0): # MAGIC HACK
        print("q-step(): fixing large elements:", np.sum(new_lnqs > 1.0), np.max(new_lnqs))
        new_lnqs = jnp.where(new_lnqs > 1.0, 1.0, new_lnqs)
    if np.any(new_lnqs < -9.0): # MAGIC HACK
        print("q-step(): fixing small elements:", np.sum(new_lnqs < -9.0), np.min(new_lnqs))
        new_lnqs = jnp.where(new_lnqs < -9.0, -9.0, new_lnqs)
    return new_lnqs, dc2

In [None]:
def objective_q(lnAs, lnqs, data, sqrt_ivars, knot_xs, xs):
    """
    This is NOT the objective, but it stands in for now!!
    """
    Lambdas, q0s, _ = q_step_regularization(lnqs)
    chi = vmap(one_element_chi, in_axes=(2, None, 1, 1, None, None, 2, 2),
               out_axes=(0))(lnqs, lnAs, data, sqrt_ivars, knot_xs, xs,
                             jnp.sqrt(Lambdas), q0s)
    sqrt_Lambda = get_A_regularization()
    return np.sum(chi * chi) + np.sum((sqrt_Lambda[:, None] * jnp.exp(lnAs)) ** 2)

def objective_A(lnAs, lnqs, data, sqrt_ivars, knot_xs, xs):
    """
    This is NOT the objective, but it stands in for now!!
    """
    sqrt_Lambda = get_A_regularization()
    chi = vmap(one_star_chi, in_axes=(1, 1, 0, 0, None),
               out_axes=(0))(lnAs, get_lnqs(lnqs, knot_xs, xs),
                             data, sqrt_ivars, sqrt_Lambda)
    Lambdas, q0s, _ = q_step_regularization(lnqs)
    return np.sum(chi ** 2) + np.sum(Lambdas * (jnp.exp(lnqs) - q0s) ** 2)

def Aq_step(data, sqrt_ivars, knot_xs, xs, ln_noise, old_lnAs, old_lnqs, rng=rng):
    """
    ## Bugs:
    - This contains multiple hacks.
    - Maybe some of the hacks should be pushed back into the A-step and
      the q-step?
    """
    prefix = "Aq-step():"
    old_objective = objective_A(old_lnAs, old_lnqs, data, sqrt_ivars, knot_xs, xs)

    # fix old_lnAs
    old_lnAs = jnp.where(jnp.isnan(old_lnAs), 1., old_lnAs)

    # add noise
    A_noise = ln_noise + np.log(rng.uniform(size=old_lnAs.shape))
    init_lnAs = jnp.logaddexp(old_lnAs, A_noise)
    q_noise = ln_noise + np.log(rng.uniform(size=old_lnqs.shape))
    q_noise[:, :, elements == "Mg"] = -np.inf # HACK 
    q_noise[:, :, elements == "Fe"] = -np.inf # HACK
    init_lnqs = jnp.logaddexp(old_lnqs, q_noise)

    # run q step
    objective1 = objective_q(init_lnAs, old_lnqs, data, sqrt_ivars, knot_xs, xs)
    new_lnqs, _ = q_step(init_lnAs, data, sqrt_ivars, knot_xs, xs, old_lnqs)
    objective2 = objective_q(init_lnAs, new_lnqs, data, sqrt_ivars, knot_xs, xs)
    if objective2 > objective1:
        print(prefix, "q-step WARNING: objective function got worse:", objective1, objective2)
        new_lnqs = old_lnqs.copy()
        objective2 = objective1

    # run A step
    objective3 = objective_A(init_lnAs, new_lnqs, data, sqrt_ivars, knot_xs, xs)
    new_lnAs, _ = A_step(new_lnqs, data, sqrt_ivars, knot_xs, xs, init_lnAs)
    objective4 = objective_A(new_lnAs, new_lnqs, data, sqrt_ivars, knot_xs, xs)
    if objective4 > objective3:
        print(prefix, "A-step WARNING: objective function got worse:", objective3, objective4)
        new_lnAs = init_lnAs.copy()
        objective4 = objective3

    # check objective
    print(old_objective, objective1, objective2, objective3, objective4)
    if objective4 < old_objective:
        print(prefix, "we took a step!", ln_noise, objective4, old_objective - objective4)
        return new_lnAs, new_lnqs, np.around(ln_noise + 0.1, 1)
    else:
        print(prefix, "we didn't take a step :(", ln_noise, old_objective, old_objective - objective4)
        return old_lnAs.copy(), old_lnqs.copy(), np.around(ln_noise - 1.0, 1)

In [None]:
def initialize():
    """
    ## Bugs:
    - DOESN'T WORK for K > 2 ??
    - very brittle
    - relies on global variables
    """
    lnqs = np.zeros((K, Nknot, M))
    lnAs = np.zeros((K, N))
    _, q0s, fixed = q_step_regularization(lnqs)
    lnq0s = np.log(np.clip(q0s, 1.e-7, None))
    lnqs = np.zeros_like(lnq0s)
    lnqs = jnp.where(fixed, lnq0s, lnqs)
    I = [0, 10, 5, 14] # BRITTLE HACK
    I = I[:K]
    lnAs, _ = A_step(lnqs[:, :, I], alldata[:, I], sqrt_allivars[:, I],
                     knot_xs, xs, lnAs)
    print("initialize():", np.median(lnAs[1:] - lnAs[0], axis=1))
    lnqs, _ = q_step(lnAs, alldata, sqrt_allivars, knot_xs, xs, lnqs)
    lnAs, _ = A_step(lnqs, alldata, sqrt_allivars, knot_xs, xs, lnAs)
    return lnAs, lnqs

In [None]:
ln_noise = -4.
new_lnAs, new_lnqs = initialize()
for i in range(16):
    new_lnAs, new_lnqs, ln_noise = Aq_step(alldata, sqrt_allivars,
                                           knot_xs, xs, ln_noise, new_lnAs,
                                           new_lnqs)
    print(i + 1, np.median(new_lnAs[1:] - new_lnAs[0], axis=1))

In [None]:
# Now do one more round of optimization
ln_noise = -4.
for i in range(16):
    new_lnAs, new_lnqs, ln_noise = Aq_step(alldata, sqrt_allivars, knot_xs, xs,
                                           ln_noise, new_lnAs, new_lnqs)
    print(i + 1, np.median(new_lnAs[1:] - new_lnAs[0], axis=1))

In [None]:
def plot_qs(lnqs):
    """
    # Bugs:
    - Relies on many global variables.
    - Assumes a rigid structure for the processes?
    """
    MgH = np.linspace(np.min(knot_xs), np.max(knot_xs), 300) # plotting xs
    new_qs = np.exp(get_lnqs(lnqs, knot_xs, MgH)) # interp to plotting xs
    w22_MgH = w22_metallicities
    w22_qs = np.exp(w22_lnqs)

    plt.figure(figsize=(10,10))
    for i in range(16):
        plt.subplot(4,4,i+1)
        new_qcc = new_qs[0,:,i]
        new_qIa = new_qs[1,:,i]
        w22_qcc = w22_qs[0,:,i]
        w22_qIa = w22_qs[1,:,i]

        plt.plot(w22_MgH, w22_qcc, 'b-', lw=4, alpha=0.25, label='qcc W22')
        plt.plot(w22_MgH, w22_qIa, 'r-', lw=4, alpha=0.25, label='qIa W22')

        plt.plot(MgH, new_qcc, 'b-', alpha=0.9, label='qcc new')
        plt.plot(MgH, new_qIa, 'r-', alpha=0.9, label='qIa new')

        plt.xlabel('[Mg/H]')
        plt.xlim(np.min(knot_xs), np.max(knot_xs))
        plt.ylabel('q '+elements[i])
        plt.ylim(-0.15, 1.5)

        if i==0:
            plt.legend(ncol=1, fontsize=10)
        #plt.ylim(-0.1,1.1)
    plt.tight_layout()

In [None]:
plot_qs(new_lnqs)

In [None]:
def plot_model_abundances(lnAs, lnqs, knot_xs, xs, data, sqrt_ivars, noise=False):
    """
    ## bugs:
    - Relies on lots of global variables.
    """
    MgHmin = -1.2

    synthdata = all_stars_K_process_model(lnAs, lnqs, knot_xs, xs)
    synthnoise = 0.
    noisestr = ""
    if noise:
        synthnoise = rng2.normal(size=synthdata.shape) / sqrt_ivars
        noisestr = " + noise"
    fig, axes = plt.subplots(len(elements) - 1, 3, figsize=(12,3 * (len(elements) - 1)))

    for j in range(len(elements) - 1):
        ax = axes[j, 0]
        ax.hist2d(data[:,0], data[:,j+1] - data[:,0],
                  cmap='magma', bins=100, range=[[MgHmin,0.4],[-0.5,0.2]], norm=LogNorm())
        ax.set_xlabel('[Mg/H]')
        ax.set_ylabel('[{}/Mg]'.format(elements[j+1]))
        ax.set_ylim(-0.5,0.2)
        if j == 0:
            ax.set_title('observed')

        ax = axes[j, 1]
        sata = synthdata + synthnoise
        ax.hist2d(sata[:,0], sata[:,j+1] - sata[:,0],
                  cmap='magma', bins=100, range=[[MgHmin,0.4],[-0.5,0.2]], norm=LogNorm())
        ax.set_xlabel('[Mg/H]')
        ax.set_ylabel('[{}/Mg]'.format(elements[j+1]))
        ax.set_ylim(-0.5,0.2)
        if j == 0:
            ax.set_title('predicted' + noisestr)

        ax = axes[j, 2]
        ax.hist2d(sqrt_ivars[:, 0] * (data[:, 0] - synthdata[:, 0]),
                  sqrt_ivars[:, j+1] * (data[:, j+1] - synthdata[:, j+1]),
                cmap='magma', bins=100, range=[[-10, 10], [-10, 10]], norm=LogNorm())
        ax.set_xlabel('[Mg/H] chi')
        ax.set_ylabel('[{}/H] chi'.format(elements[j+1]))
        if j == 0:
            ax.set_title('dimensionless residual')

    plt.tight_layout()

In [None]:
plot_model_abundances(new_lnAs, new_lnqs, knot_xs, xs, alldata, sqrt_allivars, noise=True)

In [None]:
# PCA the residuals. This code is stupid-slow.
allresids = all_stars_K_process_model(new_lnAs, new_lnqs, knot_xs, xs)
ss = np.zeros((Nbin, M))
vs = np.zeros((Nbin, 2, M))
for bin in range(Nbin):
    u, s, v = np.linalg.svd(allresids[bins == bin], full_matrices=False)
    ss[bin] = s
    vs[bin] = v[:2]

In [None]:
NUM_COLORS = M
cm = plt.get_cmap('cool')
colors =[]
for i in range(NUM_COLORS):
    color = cm(1.*i/NUM_COLORS)
    colors.append(color)

In [None]:
plt.figure()
for bin in np.arange(Nbin)[-1::-1]:
    plt.plot(ss[bin] / np.sum(ss[bin]), "o", color=colors[bin], alpha=0.5,
             label="{:5.2f}".format(metallicities[bin]))
    plt.legend(loc="upper right")

In [None]:
plt.figure(figsize=(16, 6))
for bin in np.arange(Nbin)[-1::-1]:
    v0 = vs[bin, 0]
    if np.median(v0) < 0:
        v0 *= -1.
    plt.plot(vs[bin, 0], "o", color=colors[bin], label="$Z = {:+5.2f}$".format(metallicities[bin]))
    plt.plot(vs[bin, 0], "-", color=colors[bin], alpha=0.5)
    plt.xlim(-0.5, 17)
    ax = plt.gca()
    ax.set_xticks(range(M))
    ax.set_xticklabels(elements)
    plt.legend(loc="upper right", fontsize=10)

In [None]:
# DON'T GO BELOW THIS LINE.
assert False

In [None]:
# Now optimize some models at different values of q_CC_Fe
models = []
for q_CC_Fe in np.arange(0.20, 0.51, 0.05):
    ln_noise = -4.
    for i in range(32):
        print("model:", q_CC_Fe, "iteration:", i)
        new_lnAs, new_lnqs, ln_noise = Aq_step(alldata, sqrt_allivars,
                                               metallicities, ln_noise,
                                               new_lnAs, new_lnqs)
    models += [(q_CC_Fe, new_lnqs, new_lnAs)]

In [None]:
for q_CC_Fe, new_lnqs, new_lnAs in models:
    plot_qs(new_lnqs)

In [None]:
newer_models = []
for q_CC_Fe, new_lnAs, new_lnQs in models:
    ln_noise = -4.
    for i in range(32):
        print("model:", q_CC_Fe, "iteration:", i)
        new_lnAs, new_lnqs, ln_noise = Aq_step(alldata, sqrt_allivars,
                                               metallicities, ln_noise,
                                               new_lnAs, new_lnqs)
    newer_models += [(q_CC_Fe, new_lnqs, new_lnAs)]