In [1]:
import functools

from clu import metric_writers
import numpy as np
import jax
from jax import lax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp
import torch.utils.data as data
from tqdm import tqdm

import h5py
import natsort
import tensorflow as tf
from scipy.ndimage import geometric_transform
from scipy.ndimage import gaussian_filter



In [2]:
jax.devices()

[cuda(id=0), cuda(id=1)]

In [3]:
# Parameters for the computational task.

L = 4 # number of levels (even number)
s = 5 # leaf size
r = 3 # rank

# Discretization of Omega (n_eta * n_eta).
neta = (2**L)*s

# Number of sources/detectors (n_sc).
# Discretization of the domain of alpha in polar coordinates (n_theta * n_rho).
# For simplicity, these values are set equal (n_sc = n_theta = n_rho), facilitating computation.
nx = (2**L)*s

# Standard deviation for the Gaussian blur.
blur_sigma = 0.5

# Batch size.
batch_size = 16

# Number of training datapoints.
NTRAIN = 2000

# Number of testing datapoints.
NTEST = 320

In [4]:
def cart_polar(coords):
    """
    Transforms coordinates from Cartesian to polar coordinates with custom scaling.

    Parameters:
    - coords: A tuple or list containing the (i, j) coordinates to be transformed.

    Returns:
    - A tuple (rho, theta) representing the transformed coordinates.
    """
    i, j = coords[0], coords[1]
    # Calculate the radial distance with a scaling factor.
    rho = 2 * np.sqrt((i - neta / 2) ** 2 + (j - neta / 2) ** 2) * nx / neta
    # Calculate the angle in radians and adjust the scale to fit the specified range.
    theta = ((np.arctan2((neta / 2 - j), (i - neta / 2))) % (2 * np.pi)) * nx / np.pi / 2
    return theta, rho + neta // 2

In [5]:
# Define a function to precompute the transformation matrix
# Precompute the transformation matrix from polar coordinates to Cartesian coordiantes 
cart_mat = np.zeros((neta**2, nx, nx))

