In [1]:
import tensorflow as tf
import torch as tc
import numpy as np
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
from torch.autograd import grad as tcgrad
DTYPE = tf.float64

In [2]:
def convert_to_torch(tensors):
    new_tensors = []
    for tensor in tensors:
        new_tensor = tc.tensor(tensor.numpy(), requires_grad=True)
        new_tensors.append(new_tensor)
    return new_tensors

def fn(A, B, W):
    return tc.log(tc.abs(tc.sum(W * A.det().view(-1,n_k,1,1) * B.det().view(-1,n_k,1,1), axis=1, keepdim=True)))

def compute_ae_vectors(r_atoms, r_electrons):
    # ae_vectors (n_samples, n_electrons, n_atoms, 3)
    r_atoms = tf.expand_dims(r_atoms, 1)
    r_electrons = tf.expand_dims(r_electrons, 2)
    ae_vectors = r_electrons - r_atoms
    return ae_vectors

def compute_relative_vectors(v1, v2):
    relative_vectors = v1.unsqueeze(2) - v2.unsqueeze(1)
    return relative_vectors

def compare_tensors(ptorch, tflow):
    ptorch = ptorch.detach().numpy()
    tflow = tflow.numpy()
    mask = np.isclose(ptorch, tflow, rtol=0.0, atol=1e-4).astype(np.float32)
    return np.mean(mask)

def compare_tf_tensors(tflow1, tflow2,rtol=1e-05, atol=1e-08): #
    tflow1 = tflow1.numpy()
    tflow2 = tflow2.numpy()
    mask = np.isclose(tflow1, tflow2).astype(np.float32)
    return np.mean(mask) 

ct = compare_tensors

In [16]:

def slogdet_keepdim(tensor):
    sign, tensor_out = tf.linalg.slogdet(tensor)
    tensor_out = tf.reshape(tensor_out, (*tensor_out.shape, 1, 1))
    sign = tf.reshape(sign, (*sign.shape, 1, 1))
    return sign, tensor_out


def generate_gamma(s):
    n_egvs = s.shape[2]
    gamma = [tf.reduce_prod(s[:, :, :i], axis=-1) * tf.reduce_prod(s[:, :, i+1:], axis=-1) for i in range(n_egvs-1)]
    gamma.append(tf.reduce_prod(s[:, :, :-1], axis=-1))
    gamma = tf.stack(gamma, axis=2)
    gamma = tf.expand_dims(gamma, axis=2)
    return gamma


def first_derivative_det(A):
    with tf.device("/cpu:0"):  # this is incredible stupid /// its actually not
        s, u, v = tf.linalg.svd(A, full_matrices=False)
    # s, u, v = tf.linalg.svd(A, full_matrices=False)
    v_t = tf.linalg.matrix_transpose(v)
    gamma = generate_gamma(s)
    sign = (tf.linalg.det(u) * tf.linalg.det(v))[..., None, None]
    out = sign * ((u * gamma) @ v_t)
    return out, (s, u, v_t, sign)


def generate_p(s):
    """
    :param s:
    :return:
    """
    n_samples, n_k, n_dim = s.shape
    new_shape = (1, 1, 1, n_dim, n_dim)
    s = s[..., None, None]
    s = tf.tile(s, new_shape)
    mask = np.ones(s.shape, dtype=np.bool)
    for i in range(n_dim):
        for j in range(n_dim):
            mask[..., i, i, j] = False
            mask[..., j, i, j] = False
    mask = tf.convert_to_tensor(mask)
    s = tf.where(mask, s, tf.ones_like(s, dtype=s.dtype))
    s_prod = tf.reduce_prod(s, axis=-3)
    s_prod = tf.linalg.set_diag(s_prod, tf.zeros((s_prod.shape[:-1]), dtype=s.dtype))
    return s_prod


def second_derivative_det(A, C_dash, *A_cache):
    """

    :param A:
    :param C_dash:
    :param A_cache:
    :return:
    """
    # This function computes the second order derivative of detA wrt to A
    # A matrix
    # C_bar backward sensitivity
    # A_cache cached values returned by grad_det(A)
    s, u, v_t, sign = A_cache  # decompose the cache

    M = v_t @ tf.linalg.matrix_transpose(C_dash) @ u

    p = generate_p(s)

    sgn = tf.math.sign(sign)

    m_jj = tf.linalg.diag_part(M)
    xi = -M * p

    xi_diag = p @ tf.expand_dims(m_jj, -1)
    xi = tf.linalg.set_diag(xi, tf.squeeze(xi_diag, -1))
    return sgn * u @ xi @ v_t


