<a href="https://colab.research.google.com/github/epodkwan/growthfunction/blob/main/jaxtrainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
from typing import Sequence
import jax
import optax
import numpy as np
import jax.numpy as jnp
from jax import jit,random
from flax import linen as nn
import matplotlib.pyplot as plt

In [3]:
def npy_loader(path):
    return jnp.load(path)

In [4]:
input_data=npy_loader("/content/drive/My Drive/Colab Notebooks/cosmo.npy")
input_result=npy_loader("/content/drive/My Drive/Colab Notebooks/combined.npy")
x_train=jnp.stack((input_data[0:800,0],input_data[0:800,2]),axis=1)
y_train=input_result[0:800,:]
x_validate=jnp.stack((input_data[800:900,0],input_data[800:900,2]),axis=1)
y_validate=input_result[800:900,:]



In [5]:
class SimpleMLP(nn.Module):
    features:Sequence[int]

    @nn.compact
    def __call__(self,inputs):
        x=inputs
        for i,feat in enumerate(self.features):
            x=nn.Dense(feat)(x)
            if i != len(self.features)-1:
                x=nn.relu(x)
        return x

In [6]:
layer_sizes=[64,256,256,256]
learning_rate=1e-4
epochs=30000
model=SimpleMLP(features=layer_sizes)
temp=jnp.ones(2)
params=model.init(random.PRNGKey(0),temp)
tx=optax.adam(learning_rate=learning_rate,b1=0.9,b2=0.999)
opt_state=tx.init(params)

In [7]:
@jit
def mse_loss(params,x,y_ref):
    preds=model.apply(params,x)
    diff=preds-y_ref
    return jnp.mean(diff*diff)

In [8]:
@jit
def train_step(params,opt_state,x,y_ref):
    loss,grads=jax.value_and_grad(mse_loss,argnums=0)(params,x,y_ref)
    updates,opt_state=tx.update(grads,opt_state)
    params=optax.apply_updates(params,updates)
    return loss,params,opt_state

In [None]:
order=jnp.arange(800)
for i in range(epochs):
    order=random.permutation(random.PRNGKey(i),order)
    train_loss=0
    for j in range(25):
        x_batch=x_train[order[32*j:32*(j+1)-1],:]
        y_batch=y_train[order[32*j:32*(j+1)-1],:]
        loss,params,opt_state=train_step(params,opt_state,x_batch,y_batch)
        train_loss=train_loss+loss
    if i % 100 == 99:
        train_loss=train_loss/25
        validate_loss=mse_loss(params,x_validate,y_validate)
        print((i+1),validate_loss)
        plt.scatter((i+1),jnp.log(train_loss),c='b')
        plt.scatter((i+1),jnp.log(validate_loss),c='g')
print("Training ended")
jnp.save("model.npy",params)
plt.xlabel("Epoch")
plt.ylabel("ln(loss)")
plt.title("Loss function")
plt.legend(["Training Loss","Validation Loss"])
plt.savefig("/content/drive/My Drive/Colab Notebooks/loss.png")
drive.flush_and_unmount()

100 1.2974906e-05
200 2.9636426e-06
300 1.9140075e-06
400 1.0295079e-06
500 5.6282306e-07
600 3.3947725e-07
700 5.6003995e-07
800 2.312849e-07
900 5.284297e-07
1000 1.863303e-07
1100 2.7114066e-07
1200 2.7625995e-07
1300 1.1119169e-07
1400 1.5026271e-07
1500 2.1081915e-07
1600 1.0484211e-07
1700 4.878363e-07
1800 1.3387722e-07
1900 2.0873645e-07
2000 1.13631664e-07
2100 1.3957991e-07
2200 1.645417e-07
2300 2.0037328e-07
2400 6.504763e-08
2500 1.0011713e-07
2600 7.288363e-08
2700 7.549696e-08
2800 4.5012493e-07
2900 1.8610187e-07
3000 8.010554e-08
3100 5.98634e-08
3200 1.095756e-07
3300 1.3752718e-07
3400 8.8834184e-08
3500 1.7069385e-07
3600 5.575019e-08
3700 6.6704246e-08
3800 9.184428e-08
3900 8.1158575e-08
4000 5.267464e-08
4100 1.00093175e-07
4200 9.650679e-08
4300 1.8250337e-07
4400 4.234588e-08
4500 5.213507e-08
4600 4.1672454e-08
4700 3.925358e-08
4800 4.660768e-08
4900 2.7219335e-07
5000 3.3578695e-07
5100 4.5168584e-08
5200 2.3379224e-07
5300 1.07402855e-07
5400 7.276722e-08
5

In [None]:
x_test=jnp.stack((input_data[900:1000,0],input_data[900:1000,2]),axis=1)
y_test=input_result[900:1000,:]
y_pred=model.apply(params,x_test)
print(y_pred)
error=abs(y_pred/y_test-1)
print("Max error =",jnp.max(error)*100,"%")