<a href="https://colab.research.google.com/github/epodkwan/growthfunction/blob/main/bsplinegradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.5.3-py3-none-any.whl (202 kB)
[K     |████████████████████████████████| 202 kB 4.0 MB/s 
Collecting tensorstore
  Downloading tensorstore-0.1.22-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.5 MB)
[K     |████████████████████████████████| 7.5 MB 25.4 MB/s 
[?25hCollecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 49.2 MB/s 
[?25hCollecting PyYAML>=5.4.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 47.0 MB/s 
[?25hCollecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 44.6 MB/s 
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB

In [2]:
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/growth/')

from typing import Sequence
import jax
import optax
import numpy as np
import jax.numpy as jnp
from jax import jit,vmap,grad
from flax import linen as nn
from flax.training import train_state,checkpoints
from conf import Configuration
from cosmology import Cosmology, SimpleLCDM, growth_integ
import matplotlib.pyplot as plt

In [None]:
class SimpleMLP(nn.Module):
    features:Sequence[int]
    nodes:int

    @nn.compact
    def __call__(self,inputs):
        x=inputs
        for i,feat in enumerate(self.features):
            x=nn.Dense(feat)(x)
            x=jnp.sin(x)
        t=nn.Dense(nodes-1)(x)
        c=nn.Dense(nodes+1)(x)
        t=jnp.concatenate([jnp.zeros((t.shape[0],4)),jnp.cumsum(jax.nn.softmax(t),axis=1),jnp.ones((t.shape[0],3))],axis=1)
        c=jnp.concatenate([jnp.zeros((c.shape[0],1)),c],axis=1)
        return t,c

In [None]:
def npy_loader(path):
    return jnp.load(path)

In [None]:
@jit
def _deBoorVectorized(x,t,c):
    p=3
    k=jnp.digitize(x,t)-1
    d=[c[j+k-p] for j in range(0,p+1)]
    for r in range(1,p+1):
        for j in range(p,r-1,-1):
            alpha=(x-t[j+k-p])/(t[j+1+k-r]-t[j+k-p])
            d[j]=(1.0-alpha)*d[j-1]+alpha*d[j]
    return d[p]

In [None]:
layer_sizes=[64,64]
nodes=32
learning_rate=1e-5
model=SimpleMLP(features=layer_sizes,nodes=nodes)
temp=jnp.array([[1]])
params=model.init(jax.random.PRNGKey(0),temp)
tx=optax.adam(learning_rate=learning_rate)
opt_state=tx.init(params)
deBoor=vmap(_deBoorVectorized,in_axes=(None,0,0))
state=train_state.TrainState.create(apply_fn=model.apply,params=params,tx=tx)



In [None]:
@jit
def eval_func(params,x,a):
    t,c=restored_state.apply_fn(params,x)
    preds=deBoor(jnp.clip(a,0,0.999999),t,c)
    return preds

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
restored_state=checkpoints.restore_checkpoint(ckpt_dir="/content/drive/My Drive/Colab Notebooks/checkpoint_0",target=state)
cosmo=npy_loader("/content/drive/My Drive/Colab Notebooks/cosmo.npy")
z=npy_loader("/content/drive/My Drive/Colab Notebooks/999.npy")[0,:]
a=1/(z+1)

In [None]:
@jit
def predict(params,x,a,j):
    pred=eval_func(params,x,a).reshape(-1)
    return pred[j]

In [None]:
@jit
def gradient_at(i,j):
    value1,gradient=jax.value_and_grad(predict,1)(restored_state.params,jnp.array([[cosmo[i,0]]]),a,j)
    value2=predict(restored_state.params,jnp.array([[cosmo[i,0]+0.0001]]),a,j)
    return value1,value2,gradient

In [None]:
cc=cosmology.Cosmology()
cc.scale_independent_growth_factor(z)
fig,ax1=plt.subplots(1,1,constrained_layout=True)
cosmolh=cc.clone(Omega0_cdm=cosmo[999,0]-cosmo[999,1],Omega_b=cosmo[999,1],h=cosmo[999,2])
cosmolh2=cc.clone(Omega0_cdm=cosmo[999,0]+0.0001-cosmo[999,1],Omega_b=cosmo[999,1],h=cosmo[999,2])
g=cosmolh.scale_independent_growth_factor(z)
g2=cosmolh2.scale_independent_growth_factor(z)
ax1.plot(a,(g2-g)/0.0001)
om_grad=[]
for j in range(256):
    value1,value2,gradient=gradient_at(999,j)
    om_grad.append(gradient[0,0])
    # ax1.scatter(a[j],gradient[0,0],c='r')
    # ax1.scatter(a[j],(value2-value1)/0.0001,c='g')
    # ax2.scatter(a[j],gradient[0,1],c='r')
    # ax2.scatter(a[j],(value3-value1)/0.0001,c='g')
ax1.plot(a,om_grad)
ax1.set_title("Gradient with respect to Omega_m")
ax1.legend(["nbodykit","Autodifferentiation (NN)"])
ax1.set_xlabel("a")
plt.savefig("/content/drive/My Drive/Colab Notebooks/gradient.png")
drive.flush_and_unmount()

NameError: ignored

In [3]:
@jit
def D(a, cosmo):
    conf = cosmo.conf 
    a = jnp.asarray(a, dtype=conf.cosmo_dtype)
    D = a * jnp.interp(a, conf.growth_a, cosmo.growth[0][0])
    D1 = 1 * jnp.interp(1., conf.growth_a, cosmo.growth[0][0])
    return D/D1

In [4]:
@jit
def objective_a(params,conf,a_test):
    omegam,omegak,w0,wa=params
    cosmo=SimpleLCDM(conf,Omega_m=omegam,Omega_k=omegak,w_0=w0,w_a=wa)
    cosmo=growth_integ(cosmo)
    obj=(D(jnp.asarray(a_test),cosmo))
    return obj

In [5]:
obj_grad_a=jit(grad(objective_a,argnums=(0)))

In [6]:
@jit
def calculate_gradient(omegam): 
    omegak=0.0
    w0=-1.0
    wa=0.0
    nc=32
    cell_size=8
    growth_anum=512
    growth_a=jnp.linspace(0.,1.,growth_anum) 
    conf=Configuration(cell_size=cell_size, mesh_shape=(nc,)*3,growth_anum=growth_anum)
    params=[omegam, omegak, w0, wa]
    growth=[]
    for i in np.linspace(0,1,50):
        growth.append(obj_grad_a(params,conf,i)[0])
    return growth

In [7]:
for i in np.linspace(0.1,0.5,50):
    print(calculate_gradient(i))



[DeviceArray(0., dtype=float64), DeviceArray(-0.09297562, dtype=float64), DeviceArray(-0.18585847, dtype=float64), DeviceArray(-0.27841043, dtype=float64), DeviceArray(-0.37023823, dtype=float64), DeviceArray(-0.46080027, dtype=float64), DeviceArray(-0.54941805, dtype=float64), DeviceArray(-0.63530309, dtype=float64), DeviceArray(-0.71757092, dtype=float64), DeviceArray(-0.79530098, dtype=float64), DeviceArray(-0.86755237, dtype=float64), DeviceArray(-0.93343192, dtype=float64), DeviceArray(-0.99212842, dtype=float64), DeviceArray(-1.04295341, dtype=float64), DeviceArray(-1.08539306, dtype=float64), DeviceArray(-1.11909872, dtype=float64), DeviceArray(-1.14394341, dtype=float64), DeviceArray(-1.1599771, dtype=float64), DeviceArray(-1.16744325, dtype=float64), DeviceArray(-1.16674195, dtype=float64), DeviceArray(-1.15840447, dtype=float64), DeviceArray(-1.14306218, dtype=float64), DeviceArray(-1.12141423, dtype=float64), DeviceArray(-1.0941901, dtype=float64), DeviceArray(-1.06213421, d