In [1]:
import tensorflow as tf
import torch as tc
import numpy as np
import os 
from tf_utils.utils import *
from tc_utils.utils import *
import tensorflow.keras as tk

tf.keras.backend.set_floatx('float64')


os.environ["CUDA_VISIBLE_DEVICES"]="-1"
from torch.autograd import grad as tcgrad
npDTYPE = np.float32
tcDTYPE = tc.float32
tfDTYPE = tf.float32

npDTYPE = np.float64
tcDTYPE = tc.float64
tfDTYPE = tf.float64

In [2]:
def convert_to_torch(tensors):
    new_tensors = []
    for tensor in tensors:
        new_tensor = tc.tensor(tensor.numpy(), requires_grad=True)
        new_tensors.append(new_tensor)
    return new_tensors

def tcVariable(tensor):
    return tc.autograd.Variable(tc.tensor(tensor))

def compute_relative_vectors(v1, v2):
    relative_vectors = v1.unsqueeze(2) - v2.unsqueeze(1)
    return relative_vectors

# def compare_tensors(ptorch, tflow):
#     ptorch = ptorch.detach().numpy()
#     tflow = tflow.numpy()
#     mask = np.isclose(ptorch, tflow, rtol=0.0, atol=1e-4).astype(np.float32)
#     return np.mean(mask)

def compare_tensors(ptorch, tflow):
    ptorch = ptorch.detach().numpy()
    tflow = tflow.numpy()
    ptorch = np.reshape(ptorch, tflow.shape)
    x = np.mean(np.abs(ptorch - tflow))
    return x

def compare_tf_tensors(tflow1, tflow2,rtol=1e-05, atol=1e-08): #
    tflow1 = tflow1.numpy()
    tflow2 = tflow2.numpy()
    mask = np.isclose(tflow1, tflow2).astype(np.float32)
    return np.mean(mask) 



ct = compare_tensors

In [3]:
# model fns

def linear(w, a, n_samples, n_conv, flow=True):
    if flow:
        a = tf.concat((a, tf.ones((n_samples, n_conv, 1), dtype=tfDTYPE)), axis=-1)
        out = tf.tanh(a @ w)
        return out
    else:
        a = tc.cat((a, tc.ones((n_samples, n_conv, 1), dtype=tcDTYPE)), dim=-1)
        out = tc.tanh(a @ w)
        return out
    

def env(inputs, ae_vecs, env_w, env_sigma, env_pi, n_samples, n_spins, flow=True):
    if flow:
        inputs = tf.concat((inputs, tf.ones((n_samples, n_spins, 1), dtype=tfDTYPE)), axis=-1)
        factor = tf.einsum('njf,kifs->njkis', inputs, env_w)
        
        exponent = tf.einsum('njmv,kimvc->njkimc', ae_vecs, env_sigma)
        exponential = tf.exp(-tf.norm(exponent, axis=-1))

        exp = tf.einsum('njkim,kims->njkis', exponential, env_pi)

        output = factor * exp
        output = tf.transpose(output, perm=(0, 2, 3, 1, 4))  # ij ordering doesn't matter / slight numerical diff

        return tf.squeeze(output, -1)  
        
    else:
        inputs = tc.cat((inputs, tc.ones((n_samples, n_spins, 1), dtype=tcDTYPE)), dim=-1)
        factor = tc.einsum('njf,kifs->njkis', inputs, env_w)
        
        exponent = tc.einsum('njmv,kimvc->njkimc', ae_vecs, env_sigma)
        exponential = tc.exp(-tc.norm(exponent, dim=-1))

        exp = tc.einsum('njkim,kims->njkis', exponential, env_pi)

        output = factor * exp
        output = output.permute((0, 2, 3, 1, 4))  # ij ordering doesn't matter / slight numerical diff

        return output.squeeze(-1)  
    
