In [None]:
import functools

from clu import metric_writers
import numpy as np
import jax
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp
import torch.utils.data as data
from tqdm import tqdm

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter



In [None]:
!nvidia-smi

In [None]:
jax.devices()

In [None]:
# Parameters for the computational task.

L = 4 # number of levels (even number)
s = 5 # leaf size
r = 3 # rank

# Discretization of Omega (n_eta * n_eta).
neta = (2**L)*s

# Number of sources/detectors (n_sc).
# Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
# For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
nx = (2**L)*s

# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Batch size.
batch_size = 16

# Number of training datapoints.
NTRAIN = 2000

# Number of testing datapoints.
NTEST = 320

In [1]:
def cart_polar(coords, neta, nx):
    """
    Transforms coordinates from Cartesian to polar coordinates with custom scaling.

    Parameters:
    - coords: A tuple or list containing the (i, j) coordinates to be transformed.
    - neta: Scaling factor for the radial distance.
    - nx: Scaling factor for the angle.

    Returns:
    - A tuple (rho, theta) representing the transformed coordinates.
    """
    i, j = coords[0], coords[1]
    # Calculate the radial distance with a scaling factor.
    rho = 2 * jnp.sqrt((i - neta / 2) ** 2 + (j - neta / 2) ** 2) * nx / neta
    # Calculate the angle in radians and adjust the scale to fit the specified range.
    theta = ((jnp.arctan2((neta / 2 - j), (i - neta / 2))) % (2 * jnp.pi)) * nx / jnp.pi / 2
    return theta, rho + neta // 2

