In [5]:
import numpy as np
import jax
import jax.numpy as jnp
import optax
import wandb
from NeuralPC.model.CNNs_flax import Encoder_Decoder 




In [None]:


def split_idx(length, key):
    k = jax.random.PRNGKey(key)
    idx = jax.random.permutation(k, length)
    trainIdx = idx[:int(.6*length)]
    valIdx = idx[-int(.4 * length):]
    return trainIdx, valIdx

In [None]:
import matplotlib.pyplot as plt

data = np.load('../../../datasets/Dirac/precond_data/config.l16-N200-b2.0-k0.276-unquenched-test.x.npy')
data = jnp.asarray(data)
data = jnp.transpose(data, [0, 2, 3, 1])

# expoential transform
data_exp = np.exp(1j * data)
# print(data[0])
dataReal = data_exp.real
dataImag = data.imag
dataComb = jnp.concatenate([dataReal, dataImag], axis=-1)
print(dataComb.shape)

# normalize the data!


# Data Visualization 

In [None]:
#split the data into train and val

trainIdx, valIdx = split_idx(dataComb.shape[0], 42)
trainData = dataComb[trainIdx]
valData = dataComb[valIdx]

from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, learning_rate='auto',
                   init='random').fit_transform(dataComb.reshape(dataComb.shape[0], -1))

train_embed = tsne[trainIdx]
val_embed = tsne[valIdx]
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.scatter(train_embed[:,0], train_embed[:, 1])
ax.scatter(val_embed[:, 0], val_embed[:, 1])



In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
class Data(Dataset):
    def __init__(self, data) -> None:
        super().__init__()
        self.data = data
    def __len__(self):
        return self.data.shape[0]
    def __getitem__(self, index) :
        return np.array(self.data[index])

def create_dataLoader(data, batchSize, shuffle: bool):
    dataset = Data(data)
    loader = DataLoader(dataset, batch_size=batchSize, shuffle=shuffle)
    return loader 

TrainLoader = create_dataLoader(np.array(trainData), 16, True)
ValLoader = create_dataLoader(np.array(valData), 16, False)

for v in ValLoader:
    print(type(v))
    break


# Model Training

In [None]:
from typing import Any



key = jax.random.PRNGKey(0)
optimizer = optax.adam(learning_rate=0.001)

# print(model.tabulate(key, jnp.ones((1, 16, 16, 4))))
# print(model.tabulate(jax.random.key(0), x))
@jax.jit
def MSE(params, x, y, train):
    pred = model.apply(params, x, train=train)
    return jnp.mean((y - pred) ** 2)

# params = model.init(key, dataComb)
variables = model.init(jax.random.key(0), dataComb, train=False)
params = variables['params']
batch_stats = variables['batch_stats']

# load data
import numpy as np
import optax

from flax.training import train_state

class TrainState(train_state.TrainState):
    batch_stats: Any

def init_train_state(
    model, random_key, shape, learning_rate
) -> train_state.TrainState:
    # Initialize the Model
    variables = model.init(random_key, jnp.ones(shape), train=True)
    # Create the optimizer
    optimizer = optax.adam(learning_rate)
    # Create a State
    return TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params'],
        batch_stats=variables['batch_stats']
    )



@jax.jit
def train_step(
    state: train_state.TrainState, batch: jnp.ndarray
):
    in_mat = batch
    out_mat = batch


    def loss_fn(params):
        pred, updates = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, in_mat, train=True, mutable=['batch_stats'])
        loss = jnp.mean((pred - out_mat)**2)
        return loss, updates


    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, updates), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])
    return state, loss

@jax.jit
def eval_step(
    state, batch
):
    in_mat = batch
    out_mat = batch

    pred, updates = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, in_mat, train=False, mutable=['batch_stats'])
    return jnp.mean((pred - out_mat)**2)

def train_val(trainLoader, valLoader, state, epochs, verbose):
    for ep in range(epochs):
        # training
        trainBatchLoss = []
        for train_batch in trainLoader:
            state, loss = train_step(state, train_batch.numpy())
            trainBatchLoss.append(loss)
        valBatchLoss = []        
        for val_batch in valLoader:
            vLoss = eval_step(state, val_batch.numpy())
            valBatchLoss.append(vLoss)
        if ep % 100 == 0 and verbose==True: 
            print('Epoch {}, train loss {:.4f} validation loss {:.4f}'.format(ep+1, np.mean(trainBatchLoss), np.mean(valBatchLoss)))
        wandb.log({'trainLoss': np.mean(trainBatchLoss), 'valLoss': np.mean(valBatchLoss)})
    return state
        





# from flax import struct

# class TrainState(train_state.TrainState):
#   metrics: MSE

# def create_train_state(module, rng, learning_rate, momentum):
#   """Creates an initial `TrainState`."""
#   params = model.init(rng, jnp.ones([1, 16, 16, 4]))['params'] # initialize parameters by passing a template image
#   tx = optax.sgd(learning_rate, momentum)
#   return TrainState.create(
#       apply_fn=model.apply, params=params, tx=tx,
#       metrics=MSE)