def compute_inputs(r_electrons, n_samples, ae_vectors, n_atoms, n_electrons, full_pairwise, flow=True):
    # r_atoms: (n_atoms, 3)
    # r_electrons: (n_samples, n_electrons, 3)
    # ae_vectors: (n_samples, n_electrons, n_atoms, 3)
    if flow:
        ae_distances = tf.norm(ae_vectors, axis=-1, keepdims=True)
        single_inputs = tf.concat((ae_vectors, ae_distances), axis=-1)
        single_inputs = tf.reshape(single_inputs, (-1, n_electrons, 4*n_atoms))

        re1 = tf.expand_dims(r_electrons, 2)
        re2 = tf.transpose(re1, perm=(0, 2, 1, 3))
        ee_vectors = re1 - re2

        # ** full pairwise
        if full_pairwise:
            # eye_mask = tf.expand_dims(tf.expand_dims(tf.eye(n_electrons, dtype=tf.bool), 0), -1)
            # tmp = tf.where(eye_mask, 1., tf.norm(ee_vectors, keepdims=True, axis=-1))
            # ee_distances = tf.where(eye_mask, tf.zeros_like(eye_mask, dtype=tf.float32), tmp)
            ee_vectors = tf.reshape(ee_vectors, (-1, n_electrons**2, 3))
            ee_distances = safe_norm(ee_vectors)
            pairwise_inputs = tf.concat((ee_vectors, ee_distances), axis=-1)
            # pairwise_inputs = tf.reshape(pairwise_inputs, (-1, n_electrons**2, 4))
        else:
            # ** partial pairwise
            mask = tf.eye(n_electrons, dtype=tf.bool)
            mask = ~tf.tile(tf.expand_dims(tf.expand_dims(mask, 0), 3), (n_samples, 1, 1, 3))

            ee_vectors = tf.boolean_mask(ee_vectors, mask)
            ee_vectors = tf.reshape(ee_vectors, (-1, n_electrons**2 - n_electrons, 3))
            ee_distances = tf.norm(ee_vectors, axis=-1, keepdims=True)

            pairwise_inputs = tf.concat((ee_vectors, ee_distances), axis=-1)

        return single_inputs, pairwise_inputs
    else:
        ae_distances = tc.norm(ae_vectors, dim=-1, keepdim=True)
        single_inputs = tc.cat((ae_vectors, ae_distances), dim=-1)
        single_inputs = single_inputs.view((-1, n_electrons, 4 * n_atoms))

        re1 = r_electrons.unsqueeze(2)
        re2 = re1.permute((0, 2, 1, 3))
        ee_vectors = re1 - re2

        # ** full pairwise
        if full_pairwise:
            # eye_mask = tf.expand_dims(tf.expand_dims(tf.eye(n_electrons, dtype=tf.bool), 0), -1)
            # tmp = tf.where(eye_mask, 1., tf.norm(ee_vectors, keepdims=True, axis=-1))
            # ee_distances = tf.where(eye_mask, tf.zeros_like(eye_mask, dtype=tf.float32), tmp)
            ee_vectors = ee_vectors.view((-1, n_electrons ** 2, 3))
            ee_distances = tc.norm(ee_vectors, dim=-1, keepdim=True) + 1e-16
            pairwise_inputs = tc.cat((ee_vectors, ee_distances), dim=-1)
            # pairwise_inputs = tf.reshape(pairwise_inputs, (-1, n_electrons**2, 4))
        else:
            print('NOT IMPLEMENTED')

        return single_inputs, pairwise_inputs
    
def compute_ae_vectors(r_atoms, r_electrons, flow):
    # ae_vectors (n_samples, n_electrons, n_atoms, 3)
    if flow:
        r_atoms = tf.expand_dims(r_atoms, 1)
        r_electrons = tf.expand_dims(r_electrons, 2)
        ae_vectors = r_electrons - r_atoms
    else:
        r_atoms = r_atoms.unsqueeze(1)
        r_electrons = r_electrons.unsqueeze(2)
        ae_vectors = r_electrons - r_atoms
    return ae_vectors


def init(in_dim, weight_shape, out_dim, env_init=0.0):
    if env_init == 0.0:
        minval = np.maximum(-1., -(6/(in_dim+out_dim))**0.5)
        maxval = np.minimum(1., (6/(in_dim+out_dim))**0.5)
        weights = np.random.uniform(size=weight_shape, low=minval, high=maxval)
    else:
        weights = np.random.uniform(size=weight_shape, low=-env_init, high=env_init)
    return weights.astype(npDTYPE)

def apbi(w, axis=None):
    if axis is None:
        shape = (1, *w.shape[1:])
        b = np.ones(shape).astype(npDTYPE)
        return np.concatenate((w, b), axis=0)
    shape = (*w.shape[:axis], 1, *w.shape[axis+1:])
    b = np.ones(shape).astype(npDTYPE)
    return np.concatenate((w, b), axis=axis)


# model architecture
full_pairwise = True
n_atoms = 1
n_electrons = 4
n_spin_up = 2
n_spin_down = n_electrons - n_spin_up
n_pairwise = n_electrons**2
if not full_pairwise:
    n_pairwise -= n_electrons

nf_single_in = 4 * n_atoms
nf_hidden_single = 128
nf_pairwise_in = 4
nf_hidden_pairwise = 16
nf_intermediate_single = 3*nf_hidden_single + 2*nf_hidden_pairwise

n_determinants = 8

env_init = 1.

n_samples = 100

# params # in_dim, weight_shape, out_dim, env_init=0.0
s_stream_0 = apbi(init(nf_single_in, (nf_single_in, nf_hidden_single), nf_hidden_single))
p_stream_0 = apbi(init(nf_pairwise_in, (nf_pairwise_in, nf_hidden_pairwise), nf_hidden_pairwise))