def k_sum(x):
    return tf.reduce_sum(x, axis=1, keepdims=True)


def matrix_sum(x):
    return tf.reduce_sum(x, axis=[-2, -1], keepdims=True)


@tf.function
def _log_abs_sum_det_fwd(a, b, w):
    a = tf.stop_gradient(a)
    b = tf.stop_gradient(b)
    w = tf.stop_gradient(w)
    """

    :param a:
    :param b:
    :param w:
    :return:
    """
    # Take the slogdet of all k determinants
    sign_a, logdet_a = slogdet_keepdim(a)
    sign_b, logdet_b = slogdet_keepdim(b)

    x = logdet_a + logdet_b
    xmax = tf.math.reduce_max(x, axis=1, keepdims=True)

    unshifted_exp = sign_a * sign_b * tf.exp(x)
    unshifted_exp_w = w * unshifted_exp
    sign_unshifted_sum = tf.math.sign(tf.reduce_sum(unshifted_exp_w, axis=1, keepdims=True))

    exponent = x - xmax
    shifted_exp = sign_a * sign_b * tf.exp(exponent)

    u = w * shifted_exp
    u_sum = tf.reduce_sum(u, axis=1, keepdims=True)
    sign_shifted_sum = tf.math.sign(u_sum)
    log_psi = tf.math.log(tf.math.abs(u_sum)) + xmax

    # Both of these derivations appear to be valid
    # activations = shifted_exp
    # sensitivities = sign_unshifted_sum * tf.exp(-log_psi)
    # dw = sign_unshifted_sum * sign_a * sign_b * tf.exp(x-log_psi)
    #
    # return log_psi, sign_shifted_sum, activations, sensitivities, \
    #           (a, b, w, unshifted_exp, sign_unshifted_sum, dw, sign_a, logdet_a, sign_b, logdet_b, log_psi)

    sensitivities = tf.exp(-log_psi) * sign_unshifted_sum
    # sensitivities = tf.exp(xmax-log_psi) * sign_shifted_sum

    dw = sign_unshifted_sum * sign_a * sign_b * tf.exp(x - log_psi)

    return log_psi, sign_unshifted_sum, unshifted_exp, sensitivities, \
           (a, b, w, unshifted_exp, sign_unshifted_sum, dw, sign_a, logdet_a, sign_b, logdet_b, log_psi)

# @tf.function
def _log_abs_sum_det_first_order(*fwd_cache):
    """

    :param fwd_cache:
    :return:
    """
    a, b, w, unshifted_exp, sign_unshifted_sum, dw, sign_a, logdet_a, sign_b, logdet_b, log_psi = fwd_cache

    ddeta, ddeta_cache = first_derivative_det(a)
    ddetb, ddetb_cache = first_derivative_det(b)

    dfddeta = w * sign_unshifted_sum * sign_b * tf.exp(logdet_b - log_psi)
    dfddetb = w * sign_unshifted_sum * sign_a * tf.exp(logdet_a - log_psi)

    da = dfddeta * ddeta
    db = dfddetb * ddetb

    return (da, db, dw), (sign_unshifted_sum, ddeta, ddeta_cache, ddetb, ddetb_cache, dfddeta, dfddetb, da, db, dw)


# @tf.function
def _log_abs_sum_det_second_order(a_dash, b_dash, w_dash, *cache):
    """

    :param a_dash:
    :param b_dash:
    :param w_dash:
    :param cache:
    :return:
    """
    a, b, w, unshifted_exp, sign_unshifted_sum, _, sign_a, logdet_a, sign_b, logdet_b, log_psi, \
    sign_u, ddeta, ddeta_cache, ddetb, ddetb_cache, dfddeta, dfddetb, da, db, dw = cache

    dfddeta_w = dfddeta / w
    dfddetb_w = dfddetb / w

    ddeta_sum = matrix_sum(a_dash * ddeta)
    da_sum = matrix_sum(da * a_dash)
    ddetb_sum = matrix_sum(b_dash * ddetb)
    db_sum = matrix_sum(db * b_dash)
    a_sum = k_sum(dfddeta * ddeta_sum)
    b_sum = k_sum(dfddetb * ddetb_sum)

    # Compute second deriviate of f wrt to w
    d2w = w_dash * -dw * k_sum(dw)

    # compute deriviate of df/da wrt to w
    dadw = -dw * k_sum(da_sum)
    dadw += dfddeta_w * ddeta_sum  # i=j

    # compute derivative of df/db wrt to w
    dbdw = -dw * k_sum(db_sum)
    dbdw += dfddetb_w * ddetb_sum  # i=j

    # Compute second derivative of f wrt to a
    d2a = -da * a_sum
    d2a += dfddeta * second_derivative_det(a, a_dash, *ddeta_cache)  # i=j
    # Compute derivative of df/db wrt to a
    dbda = -da * b_sum
    dbda += ddeta * sign_u * tf.exp(-log_psi) * w * ddetb_sum  # i=j
    # Compute derivative of df/dw wrt to a
    dwda = w_dash * -da * k_sum(dw)
    dwda += w_dash * da / w  # i=j

    # Compute second derivative of f wrt to b
    d2b = -db * b_sum
    d2b += dfddetb * second_derivative_det(b, b_dash, *ddetb_cache)  # i=j
    # Compute derivative of df/da wrt to b
    dadb = -db * a_sum
    dadb += ddetb * sign_u * tf.exp(-log_psi) * w * ddeta_sum  # i=j
    # Compute derivative of df/dw wrt to b
    dwdb = w_dash * -db * k_sum(dw)
    dwdb += w_dash * db / w  # i=j

    return (d2a + dbda + dwda), (d2b + dadb + dwdb), \
           (d2w + dadw + dbdw), \
           None, None, None, None, None, None, None, None, None, None, None, None

