In [1]:
import tensorflow as tf
import torch as tc
import numpy as np
import os 
from tf_utils.utils import *
from tc_utils.utils import *
import tensorflow.keras as tk

tf.keras.backend.set_floatx('float64')


os.environ["CUDA_VISIBLE_DEVICES"]="-1"
from torch.autograd import grad as tcgrad
npDTYPE = np.float32
tcDTYPE = tc.float32
tfDTYPE = tf.float32

npDTYPE = np.float64
tcDTYPE = tc.float64
tfDTYPE = tf.float64

In [2]:
def convert_to_torch(tensors):
    new_tensors = []
    for tensor in tensors:
        new_tensor = tc.tensor(tensor.numpy(), requires_grad=True)
        new_tensors.append(new_tensor)
    return new_tensors

def tcVariable(tensor):
    return tc.autograd.Variable(tc.tensor(tensor))

def compute_relative_vectors(v1, v2):
    relative_vectors = v1.unsqueeze(2) - v2.unsqueeze(1)
    return relative_vectors

# def compare_tensors(ptorch, tflow):
#     ptorch = ptorch.detach().numpy()
#     tflow = tflow.numpy()
#     mask = np.isclose(ptorch, tflow, rtol=0.0, atol=1e-4).astype(np.float32)
#     return np.mean(mask)

def compare_tensors(ptorch, tflow):
    ptorch = ptorch.detach().numpy()
    tflow = tflow.numpy()
    ptorch = np.reshape(ptorch, tflow.shape)
    x = np.mean(np.abs(ptorch - tflow))
    return x

# def compare_tf_tensors(tflow1, tflow2,rtol=1e-05, atol=1e-08): #
#     tflow1 = tflow1.numpy()
#     tflow2 = tflow2.numpy()
#     mask = np.isclose(tflow1, tflow2).astype(np.float32)
#     return np.mean(mask) 

def compare_tf_tensors(tflow1, tflow2,rtol=1e-05, atol=1e-08): #
    tflow1 = tflow1.numpy()
    tflow2 = tflow2.numpy()
    return np.mean(np.abs(tflow1 - tflow2))

ct = compare_tensors

def tflaplacian(model, r_electrons):
    n_electrons = r_electrons.shape[1]
    r_electrons = tf.reshape(r_electrons, (-1, n_electrons*3))
    r_s = [r_electrons[..., i] for i in range(r_electrons.shape[-1])]
    with tf.GradientTape(True) as g:
        [g.watch(r) for r in r_s]
        r_electrons = tf.stack(r_s, -1)
        r_electrons = tf.reshape(r_electrons, (-1, n_electrons, 3))
        with tf.GradientTape(True) as gg:
            gg.watch(r_electrons)
            log_phi = model(r_electrons)
        dlogphi_dr = gg.gradient(log_phi, r_electrons)
        dlogphi_dr = tf.reshape(dlogphi_dr, (-1, n_electrons*3))
        grads = [dlogphi_dr[..., i] for i in range(dlogphi_dr.shape[-1])]
    d2logphi_dr2 = tf.stack([g.gradient(grad, r) for grad, r in zip(grads, r_s)], -1)
    return dlogphi_dr**2, d2logphi_dr2


In [3]:
# model fns

def linear(w, a, n_samples, n_conv, flow=True):
    if flow:
        a = tf.concat((a, tf.ones((n_samples, n_conv, 1), dtype=tfDTYPE)), axis=-1)
        out = tf.tanh(a @ w)
        return out
    else:
        a = tc.cat((a, tc.ones((n_samples, n_conv, 1), dtype=tcDTYPE)), dim=-1)
        out = tc.tanh(a @ w)
        return out
    