s_stream_1 = apbi(init(nf_intermediate_single, (nf_intermediate_single, nf_hidden_single), nf_hidden_single))
p_stream_1 = apbi(init(nf_hidden_pairwise, (nf_hidden_pairwise, nf_hidden_pairwise), nf_hidden_pairwise))

s_stream_2 = apbi(init(nf_intermediate_single, (nf_intermediate_single, nf_hidden_single), nf_hidden_single))

env_w_up = apbi(init(nf_hidden_single, (n_determinants, n_spin_up, nf_hidden_single, 1), 1), axis=2)
env_sigma_up  = init(3, (n_determinants, n_spin_up, n_atoms, 3, 3), 3, env_init=env_init)
env_pi_up = init(n_atoms, (n_determinants, n_spin_up, n_atoms, 1), 1, env_init=env_init)

env_w_down = apbi(init(nf_hidden_single, (n_determinants, n_spin_down, nf_hidden_single, 1), 1), axis=2)
env_sigma_down = init(3, (n_determinants, n_spin_down, n_atoms, 3, 3), 3, env_init=env_init)
env_pi_down = init(n_atoms, (n_determinants, n_spin_down, n_atoms, 1), 1, env_init=env_init)

w_final = init(n_determinants, (1, n_determinants, 1, 1), 1, env_init=env_init/n_determinants)

In [4]:

@tf.custom_gradient
def safe_norm_grad(x, norm):
    # x : (n, ne**2, 3)
    # norm : (n, ne**2, 1)
    g = x / norm
    g = tf.where(tf.math.is_nan(g), tf.zeros_like(g), g)
    cache = (x, norm)

    def grad_grad(dy):
        x, norm = cache
        x = tf.expand_dims(x, -1)  # (n, ne**2, 3, 1)
        xx = x * tf.transpose(x, perm=(0, 1, 3, 2))  # cross terms
        inv_norm = tf.tile(1. / norm, (1, 1, 3))  # (n, ne**2, 3) inf where the ee terms are same e
        norm_diag = tf.linalg.diag(inv_norm) # (n, ne**2, 3, 3) # diagonal where the basis vector is the same
        gg = norm_diag - xx / tf.expand_dims(norm, -1)**3
        gg = tf.reduce_sum(gg, axis=-1)
        gg = tf.where(tf.math.is_nan(gg), tf.zeros_like(gg), gg)
        tf.debugging.check_numerics(gg, 'gg')
        tf.debugging.check_numerics(dy, 'dy')
        return dy*gg, None

    return g, grad_grad

@tf.custom_gradient
def safe_norm(x):
    norm = tf.norm(x, keepdims=True, axis=-1)
    def grad(dy):
        g = safe_norm_grad(x, norm)
        return dy*g
    return norm, grad

class Mixer(tk.Model):
    def __init__(self, n_electrons, n_single_features, n_pairwise, n_pairwise_features, n_spin_up, n_spin_down, full_pairwise):
        super(Mixer, self).__init__()

        self.n_spin_up = float(n_spin_up)
        self.n_spin_down = float(n_spin_down)

        tmp1 = tf.ones((1, n_spin_up, n_single_features), dtype=tf.bool)
        tmp2 = tf.zeros((1, n_spin_down, n_single_features), dtype=tf.bool)
        self.spin_up_mask = tf.concat((tmp1, tmp2), 1)
        self.spin_down_mask = ~self.spin_up_mask

        if full_pairwise:
            self.pairwise_spin_up_mask, self.pairwise_spin_down_mask = \
                generate_pairwise_masks_full(n_electrons, n_pairwise, n_spin_up, n_spin_down, n_pairwise_features)
        else:
            self.pairwise_spin_up_mask, self.pairwise_spin_down_mask = \
                generate_pairwise_masks(n_electrons, n_pairwise, n_spin_up, n_spin_down, n_pairwise_features)

    # @tf.function
    def call(self, single, pairwise, n_samples, n_electrons):
        # single (n_samples, n_electrons, n_single_features)
        # pairwise (n_samples, n_electrons, n_pairwise_features)
        spin_up_mask = tf.tile(self.spin_up_mask, (n_samples, 1, 1))
        spin_down_mask = tf.tile(self.spin_down_mask, (n_samples, 1, 1))

        # --- Single summations
        replace = tf.zeros_like(single, dtype=tfDTYPE)
        # up
        sum_spin_up = tf.where(spin_up_mask, single, replace)
        sum_spin_up = tf.reduce_sum(sum_spin_up, 1, keepdims=True) / self.n_spin_up
        sum_spin_up = tf.tile(sum_spin_up, (1, n_electrons, 1))
        # down
        sum_spin_down = tf.where(spin_down_mask, single, replace)
        sum_spin_down = tf.reduce_sum(sum_spin_down, 1, keepdims=True) / self.n_spin_down
        sum_spin_down = tf.tile(sum_spin_down, (1, n_electrons, 1))

        # --- Pairwise summations
        sum_pairwise = tf.tile(tf.expand_dims(pairwise, 1), (1, n_electrons, 1, 1))
        replace = tf.zeros_like(sum_pairwise, dtype=tfDTYPE)
        # up
        sum_pairwise_up = tf.where(self.pairwise_spin_up_mask, sum_pairwise, replace)
        sum_pairwise_up = tf.reduce_sum(sum_pairwise_up, 2) / self.n_spin_up
        # down
        sum_pairwise_down = tf.where(self.pairwise_spin_down_mask, sum_pairwise, replace)
        sum_pairwise_down = tf.reduce_sum(sum_pairwise_down, 2) / self.n_spin_down

        features = tf.concat((single, sum_spin_up, sum_spin_down, sum_pairwise_up, sum_pairwise_down), 2)
        return features
    

