In [96]:
import numpy as np
import jax
import flax
import flax.linen as nn
import jax.numpy as jnp
from tqdm import tqdm
from jax import jit
import math
from scipy.optimize import fsolve


import jax.example_libraries.optimizers as jax_opt

import matplotlib.pyplot as plt


In [97]:
def init_params_SAMPLE(layer_widths):
    weights = []
    biases = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        weights.append(jnp.array(np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in)))
        biases.append(jnp.array(np.random.normal(size=(n_out,))))
        
    return [weights, biases]

In [98]:
#initializes a set of NN parameters
def init_params(layer_widths):
    weights = []
    biases = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        weights.append(jnp.array(np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in)))
        biases.append(jnp.array(np.random.normal(size=(n_out,))))
        
    return flatten_params([weights, biases])

In [99]:
def get_phi_params(layer_widths):
    tot_params = []
    for i in range(N_UP + N_DOWN):
        tot_params.append(init_params(layer_widths))
    
    return tot_params

In [100]:
@jit
def flatten_params(ps):
    weights, biases = ps
    params = jnp.array([])
    for i in range(len(weights)):
        params = jnp.concatenate((params, weights[i].flatten()))
        params = jnp.concatenate((params, biases[i].flatten()))
    return jnp.array(params)

In [101]:
@jit
def unflatten_params(params):
    weights = []
    biases = []
    start = 0
    for i in range(len(SAMPLE_W)):
        end = start + SAMPLE_W[i].size 
        weights.append(jnp.reshape(jnp.array(params[start:end]), SAMPLE_W[i].shape))
        start = end
        end = start + SAMPLE_B[i].size
        biases.append(jnp.reshape(jnp.array(params[start:end]), SAMPLE_B[i].shape))
        start = end
    return [weights, biases]

In [102]:
@jit
def NN(coords, params):
    weights, biases = unflatten_params(params) 
    a = jnp.array(coords)
    for i in range(len(weights) - 1):
        z = jnp.dot(a, weights[i]) + biases[i]
        a = jax.nn.celu(z)
    a = jnp.dot(a, weights[-1]) + biases[-1]
    return a[0]

In [103]:
def inputs_up(coords, j):
    
    reordered = [coords[j]] + coords[:j] + coords[j+1:] 
    
    
    sym_piece1 = reordered[1:N_UP]
    sym_piece2 = reordered[N_UP:]
    
    num = 2.0
    
    new1 = []
    new2 = []
    for i in range(1, N_UP):
        new1.append(sum((np.array(sym_piece1)/num)**i))
    for i in range(1, N_DOWN+1):
        new2.append(sum((np.array(sym_piece2)/num)**i))
    
    return [reordered[0]] + new1 + new2

In [104]:
def inputs_down(coords, j):
    reordered = [coords[j+N_UP]] + coords[:j+N_UP] + coords[j+N_UP+1:] 
    
    sym_piece1 = reordered[1:N_UP+1]
    sym_piece2 = reordered[N_UP+1:]
    
    num = 2.0
    
    new1 = []
    new2 = []
    for i in range(1, N_UP+1):
        new1.append(sum((np.array(sym_piece1)/num)**i))
    for i in range(1, N_DOWN):
        new2.append(sum((np.array(sym_piece2)/num)**i))
        
    return [reordered[0]] + new1 + new2
    

In [105]:
def PHI_up(coords, params):
    mat = np.zeros((N_UP, N_UP))
    for i in range(N_UP):
        for j in range(N_UP):
            mat[i][j] = NN(inputs_up(coords, j), params[i])
    
    return np.linalg.det(mat)

In [106]:
def PHI_down(coords, params):
    mat = np.zeros((N_DOWN, N_DOWN))
    for i in range(N_DOWN):
        for j in range(N_DOWN):
            mat[i][j] = NN(inputs_down(coords, j), params[i])
    
    return np.linalg.det(mat)

In [107]:
def Psi(coords, params):
    return PHI_up(coords, params)*PHI_down(coords, params)

In [108]:
N_UP = 2
N_DOWN = 1

In [109]:
lss = [N_UP+N_DOWN, 5, 2, 1]

In [110]:
SAMPLE_W, SAMPLE_B = init_params_SAMPLE(lss)

In [111]:
phis = get_phi_params(lss)

In [119]:
print(phis)

[Array([-1.0841054 ,  0.250168  ,  0.95714045,  0.43275496, -1.4159592 ,
        0.72935706,  0.13850223,  0.9647855 ,  0.01983636, -1.4240693 ,
       -0.28382698, -0.02209239, -0.3736783 ,  0.78938675, -1.0608984 ,
        2.102186  ,  0.48607296,  1.7776774 ,  0.4649085 ,  0.72186756,
        1.2012659 , -0.24226609,  0.11984123, -0.7048975 ,  0.4780057 ,
       -0.6237609 ,  0.37379402,  1.7054244 , -0.07227264,  0.25605917,
       -0.15040253, -1.0511962 ,  2.1371129 ,  0.6787612 , -0.21914542],      dtype=float32), Array([-0.24875341, -0.41059425,  0.0235319 ,  0.25725865, -0.3345478 ,
        0.73682964, -0.8733575 ,  0.15866962,  0.3437804 , -0.4928135 ,
        0.02310087,  1.0655915 , -1.4893063 , -1.3438495 , -0.4990888 ,
       -1.6293653 ,  0.01850508,  0.51292956,  0.28775594,  2.1952846 ,
       -0.303544  ,  0.6366342 , -0.7383139 ,  0.9870084 ,  0.50908405,
        1.6146858 ,  0.07765636, -0.02969163, -0.8505752 ,  0.1657593 ,
       -0.01972426,  0.84009546, -1.72876

In [112]:
x1u = 0.1
x2u = 0.2
x1d = 0.3

In [113]:
Psi([x1u, x2u, x1d], phis)

1.307428885699463

In [114]:
Psi([x1u, x1d, x2u], phis) #no symmetry expected

2.426176759000212

In [115]:
Psi([x2u, x1u, x1d], phis) #antisymmetry

-1.30742888569947

In [116]:
Psi([x2u, x1d, x1u], phis) #no symmetry expected

1.090128361079354

In [117]:
Psi([x1d, x1u, x1d], phis) #no symmetry expected

-2.63548694422091

In [121]:
Psi([x1d, x2u, x1u], phis) #no symmetry expected but this is -Psi([x2u, x1d, x1u], phis). Is this ok?

-1.0901283610793557

In [122]:
Psi([x1d, x1u, x2u], phis) 

-2.426176759000219