def env(inputs, ae_vecs, env_w, env_sigma, env_pi, n_samples, n_spins, flow=True):
    if flow:
        inputs = tf.concat((inputs, tf.ones((n_samples, n_spins, 1), dtype=tfDTYPE)), axis=-1)
        factor = tf.einsum('njf,kifs->njkis', inputs, env_w)
        
        exponent = tf.einsum('njmv,kimvc->njkimc', ae_vecs, env_sigma)
        exponential = tf.exp(-tf.norm(exponent, axis=-1))

        exp = tf.einsum('njkim,kims->njkis', exponential, env_pi)

        output = factor * exp
        output = tf.transpose(output, perm=(0, 2, 3, 1, 4))  # ij ordering doesn't matter / slight numerical diff

        return tf.squeeze(output, -1)  
        
    else:
        inputs = tc.cat((inputs, tc.ones((n_samples, n_spins, 1), dtype=tcDTYPE)), dim=-1)
        factor = tc.einsum('njf,kifs->njkis', inputs, env_w)
        
        exponent = tc.einsum('njmv,kimvc->njkimc', ae_vecs, env_sigma)
        exponential = tc.exp(-tc.norm(exponent, dim=-1))

        exp = tc.einsum('njkim,kims->njkis', exponential, env_pi)

        output = factor * exp
        output = output.permute((0, 2, 3, 1, 4))  # ij ordering doesn't matter / slight numerical diff

        return output.squeeze(-1)  
    
def compute_inputs(r_electrons, n_samples, ae_vectors, n_atoms, n_electrons, full_pairwise, flow=True):
    # r_atoms: (n_atoms, 3)
    # r_electrons: (n_samples, n_electrons, 3)
    # ae_vectors: (n_samples, n_electrons, n_atoms, 3)
    if flow:
        ae_distances = tf.norm(ae_vectors, axis=-1, keepdims=True)
        single_inputs = tf.concat((ae_vectors, ae_distances), axis=-1)
        single_inputs = tf.reshape(single_inputs, (-1, n_electrons, 4*n_atoms))

        re1 = tf.expand_dims(r_electrons, 2)
        re2 = tf.transpose(re1, perm=(0, 2, 1, 3))
        ee_vectors = re1 - re2

        # ** full pairwise
        if full_pairwise:
            # eye_mask = tf.expand_dims(tf.expand_dims(tf.eye(n_electrons, dtype=tf.bool), 0), -1)
            # tmp = tf.where(eye_mask, 1., tf.norm(ee_vectors, keepdims=True, axis=-1))
            # ee_distances = tf.where(eye_mask, tf.zeros_like(eye_mask, dtype=tf.float32), tmp)
            ee_vectors = tf.reshape(ee_vectors, (-1, n_electrons**2, 3))
            ee_distances = safe_norm(ee_vectors)
            pairwise_inputs = tf.concat((ee_vectors, ee_distances), axis=-1)
            # pairwise_inputs = tf.reshape(pairwise_inputs, (-1, n_electrons**2, 4))
        else:
            # ** partial pairwise
            mask = tf.eye(n_electrons, dtype=tf.bool)
            mask = ~tf.tile(tf.expand_dims(tf.expand_dims(mask, 0), 3), (n_samples, 1, 1, 3))

            ee_vectors = tf.boolean_mask(ee_vectors, mask)
            ee_vectors = tf.reshape(ee_vectors, (-1, n_electrons**2 - n_electrons, 3))
            ee_distances = tf.norm(ee_vectors, axis=-1, keepdims=True)

            pairwise_inputs = tf.concat((ee_vectors, ee_distances), axis=-1)

        return single_inputs, pairwise_inputs
    else:
        ae_distances = tc.norm(ae_vectors, dim=-1, keepdim=True)
        single_inputs = tc.cat((ae_vectors, ae_distances), dim=-1)
        single_inputs = single_inputs.view((-1, n_electrons, 4 * n_atoms))

        re1 = r_electrons.unsqueeze(2)
        re2 = re1.permute((0, 2, 1, 3))
        ee_vectors = re1 - re2

        # ** full pairwise
        if full_pairwise:
            # eye_mask = tf.expand_dims(tf.expand_dims(tf.eye(n_electrons, dtype=tf.bool), 0), -1)
            # tmp = tf.where(eye_mask, 1., tf.norm(ee_vectors, keepdims=True, axis=-1))
            # ee_distances = tf.where(eye_mask, tf.zeros_like(eye_mask, dtype=tf.float32), tmp)
            ee_vectors = ee_vectors.view((-1, n_electrons ** 2, 3))
            ee_distances = tc.norm(ee_vectors, dim=-1, keepdim=True) + 1e-16
            pairwise_inputs = tc.cat((ee_vectors, ee_distances), dim=-1)
            # pairwise_inputs = tf.reshape(pairwise_inputs, (-1, n_electrons**2, 4))
        else:
            print('NOT IMPLEMENTED')

        return single_inputs, pairwise_inputs
    