# @jax.jit
# def train_step(state, batch):
#   """Train for a single step."""
#   def loss_fn(params):
#     logits = state.apply_fn({'params': params}, batch['image'])
#     loss = optax.softmax_cross_entropy_with_integer_labels(
#         logits=logits, labels=batch['label']).mean()
#     return loss
#   grad_fn = jax.grad(loss_fn)
#   grads = grad_fn(state.params)
#   state = state.apply_gradients(grads=grads)
#   return state





# trainIdx, valIdx = split_idx(dataComb.shape[0], 42)
# trainData = dataComb[trainIdx]
# valData = dataComb[valIdx]


# opt_state = optimizer.init(params)
# loss_grad_fn = jax.value_and_grad(MSE)
# for i in range(epochs):
#     trainLoss, grads = loss_grad_fn(params, trainData, trainData, True)
#     valLoss = MSE(params, valData, valData, False)
#     # Do the learning - updating the params.
#     updates, opt_state = optimizer.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)
#     wandb.log({'trainLoss': trainLoss, 'valLoss': valLoss})
#     if i % 100 == 0:
#         print('Loss step {}: '.format(i), trainLoss, 'validation loss {}'.format(valLoss))

In [None]:
epochs = 10000
learning_rate = 0.1
h_ch = 64
data = "config_exp"

model = Encoder_Decoder(4, 4, h_ch, (3,3))
# train the autoecoder
run = wandb.init(
    # Set the project where this run will be logged
    project="reconstructDiracConfig",
    name='BatchNorm-4-lr0.1',
    # Track hyperparameters and run metadata
    config={
        "learning_rate": learning_rate,
        "epochs": epochs,
        "h_ch": h_ch,
    })

state = init_train_state(
    model, key, (1, 16, 16, 4), learning_rate
)
final_state = train_val(TrainLoader, ValLoader, state, epochs, verbose=False)

# Train for the preconditioner

In [2]:
# is element wise projection increasing the channels? i think so
from NeuralPC.utils.dirac import DiracOperator
from jax.scipy.sparse.linalg import cg


def conditionNum(L):
    '''
    L is the lower trangular matrix
    '''
    pass

def PCG_loss(L, U1, b, kappa, steps):
    '''
        L : the neural network output, lower trangular preconditioner of shape (b, L^2, L^2, 2)
        U1: gauge configuration of shape (b, L, L, 2).
        steps: total steps for PCG to run.
    '''
    def runPCG(operator, b, precond=None):
        x = jnp.zeros((1, L, L, 2))
        x_sol = cg(A=operator, b=b, x0=x, M=precond, maxiter=steps)
        return x_sol
    
    # fix the iteration step, calculate the residual b - Ax_sol and minimize the squared value.
    operator = DiracOperator()
    x_sol = runPCG(operator)
    residual = b - operator(U1, kappa, x_sol)
    return residual

def BatchPCGLoss(L, U1, b, kappa, steps):
    batchResidual = jax.vmap(PCG_loss, in_axes=[0, 1, 2, None, None])(L, U1, b, kappa, steps)
    return jnp.mean(batchResidual**2)



In [3]:
# test block for cg

from jax.scipy.sparse.linalg import cg
from NeuralPC.utils.dirac import DiracOperator
from functools import partial
import jax
import numpy as np
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
steps=100

def random_b(key, shape):
    # Generate random values for the real and imaginary parts
    real_part = 1 - jax.random.uniform(key, shape)
    imag_part = 1 - jax.random.uniform(jax.random.split(key)[1], shape)
    # Combine the real and imaginary parts
    complex_array = real_part + 1j * imag_part
    return complex_array


def runPCG(operator, b, precond=None):
    x = jnp.zeros(b.shape).astype(b.dtype)
    x_sol = cg(A=operator, b=b, x0=x, M=precond, maxiter=steps)
    return x_sol

U1 = np.load('../../../datasets/Dirac/precond_data/config.l8-N200-b2.0-k0.276-unquenched-test.x.npy')
U1 = jnp.exp(1j*U1).astype(jnp.complex128)
print(U1.shape)
operator = partial(DiracOperator, U1=U1,kappa=0.276)
y = operator(x=jnp.ones((200, 8, 8, 2)))

b = random_b(jax.random.PRNGKey(0), (200, 8, 8, 2)).astype(jnp.complex128)
# M = random_b(jax.random.PRNGKey(1), (8, 8)).astype(jnp.complex128)

x_sol = runPCG(operator=operator, b=b, precond=None)
x_sol[0].shape



(200, 2, 8, 8)


(200, 8, 8, 2)

In [None]:
# test the jax implementation of the Dirac operator

from jax.scipy.sparse.linalg import cg
from NeuralPC.utils.dirac import example, exampleJAX
import jax
jax.config.update("jax_enable_x64", True)
y_true = example('../../../datasets/Dirac/precond_data/config.l8-N200-b2.0-k0.276-unquenched-test.x.npy')
y_jax = exampleJAX('../../../datasets/Dirac/precond_data/config.l8-N200-b2.0-k0.276-unquenched-test.x.npy')

print(y_true.numpy() - y_jax)