for i in range(nx):
    for j in range(nx):
        # Create a dummy matrix with a single one at position (i, j) and zeros elsewhere.
        mat_dummy = np.zeros((nx, nx))
        mat_dummy[i, j] = 1
        # Pad the dummy matrix in polar coordinates to cover the target space in Cartesian coordinates.
        pad_dummy = np.pad(mat_dummy, ((0, 0), (neta // 2, neta // 2)), 'edge')
        # Apply the geometric transformation to map the dummy matrix to polar coordinates
        cart_mat[:, i, j] = geometric_transform(pad_dummy, cart_polar, output_shape=[neta, neta], mode='grid-wrap').flatten()

cart_mat = np.reshape(cart_mat, (neta**2, nx**2))
# Removing small values
cart_mat = np.where(np.abs(cart_mat) > 0.001, cart_mat, 0)
# Convert to sparse matrix in tensorflow
#cart_mat = tf.sparse.from_dense(tf.cast(cart_mat, dtype='float32'))

In [6]:
from jax.experimental import sparse
cart_mat = sparse.BCOO.fromdense(cart_mat)

In [7]:
tf.config.set_visible_devices([], device_type='GPU')

name = 'shepp_logan'

# Loading and preprocessing perturbation data (eta)
with h5py.File(f'{name}/eta.h5', 'r') as f:
    # Read eta data, apply Gaussian blur, and reshape
    eta_re = f[list(f.keys())[0]][:NTRAIN, :].reshape(-1, neta, neta)
    blur_fn = lambda x: gaussian_filter(x, sigma=blur_sigma)
    eta_re = np.stack([blur_fn(eta_re[i, :, :].T) for i in range(NTRAIN)]).astype('float32')

# Loading and preprocessing scatter data (Lambda)
with h5py.File(f'{name}/scatter.h5', 'r') as f:
    keys = natsort.natsorted(f.keys())

    # Process real part of scatter data
    tmp1 = f[keys[3]][:NTRAIN, :]
    tmp2 = f[keys[4]][:NTRAIN, :]
    tmp3 = f[keys[5]][:NTRAIN, :]
    scatter_re = np.stack((tmp1, tmp2, tmp3), axis=-1)

    # Process imaginary part of scatter data
    tmp1 = f[keys[0]][:NTRAIN, :]
    tmp2 = f[keys[1]][:NTRAIN, :]
    tmp3 = f[keys[2]][:NTRAIN, :]
    scatter_im = np.stack((tmp1, tmp2, tmp3), axis=-1)
    
    # Combine real and imaginary parts
    scatter = np.stack((scatter_re, scatter_im), axis=1).astype('float32')
    
# Clean up temporary variables to free memory
del scatter_re, scatter_im, tmp1, tmp2, tmp3

## Create a TensorFlow dataset for training
#dataset = tf.data.Dataset.from_tensor_slices((scatter, eta_re))
#dataset = dataset.shuffle(buffer_size=500)
#dataset = dataset.batch(batch_size)
#dataset = dataset.prefetch(tf.data.AUTOTUNE)
#dataset = data_loader = dataset.as_numpy_iterator()

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

dataset = [(scatter[i,:,:,:], eta_re[i,:,:]) for i in range(NTRAIN)]
data_loader = data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=numpy_collate)

In [9]:
# Define the F^* layer using Flax
class Fstar(nn.Module):
    nx: int
    neta: int
    cart_mat: jnp.ndarray

    def setup(self):
        kernel_shape = (self.nx, self.nx)
        p_shape = (1, self.nx)
        
        self.pre1 = self.param('pre1', nn.initializers.glorot_uniform(), p_shape)
        self.pre2 = self.param('pre2', nn.initializers.glorot_uniform(), p_shape)
        self.pre3 = self.param('pre3', nn.initializers.glorot_uniform(), p_shape)
        self.pre4 = self.param('pre4', nn.initializers.glorot_uniform(), p_shape)

        self.post1 = self.param('post1', nn.initializers.glorot_uniform(), p_shape)
        self.post2 = self.param('post2', nn.initializers.glorot_uniform(), p_shape)
        self.post3 = self.param('post3', nn.initializers.glorot_uniform(), p_shape)
        self.post4 = self.param('post4', nn.initializers.glorot_uniform(), p_shape)
        
        self.cos_kernel1 = self.param('cos_kernel1', nn.initializers.glorot_uniform(), kernel_shape)
        self.sin_kernel1 = self.param('sin_kernel1', nn.initializers.glorot_uniform(), kernel_shape)
        self.cos_kernel2 = self.param('cos_kernel2', nn.initializers.glorot_uniform(), kernel_shape)
        self.sin_kernel2 = self.param('sin_kernel2', nn.initializers.glorot_uniform(), kernel_shape)

    def __call__(self, inputs):
        # Separate real and imaginary parts of inputs
        R, I = inputs[:, 0, :], inputs[:, 1, :]
        
        # Define rotation function
        def rotationindex(n):
            index = jnp.reshape(jnp.arange(0, n**2, 1), [n, n])
            return jnp.concatenate([jnp.roll(index, shift=[-i,-i], axis=[0,1]) for i in range(n)], 0)
        
        rindex = lambda d: jnp.take(d, rotationindex(self.nx))
        
        Rs = jax.vmap(rindex)(R)
        Rs = jnp.reshape(Rs, [-1, self.nx, self.nx])
        Is = jax.vmap(rindex)(I)
        Is = jnp.reshape(Is, [-1, self.nx, self.nx])
        
        def helper(pre, post, kernel2, kernel1, data):
            return jnp.matmul(post, jnp.multiply(kernel2, jnp.matmul(jnp.multiply(data, pre), kernel1)))  
        
        output_polar = helper(self.pre1, self.post1, self.cos_kernel1, self.cos_kernel2, Rs) \
                      + helper(self.pre2, self.post2, self.sin_kernel1, self.sin_kernel2, Rs) \
                      + helper(self.pre3, self.post3, self.cos_kernel2, self.sin_kernel1, Is) \
                      + helper(self.pre4, self.post4, self.sin_kernel2, self.cos_kernel1, Is)
        
        output_polar = jnp.reshape(output_polar, (-1, self.nx, self.nx))
        
        # Convert from polar to Cartesian coordinates
        def polar_to_cart(x):
            x = jnp.reshape(x, (self.nx**2, 1))
            x = self.cart_mat @ x
            return jnp.reshape(x, (self.neta, self.neta))
        
        output_cart = jax.vmap(polar_to_cart)(output_polar)
        return jnp.reshape(output_cart, (-1, self.neta, self.neta, 1))

# Define the main model using Flax
class MyModel(nn.Module):
    nx: int
    neta: int
    cart_mat: jnp.ndarray
    num_cnn: int

    def setup(self):
        self.fstar_layer0 = Fstar(nx=self.nx, neta=self.neta, cart_mat=self.cart_mat)
        self.fstar_layer1 = Fstar(nx=self.nx, neta=self.neta, cart_mat=self.cart_mat)
        self.fstar_layer2 = Fstar(nx=self.nx, neta=self.neta, cart_mat=self.cart_mat)
        
        self.convs = [nn.Conv(features=6, kernel_size=(3, 3), padding='SAME') for _ in range(self.num_cnn - 1)]
        self.final_conv = nn.Conv(features=1, kernel_size=(3, 3), padding='SAME')

    def __call__(self, inputs):
        y0 = self.fstar_layer0(inputs[:, :, :, 0])
        y1 = self.fstar_layer1(inputs[:, :, :, 1])
        y2 = self.fstar_layer2(inputs[:, :, :, 2])
        
        y = jnp.concatenate([y0, y1, y2], axis = -1)

        for conv_layer in self.convs:
            tmp = conv_layer(y)
            tmp = jax.nn.relu(tmp)
            y = jnp.concatenate([y, tmp], axis = -1)
        
        y = self.final_conv(y)

        return y[:,:,:,0]



In [18]:
# Instantiate the model
model = MyModel(nx, neta, cart_mat, 8)


rng = jax.random.PRNGKey(42)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (batch_size, 2, 6400, 3))  
# Define an optimizer
optimizer = optax.adam(learning_rate=1e-3)
params = model.init(init_rng, inp)  # Initialize parameters


In [19]:
from flax.training import train_state

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

In [20]:
def calculate_loss_acc(state, params, batch):
    x, y = batch
    # Obtain the logits and predictions of the model for the input data
    pred = state.apply_fn(params, x)
       
    # Calculate the loss and accuracy
    loss = jnp.mean((pred - y) ** 2)
    acc = jnp.sqrt(loss/jnp.mean(y ** 2))
    return loss, acc

In [21]:
batch = next(iter(data_loader))
calculate_loss_acc(model_state, model_state.params, batch)

(Array(0.01939696, dtype=float32), Array(0.97791815, dtype=float32))

In [22]:
@jax.jit  # Jit the function for efficiency
def train_step(state, batch):
    # Gradient function
    grad_fn = jax.value_and_grad(calculate_loss_acc,  # Function to calculate the loss
                                 argnums=1,  # Parameters are second argument of the function
                                 has_aux=True  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    (loss, acc), grads = grad_fn(state, state.params, batch)
    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss, acc


In [23]:
@jax.jit  # Jit the function for efficiency
def eval_step(state, batch):
    # Determine the accuracy
    _, acc = calculate_loss_acc(state, state.params, batch)
    return acc

In [24]:
def train_model(state, data_loader, num_epochs=100):
    # Training loop
    for epoch in range(num_epochs):
        
        for batch in data_loader:
            state, loss, acc = train_step(state, batch)
        print(acc)
            
    return state

In [None]:
trained_model_state = train_model(model_state, data_loader, num_epochs=5000)

0.25455862
0.19897443
0.17354843
0.15688986
0.13922612
0.13644192
0.12989336
0.13155568
0.1180487
0.11828743
0.11255716
0.10880896
0.109767854
0.10609624
0.09753256
0.092296325
0.08264763
0.08392045
0.09134734
0.089241624
0.09046501
0.079866685
0.080469795
0.075873986
0.07323089
0.08085374
0.07404339
0.065801635
0.074725606
0.066407755
0.06981976
0.06988
0.07045619
0.062468685
0.060781524
0.06484609
0.061060004
0.060614735
0.06206696
0.05766756
0.05866712
0.05608634
0.058511708
0.05944217
0.051764812
0.055019785
0.056162808
0.05359541
0.054161176
0.05138499
0.051150475
0.04726951
0.049116228
0.05062388
0.053629696
0.047699884
0.05012803
0.048455406
0.05057133
0.048115037
0.044389706
0.047832746
0.048106536
0.046598658
0.050254732
0.048895407
0.046088554
0.045845576
0.042230926
0.042200796
0.047221813
0.045259368
0.047108307
0.04007353
0.04360658
0.04501091
0.046463728
0.043374684