def compute_ae_vectors(r_atoms, r_electrons, flow):
    # ae_vectors (n_samples, n_electrons, n_atoms, 3)
    if flow:
        r_atoms = tf.expand_dims(r_atoms, 1)
        r_electrons = tf.expand_dims(r_electrons, 2)
        ae_vectors = r_electrons - r_atoms
    else:
        r_atoms = r_atoms.unsqueeze(1)
        r_electrons = r_electrons.unsqueeze(2)
        ae_vectors = r_electrons - r_atoms
    return ae_vectors


def init(in_dim, weight_shape, out_dim, env_init=0.0):
    if env_init == 0.0:
        minval = np.maximum(-1., -(6/(in_dim+out_dim))**0.5)
        maxval = np.minimum(1., (6/(in_dim+out_dim))**0.5)
        weights = np.random.uniform(size=weight_shape, low=minval, high=maxval)
    else:
        weights = np.random.uniform(size=weight_shape, low=-env_init, high=env_init)
    return weights.astype(npDTYPE)

def apbi(w, axis=None):
    if axis is None:
        shape = (1, *w.shape[1:])
        b = np.ones(shape).astype(npDTYPE)
        return np.concatenate((w, b), axis=0)
    shape = (*w.shape[:axis], 1, *w.shape[axis+1:])
    b = np.ones(shape).astype(npDTYPE)
    return np.concatenate((w, b), axis=axis)


# model architecture
full_pairwise = True
n_atoms = 1
n_electrons = 4
n_spin_up = 2
n_spin_down = n_electrons - n_spin_up
n_pairwise = n_electrons**2
if not full_pairwise:
    n_pairwise -= n_electrons

nf_single_in = 4 * n_atoms
nf_hidden_single = 128
nf_pairwise_in = 4
nf_hidden_pairwise = 16
nf_intermediate_single = 3*nf_hidden_single + 2*nf_hidden_pairwise

n_determinants = 8

env_init = 1.

n_samples = 100

# params # in_dim, weight_shape, out_dim, env_init=0.0
s_stream_0 = apbi(init(nf_single_in, (nf_single_in, nf_hidden_single), nf_hidden_single))
p_stream_0 = apbi(init(nf_pairwise_in, (nf_pairwise_in, nf_hidden_pairwise), nf_hidden_pairwise))

s_stream_1 = apbi(init(nf_intermediate_single, (nf_intermediate_single, nf_hidden_single), nf_hidden_single))
p_stream_1 = apbi(init(nf_hidden_pairwise, (nf_hidden_pairwise, nf_hidden_pairwise), nf_hidden_pairwise))

s_stream_2 = apbi(init(nf_intermediate_single, (nf_intermediate_single, nf_hidden_single), nf_hidden_single))

env_w_up = apbi(init(nf_hidden_single, (n_determinants, n_spin_up, nf_hidden_single, 1), 1), axis=2)
env_sigma_up  = init(3, (n_determinants, n_spin_up, n_atoms, 3, 3), 3, env_init=env_init)
env_pi_up = init(n_atoms, (n_determinants, n_spin_up, n_atoms, 1), 1, env_init=env_init)

env_w_down = apbi(init(nf_hidden_single, (n_determinants, n_spin_down, nf_hidden_single, 1), 1), axis=2)
env_sigma_down = init(3, (n_determinants, n_spin_down, n_atoms, 3, 3), 3, env_init=env_init)
env_pi_down = init(n_atoms, (n_determinants, n_spin_down, n_atoms, 1), 1, env_init=env_init)

