<a href="https://colab.research.google.com/github/epodkwan/growthfunction/blob/main/bsplineemulator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.6.0-py3-none-any.whl (180 kB)
[K     |████████████████████████████████| 180 kB 13.8 MB/s 
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 73.8 MB/s 
[?25hCollecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 41.6 MB/s 
Collecting jax>=0.3.16
  Downloading jax-0.3.16.tar.gz (1.0 MB)
[K     |████████████████████████████████| 1.0 MB 63.0 MB/s 
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 3.3 MB/s 
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB)
Collecting chex>=0.0.4
  Downloading chex-0.1.4-py3-none-any.whl (76 kB)
[K     |████████████████████████████████| 76 kB 2.3 MB/

In [None]:
import random
import statistics
from typing import Sequence
import jax
import optax
import numpy as np
import jax.numpy as jnp
from jax import jit,vmap
from flax import linen as nn
from flax.training import train_state,checkpoints
import matplotlib
import matplotlib.pyplot as plt



In [None]:
class SimpleMLP(nn.Module):
    features:Sequence[int]
    nodes:int

    @nn.compact
    def __call__(self,inputs):
        x=inputs
        for i,feat in enumerate(self.features):
            x=nn.Dense(feat)(x)
            x=jnp.sin(x)
        t=nn.Dense(nodes-1)(x)
        c=nn.Dense(nodes+1)(x)
        t=jnp.concatenate([jnp.zeros((t.shape[0],4)),jnp.cumsum(jax.nn.softmax(t),axis=1),jnp.ones((t.shape[0],3))],axis=1)
        c=jnp.concatenate([jnp.zeros((c.shape[0],1)),c],axis=1)
        return t,c

In [None]:
def npy_loader(path):
    return jnp.load(path)

In [None]:
@jit
def _deBoorVectorized(x,t,c):
    p=3
    k=jnp.digitize(x,t)-1
    d=[c[j+k-p] for j in range(0,p+1)]
    for r in range(1,p+1):
        for j in range(p,r-1,-1):
            alpha=(x-t[j+k-p])/(t[j+1+k-r]-t[j+k-p])
            d[j]=(1.0-alpha)*d[j-1]+alpha*d[j]
    return d[p]

In [None]:
layer_sizes=[64,64]
nodes=16
learning_rate=1e-6
model=SimpleMLP(features=layer_sizes,nodes=nodes)
temp=jnp.array([[1]])
params=model.init(jax.random.PRNGKey(0),temp)
tx=optax.adam(learning_rate=learning_rate)
opt_state=tx.init(params)
deBoor=vmap(_deBoorVectorized,in_axes=(None,0,0))
state=train_state.TrainState.create(apply_fn=model.apply,params=params,tx=tx)

In [None]:
@jit
def eval_func(params,x,a):
    t,c=restored_state.apply_fn(params,x)
    preds=deBoor(jnp.clip(a,0,0.99999),t,c)
    return preds

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
restored_state=checkpoints.restore_checkpoint(ckpt_dir="/content/drive/My Drive/Colab Notebooks/retrain/checkpoint_0",target=state)
cosmo=npy_loader("/content/drive/My Drive/Colab Notebooks/data1/cosmo.npy")
cosmo_num=random.randrange(1000)
parameters=np.array([[cosmo[cosmo_num,0]]])
input_result=npy_loader("/content/drive/My Drive/Colab Notebooks/data1/combined.npy")
check_result=npy_loader("/content/drive/My Drive/Colab Notebooks/data1check/"+str(cosmo_num)+".npy")
a=npy_loader("/content/drive/My Drive/Colab Notebooks/data1/999.npy")[0,:]
d_data=input_result[cosmo_num,:]
d_test=eval_func(restored_state.params,parameters,a).reshape(-1)
d_check=eval_func(restored_state.params,parameters,check_result[0,:]).reshape(-1)
discrepancy=d_test/d_data
# discrepancy=d_check/check_result[1,:]
fig,(ax1,ax2)=plt.subplots(2,1,constrained_layout=True)
ax1.plot(a,d_data,label="Reference")
ax1.plot(a,d_test,label="Fitting")
ax1.plot(check_result[0,:],d_check,label="Verify")
ax1.set_ylabel("D")
ax1.legend()
ax1.set_title("Omega_m="+str(np.round(parameters[0,0],3)))
ax2.plot(a,discrepancy,label="Discrepancy")
ax2.set_xlabel("a")
ax2.set_ylabel("Predict/Data")
ax2.ticklabel_format(useOffset=False)
ax2.legend()
plt.savefig("/content/drive/My Drive/Colab Notebooks/cosmo"+str(cosmo_num)+".png")
a_plot=[]
med=[]
mean_error=[]
std=[]
for i in range(10):
    plt.clf()
    temp=[]
    for j in range(1000):
        cosmo_num=j
        d_data=input_result[cosmo_num,:]
        parameters[0,0]=cosmo[cosmo_num,0]
        d_test=eval_func(restored_state.params,parameters,a).reshape(-1)
        temp.append((d_test[i*28]/d_data[i*28]-1).item())
        plt.scatter(cosmo[cosmo_num,0].item(),temp[-1],c='b')
    plt.xlabel("Omega_m")
    plt.ylabel("Fractional Error")
    plt.title("Fractional Error of Cosmos (a="+str(np.round(a[i*28].item(),3))+")")
    plt.savefig("/content/drive/My Drive/Colab Notebooks/error"+str(i)+".png")
    a_plot.append(a[i*28].item())
    med.append(statistics.median(temp))
    mean_error.append(statistics.mean(temp))
    std.append(statistics.stdev(temp))
plt.clf()
fig,ax=plt.subplots(constrained_layout=True)
ax.plot(a_plot,med,label="Median")
ax.errorbar(a_plot,mean_error,std,label="Mean")
ax.set_xscale('log')
ax.set_xlabel("a")
ax.set_ylabel("Fractional Error")
ax.set_title("Fractional Error")
ax.legend()
plt.savefig("/content/drive/My Drive/Colab Notebooks/centralerror.png")
drive.flush_and_unmount()