In [38]:
import functools

from clu import metric_writers
import numpy as np
import jax
from jax import lax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp
import torch.utils.data as data
from tqdm import tqdm

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter

In [39]:
jax.devices()

[cuda(id=0), cuda(id=1)]

In [40]:
# Parameters for the computational task.

L = 4 # number of levels (even number)
s = 5 # leaf size
r = 3 # rank

# Discretization of Omega (n_eta * n_eta).
neta = (2**L)*s

# Number of sources/detectors (n_sc).
# Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
# For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
nx = (2**L)*s

# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Batch size.
batch_size = 16

# Number of training datapoints.
NTRAIN = 2000

# Number of testing datapoints.
NTEST = 320

In [4]:
def cart_polar(coords):
    """
    Transforms coordinates from Cartesian to polar coordinates with custom scaling.

    Parameters:
    - coords: A tuple or list containing the (i, j) coordinates to be transformed.

    Returns:
    - A tuple (rho, theta) representing the transformed coordinates.
    """
    i, j = coords[0], coords[1]
    # Calculate the radial distance with a scaling factor.
    rho = 2 * np.sqrt((i - neta / 2) ** 2 + (j - neta / 2) ** 2) * nx / neta
    # Calculate the angle in radians and adjust the scale to fit the specified range.
    theta = ((np.arctan2((neta / 2 - j), (i - neta / 2))) % (2 * np.pi)) * nx / np.pi / 2
    return theta, rho + neta // 2

In [5]:
# Define a function to precompute the transformation matrix
# Precompute the transformation matrix from polar coordinates to Cartesian coordiantes 
cart_mat = np.zeros((neta**2, nx, nx))