w_final = init(n_determinants, (1, n_determinants, 1, 1), 1, env_init=env_init/n_determinants)

In [4]:

@tf.custom_gradient
def safe_norm_grad(x, norm):
    # x : (n, ne**2, 3)
    # norm : (n, ne**2, 1)
    g = x / norm
    g = tf.where(tf.math.is_nan(g), tf.zeros_like(g), g)
    cache = (x, norm)

    def grad_grad(dy):
        x, norm = cache
        x = tf.expand_dims(x, -1)  # (n, ne**2, 3, 1)
        xx = x * tf.transpose(x, perm=(0, 1, 3, 2))  # cross terms
        inv_norm = tf.tile(1. / norm, (1, 1, 3))  # (n, ne**2, 3) inf where the ee terms are same e
        norm_diag = tf.linalg.diag(inv_norm) # (n, ne**2, 3, 3) # diagonal where the basis vector is the same
        gg = norm_diag - xx / tf.expand_dims(norm, -1)**3
        gg = tf.reduce_sum(gg, axis=-1)
        gg = tf.where(tf.math.is_nan(gg), tf.zeros_like(gg), gg)
        tf.debugging.check_numerics(gg, 'gg')
        tf.debugging.check_numerics(dy, 'dy')
        return dy*gg, None

    return g, grad_grad

@tf.custom_gradient
def safe_norm(x):
    norm = tf.norm(x, keepdims=True, axis=-1)
    def grad(dy):
        g = safe_norm_grad(x, norm)
        return dy*g
    return norm, grad

class Mixer(tk.Model):
    def __init__(self, n_electrons, n_single_features, n_pairwise, n_pairwise_features, n_spin_up, n_spin_down, full_pairwise):
        super(Mixer, self).__init__()

        self.n_spin_up = float(n_spin_up)
        self.n_spin_down = float(n_spin_down)

        tmp1 = tf.ones((1, n_spin_up, n_single_features), dtype=tf.bool)
        tmp2 = tf.zeros((1, n_spin_down, n_single_features), dtype=tf.bool)
        self.spin_up_mask = tf.concat((tmp1, tmp2), 1)
        self.spin_down_mask = ~self.spin_up_mask

        if full_pairwise:
            self.pairwise_spin_up_mask, self.pairwise_spin_down_mask = \
                generate_pairwise_masks_full(n_electrons, n_pairwise, n_spin_up, n_spin_down, n_pairwise_features)
        else:
            self.pairwise_spin_up_mask, self.pairwise_spin_down_mask = \
                generate_pairwise_masks(n_electrons, n_pairwise, n_spin_up, n_spin_down, n_pairwise_features)

    # @tf.function
    def call(self, single, pairwise, n_samples, n_electrons):
        # single (n_samples, n_electrons, n_single_features)
        # pairwise (n_samples, n_electrons, n_pairwise_features)
        spin_up_mask = tf.tile(self.spin_up_mask, (n_samples, 1, 1))
        spin_down_mask = tf.tile(self.spin_down_mask, (n_samples, 1, 1))

        # --- Single summations
        replace = tf.zeros_like(single, dtype=tfDTYPE)
        # up
        sum_spin_up = tf.where(spin_up_mask, single, replace)
        sum_spin_up = tf.reduce_sum(sum_spin_up, 1, keepdims=True) / self.n_spin_up
        sum_spin_up = tf.tile(sum_spin_up, (1, n_electrons, 1))
        # down
        sum_spin_down = tf.where(spin_down_mask, single, replace)
        sum_spin_down = tf.reduce_sum(sum_spin_down, 1, keepdims=True) / self.n_spin_down
        sum_spin_down = tf.tile(sum_spin_down, (1, n_electrons, 1))

        # --- Pairwise summations
        sum_pairwise = tf.tile(tf.expand_dims(pairwise, 1), (1, n_electrons, 1, 1))
        replace = tf.zeros_like(sum_pairwise, dtype=tfDTYPE)
        # up
        sum_pairwise_up = tf.where(self.pairwise_spin_up_mask, sum_pairwise, replace)
        sum_pairwise_up = tf.reduce_sum(sum_pairwise_up, 2) / self.n_spin_up
        # down
        sum_pairwise_down = tf.where(self.pairwise_spin_down_mask, sum_pairwise, replace)
        sum_pairwise_down = tf.reduce_sum(sum_pairwise_down, 2) / self.n_spin_down

        features = tf.concat((single, sum_spin_up, sum_spin_down, sum_pairwise_up, sum_pairwise_down), 2)
        return features
    