class fermi_tf():
    def __init__(self, r_atoms):
        self.r_atoms = r_atoms
        
        self.s0 = tf.Variable(s_stream_0)
        self.p0 = tf.Variable(p_stream_0)
        self.m0 = Mixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s1 = tf.Variable(s_stream_1)
        self.p1 = tf.Variable(p_stream_1)
        self.m1 = Mixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s2 = tf.Variable(s_stream_2)
        
        self.env_w_up = tf.Variable(env_w_up)
        self.env_sigma_up = tf.Variable(env_sigma_up)
        self.env_pi_up = tf.Variable(env_pi_up)
        
        self.env_w_down = tf.Variable(env_w_down)
        self.env_sigma_down = tf.Variable(env_sigma_down)
        self.env_pi_down = tf.Variable(env_pi_down)
        
        self.w_final = tf.Variable(w_final)
        
    def __call__(self, samples):
        n_samples = samples.shape[0]
        ae_vectors = compute_ae_vectors(self.r_atoms, samples, flow=True)
        single, pairwise = compute_inputs(samples, n_samples, ae_vectors, n_atoms, n_electrons, full_pairwise)
        
        # streams 
        s0 = linear(self.s0, single, n_samples, n_electrons, flow=True)
        p0 = linear(self.p0, pairwise, n_samples, n_pairwise, flow=True)
        s0m = self.m0(s0, p0, n_samples, n_electrons)
        
        s1 = linear(self.s1, s0m, n_samples, n_electrons, flow=True)
        p1 = linear(self.p1, p0, n_samples, n_pairwise, flow=True)
        s1m = self.m1(s1, p1, n_samples, n_electrons)
        
        s2 = linear(self.s2, s1m, n_samples, n_electrons, flow=True)
        
        # env inputs
        ae_vectors_up, ae_vectors_down = tf.split(ae_vectors, [n_spin_up, n_spin_down], axis=1)
        inputs_up, inputs_down = tf.split(s2, [n_spin_up, n_spin_down], axis=1)
        
        # env_w 'njf,kifs->nkjis'
        # env_sigma 'njmv,kimvc->nkjimc'
        # env_pi 'njkim,kims->nkjis'
        up_dets = env(inputs_up, ae_vectors_up, self.env_w_up, self.env_sigma_up, self.env_pi_up, n_samples, n_spin_up)
        down_dets = env(inputs_down, ae_vectors_down, self.env_w_down, self.env_sigma_down, self.env_pi_down, n_samples, n_spin_down)
        
        log_psi, _, _, _ = log_abs_sum_det(up_dets, down_dets, self.w_final)
        
        return log_psi

In [5]:

class fermi_tc():
    def __init__(self, r_atoms):
        self.r_atoms = r_atoms
        
        self.s0 = tcVariable(s_stream_0)
        self.p0 = tcVariable(p_stream_0)
        self.m0 = tcMixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s1 = tcVariable(s_stream_1)
        self.p1 = tcVariable(p_stream_1)
        self.m1 = tcMixer(n_electrons, nf_hidden_single, n_pairwise, nf_hidden_pairwise, n_spin_up, n_spin_down, full_pairwise)
        
        self.s2 = tcVariable(s_stream_2)
        
        self.env_w_up = tcVariable(env_w_up)
        self.env_sigma_up = tcVariable(env_sigma_up)
        self.env_pi_up = tcVariable(env_pi_up)
        
        self.env_w_down = tcVariable(env_w_down)
        self.env_sigma_down = tcVariable(env_sigma_down)
        self.env_pi_down = tcVariable(env_pi_down)
        
        self.w_final = tcVariable(w_final)
        
    def __call__(self, samples):
        n_samples = samples.shape[0]
        ae_vectors = compute_ae_vectors(self.r_atoms, samples, flow=False)
        single, pairwise = compute_inputs(samples, n_samples, ae_vectors, n_atoms, n_electrons, full_pairwise, flow=False)
        
        # streams 
        s0 = linear(self.s0, single, n_samples, n_electrons, flow=False)
        p0 = linear(self.p0, pairwise, n_samples, n_pairwise, flow=False)
        s0m = self.m0(s0, p0, n_samples, n_electrons)
        
        s1 = linear(self.s1, s0m, n_samples, n_electrons, flow=False)
        p1 = linear(self.p1, p0, n_samples, n_pairwise, flow=False)
        s1m = self.m1(s1, p1, n_samples, n_electrons)
        
        s2 = linear(self.s2, s1m, n_samples, n_electrons, flow=False)
        
        # env inputs
        ae_vectors_up, ae_vectors_down = ae_vectors.split([n_spin_up, n_spin_down], dim=1)
        inputs_up, inputs_down = s2.split([n_spin_up, n_spin_down], dim=1)
        
        # env_w 'njf,kifs->nkjis'
        # env_sigma 'njmv,kimvc->nkjimc'
        # env_pi 'njkim,kims->nkjis'
        up_dets = env(inputs_up, ae_vectors_up, self.env_w_up, self.env_sigma_up, self.env_pi_up, n_samples, n_spin_up, flow=False)
        down_dets = env(inputs_down, ae_vectors_down, self.env_w_down, self.env_sigma_down, self.env_pi_down, n_samples, n_spin_down, flow=False)
        
        log_psi = (self.w_final.squeeze(-1).squeeze(-1) * up_dets.det() * down_dets.det()).sum(1).abs().log()

        return log_psi



class tcMixer():
    def __init__(self, n_electrons, n_single_features, n_pairwise, n_pairwise_features, n_spin_up, n_spin_down, full_pairwise):
        super(tcMixer, self).__init__()

        self.n_spin_up = float(n_spin_up)
        self.n_spin_down = float(n_spin_down)

        tmp1 = tc.ones((1, n_spin_up, n_single_features), dtype=tc.bool)
        tmp2 = tc.zeros((1, n_spin_down, n_single_features), dtype=tc.bool)
        self.spin_up_mask = tc.cat((tmp1, tmp2), dim=1)
        self.spin_down_mask = ~self.spin_up_mask

        if full_pairwise:
            self.pairwise_spin_up_mask, self.pairwise_spin_down_mask = \
                tc_generate_pairwise_masks_full(n_electrons, n_pairwise, n_spin_up, n_spin_down, n_pairwise_features)
        else:
            self.pairwise_spin_up_mask, self.pairwise_spin_down_mask = \
                tc_generate_pairwise_masks(n_electrons, n_pairwise, n_spin_up, n_spin_down, n_pairwise_features)

    # @tf.function
    def __call__(self, single, pairwise, n_samples, n_electrons):
        # single (n_samples, n_electrons, n_single_features)
        # pairwise (n_samples, n_electrons, n_pairwise_features)
        spin_up_mask = self.spin_up_mask.repeat((n_samples, 1, 1))
        spin_down_mask = self.spin_down_mask.repeat((n_samples, 1, 1))

        # --- Single summations
        replace = tc.zeros_like(single, dtype=tcDTYPE)
        # up
        sum_spin_up = tc.where(spin_up_mask, single, replace)
        sum_spin_up = sum_spin_up.sum(1, keepdim=True) / self.n_spin_up
        sum_spin_up = sum_spin_up.repeat((1, n_electrons, 1))
        # down
        sum_spin_down = tc.where(spin_down_mask, single, replace)
        sum_spin_down = sum_spin_down.sum(1, keepdim=True) / self.n_spin_down
        sum_spin_down = sum_spin_down.repeat((1, n_electrons, 1))

        # --- Pairwise summations
        sum_pairwise = pairwise.unsqueeze(1).repeat((1, n_electrons, 1, 1))
        replace = tc.zeros_like(sum_pairwise, dtype=tcDTYPE)
        # up
        sum_pairwise_up = tc.where(self.pairwise_spin_up_mask, sum_pairwise, replace)
        sum_pairwise_up = sum_pairwise_up.sum(2) / self.n_spin_up
        # down
        sum_pairwise_down = tc.where(self.pairwise_spin_down_mask, sum_pairwise, replace)
        sum_pairwise_down = sum_pairwise_down.sum(2) / self.n_spin_down

        features = tc.cat((single, sum_spin_up, sum_spin_down, sum_pairwise_up, sum_pairwise_down), dim=2)
        return features
    





In [6]:
samples = np.random.normal(size=(n_samples, n_electrons, 3)).astype(npDTYPE)
r_atoms = np.zeros(shape=(1,3)).astype(npDTYPE)

