In [1]:
import os
import sys
import math
import logging
from pathlib import Path

import jax.numpy as np
import scipy as sp
import sklearn
import statsmodels.api as sm
from statsmodels.formula.api import ols

import jax
from jax import grad, jit

%load_ext autoreload
%autoreload 2

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
sns.set_context("poster")
sns.set(rc={'figure.figsize': (16, 9.)})
sns.set_style("whitegrid")

import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [2]:
from hoover import *
from hoover.seds import FMatrix
from hoover.likelihood import LogProb

In [14]:
sed = FMatrix(['dustmbb', 'syncpl'])
parameters = {'nu': np.logspace(1, 2, dtype=np.float32), 'nu_ref_d': np.float32(353), 'nu_ref_s': np.float32(23.), 'beta_d': np.float32(1.5), 'beta_s': np.float32(-3.), 'T_d': np.float32(20)}
eval = sed(**parameters)

In [15]:
sedg = grad(lambda p: np.sum(sed(p, **parameters)))

In [16]:
sed = FMatrix(['dustmbb', 'syncpl'])
parameters = {'nu': np.float32(100), 'nu_ref_d': np.float32(353), 'nu_ref_s': np.float32(23.), 'beta_d': np.float32(1.5), 'beta_s': np.float32(-3.), 'T_d': np.float32(20)}
print(sed(**parameters).shape)

(2,)


In [17]:
func = lambda f: np.sum(sed({'nu': f}, **parameters), axis=0)

In [18]:
key = jax.random.PRNGKey(32)
data = jax.random.normal(key, shape=(1, 2, 100), dtype=np.float32)
freqs = np.array([20., 40., 100., 143., 240., 353.]).astype(np.float32)
parameters.update({'nu': freqs})
data *= freqs[:, None, None]
cov = np.ones((6, 2, 100)) * (np.sum(sed({'nu': freqs}), axis=0) ** 2)[:, None, None]

In [19]:
lnP = LogProb({'data_mean': data, 'data_var': cov}, sed)
print(lnP(parameters))

TypeError: __init__() missing 1 required positional argument: 'fmatrix'

In [11]:
func = lambda beta_d: lnP(parameters, beta_d=beta_d)

In [12]:
gfunc = grad(func)

In [13]:
print(func(1.54))
print(gfunc(1.54))

55115870.0
-7622939.0