# @tf.autograph.experimental.do_not_convert
# @tf.function
@tf.custom_gradient
def first_order_gradient(a_unused, b_unused, w_unused, *fwd_cache):
    """

    :param a_unused:
    :param b_unused:
    :param w_unused:
    :param fwd_cache:
    :return:
    """
    (da, db, dw), first_order_cache = _log_abs_sum_det_first_order(*fwd_cache)
    return (da, db, dw), \
           lambda a_dash, b_dash, w_dash: _log_abs_sum_det_second_order(
               a_dash, b_dash, w_dash, *fwd_cache, *first_order_cache)


# @tf.autograph.experimental.do_not_convert
# @tf.function
@tf.custom_gradient
def log_abs_sum_det(a, b, w):
    """

    :param a:
    :param b:
    :param w:
    :return:
    """
    log_psi, sign, act, sens, fwd_cache = _log_abs_sum_det_fwd(a, b, w)

    def _first_order_grad(dy, dsg, _, __):
        da, db, dw = first_order_gradient(a, b, w, *fwd_cache)

        # print('dy', dy)
        return dy * da, dy * db, tf.reduce_sum(dy * dw, axis=0, keepdims=True)
    # tf.reduce_sum(dy * dw, axis=0, keepdims=True)

    return (log_psi, sign, act, sens), _first_order_grad





In [25]:
n_samples = 100
n_k = 10
ndim = 4
A = tf.random.normal((n_samples, n_k, ndim, ndim), dtype=DTYPE)
B = tf.random.normal((n_samples, n_k, ndim, ndim), dtype=DTYPE)
W = tf.random.normal((1,n_k,1,1), dtype=DTYPE)

A_tc, B_tc, W_tc = convert_to_torch([A, B, W])

with tf.GradientTape(True) as g:
    g.watch(A)
    g.watch(W)
    g.watch(B)
    with tf.GradientTape(True) as gg:
        gg.watch(A)
        gg.watch(B)
        gg.watch(W)
        log_psi,_, _, _ = log_abs_sum_det(A, B, W)
    grad_1_A = gg.gradient(log_psi, A)
    grad_1_B = gg.gradient(log_psi, B)
    grad_1_W = gg.gradient(log_psi, W)
grad_2_A = g.gradient(grad_1_A, A)

log_psi_tc = fn(A_tc, B_tc, W_tc)

# first order
grad_1_Atc = tcgrad(log_psi_tc.sum(), A_tc, retain_graph=True, create_graph=True)[0]
grad_1_Btc = tcgrad(log_psi_tc.sum(), B_tc, retain_graph=True, create_graph=True)[0]
grad_1_Wtc = tcgrad(log_psi_tc.sum(), W_tc, retain_graph=True, create_graph=True)[0]

# second order
grad_2_Atc = tcgrad(grad_1_Atc.sum(), A_tc)[0]

# validation
validate_log_psi = compare_tensors(log_psi_tc, log_psi)
print(validate_log_psi)
validate_gradW = compare_tensors(grad_1_Wtc, grad_1_W)
print(validate_gradW)

validate_grad_1_A = compare_tensors(grad_1_Atc, grad_1_A) 
print(validate_grad_1_A)

validate_grad_2_A = compare_tensors(grad_2_Atc, grad_2_A)
print(validate_grad_2_A)