ftf = fermi_tf(tf.convert_to_tensor(r_atoms, dtype=tfDTYPE))
psi = ftf(tf.convert_to_tensor(samples, dtype=tfDTYPE))

ftc = fermi_tc(tc.tensor(r_atoms))
tcpsi = ftc(tc.tensor(samples))

print(compare_tensors(tcpsi, psi))

# for x, y in zip(tf.reshape(psi, (-1,)), tcpsi.view(-1)):
#     print(x.numpy(), y.numpy(), x.numpy() - y.numpy())

# def get_jacobian(net, x, noutputs):
#     x = x.squeeze()
#     n = x.size()[0]
#     x = x.repeat(noutputs, 1, 1)
#     x.requires_grad_(True)
#     y = net(x)
#     y.backward(tc.eye(noutputs))
#     y.backward()
#     return x

# z = get_jacobian(ftc, tc.tensor(samples), n_samples)

Instructions for updating:
Use tf.identity instead.
6.297184995673888e-15


In [7]:
def tflaplacian(model, r_electrons):
    n_electrons = r_electrons.shape[1]
    r_electrons = tf.reshape(r_electrons, (-1, n_electrons*3))
    r_s = [r_electrons[..., i] for i in range(r_electrons.shape[-1])]
    with tf.GradientTape(True) as g:
        [g.watch(r) for r in r_s]
        r_electrons = tf.stack(r_s, -1)
        r_electrons = tf.reshape(r_electrons, (-1, n_electrons, 3))
        with tf.GradientTape(True) as gg:
            gg.watch(r_electrons)
            log_phi = model(r_electrons)
        dlogphi_dr = gg.gradient(log_phi, r_electrons)
        dlogphi_dr = tf.reshape(dlogphi_dr, (-1, n_electrons*3))
        grads = [dlogphi_dr[..., i] for i in range(dlogphi_dr.shape[-1])]
    d2logphi_dr2 = tf.stack([g.gradient(grad, r) for grad, r in zip(grads, r_s)], -1)
    return dlogphi_dr**2, d2logphi_dr2



# split before
def tclaplacian(model, samples):
    gs = []
    ggs = []
    for sample in samples:
#         sample = tc.tensor(sample, requires_grad=True).view((1,n_electrons,3))
        s_vars = [tc.tensor(s, requires_grad=True, dtype=tcDTYPE) for s in sample.reshape(-1)]
        sample = tc.tensor(s_vars, requires_grad=True).view((1, n_electrons, 3))
        
        log_psi = ftc(sample)
#         print(log_psi.shape)
        grads = [tcgrad(log_psi, s, retain_graph=True, create_graph=True)[0] for s in s_vars]
        grads_grads = [tc.grad(g, s)[0] for s in s_vars]

#         grad_grad = []
#         for g, s in zip(grad, sample):
#             gg = tcgrad(g, s)[0]
#             grad_grad.append(gg)
#         grad_grad = tc.tensor(grad_grad).view((1, n_electrons, 3))
#         gs.append(grad)
#         ggs.append(grad_grad)
        
    gs = tc.cat(gs, dim=0)
    ggs = tc.cat(ggs, dim=0)
    
    return gs**2, ggs

# use single element of grad
def tclaplacian(model, samples):
    gs = []
    ggs = []
    for sample in samples:
        sample = tc.tensor(sample, requires_grad=True).view((1,n_electrons,3))

        log_psi = ftc(sample)

        g = tcgrad(log_psi, sample, retain_graph=True, create_graph=True)[0]
        grads = g.view(-1)
        grads_grads = tc.cat([tcgrad(g, sample, retain_graph=True)[0].view(-1, 1) for g in grads], dim=1)
        
        tmp = tc.tensor(grads_grads)

        diag = tc.diag(tmp).view(1, n_electrons, 3)
        gg = diag.view(1, n_electrons, 3)
        gs.append(g)
        ggs.append(gg)
        
    gs = tc.cat(gs, dim=0)
    ggs = tc.cat(ggs, dim=0)
    
    return gs**2, ggs



In [8]:
n_samples = 5

samples = np.random.normal(size=(n_samples, n_electrons, 3)).astype(npDTYPE)
r_atoms = np.zeros(shape=(1,3)).astype(npDTYPE)

ftf = fermi_tf(tf.convert_to_tensor(r_atoms, dtype=tfDTYPE))
psi = ftf(tf.convert_to_tensor(samples, dtype=tfDTYPE))

ftc = fermi_tc(tc.tensor(r_atoms))

tcpsi = ftc(tc.tensor(samples))

print(compare_tensors(tcpsi, psi))

tfgs, tfggs = tflaplacian(ftf, samples)

gs, ggs = tclaplacian(ftc, samples)

print(compare_tensors(gs, tfgs))
print(compare_tensors(ggs, tfggs))
print(gs.shape)
print(ggs.shape)


