<a href="https://colab.research.google.com/github/epodkwan/growthfunction/blob/main/gradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.5.3-py3-none-any.whl (202 kB)
[K     |████████████████████████████████| 202 kB 7.9 MB/s 
[?25hCollecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 67.6 MB/s 
Collecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 57.9 MB/s 
[?25hCollecting PyYAML>=5.4.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 8.7 MB/s 
Collecting tensorstore
  Downloading tensorstore-0.1.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.1 MB)
[K     |████████████████████████████████| 9.1 MB 30.7 MB/s 
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB)
Colle

In [2]:
from typing import Sequence
import jax
import optax
import numpy as np
import jax.numpy as jnp
from jax import jit,grad,jacfwd,jacrev
from flax import linen as nn
from flax.training import train_state,checkpoints

In [3]:
def npy_loader(path):
    return jnp.load(path)

In [4]:
class SimpleMLP(nn.Module):
    features:Sequence[int]

    @nn.compact
    def __call__(self,inputs):
        x=inputs
        for i,feat in enumerate(self.features):
            x=nn.Dense(feat)(x)
            if i != len(self.features)-1:
                x=nn.relu(x)
        return x

In [5]:
layer_sizes=[64,256,256,256]
learning_rate=1e-6
model=SimpleMLP(features=layer_sizes)
temp=jnp.ones(2)
params=model.init(jax.random.PRNGKey(0),temp)
tx=optax.adam(learning_rate=learning_rate,b1=0.99)
opt_state=tx.init(params)
state=train_state.TrainState.create(apply_fn=model.apply,params=params,tx=tx)

In [27]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [28]:
restored_state=checkpoints.restore_checkpoint(ckpt_dir="/content/drive/My Drive/Colab Notebooks/checkpoint_0",target=state)
cosmo=npy_loader("/content/drive/My Drive/Colab Notebooks/cosmo.npy")

In [29]:
@jit
def predict(params,x,j):
    pred=restored_state.apply_fn(params,x)
    return pred[j]

In [30]:
@jit
def gradient_at(i,j):
    value,gradient=jax.value_and_grad(predict,1)(restored_state.params,jnp.array([cosmo[i,0],cosmo[i,2]]),j)
    return gradient

In [31]:
for i in range(256):
    gradient=gradient_at(i,0)
    print(gradient)
drive.flush_and_unmount()

[-0.7614598  -0.07660961]
[-0.6179223   0.14039922]
[-1.6465211  -0.05333281]
[-0.6012536  -0.03397155]
[-2.12252   -0.2090075]
[-0.78193784 -0.07651401]
[-0.86739767 -0.12670243]
[-2.4973788 -0.2191515]
[-1.3637196 -0.0655781]
[-0.77127093 -0.05368638]
[-1.4150987 -0.0555259]
[-1.2053214  -0.09776592]
[-1.2053214  -0.09776592]
[-1.0869842  -0.22081149]
[-3.2970943  -0.34798503]
[-0.78193784 -0.07651401]
[-2.542775  -0.2856474]
[-2.6430893  -0.09850419]
[-1.1670246 -0.2809925]
[-1.5043253  -0.25573742]
[-2.7343366  -0.16926742]
[-1.3036956  -0.01368558]
[-0.8391052  -0.07827044]
[-0.6704276  -0.06920338]
[-0.5664871  -0.11496091]
[-0.59609514 -0.10008264]
[-0.6001935  -0.12830174]
[-0.6951928  -0.03896451]
[-1.465295   -0.07964277]
[-0.43644843 -0.03236485]
[-0.77127093 -0.05368638]
[-0.6704276  -0.06920338]
[-0.4722743  -0.08108783]
[-0.84219587 -0.04907477]
[-0.8352068  -0.02922058]
[-0.5113052  -0.07637024]
[-2.2646012  -0.07850134]
[-0.69803756 -0.2632047 ]
[-1.2053214  -0.09776592