<a href="https://colab.research.google.com/github/epodkwan/growthfunction/blob/main/gradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.5.3-py3-none-any.whl (202 kB)
[K     |████████████████████████████████| 202 kB 23.4 MB/s 
Collecting tensorstore
  Downloading tensorstore-0.1.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.1 MB)
[K     |████████████████████████████████| 9.1 MB 57.4 MB/s 
Collecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 44.9 MB/s 
Collecting PyYAML>=5.4.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 55.5 MB/s 
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 61.2 MB/s 
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |█

In [4]:
from typing import Sequence
import jax
import optax
import numpy as np
import jax.numpy as jnp
from jax import jit
from flax import linen as nn
from flax.training import train_state,checkpoints

In [5]:
def npy_loader(path):
    return jnp.load(path)

In [6]:
class SimpleMLP(nn.Module):
    features:Sequence[int]

    @nn.compact
    def __call__(self,inputs):
        x=inputs
        for i,feat in enumerate(self.features):
            x=nn.Dense(feat)(x)
            if i != len(self.features)-1:
                x=nn.relu(x)
        return x

In [7]:
layer_sizes=[64,256,256,256]
learning_rate=1e-6
model=SimpleMLP(features=layer_sizes)
temp=jnp.ones(2)
params=model.init(jax.random.PRNGKey(0),temp)
tx=optax.adam(learning_rate=learning_rate,b1=0.99)
opt_state=tx.init(params)
state=train_state.TrainState.create(apply_fn=model.apply,params=params,tx=tx)



In [13]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [14]:
restored_state=checkpoints.restore_checkpoint(ckpt_dir="/content/drive/My Drive/Colab Notebooks/checkpoint_0",target=state)
cosmo=npy_loader("/content/drive/My Drive/Colab Notebooks/cosmo.npy")

In [15]:
@jit
def predict(params,x,j):
    pred=jnp.exp(restored_state.apply_fn(params,x))
    return pred[j]

In [16]:
@jit
def gradient_at(i,j):
    value,gradient=jax.value_and_grad(predict,1)(restored_state.params,jnp.array([cosmo[i,0],cosmo[i,2]]),j)
    return gradient

In [17]:
for i in range(256):
    gradient=gradient_at(i,0)
    print(gradient)
drive.flush_and_unmount()

[-0.00969556 -0.00097546]
[-0.00724644  0.00164647]
[-0.02539393 -0.00082254]
[-0.00698015 -0.00039439]
[-0.03381014 -0.00332933]
[-0.01011131 -0.00098941]
[-0.01145566 -0.00167335]
[-0.04179464 -0.00366759]
[-0.01994135 -0.00095894]
[-0.00992877 -0.00069112]
[-0.02034703 -0.00079838]
[-0.01684329 -0.00136619]
[-0.01743107 -0.00141387]
[-0.01579178 -0.00320797]
[-0.06052754 -0.00638826]
[-0.01023008 -0.00100104]
[-0.0430318  -0.00483406]
[-0.0449253  -0.00167431]
[-0.01692978 -0.0040763 ]
[-0.02116844 -0.00359866]
[-0.04763147 -0.0029486 ]
[-0.01862137 -0.00019548]
[-0.01090807 -0.00101749]
[-0.0083493  -0.00086184]
[-0.00697048 -0.00141456]
[-0.00725721 -0.00121846]
[-0.00744338 -0.00159115]
[-0.00882618 -0.0004947 ]
[-0.02131055 -0.00115828]
[-0.00509593 -0.00037789]
[-0.010048   -0.00069942]
[-0.00850423 -0.00087783]
[-0.00556357 -0.00095524]
[-0.01100722 -0.00064139]
[-0.01104703 -0.0003865 ]
[-0.0061329  -0.00091603]
[-0.03634317 -0.00125983]
[-0.00895896 -0.0033781 ]
[-0.01701989