<a href="https://colab.research.google.com/github/epodkwan/growthfunction/blob/main/gradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.5.3-py3-none-any.whl (202 kB)
[K     |████████████████████████████████| 202 kB 23.4 MB/s 
Collecting tensorstore
  Downloading tensorstore-0.1.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.1 MB)
[K     |████████████████████████████████| 9.1 MB 57.4 MB/s 
Collecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 44.9 MB/s 
Collecting PyYAML>=5.4.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 55.5 MB/s 
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 61.2 MB/s 
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |█

In [4]:
from typing import Sequence
import jax
import optax
import numpy as np
import jax.numpy as jnp
from jax import jit
from flax import linen as nn
from flax.training import train_state,checkpoints

In [5]:
def npy_loader(path):
    return jnp.load(path)

In [6]:
class SimpleMLP(nn.Module):
    features:Sequence[int]

    @nn.compact
    def __call__(self,inputs):
        x=inputs
        for i,feat in enumerate(self.features):
            x=nn.Dense(feat)(x)
            if i != len(self.features)-1:
                x=nn.relu(x)
        return x

In [7]:
layer_sizes=[64,256,256,256]
learning_rate=1e-6
model=SimpleMLP(features=layer_sizes)
temp=jnp.ones(2)
params=model.init(jax.random.PRNGKey(0),temp)
tx=optax.adam(learning_rate=learning_rate,b1=0.99)
opt_state=tx.init(params)
state=train_state.TrainState.create(apply_fn=model.apply,params=params,tx=tx)



In [8]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
restored_state=checkpoints.restore_checkpoint(ckpt_dir="/content/drive/My Drive/Colab Notebooks/checkpoint_0",target=state)
cosmo=npy_loader("/content/drive/My Drive/Colab Notebooks/cosmo.npy")

In [10]:
@jit
def predict(params,x,j):
    pred=restored_state.apply_fn(params,x)
    return pred[j]

In [11]:
@jit
def gradient_at(i,j):
    value,gradient=jax.value_and_grad(predict,1)(restored_state.params,jnp.array([cosmo[i,0],cosmo[i,2]]),j)
    return gradient

In [12]:
for i in range(256):
    gradient=gradient_at(i,0)
    print(gradient)
drive.flush_and_unmount()

[-0.7614598  -0.07660961]
[-0.6179226   0.14039958]
[-1.6465209  -0.05333257]
[-0.6012536  -0.03397131]
[-2.122519   -0.20900702]
[-0.7819383  -0.07651401]
[-0.8673977  -0.12670219]
[-2.4973788  -0.21915102]
[-1.3637195 -0.0655787]
[-0.7712712  -0.05368638]
[-1.4150984  -0.05552626]
[-1.2053211  -0.09776592]
[-1.2053211  -0.09776592]
[-1.0869842  -0.22081172]
[-3.297094   -0.34798503]
[-0.7819383  -0.07651401]
[-2.5427752  -0.28564692]
[-2.643089   -0.09850395]
[-1.1670245  -0.28099203]
[-1.504325   -0.25573707]
[-2.7343364  -0.16926718]
[-1.3036954 -0.0136857]
[-0.8391055 -0.0782702]
[-0.6704273  -0.06920314]
[-0.5664872  -0.11496043]
[-0.5960954 -0.1000824]
[-0.60019374 -0.1283021 ]
[-0.6951926  -0.03896427]
[-1.4652946  -0.07964277]
[-0.43644795 -0.03236485]
[-0.7712712  -0.05368638]
[-0.6704273  -0.06920314]
[-0.47227395 -0.08108759]
[-0.842196   -0.04907477]
[-0.8352069  -0.02922058]
[-0.5113053  -0.07636976]
[-2.2646008  -0.07850146]
[-0.698038   -0.26320505]
[-1.2053211  -0.0977