for i in range(nx):
    for j in range(nx):
        # Create a dummy matrix with a single one at position (i, j) and zeros elsewhere.
        mat_dummy = np.zeros((nx, nx))
        mat_dummy[i, j] = 1
        # Pad the dummy matrix in polar coordinates to cover the target space in Cartesian coordinates.
        pad_dummy = np.pad(mat_dummy, ((0, 0), (neta // 2, neta // 2)), 'edge')
        # Apply the geometric transformation to map the dummy matrix to polar coordinates
        cart_mat[:, i, j] = geometric_transform(pad_dummy, cart_polar, output_shape=[neta, neta], mode='grid-wrap').flatten()

cart_mat = np.reshape(cart_mat, (neta**2, nx**2))
# Removing small values
cart_mat = np.where(np.abs(cart_mat) > 0.001, cart_mat, 0)
# Convert to sparse matrix in tensorflow
#cart_mat = tf.sparse.from_dense(tf.cast(cart_mat, dtype='float32'))

In [6]:
from jax.experimental import sparse
cart_mat = sparse.BCOO.fromdense(cart_mat)

In [41]:
tf.config.set_visible_devices([], device_type='GPU')

name = 'shepp_logan'

# Loading and preprocessing perturbation data (eta)
with h5py.File(f'{name}/eta.h5', 'r') as f:
    # Read eta data, apply Gaussian blur, and reshape
    eta_re = f[list(f.keys())[0]][:NTRAIN, :].reshape(-1, neta, neta)
    blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
    eta_re = np.stack([blur_fn(eta_re[i, :, :].T) for i in range(NTRAIN)]).astype('float32')

# Loading and preprocessing scatter data (Lambda)
with h5py.File(f'{name}/scatter.h5', 'r') as f:
    keys = natsort.natsorted(f.keys())

    # Process real part of scatter data
    tmp1 = f[keys[3]][:NTRAIN, :]
    tmp2 = f[keys[4]][:NTRAIN, :]
    tmp3 = f[keys[5]][:NTRAIN, :]
    scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

    # Process imaginary part of scatter data
    tmp1 = f[keys[0]][:NTRAIN, :]
    tmp2 = f[keys[1]][:NTRAIN, :]
    tmp3 = f[keys[2]][:NTRAIN, :]
    scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
    # Combine real and imaginary parts
    scatter = np.stack((scatter_re, scatter_im), axis=-2).astype('float32')
    
# Clean up temporary variables to free memory
del scatter_re, scatter_im, tmp1, tmp2, tmp3

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

dataset = [(scatter[i,:,:,:], eta_re[i,:,:]) for i in range(NTRAIN)]
data_loader = data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=numpy_collate)

In [48]:
class V(nn.Module):
    r: int

    @nn.compact
    def __call__(self, x):
        n, s = x.shape[2], x.shape[3]

        init_fn = nn.initializers.glorot_uniform()
        vr1 = self.param('vr1', init_fn, (n, s, self.r))
        vi1 = self.param('vi1', init_fn, (n, s, self.r))
        vr2 = self.param('vr2', init_fn, (n, s, self.r))
        vi2 = self.param('vi2', init_fn, (n, s, self.r))
        vr3 = self.param('vr3', init_fn, (n, s, self.r))
        vi3 = self.param('vi3', init_fn, (n, s, self.r))
        vr4 = self.param('vr4', init_fn, (n, s, self.r))
        vi4 = self.param('vi4', init_fn, (n, s, self.r))

        x_re, x_im = x[..., 0], x[..., 1]

        y_re_1 = jnp.einsum('...iaj,ajk->...iak', x_re, vr1)
        y_re_1 = jnp.einsum('abj...i,bjk->abk...i', y_re_1, vr1)
        y_re_2 = jnp.einsum('...iaj,ajk->...iak', x_re, vi1)
        y_re_2 = jnp.einsum('abj...i,bjk->abk...i', y_re_2, vi1)
        y_re_3 = jnp.einsum('...iaj,ajk->...iak', x_im, vi2)
        y_re_3 = jnp.einsum('abj...i,bjk->abk...i', y_re_3, vr2)
        y_re_4 = jnp.einsum('...iaj,ajk->...iak', x_im, vr2)
        y_re_4 = jnp.einsum('abj...i,bjk->abk...i', y_re_4, vi2)
        y_re = y_re_1+y_re_2+y_re_3+y_re_4
        
        y_im_1 = jnp.einsum('...iaj,ajk->...iak', x_im, vr3)
        y_im_1 = jnp.einsum('abj...i,bjk->abk...i', y_im_1, vr3)
        y_im_2 = jnp.einsum('...iaj,ajk->...iak', x_im, vi3)
        y_im_2 = jnp.einsum('abj...i,bjk->abk...i', y_im_2, vi3)
        y_im_3 = jnp.einsum('...iaj,ajk->...iak', x_re, vi4)
        y_im_3 = jnp.einsum('abj...i,bjk->abk...i', y_im_3, vr4)
        y_im_4 = jnp.einsum('...iaj,ajk->...iak', x_re, vr4)
        y_im_4 = jnp.einsum('abj...i,bjk->abk...i', y_im_4, vi4)
        y_im = y_im_1+y_im_2+y_im_3+y_im_4
        
        y = jnp.stack([y_re, y_im], axis=-1)
        
        return y

In [49]:
# Precomputing indices used for grouping neighboring blocks prior to applying Layer Hs.
def build_permutation_indices(L, l):
    delta = 2**(L-l-1)
    tmp = np.tile(np.arange(2)*delta, delta)
    tmp += np.repeat(np.arange(delta), 2)
    tmp = np.tile(tmp, 2**l)
    tmp += np.repeat(np.arange(2**l)*(2**(L-l)), 2**(L-l))
    return tmp

In [50]:
class H(nn.Module):
    L: int
    l: int

    def setup(self):
        # Compute permutation indices
        self.perm_idx = build_permutation_indices(self.L, self.l)

    @nn.compact
    def __call__(self, x):
        # Placeholder for actual input shape dependent variables
        m = x.shape[2] // 2
        s = x.shape[3] * 2

        # Define weights
        init_fn = nn.initializers.glorot_uniform()
        hr1 = self.param('hr1', init_fn, (m, s, s))
        hi1 = self.param('hi1', init_fn, (m, s, s))
        hr2 = self.param('hr2', init_fn, (m, s, s))
        hi2 = self.param('hi2', init_fn, (m, s, s))
        hr3 = self.param('hr3', init_fn, (m, s, s))
        hi3 = self.param('hi3', init_fn, (m, s, s))
        hr4 = self.param('hr4', init_fn, (m, s, s))
        hi4 = self.param('hi4', init_fn, (m, s, s))

        # Apply permutations
        x = x.take(self.perm_idx, axis=1).take(self.perm_idx, axis=3)
        
        # Reshape operation
        x = x.reshape((-1, m, s, m, s, 2))
        # Split real and imaginary parts for processing
        x_re, x_im = x[..., 0], x[..., 1]
        
        y_re_1 = jnp.einsum('...iaj,ajk->...iak', x_re, hr1)
        y_re_1 = jnp.einsum('abj...i,bjk->abk...i', y_re_1, hr1)
        y_re_2 = jnp.einsum('...iaj,ajk->...iak', x_re, hi1)
        y_re_2 = jnp.einsum('abj...i,bjk->abk...i', y_re_2, hi1)
        y_re_3 = jnp.einsum('...iaj,ajk->...iak', x_im, hi2)
        y_re_3 = jnp.einsum('abj...i,bjk->abk...i', y_re_3, hr2)
        y_re_4 = jnp.einsum('...iaj,ajk->...iak', x_im, hr2)
        y_re_4 = jnp.einsum('abj...i,bjk->abk...i', y_re_4, hi2)
        y_re = y_re_1+y_re_2+y_re_3+y_re_4
        
        y_im_1 = jnp.einsum('...iaj,ajk->...iak', x_im, hr3)
        y_im_1 = jnp.einsum('abj...i,bjk->abk...i', y_im_1, hr3)
        y_im_2 = jnp.einsum('...iaj,ajk->...iak', x_im, hi3)
        y_im_2 = jnp.einsum('abj...i,bjk->abk...i', y_im_2, hi3)
        y_im_3 = jnp.einsum('...iaj,ajk->...iak', x_re, hi4)
        y_im_3 = jnp.einsum('abj...i,bjk->abk...i', y_im_3, hr4)
        y_im_4 = jnp.einsum('...iaj,ajk->...iak', x_re, hr4)
        y_im_4 = jnp.einsum('abj...i,bjk->abk...i', y_im_4, hi4)
        y_im = y_im_1+y_im_2+y_im_3+y_im_4
        
        y = jnp.stack([y_re, y_im], axis=-1)

        n = m * 2
        r = s // 2
        y = y.reshape((-1, n, r, n, r, 2))

        return y

In [51]:
# Precomputing indices used for redistributing blocks according to the transformation represented by x -> M*xM.
def build_switch_indices(L):
    L = L // 2
    tmp = np.arange(2**L)*(2**L)
    tmp = np.tile(tmp, 2**L)
    tmp += np.repeat(np.arange(2**L), 2**L)
    return tmp

In [52]:
class M(nn.Module):
    @nn.compact
    def __call__(self, x):
        n, r = x.shape[2], x.shape[3]

        # Initialize weights
        init_fn = nn.initializers.glorot_uniform()
        mr1 = self.param('mr1', init_fn, (n, r, r))
        mi1 = self.param('mi1', init_fn, (n, r, r))
        mr2 = self.param('mr2', init_fn, (n, r, r))
        mi2 = self.param('mi2', init_fn, (n, r, r))
        mr3 = self.param('mr3', init_fn, (n, r, r))
        mi3 = self.param('mi3', init_fn, (n, r, r))
        mr4 = self.param('mr4', init_fn, (n, r, r))
        mi4 = self.param('mi4', init_fn, (n, r, r))

        x_re, x_im = x[..., 0], x[..., 1]

        y_re_1 = jnp.einsum('...iaj,ajk->...iak', x_re, mr1)
        y_re_1 = jnp.einsum('abj...i,bjk->abk...i', y_re_1, mr1)
        y_re_2 = jnp.einsum('...iaj,ajk->...iak', x_re, mi1)
        y_re_2 = jnp.einsum('abj...i,bjk->abk...i', y_re_2, mi1)
        y_re_3 = jnp.einsum('...iaj,ajk->...iak', x_im, mi2)
        y_re_3 = jnp.einsum('abj...i,bjk->abk...i', y_re_3, mr2)
        y_re_4 = jnp.einsum('...iaj,ajk->...iak', x_im, mr2)
        y_re_4 = jnp.einsum('abj...i,bjk->abk...i', y_re_4, mi2)
        y_re = y_re_1+y_re_2+y_re_3+y_re_4
        
        y_im_1 = jnp.einsum('...iaj,ajk->...iak', x_im, mr3)
        y_im_1 = jnp.einsum('abj...i,bjk->abk...i', y_im_1, mr3)
        y_im_2 = jnp.einsum('...iaj,ajk->...iak', x_im, mi3)
        y_im_2 = jnp.einsum('abj...i,bjk->abk...i', y_im_2, mi3)
        y_im_3 = jnp.einsum('...iaj,ajk->...iak', x_re, mi4)
        y_im_3 = jnp.einsum('abj...i,bjk->abk...i', y_im_3, mr4)
        y_im_4 = jnp.einsum('...iaj,ajk->...iak', x_re, mr4)
        y_im_4 = jnp.einsum('abj...i,bjk->abk...i', y_im_4, mi4)
        y_im = y_im_1+y_im_2+y_im_3+y_im_4
        
        y = jnp.stack([y_re, y_im], axis=-1)

        return y

In [53]:
class G(nn.Module):
    L: int
    l: int

    def setup(self):
        # Setup is called once to create parameters, we'll store perm_idx here but its creation is static
        self.perm_idx = build_permutation_indices(self.L, self.l)

    @nn.compact
    def __call__(self, x):
        # Dimensions need to be dynamically inferred from 'x'
        m = x.shape[2] // 2
        s = x.shape[3] * 2

        # Initialize weights
        init_fn = nn.initializers.glorot_uniform()
        gr1 = self.param('gr1', init_fn, (m, s, s))
        gi1 = self.param('gi1', init_fn, (m, s, s))
        gr2 = self.param('gr2', init_fn, (m, s, s))
        gi2 = self.param('gi2', init_fn, (m, s, s))
        gr3 = self.param('gr3', init_fn, (m, s, s))
        gi3 = self.param('gi3', init_fn, (m, s, s))
        gr4 = self.param('gr4', init_fn, (m, s, s))
        gi4 = self.param('gi4', init_fn, (m, s, s))

        # Reshape and perform operations
        x = x.reshape((-1, m, s, m, s, 2))
        x_re, x_im = x[..., 0], x[..., 1]

        y_re_1 = jnp.einsum('...iaj,ajk->...iak', x_re, gr1)
        y_re_1 = jnp.einsum('abj...i,bjk->abk...i', y_re_1, gr1)
        y_re_2 = jnp.einsum('...iaj,ajk->...iak', x_re, gi1)
        y_re_2 = jnp.einsum('abj...i,bjk->abk...i', y_re_2, gi1)
        y_re_3 = jnp.einsum('...iaj,ajk->...iak', x_im, gi2)
        y_re_3 = jnp.einsum('abj...i,bjk->abk...i', y_re_3, gr2)
        y_re_4 = jnp.einsum('...iaj,ajk->...iak', x_im, gr2)
        y_re_4 = jnp.einsum('abj...i,bjk->abk...i', y_re_4, gi2)
        y_re = y_re_1+y_re_2+y_re_3+y_re_4
        
        y_im_1 = jnp.einsum('...iaj,ajk->...iak', x_im, gr3)
        y_im_1 = jnp.einsum('abj...i,bjk->abk...i', y_im_1, gr3)
        y_im_2 = jnp.einsum('...iaj,ajk->...iak', x_im, gi3)
        y_im_2 = jnp.einsum('abj...i,bjk->abk...i', y_im_2, gi3)
        y_im_3 = jnp.einsum('...iaj,ajk->...iak', x_re, gi4)
        y_im_3 = jnp.einsum('abj...i,bjk->abk...i', y_im_3, gr4)
        y_im_4 = jnp.einsum('...iaj,ajk->...iak', x_re, gr4)
        y_im_4 = jnp.einsum('abj...i,bjk->abk...i', y_im_4, gi4)
        y_im = y_im_1+y_im_2+y_im_3+y_im_4

        y = jnp.stack([y_re, y_im], axis=-1)

        # Final reshape and permutation
        n, r = m * 2, s // 2
        y = y.reshape((-1, n, r, n, r, 2))
        y = y.take(self.perm_idx, axis=1).take(self.perm_idx, axis=3)

        return y

In [54]:
class U(nn.Module):
    s: int  # Size parameter

    @nn.compact
    def __call__(self, x):
        # Extracting the shapes for weight initialization
        n, r, c = x.shape[2], x.shape[3], x.shape[-1]
        nx = n*self.s
        
        # Weight initialization
        init_fn = nn.initializers.glorot_uniform()
        ur1 = self.param('ur1', init_fn, (n, r, self.s))
        ui1 = self.param('ui1', init_fn, (n, r, self.s))
        ur2 = self.param('ur2', init_fn, (n, r, self.s))
        ui2 = self.param('ui2', init_fn, (n, r, self.s))
        ur3 = self.param('ur3', init_fn, (n, r, self.s))
        ui3 = self.param('ui3', init_fn, (n, r, self.s))
        ur4 = self.param('ur4', init_fn, (n, r, self.s))
        ui4 = self.param('ui4', init_fn, (n, r, self.s))

        # Splitting real and imaginary parts
        x_re, x_im = x[..., 0], x[..., 1]

        # Performing the einsum operations
        y_re_1 = jnp.einsum('...iaj,ajk->...iak', x_re, ur1)
        y_re_1 = jnp.einsum('abj...i,bjk->abk...i', y_re_1, ur1)
        y_re_2 = jnp.einsum('...iaj,ajk->...iak', x_re, ui2)
        y_re_2 = jnp.einsum('abj...i,bjk->abk...i', y_re_2, ui2)
        y_re_3 = jnp.einsum('...iaj,ajk->...iak', x_im, ui3)
        y_re_3 = jnp.einsum('abj...i,bjk->abk...i', y_re_3, ur3)
        y_re_4 = jnp.einsum('...iaj,ajk->...iak', x_im, ur4)
        y_re_4 = jnp.einsum('abj...i,bjk->abk...i', y_re_4, ui4)
        y_re = y_re_1+y_re_2+y_re_3+y_re_4
        # Final sum of y_re components
        y_re = y_re_1 + y_re_2 + y_re_3 + y_re_4

        return y_re.reshape((-1, nx, nx, 1))

In [None]:
class Fstar(nn.Module):
    L: int
    s: int
    r: int
    NUM_RESNET: int
    cart_mat: jnp.ndarray
    r_index: jnp.ndarray
    
    def setup(self):
        self.n = 2**self.L
        self.nx = (2**self.L)*self.s
        self.neta = (2**self.L)*self.s
        self.V = V(self.r)
        self.Hs = [H(self.L, l) for l in range(self.L-1, self.L//2-1, -1)]
        self.Ms = [M() for _ in range(2 * self.NUM_RESNET)]
        self.Gs = [G(self.L, l) for l in range(self.L//2, self.L)]
        self.U = U(self.s)
        self.switch_idx = build_switch_indices(self.L)

    def __call__(self, inputs):
        
        def helper(input):
            y = jnp.reshape(jnp.take(input, r_index, axis=0), (-1, self.nx, self.nx, 2))

            y = self.V(y)
            
            for h in self.Hs:
                y = h(y)
                     
            y = y.take(self.switch_idx, axis=1).take(self.switch_idx, axis=3)
            for m in self.Ms:
                y = m(y) if m is self.Ms[-1] else y + nn.relu(m(y))
        
            for g in self.Gs:
                y = g(y)
        
            y = self.U(y)
            
            y = jnp.diagonal(y, axis1 = 1, axis2 = 2)
            y = jnp.reshape(y, (-1, self.nx**2, 1))
            y = self.cart_mat @ y
            
            return jnp.reshape(x, (self.neta, self.neta, 1))

        return jax.vmap(helper)(inputs)

# Define the main model using Flax
class MyModel(nn.Module):
    L: int
    s: int
    r: int
    NUM_RESNET: int
    cart_mat: jnp.ndarray
    r_index: jnp.ndarray

    def setup(self):
        self.fstar_layer0 = Fstar(L=self.L, s=self.s, r=self.r, NUM_RESNET = self.NUM_RESNET, cart_mat=self.cart_mat, r_index=self.r_index)
        self.fstar_layer1 = Fstar(L=self.L, s=self.s, r=self.r, NUM_RESNET = self.NUM_RESNET, cart_mat=self.cart_mat, r_index=self.r_index)
        self.fstar_layer2 = Fstar(L=self.L, s=self.s, r=self.r, NUM_RESNET = self.NUM_RESNET, cart_mat=self.cart_mat, r_index=self.r_index)
        self.convs = [nn.Conv(features=6, kernel_size=(3, 3), padding='SAME') for _ in range(9)]
        self.final_conv = nn.Conv(features=1, kernel_size=(3, 3), padding='SAME')

    def __call__(self, inputs):
        y0 = self.fstar_layer0(inputs[:, :, :, 0])
        y1 = self.fstar_layer1(inputs[:, :, :, 1])
        y2 = self.fstar_layer2(inputs[:, :, :, 2])
        
        y = jnp.concatenate([y0, y1, y2], axis = -1)

        for conv_layer in self.convs:
            tmp = conv_layer(y)
            tmp = jax.nn.relu(tmp)
            y = jnp.concatenate([y, tmp], axis = -1)
        
        y = self.final_conv(y)
        
        return y[:,:,:,0]



In [17]:
def rotationindex(n):
    index = jnp.reshape(jnp.arange(0, n**2, 1), [n, n])
    return jnp.concatenate([jnp.roll(index, shift=[-i,-i], axis=[0,1]) for i in range(n)], 0)

In [19]:
r_index = rotationindex(80)
r_index

Array([[   0,    1,    2, ...,   77,   78,   79],
       [  80,   81,   82, ...,  157,  158,  159],
       [ 160,  161,  162, ...,  237,  238,  239],
       ...,
       [6159, 6080, 6081, ..., 6156, 6157, 6158],
       [6239, 6160, 6161, ..., 6236, 6237, 6238],
       [6319, 6240, 6241, ..., 6316, 6317, 6318]], dtype=int32)

In [29]:
data = scatter[0,:,:,0]

In [30]:
data.shape

(2, 6400)

In [31]:
c1 = jnp.take(data, r_index, axis = 1)

In [35]:

c2 = jnp.stack((r1, i1), axis=0)

In [36]:
c1 - c2

Array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

In [19]:
# Instantiate the model
model = MyModel(L, s, r, 3, cart_mat, r_index)


rng = jax.random.PRNGKey(42)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (batch_size, 2, 6400, 3))  
# Define an optimizer
optimizer = optax.adam(learning_rate=4e-3)
params = model.init(init_rng, inp)  # Initialize parameters


In [20]:
from flax.training import train_state

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

In [21]:
def calculate_loss_acc(state, params, batch):
    x, y = batch
    # Obtain the logits and predictions of the model for the input data
    pred = state.apply_fn(params, x)
       
    # Calculate the loss and accuracy
    loss = jnp.mean((pred - y) ** 2)
    acc = jnp.sqrt(loss/jnp.mean(y ** 2))
    return loss, acc

In [22]:
batch = next(iter(data_loader))
calculate_loss_acc(model_state, model_state.params, batch)

(Array(0.01913041, dtype=float32), Array(1., dtype=float32))

In [23]:
@jax.jit  # Jit the function for efficiency
def train_step(state, batch):
    # Gradient function
    grad_fn = jax.value_and_grad(calculate_loss_acc,  # Function to calculate the loss
                                 argnums=1,  # Parameters are second argument of the function
                                 has_aux=True  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    (loss, acc), grads = grad_fn(state, state.params, batch)
    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss, acc


In [24]:
@jax.jit  # Jit the function for efficiency
def eval_step(state, batch):
    # Determine the accuracy
    _, acc = calculate_loss_acc(state, state.params, batch)
    return acc

In [25]:
def train_model(state, data_loader, num_epochs=100):
    # Training loop
    for epoch in range(num_epochs):
        
        for batch in data_loader:
            state, loss, acc = train_step(state, batch)
        print(acc)
            
    return state

In [None]:
trained_model_state = train_model(model_state, data_loader, num_epochs=5000)

0.8835075
0.8780626
0.57185066
0.28665906
0.23034663
0.2321662
0.2068171
0.17213953
0.17073564
0.1839958
0.15804748
0.1711674
0.1473232
0.14383674
0.1601862
0.13681433
0.13639821
0.12729992
0.11110984
0.11529502
0.101152524
0.11395437
0.09848287
0.091624506
0.09400712
0.08499524
0.09534076
0.07303455
0.0848759
0.083433114
0.09147592
0.08247699
0.0744942
0.076485455
0.08341152
0.081816405
0.07906338
0.074586496
0.08885229
0.0724868
0.077469006
0.07469498