# if validate_grad1 < 1:
#     log_psi_tc = log_psi_tc.reshape(-1)
#     log_psi = tf.reshape(log_psi, (-1,))
#     for el1, el2 in zip(log_psi_tc, log_psi):
#         print(abs(el1.detach().numpy()-el2.numpy()))
#     print(log_psi_tc, log_psi)
    
#     print('determinants: ', A_tc.det(), B_tc.det())
# print(compare_tensors(grad_1_tc, grad_1))

# print('B: ', ct(B_tc, B))
# print('ae_vectors:', compare_tensors(ae_vectors_tc, ae_vectors))
# print('exponential: ', compare_tensors(expo))

1.0
1.0
1.0
1.0


In [5]:
grad_2

NameError: name 'grad_2' is not defined

In [6]:
def fermi_tf(r_atoms, re, w1, w_up, b_up, Sigma_up, Pi_up, w_down, b_down, Sigma_down, Pi_down):
    ae_vectors = compute_ae_vectors(r_atoms, re)
    r2 = tf.einsum('fv,niv->nif',w1,re)

    # Envelopes
    ae_vectors_up = ae_vectors[:,:n_up,...]
    r_up = r2[:,:n_up,:]
    factor = tf.einsum('njf,kfi->nkji', r_up, w_up)
    factor = factor + b_up
    exp = tf.einsum('kimvc,njmv->nkijmc', Sigma_up, ae_vectors_up)
    exponential = tf.exp(-tf.norm(exp, axis=-1))
    exp = tf.einsum('nkijm,kim->nkij', exponential, Pi_up)
    A = factor * exp

    ae_vectors_down = ae_vectors[:,n_up:,...]
    r_down = r2[:,n_up:,:]
    factor = tf.einsum('njf,kfi->nkji', r_down, w_down)
    factor = factor + b_down
    exp = tf.einsum('kimvc,njmv->nkijmc', Sigma_down, ae_vectors_down)
    exponential = tf.exp(-tf.norm(exp, axis=-1))
    exp = tf.einsum('nkijm,kim->nkij', exponential, Pi_down)
    B = factor * exp
    
    return A, B
    
    
def fermi_tc(r_atoms_tc, re_tc, w1_tc,  w_up_tc, b_up_tc, Sigma_up_tc, Pi_up_tc, \
              w_down_tc, b_down_tc, Sigma_down_tc, Pi_down_tc):
    ae_vectors_tc = compute_relative_vectors(re_tc,r_atoms_tc)
    # ### Torch
    r2_tc = tc.einsum('fv,niv->nif',[w1_tc,re_tc])

    # Envelopes
    ae_vectors_up_tc = ae_vectors_tc[:,:n_up,...]
    r_up_tc = r2_tc[:,:n_up,:]
    factor = tc.einsum('njf,kfi->nkji', [r_up_tc, w_up_tc])
    factor = factor + b_up_tc
    exp = tc.einsum('kimvc,njmv->nkijmc', [Sigma_up_tc, ae_vectors_up_tc])
    exponential = tc.exp(-exp.norm(dim=-1))
    exp = tc.einsum('nkijm,kim->nkij', [exponential, Pi_up_tc])
    A_tc = factor * exp

    ae_vectors_down_tc = ae_vectors_tc[:,n_up:,...]
    r_down_tc = r2_tc[:,n_up:,:]
    factor = tc.einsum('njf,kfi->nkji', [r_down_tc, w_down_tc])
    factor = factor + b_down_tc
    exp = tc.einsum('kimvc,njmv->nkijmc', [Sigma_down_tc, ae_vectors_down_tc])
    exponential = tc.exp(-exp.norm(dim=-1))
    exp = tc.einsum('nkijm,kim->nkij', [exponential, Pi_down_tc])
    B_tc = factor * exp
    return A_tc, B_tc
    

In [7]:
n_samples = 100
n_k = 10
n_f = 20
n_up = 2
n_e = 4
n_down = n_e - n_up
n_atom = 1 

### Parameters
# First
r_atoms = tf.zeros((n_samples, n_atom, 3), dtype=DTYPE)
re = tf.random.normal((n_samples, n_e, 3), dtype=DTYPE) #* 100.
w1 = tf.random.normal((n_f, 3), dtype=DTYPE)
# up
w_up = tf.random.normal((n_k, n_f, n_up), dtype=DTYPE)
b_up = tf.random.normal((n_k, n_up, 1), dtype=DTYPE)
Sigma_up = tf.random.normal((n_k, n_up, n_atom, 3, 3), dtype=DTYPE)
Pi_up = tf.random.normal((n_k, n_up, n_atom), dtype=DTYPE)
# down
w_down = tf.random.normal((n_k, n_f, n_down), dtype=DTYPE)
b_down = tf.random.normal((n_k, n_down, 1), dtype=DTYPE)
Sigma_down = tf.random.normal((n_k, n_down, n_atom, 3, 3), dtype=DTYPE)
Pi_down = tf.random.normal((n_k, n_down, n_atom), dtype=DTYPE)
# Weights
W = tf.random.normal((1, n_k, 1, 1), dtype=DTYPE)