class fermi_tf_custom():
    def __init__(self, r_atoms):
        self.r_atoms = r_atoms
        
        self.s0 = tf.Variable(s_stream_0)
        self.p0 = tf.Variable(p_stream_0)
        self.m0 = Mixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s1 = tf.Variable(s_stream_1)
        self.p1 = tf.Variable(p_stream_1)
        self.m1 = Mixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s2 = tf.Variable(s_stream_2)
        
        self.env_w_up = tf.Variable(env_w_up)
        self.env_sigma_up = tf.Variable(env_sigma_up)
        self.env_pi_up = tf.Variable(env_pi_up)
        
        self.env_w_down = tf.Variable(env_w_down)
        self.env_sigma_down = tf.Variable(env_sigma_down)
        self.env_pi_down = tf.Variable(env_pi_down)
        
        self.w_final = tf.Variable(w_final)
        
    def __call__(self, samples):
        n_samples = samples.shape[0]
        ae_vectors = compute_ae_vectors(self.r_atoms, samples, flow=True)
        single, pairwise = compute_inputs(samples, n_samples, ae_vectors, n_atoms, n_electrons, full_pairwise)
        
        # streams 
        s0 = linear(self.s0, single, n_samples, n_electrons, flow=True)
        p0 = linear(self.p0, pairwise, n_samples, n_pairwise, flow=True)
        s0m = self.m0(s0, p0, n_samples, n_electrons)
        
        s1 = linear(self.s1, s0m, n_samples, n_electrons, flow=True)
        p1 = linear(self.p1, p0, n_samples, n_pairwise, flow=True)
        s1m = self.m1(s1, p1, n_samples, n_electrons)
        
        s2 = linear(self.s2, s1m, n_samples, n_electrons, flow=True)
        
        # env inputs
        ae_vectors_up, ae_vectors_down = tf.split(ae_vectors, [n_spin_up, n_spin_down], axis=1)
        inputs_up, inputs_down = tf.split(s2, [n_spin_up, n_spin_down], axis=1)
        
        # env_w 'njf,kifs->nkjis'
        # env_sigma 'njmv,kimvc->nkjimc'
        # env_pi 'njkim,kims->nkjis'
        up_dets = env(inputs_up, ae_vectors_up, self.env_w_up, self.env_sigma_up, self.env_pi_up, n_samples, n_spin_up)
        down_dets = env(inputs_down, ae_vectors_down, self.env_w_down, self.env_sigma_down, self.env_pi_down, n_samples, n_spin_down)
        
        log_psi, _, _, _ = log_abs_sum_det(up_dets, down_dets, self.w_final)
        
        return log_psi
    
    