In [None]:
# Define a function to precompute the transformation matrix
def precompute_transform_matrix(neta, nx, cart_polar):
    cart_mat = jnp.zeros((neta**2, nx, nx))

    for i in range(nx):
        for j in range(nx):
            # Create a dummy matrix with a single one at position (i, j) and zeros elsewhere.
            mat_dummy = jnp.zeros((nx, nx))
            mat_dummy = lax.dynamic_update_index_in_dim(mat_dummy, (i, j), 1)
            # Pad the dummy matrix in polar coordinates to cover the target space in Cartesian coordinates.
            pad_dummy = jnp.pad(mat_dummy, ((0, 0), (neta // 2, neta // 2)), 'edge')
            # Apply the geometric transformation to map the dummy matrix to polar coordinates
            cart_mat = lax.dynamic_update_index_in_dim(cart_mat, (slice(None), i, j), 
                                                        jnp.ravel(geometric_transform(pad_dummy, cart_polar, output_shape=[neta, neta], mode='grid-wrap')))

    cart_mat = jnp.reshape(cart_mat, (neta**2, nx**2))
    # Removing small values
    cart_mat = jnp.where(jnp.abs(cart_mat) > 0.001, cart_mat, 0)
    return cart_mat

In [None]:
name = 'shepp_logan'

# Define a function to load and preprocess perturbation data (eta)
def load_and_preprocess_eta(name, NTRAIN, neta, blur_sigma):
    with h5py.File(f'{name}/eta.h5', 'r') as f:
        # Read eta data, apply Gaussian blur, and reshape
        eta_re = f[list(f.keys())[0]][:NTRAIN, :].reshape(-1, neta, neta)
        blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
        eta_re = jnp.stack([blur_fn(eta_re[i, :, :]) for i in range(NTRAIN)]).astype('float32')
    return eta_re

# Define a function to load and preprocess scatter data (Lambda)
def load_and_preprocess_scatter(name, NTRAIN, nx):
    with h5py.File(f'{name}/scatter.h5', 'r') as f:
        keys = natsort.natsorted(f.keys())
        
        # Process real part of scatter data
        tmp1 = f[keys[3]][:NTRAIN, :].reshape((-1, nx, nx))
        tmp2 = f[keys[4]][:NTRAIN, :].reshape((-1, nx, nx))
        tmp3 = f[keys[5]][:NTRAIN, :].reshape((-1, nx, nx))
        scatter_re = jnp.stack((tmp1, tmp2, tmp3), axis=-1)
        
        # Process imaginary part of scatter data
        tmp1 = f[keys[0]][:NTRAIN, :].reshape((-1, nx, nx))
        tmp2 = f[keys[1]][:NTRAIN, :].reshape((-1, nx, nx))
        tmp3 = f[keys[2]][:NTRAIN, :].reshape((-1, nx, nx))
        scatter_im = jnp.stack((tmp1, tmp2, tmp3), axis=-1)
        
        # Combine real and imaginary parts
        scatter = jnp.stack((scatter_re, scatter_im), axis=1).astype('float32')
    return scatter

# Load and preprocess perturbation data
eta_re = load_and_preprocess_eta(name, NTRAIN, neta, blur_sigma)

# Load and preprocess scatter data
scatter = load_and_preprocess_scatter(name, NTRAIN, nx)

# Clean up temporary variables to free memory
del scatter_re, scatter_im, tmp1, tmp2, tmp3

# Convert JAX arrays to NumPy arrays as TensorFlow works with NumPy arrays
eta_re_np = jax.device_get(eta_re)
scatter_np = jax.device_get(scatter)

# Create a TensorFlow dataset for training
trn_dataset = tf.data.Dataset.from_tensor_slices((scatter_np, eta_re_np))
trn_dataset = trn_dataset.prefetch(tf.data.experimental.AUTOTUNE)
trn_dataset = trn_dataset.shuffle(buffer_size=200)
trn_dataset = trn_dataset.batch(BATCH_SIZE)
dataset = eval_dataloader = trn_dataset.as_numpy_iterator()

In [None]:
# Define the F^* layer using Flax
class Fstar(nn.Module):
    nx: int
    neta: int
    cart_mat: jnp.ndarray

    def setup(self):
        self.pre1 = self.param('pre1', lambda key, shape: jax.random.uniform(key, shape))
        self.post1 = self.param('post1', lambda key, shape: jax.random.uniform(key, shape))
        self.cos_kernel1 = self.param('cos_kernel1', lambda key, shape: jax.random.uniform(key, shape))
        self.sin_kernel1 = self.param('sin_kernel1', lambda key, shape: jax.random.uniform(key, shape))
        self.cos_kernel2 = self.param('cos_kernel2', lambda key, shape: jax.random.uniform(key, shape))
        self.sin_kernel2 = self.param('sin_kernel2', lambda key, shape: jax.random.uniform(key, shape))

    def __call__(self, inputs):
        # Separate real and imaginary parts of inputs
        R, I = inputs[:, 0, :, :], inputs[:, 1, :, :]
        
        # Define rotation function
        def rotationindex(n):
            index = jnp.reshape(jnp.arange(0, n**2, 1), [n, n])
            return jnp.concatenate([jnp.roll(index, shift=[-i,-i], axis=[0,1]) for i in range(n)], 0)
        
        rindex = lambda d: jax.lax.dynamic_slice(jnp.reshape(d, [-1]), rotationindex(self.nx), [self.nx * self.nx])
        
        Rs = jax.vmap(rindex)(R)
        Rs = jnp.reshape(Rs, [-1, self.nx, self.nx])
        Is = jax.vmap(rindex)(I)
        Is = jnp.reshape(Is, [-1, self.nx, self.nx])
        
        def helper(pre, post, kernel2, kernel1, data):
            return jnp.matmul(post, jnp.multiply(kernel2, jnp.matmul(jnp.multiply(data, pre), kernel1)))  
        
        output_polar = helper(self.pre1, self.post1, self.cos_kernel1, self.cos_kernel2, Rs) \
                      + helper(self.pre2, self.post2, self.sin_kernel1, self.sin_kernel2, Rs) \
                      + helper(self.pre3, self.post3, self.cos_kernel2, self.sin_kernel1, Is) \
                      + helper(self.pre4, self.post4, self.sin_kernel2, self.cos_kernel1, Is)
        
        output_polar = jnp.reshape(output_polar, (-1, self.nx, self.nx))
        
        # Convert from polar to Cartesian coordinates
        def polar_to_cart(x):
            x = jnp.reshape(x, (self.nx**2, 1))
            x = jnp.dot(self.cart_mat, x)
            return jnp.reshape(x, (self.neta, self.neta))
        
        output_cart = jax.vmap(polar_to_cart)(output_polar)
        return jnp.reshape(output_cart, (-1, self.neta, self.neta, 1))

# Define the main model using Flax
class MyModel(nn.Module):
    nx: int
    neta: int
    cart_mat: jnp.ndarray
    num_cnn: int

    def setup(self):
        self.fstar_layer = Fstar(nx=self.nx, neta=self.neta, cart_mat=self.cart_mat)
        self.convs = [nn.Conv(features=6, kernel_size=(3, 3), padding='SAME') for _ in range(self.num_cnn - 1)]
        self.final_conv = nn.Conv(features=1, kernel_size=(3, 3), padding='SAME')

    def __call__(self, inputs):
        y1 = self.fstar_layer(inputs[:, :, :, :, 0])
        y2 = self.fstar_layer(inputs[:, :, :, :, 1])
        y3 = self.fstar_layer(inputs[:, :, :, :, 2])

        y = jnp.concatenate([y1, y2, y3], axis=-1)

        for conv_layer in self.convs:
            y = conv_layer(y)
            y = jax.nn.relu(y)
        
        y = self.final_conv(y)
        
        return y



In [None]:
# Instantiate the model
model = MyModel(nx, neta, cart_mat, NUM_CNN)


rng = jax.random.PRNGKey(42)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (batch_size, L1, L1, 2))  
# Define an optimizer
optimizer = optax.adam(learning_rate=1e-4)
params = model.init(init_rng, inp)  # Initialize parameters


In [None]:
from flax.training import train_state

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

In [None]:
def calculate_loss_acc(state, params, batch):
    x, y = batch
    # Obtain the logits and predictions of the model for the input data
    pred = state.apply_fn(params, x)
       
    # Calculate the loss and accuracy
    loss = jnp.mean((pred - y) ** 2)
    acc = jnp.sqrt(loss/jnp.mean(y ** 2))
    return loss, acc

In [None]:
batch = next(iter(data_loader))
calculate_loss_acc(model_state, model_state.params, batch)

In [None]:
@jax.jit  # Jit the function for efficiency
def train_step(state, batch):
    # Gradient function
    grad_fn = jax.value_and_grad(calculate_loss_acc,  # Function to calculate the loss
                                 argnums=1,  # Parameters are second argument of the function
                                 has_aux=True  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    (loss, acc), grads = grad_fn(state, state.params, batch)
    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss, acc


In [None]:
@jax.jit  # Jit the function for efficiency
def eval_step(state, batch):
    # Determine the accuracy
    _, acc = calculate_loss_acc(state, state.params, batch)
    return acc

In [41]:
def train_model(state, data_loader, num_epochs=100):
    # Training loop
    for epoch in range(num_epochs):
        
        for batch in data_loader:
            state, loss, acc = train_step(state, batch)
        print(acc)
            
    return state

In [42]:
trained_model_state = train_model(model_state, data_loader, num_epochs=500)

0.62765276
0.5699285
0.51149094
0.4364011
0.43604264
0.3773166
0.36706853
0.38370332
0.35950708
0.3578673
0.33781496
0.33543062
0.3399269
0.30684948
0.30500194
0.2904274
0.28377196
0.30791003
0.30995208
0.30437276
0.3007315
0.27253976
0.286515
0.26192582
0.25777122
0.25902832
0.25168607
0.2622497
0.2806948
0.26001588
0.2630823
0.24773355
0.23084229
0.25927553
0.26018977
0.2271517
0.24457313
0.22907971
0.23938344
0.24449262
0.24650888
0.26309568
0.22288127
0.23615713
0.220491
0.21007022
0.21267845
0.23347838
0.22431682
0.20537148
0.20106609
0.22448424
0.20818888
0.21766713
0.21276915
0.18663907
0.19759814
0.18237337
0.19505891
0.22396858
0.19007027
0.20751317
0.21308897
0.19851148
0.19169652
0.1964183
0.19370717
0.18364118
0.18232246
0.18579541
0.18737678
0.18589503
0.19248821
0.18174112
0.17948163
0.1833093
0.1747682
0.18085857
0.18993571
0.18191735
0.1803571
0.17224161
0.17769665
0.17789787
0.1709675
0.1734087
0.16577762
0.17144592
0.17202622
0.1773083
0.16912267
0.18149474
0.19026102