with tf.GradientTape(True) as g:
    g.watch(W)
    g.watch(re)
    with tf.GradientTape(True) as gg:
        gg.watch(re)
        gg.watch(W)
        A, B = fermi_tf(r_atoms, re, w1,  w_up, b_up, Sigma_up, Pi_up, \
                         w_down, b_down, Sigma_down, Pi_down)
        log_psi,_ = log_abs_sum_det(A, B, W)

    grad_1_A = gg.gradient(log_psi, A)
    grad_1_B = gg.gradient(log_psi, B)
    grad_1_W = gg.gradient(log_psi, W)
    grad_1_re = gg.gradient(log_psi, re)
    grads_1_re = tf.reshape(grad_1_re, (-1, n_e*3))
    grads_1_re = [grads_1_re[..., i] for i in range(grads_1_re.shape[-1])]
    
grad_2_re = g.gradient(grads_1_re[0], re)
grad_2_A = g.gradient(grad_1_A, A)

# print(grad_2_re)
# print(grad_2)

### Torch
re_tc, w1_tc, w_up_tc, b_up_tc, w_down_tc, b_down_tc, W_tc = \
convert_to_torch([re, w1, w_up, b_up, w_down, b_down, W])
Sigma_up_tc, Pi_up_tc, Sigma_down_tc, Pi_down_tc = \
convert_to_torch([Sigma_up, Pi_up, Sigma_down, Pi_down])
r_atoms_tc = convert_to_torch([r_atoms])[0]


A_tc, B_tc = fermi_tc(r_atoms_tc, re_tc, w1_tc, w_up_tc, b_up_tc, Sigma_up_tc, Pi_up_tc, \
              w_down_tc, b_down_tc, Sigma_down_tc, Pi_down_tc)
log_psi_tc = fn(A_tc, B_tc, W_tc)

grad_1_Atc = tcgrad(log_psi_tc.sum(), A_tc, retain_graph=True, create_graph=True)[0]
grad_2_Atc = tcgrad(grad_1_Atc.sum(), A_tc, retain_graph=True)[0]
grad_1_Btc = tcgrad(log_psi_tc.sum(), B_tc, retain_graph=True)[0]
grad_1_Wtc = tcgrad(log_psi_tc.sum(), W_tc, retain_graph=True)[0]
grad_1_retc = tcgrad(log_psi_tc.sum(), re_tc, retain_graph=True, create_graph=True)[0]
grads_1_retc = grad_1_retc.view(-1,3*n_e)
grad_2_retc = tcgrad(grads_1_retc[:,0].sum(), re_tc, retain_graph=True)[0]

grad_1 = tcgrad(log_psi_tc.sum(), A_tc)

validate_log_psi = compare_tensors(log_psi_tc, log_psi)
print(validate_log_psi)
validate_gradW = compare_tensors(grad_1_Wtc, grad_1_W)
print(validate_gradW)

print(compare_tensors(grad_1_Atc, grad_1_A))
print(compare_tensors(grad_1_Btc, grad_1_B))

validate_re = ct(grad_1_retc, grad_1_re)
print('re grad: ', validate_re)

validate_grads_re = ct(grads_1_retc[:,0], grads_1_re[0])
print('re grads: ', validate_grads_re)

validate_2_re = ct(grad_2_retc, grad_2_re)
print('re grad: ', validate_2_re)

validate_2_A = ct(grad_2_Atc, grad_2_A)
print('re grad: ', validate_2_A)



NameError: name 'logabssumdet' is not defined

In [8]:
grad_1_retc.shape

NameError: name 'grad_1_retc' is not defined

In [9]:
grad_1_re_loss

NameError: name 'grad_1_re_loss' is not defined

In [10]:
grad_2_A

NameError: name 'grad_2_A' is not defined

In [11]:
grad_2_retc

NameError: name 'grad_2_retc' is not defined

In [12]:
# --- compare nico to new 

# --- outputs the same

# --- grad1 the same (ABW)

# --- grad1 the same (r)

# --- grad2 the same (ABW)

# --- grad2 the same (r)

# --- speed