for g1, g2 in zip(tfggs, ggs):
    print(g1, g2)

3.375077994860476e-14




1.0270995505490266e-09
nan
torch.Size([5, 4, 3])
torch.Size([5, 4, 3])
tf.Tensor(
[ -19.29961634  -87.18604379   -0.81376779   -4.90160414 -132.53382403
  -11.76403391   -0.99668069   -0.80846922   -0.78383973   -1.88864622
   -1.68921441   -0.39341681], shape=(12,), dtype=float64) tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]], dtype=torch.float64)
tf.Tensor(
[ -673.18811932  -251.43500673  -809.75039198 -1266.04850088
   -94.8649032  -1482.00867262   -12.41136838   -87.36419847
    -3.36276821     5.85261893   -54.23747558   -11.01805754], shape=(12,), dtype=float64) tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]], dtype=torch.float64)
tf.Tensor(
[  0.20252174  -9.02504047  -0.69286701  -5.40744415  -2.38037426
 -18.45727703  -3.24288052  -0.48214612  -0.81017975  -0.42474699
  -7.97809512  -8.04122579], shape=(12,), dtype=float64) tensor([[nan, nan, nan],
        [nan, nan, nan],

In [9]:
tc.diag(tc.tensor([[1,2], [4, 5]]))

tensor([1, 5])

In [None]:
import torch

class tc_safe_norm_2nd(tc.autograd.Function):
   @staticmethod
    def forward(ctx, tensor):
        norm = tc.norm(tensor, dim=-1, keepdim=True
        grad = tensor / norm
        grad = tf.where(tc.isnan(grad), tc.zeros_like(grad), grad)
        ctx.save_for_backward([norm, tensor])
        return norm

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        norm, tensor = ctx.saved_tensors
        
        return grad 
    
class tc_safe_norm(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, tensor):
        x = tc.norm(tensor, dim=-1, keepdim=True)
        ctx.save_for_backward(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [10]:



# n_samples = 100
# n_k = 10
# ndim = 4
# A = tf.random.normal((n_samples, n_k, ndim, ndim), dtype=DTYPE)
# B = tf.random.normal((n_samples, n_k, ndim, ndim), dtype=DTYPE)
# W = tf.random.normal((1,n_k,1,1), dtype=DTYPE)

# A_tc, B_tc, W_tc = convert_to_torch([A, B, W])

# with tf.GradientTape(True) as g:
#     g.watch(A)
#     g.watch(W)
#     g.watch(B)
#     with tf.GradientTape(True) as gg:
#         gg.watch(A)
#         gg.watch(B)
#         gg.watch(W)
#         log_psi,_, _, _ = log_abs_sum_det(A, B, W)
#     grad_1_A = gg.gradient(log_psi, A)
#     grad_1_B = gg.gradient(log_psi, B)
#     grad_1_W = gg.gradient(log_psi, W)
# grad_2_A = g.gradient(grad_1_A, A)

# log_psi_tc = fn(A_tc, B_tc, W_tc)

# # first order
# grad_1_Atc = tcgrad(log_psi_tc.sum(), A_tc, retain_graph=True, create_graph=True)[0]
# grad_1_Btc = tcgrad(log_psi_tc.sum(), B_tc, retain_graph=True, create_graph=True)[0]
# grad_1_Wtc = tcgrad(log_psi_tc.sum(), W_tc, retain_graph=True, create_graph=True)[0]

# # second order
# grad_2_Atc = tcgrad(grad_1_Atc.sum(), A_tc)[0]

# # validation
# validate_log_psi = compare_tensors(log_psi_tc, log_psi)
# print(validate_log_psi)
# validate_gradW = compare_tensors(grad_1_Wtc, grad_1_W)
# print(validate_gradW)

# validate_grad_1_A = compare_tensors(grad_1_Atc, grad_1_A) 
# print(validate_grad_1_A)

# validate_grad_2_A = compare_tensors(grad_2_Atc, grad_2_A)
# print(validate_grad_2_A)

# # if validate_grad1 < 1:
# #     log_psi_tc = log_psi_tc.reshape(-1)
# #     log_psi = tf.reshape(log_psi, (-1,))
# #     for el1, el2 in zip(log_psi_tc, log_psi):
# #         print(abs(el1.detach().numpy()-el2.numpy()))
# #     print(log_psi_tc, log_psi)
    
# #     print('determinants: ', A_tc.det(), B_tc.det())
# # print(compare_tensors(grad_1_tc, grad_1))

# # print('B: ', ct(B_tc, B))
# # print('ae_vectors:', compare_tensors(ae_vectors_tc, ae_vectors))
# # print('exponential: ', compare_tensors(expo))

In [11]:
# n_samples = 100
# n_k = 10
# n_f = 20
# n_up = 2
# n_e = 4
# n_down = n_e - n_up
# n_atom = 1 

# ### Parameters
# # First
# r_atoms = tf.zeros((n_samples, n_atom, 3), dtype=DTYPE)
# re = tf.random.normal((n_samples, n_e, 3), dtype=DTYPE) #* 100.
# w1 = tf.random.normal((n_f, 3), dtype=DTYPE)
# # up
# w_up = tf.random.normal((n_k, n_f, n_up), dtype=DTYPE)
# b_up = tf.random.normal((n_k, n_up, 1), dtype=DTYPE)
# Sigma_up = tf.random.normal((n_k, n_up, n_atom, 3, 3), dtype=DTYPE)
# Pi_up = tf.random.normal((n_k, n_up, n_atom), dtype=DTYPE)
# # down
# w_down = tf.random.normal((n_k, n_f, n_down), dtype=DTYPE)
# b_down = tf.random.normal((n_k, n_down, 1), dtype=DTYPE)
# Sigma_down = tf.random.normal((n_k, n_down, n_atom, 3, 3), dtype=DTYPE)
# Pi_down = tf.random.normal((n_k, n_down, n_atom), dtype=DTYPE)
# # Weights
# W = tf.random.normal((1, n_k, 1, 1), dtype=DTYPE)

# with tf.GradientTape(True) as g:
#     g.watch(W)
#     g.watch(re)
#     with tf.GradientTape(True) as gg:
#         gg.watch(re)
#         gg.watch(W)
#         A, B = fermi_tf(r_atoms, re, w1,  w_up, b_up, Sigma_up, Pi_up, \
#                          w_down, b_down, Sigma_down, Pi_down)
#         log_psi,_ = log_abs_sum_det(A, B, W)

#     grad_1_A = gg.gradient(log_psi, A)
#     grad_1_B = gg.gradient(log_psi, B)
#     grad_1_W = gg.gradient(log_psi, W)
#     grad_1_re = gg.gradient(log_psi, re)
#     grads_1_re = tf.reshape(grad_1_re, (-1, n_e*3))
#     grads_1_re = [grads_1_re[..., i] for i in range(grads_1_re.shape[-1])]
    
# grad_2_re = g.gradient(grads_1_re[0], re)
# grad_2_A = g.gradient(grad_1_A, A)

# # print(grad_2_re)
# # print(grad_2)

# ### Torch
# re_tc, w1_tc, w_up_tc, b_up_tc, w_down_tc, b_down_tc, W_tc = \
# convert_to_torch([re, w1, w_up, b_up, w_down, b_down, W])
# Sigma_up_tc, Pi_up_tc, Sigma_down_tc, Pi_down_tc = \
# convert_to_torch([Sigma_up, Pi_up, Sigma_down, Pi_down])
# r_atoms_tc = convert_to_torch([r_atoms])[0]


# A_tc, B_tc = fermi_tc(r_atoms_tc, re_tc, w1_tc, w_up_tc, b_up_tc, Sigma_up_tc, Pi_up_tc, \
#               w_down_tc, b_down_tc, Sigma_down_tc, Pi_down_tc)
# log_psi_tc = fn(A_tc, B_tc, W_tc)

# grad_1_Atc = tcgrad(log_psi_tc.sum(), A_tc, retain_graph=True, create_graph=True)[0]
# grad_2_Atc = tcgrad(grad_1_Atc.sum(), A_tc, retain_graph=True)[0]
# grad_1_Btc = tcgrad(log_psi_tc.sum(), B_tc, retain_graph=True)[0]
# grad_1_Wtc = tcgrad(log_psi_tc.sum(), W_tc, retain_graph=True)[0]
# grad_1_retc = tcgrad(log_psi_tc.sum(), re_tc, retain_graph=True, create_graph=True)[0]
# grads_1_retc = grad_1_retc.view(-1,3*n_e)
# grad_2_retc = tcgrad(grads_1_retc[:,0].sum(), re_tc, retain_graph=True)[0]

# grad_1 = tcgrad(log_psi_tc.sum(), A_tc)

# validate_log_psi = compare_tensors(log_psi_tc, log_psi)
# print(validate_log_psi)
# validate_gradW = compare_tensors(grad_1_Wtc, grad_1_W)
# print(validate_gradW)

# print(compare_tensors(grad_1_Atc, grad_1_A))
# print(compare_tensors(grad_1_Btc, grad_1_B))

# validate_re = ct(grad_1_retc, grad_1_re)
# print('re grad: ', validate_re)

# validate_grads_re = ct(grads_1_retc[:,0], grads_1_re[0])
# print('re grads: ', validate_grads_re)

# validate_2_re = ct(grad_2_retc, grad_2_re)
# print('re grad: ', validate_2_re)

# validate_2_A = ct(grad_2_Atc, grad_2_A)
# print('re grad: ', validate_2_A)