class fermi_tf():
    def __init__(self, r_atoms):
        self.r_atoms = r_atoms
        
        self.s0 = tf.Variable(s_stream_0)
        self.p0 = tf.Variable(p_stream_0)
        self.m0 = Mixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s1 = tf.Variable(s_stream_1)
        self.p1 = tf.Variable(p_stream_1)
        self.m1 = Mixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s2 = tf.Variable(s_stream_2)
        
        self.env_w_up = tf.Variable(env_w_up)
        self.env_sigma_up = tf.Variable(env_sigma_up)
        self.env_pi_up = tf.Variable(env_pi_up)
        
        self.env_w_down = tf.Variable(env_w_down)
        self.env_sigma_down = tf.Variable(env_sigma_down)
        self.env_pi_down = tf.Variable(env_pi_down)
        
        self.w_final = tf.Variable(w_final)
        
    def __call__(self, samples):
        n_samples = samples.shape[0]
        ae_vectors = compute_ae_vectors(self.r_atoms, samples, flow=True)
        single, pairwise = compute_inputs(samples, n_samples, ae_vectors, n_atoms, n_electrons, full_pairwise)
        
        # streams 
        s0 = linear(self.s0, single, n_samples, n_electrons, flow=True)
        p0 = linear(self.p0, pairwise, n_samples, n_pairwise, flow=True)
        s0m = self.m0(s0, p0, n_samples, n_electrons)
        
        s1 = linear(self.s1, s0m, n_samples, n_electrons, flow=True)
        p1 = linear(self.p1, p0, n_samples, n_pairwise, flow=True)
        s1m = self.m1(s1, p1, n_samples, n_electrons)
        
        s2 = linear(self.s2, s1m, n_samples, n_electrons, flow=True)
        
        # env inputs
        ae_vectors_up, ae_vectors_down = tf.split(ae_vectors, [n_spin_up, n_spin_down], axis=1)
        inputs_up, inputs_down = tf.split(s2, [n_spin_up, n_spin_down], axis=1)
        
        # env_w 'njf,kifs->nkjis'
        # env_sigma 'njmv,kimvc->nkjimc'
        # env_pi 'njkim,kims->nkjis'
        up_dets = env(inputs_up, ae_vectors_up, self.env_w_up, self.env_sigma_up, self.env_pi_up, n_samples, n_spin_up)
        down_dets = env(inputs_down, ae_vectors_down, self.env_w_down, self.env_sigma_down, self.env_pi_down, n_samples, n_spin_down)
        print('\n')
        up_dets = tf.expand_dims(tf.expand_dims(tf.linalg.det(up_dets), -1), -1)
        print(up_dets.shape, self.w_final.shape, (self.w_final * up_dets).shape)
        down_dets = tf.expand_dims(tf.expand_dims(tf.linalg.det(down_dets), -1), -1)
        
#         up_dets = tf.linalg.det(up_dets)
#         down_dets = tf.linalg.det(down_dets)
        tmp = self.w_final * up_dets * down_dets 
        print(tmp.shape)
        log_psi = tf.math.log(tf.abs(tf.reduce_sum(tmp, axis=1)))
        print('\n')
        return log_psi
    

def slogdet_keepdim(tensor):
    sign, tensor_out = tf.linalg.slogdet(tensor)
    tensor_out = tf.reshape(tensor_out, (*tensor_out.shape, 1, 1))
    sign = tf.reshape(sign, (*sign.shape, 1, 1))
    return sign, tensor_out


def generate_gamma(s):
    n_egvs = s.shape[2]
    gamma = [tf.reduce_prod(s[:, :, :i], axis=-1) * tf.reduce_prod(s[:, :, i+1:], axis=-1) for i in range(n_egvs-1)]
    gamma.append(tf.reduce_prod(s[:, :, :-1], axis=-1))
    gamma = tf.stack(gamma, axis=2)
    gamma = tf.expand_dims(gamma, axis=2)
    return gamma


def first_derivative_det(A):
    with tf.device("/cpu:0"):  # this is incredible stupid /// its actually not
        s, u, v = tf.linalg.svd(A, full_matrices=False)
    v_t = tf.linalg.matrix_transpose(v)
    gamma = generate_gamma(s)
    sign = (tf.linalg.det(u) * tf.linalg.det(v))[..., None, None]
    out = sign * ((u * gamma) @ v_t)
    return out, (s, u, v_t, sign)


def generate_p(s):
    n_samples, n_k, n_dim = s.shape
    new_shape = (1, 1, 1, n_dim, n_dim)
    s = s[..., None, None]
    s = tf.tile(s, new_shape)
    mask = np.ones(s.shape, dtype=np.bool)
    for i in range(n_dim):
        for j in range(n_dim):
            mask[..., i, i, j] = False
            mask[..., j, i, j] = False
    mask = tf.convert_to_tensor(mask)
    s = tf.where(mask, s, tf.ones_like(s, dtype=s.dtype))
    s_prod = tf.reduce_prod(s, axis=-3)
    s_prod = tf.linalg.set_diag(s_prod, tf.zeros((s_prod.shape[:-1]), dtype=s.dtype))
    return s_prod


def second_derivative_det(A, C_dash, *A_cache):

    s, u, v_t, sign = A_cache  # decompose the cache

    M = v_t @ tf.linalg.matrix_transpose(C_dash) @ u

    p = generate_p(s)

    sgn = tf.math.sign(sign)

    m_jj = tf.linalg.diag_part(M)
    xi = -M * p

    xi_diag = p @ tf.expand_dims(m_jj, -1)
    xi = tf.linalg.set_diag(xi, tf.squeeze(xi_diag, -1))
    return sgn * u @ xi @ v_t


def k_sum(x):
    return tf.reduce_sum(x, axis=1, keepdims=True)


def matrix_sum(x):
    return tf.reduce_sum(x, axis=[-2, -1], keepdims=True)


def _log_abs_sum_det_second_order(a_dash, b_dash, w_dash, *cache):

    a, b, w, unshifted_exp, sign_unshifted_sum, sign_a, logdet_a, sign_b, logdet_b, log_psi, \
    sign_u, ddeta, ddeta_cache, ddetb, ddetb_cache, dfddeta, dfddetb, da, db, dw = cache

    dfddeta_w = dfddeta / w
    dfddetb_w = dfddetb / w

    ddeta_sum = matrix_sum(a_dash * ddeta)
    da_sum = matrix_sum(da * a_dash)
    ddetb_sum = matrix_sum(b_dash * ddetb)
    db_sum = matrix_sum(db * b_dash)
    a_sum = k_sum(dfddeta * ddeta_sum)
    b_sum = k_sum(dfddetb * ddetb_sum)

    # Compute second deriviate of f wrt to w
    d2w = w_dash * -dw * k_sum(dw)

    # compute deriviate of df/da wrt to w
    dadw = -dw * k_sum(da_sum)
    dadw += dfddeta_w * ddeta_sum  # i=j

    # compute derivative of df/db wrt to w
    dbdw = -dw * k_sum(db_sum)
    dbdw += dfddetb_w * ddetb_sum  # i=j

    # Compute second derivative of f wrt to a
    d2a = -da * a_sum
    d2a += dfddeta * second_derivative_det(a, a_dash, *ddeta_cache)  # i=j
    # Compute derivative of df/db wrt to a
    dbda = -da * b_sum
    dbda += ddeta * sign_u * tf.exp(-log_psi) * w * ddetb_sum  # i=j
    # Compute derivative of df/dw wrt to a
    dwda = w_dash * -da * k_sum(dw)
    dwda += w_dash * da / w  # i=j

    # Compute second derivative of f wrt to b
    d2b = -db * b_sum
    d2b += dfddetb * second_derivative_det(b, b_dash, *ddetb_cache)  # i=j
    # Compute derivative of df/da wrt to b
    dadb = -db * a_sum
    dadb += ddetb * sign_u * tf.exp(-log_psi) * w * ddeta_sum  # i=j
    # Compute derivative of df/dw wrt to b
    dwdb = w_dash * -db * k_sum(dw)
    dwdb += w_dash * db / w  # i=j

    return (d2a + dbda + dwda), (d2b + dadb + dwdb), \
           (d2w + dadw + dbdw), \
           None, None, None, None, None, None, None, None, None, None

def _log_abs_sum_det_fwd(a, b, w):

    # Take the slogdet of all k determinants
    sign_a, logdet_a = slogdet_keepdim(a)
    sign_b, logdet_b = slogdet_keepdim(b)

    x = logdet_a + logdet_b
    xmax = tf.math.reduce_max(x, axis=1, keepdims=True)

    unshifted_exp = sign_a * sign_b * tf.exp(x)
    unshifted_exp_w = w * unshifted_exp
    sign_unshifted_sum = tf.math.sign(tf.reduce_sum(unshifted_exp_w, axis=1, keepdims=True))

    exponent = x - xmax
    shifted_exp = sign_a * sign_b * tf.exp(exponent)

    u = w * shifted_exp
    u_sum = tf.reduce_sum(u, axis=1, keepdims=True)
    sign_shifted_sum = tf.math.sign(u_sum)
    log_psi = tf.math.log(tf.math.abs(u_sum)) + xmax

    sensitivities = tf.exp(-log_psi) * sign_unshifted_sum

    return log_psi, sign_unshifted_sum, unshifted_exp, sensitivities, \
           (a, b, w, unshifted_exp, sign_unshifted_sum, sign_a, logdet_a, sign_b, logdet_b, log_psi)


def _log_abs_sum_det_first_order(*fwd_cache):

    a, b, w, unshifted_exp, sign_unshifted_sum, sign_a, logdet_a, sign_b, logdet_b, log_psi = fwd_cache

    ddeta, ddeta_cache = first_derivative_det(a)
    ddetb, ddetb_cache = first_derivative_det(b)

    dfddeta = w * sign_unshifted_sum * sign_b * tf.exp(logdet_b - log_psi)
    dfddetb = w * sign_unshifted_sum * sign_a * tf.exp(logdet_a - log_psi)

    da = dfddeta * ddeta
    db = dfddetb * ddetb
    dw = sign_unshifted_sum * unshifted_exp * tf.exp(-log_psi)
    
    return (da, db, dw), (sign_unshifted_sum, ddeta, ddeta_cache, ddetb, ddetb_cache, dfddeta, dfddetb, da, db, dw)


@tf.custom_gradient
def first_order_gradient(a_unused, b_unused, w_unused, *fwd_cache):

    (da, db, dw), first_order_cache = _log_abs_sum_det_first_order(*fwd_cache)
    return (da, db, dw), \
           lambda a_dash, b_dash, w_dash: _log_abs_sum_det_second_order(
               a_dash, b_dash, w_dash, *fwd_cache, *first_order_cache)

@tf.custom_gradient
def log_abs_sum_det(a, b, w):

    log_psi, sign, act, sens, fwd_cache = _log_abs_sum_det_fwd(a, b, w)

    def _first_order_grad(dy, dsg, _, __):
        da, db, dw = first_order_gradient(a, b, w, *fwd_cache)

        return dy * da, dy * db, tf.reduce_sum(dy * dw, axis=0, keepdims=True)

    return (log_psi, sign, act, sens), _first_order_grad






def ps(ls):
    for t, n in ls:
        print(n, t.shape)
        
    

In [5]:
samples = np.random.normal(size=(n_samples, n_electrons, 3)).astype(npDTYPE)
r_atoms = np.zeros(shape=(1,3)).astype(npDTYPE)

r_atoms = tf.convert_to_tensor(r_atoms, dtype=tfDTYPE)
samples = tf.convert_to_tensor(samples, dtype=tfDTYPE)

ftf_custom = fermi_tf_custom(r_atoms)
psi = ftf_custom(samples)

ftf = fermi_tf(r_atoms)
tcpsi = ftf(samples)

print(tcpsi.shape, psi.shape)
print(compare_tf_tensors(tf.squeeze(tcpsi), tf.squeeze(psi)))

# for a, b in zip(psi, tcpsi):
#     print(a.numpy(), b.numpy())


Instructions for updating:
Use tf.identity instead.


(100, 8, 1, 1) (1, 8, 1, 1) (100, 8, 1, 1)
(100, 8, 1, 1)


(100, 1, 1) (100, 1, 1, 1)
3.730349362740526e-16


In [6]:
g_custom, gg_custom = tflaplacian(ftf_custom, samples)

g, gg = tflaplacian(ftf, samples)

print(compare_tf_tensors(tf.squeeze(g_custom), tf.squeeze(g)))
print(compare_tf_tensors(tf.squeeze(gg_custom), tf.squeeze(gg)))





(100, 8, 1, 1) (1, 8, 1, 1) (100, 8, 1, 1)
(100, 8, 1, 1)


4.039320372641201e-14
87.1913604402419


In [7]:
tc.diag(tc.tensor([[1,2], [4, 5]]))

tensor([1, 5])