In [35]:
# This is it boys. The holy grail. 
# A systematically improvable, invertible descriptor which can do the whole table. 
#
# Dare we steal fire from the gods? 
import tensorflow as tf
import numpy as np
import time 

In [46]:
MAXATOMIC = 55
l_max = 3
NELECHANNEL = 4 

atom_codes = np.random.random((56,NELECHANNEL)) # Assumption is these come from an autoencoder. 
GaussParams = np.array([[0.35, 0.35], [0.70, 0.35], [1.05, 0.35], [1.40, 0.35], [1.75, 0.35], [2.10, 0.35], [2.45, 0.35],
                        [2.80, 0.35], [3.15, 0.35], [3.50, 0.35], [3.85, 0.35], [4.20, 0.35], [4.55, 0.35], [4.90, 0.35]])

# The batch consists of a usual dense tensormol batch. 
batch_size = 200 
MaxNAtom = 20
xyzs = np.random.random((batch_size,MaxNAtom,3))*5.0
Zs = np.random.randint(55,size=(batch_size,MaxNAtom)) # some z are zero. 

In [47]:
def inv_gaush_element_coded(xyzs, Zs, gauss_params, elecode, l_max, chiral=False): 
    """
    For a batch of xyzs create INVARIANT embeddings per atom. 
    These are both rotationally and reflectionally invariant 
    while including phase information. Chirality can be turned on
    with chiral, ie: if chiral = True then enantiomers have different embeddings. 
    
    suppose mol m, atom,i and environmental atoms j in mol m. 
    the output is 
    
    SH(Canonical(dxyz)) => mol X max_atom X max_atom X NSH
    Code(Zs) => mol X max_atom X NELECHANNEL
    Rad(dxyz) => mol X X max_atom X max_atom X NRAD
    out(m,i,alpha,beta,zeta) = \sum_j SH[m,i,j,alpha]*Rad[m,i,j,beta]*Code(m,j,zeta)
    
    Args: 
        xyzs: NMol X MaxNAtom X 3 coordinate tensor. 
        Zs: NMol X MaxNAtom X 1 atomic number tensor.
        gauss_params: ngaus X 2 gaussian parameter tensor. 
        elecode: MAXATOMIC X NELECHANNEL element coding. 
        l_max: maximum spherical harmonic. 
        chiral: chiral invariance off = True 
    """
    dxyzs = Canonicalize(tf.expand_dims(xyzs, axis=2) - tf.expand_dims(xyzs, axis=1),chiral)
    dist_tensor = tf.norm(dxyzs+1.e-36,axis=3)
    SH = tf_spherical_harmonics(dxyzs, dist_tensor, l_max)
    RAD = tf_gauss(dist_tensor, gauss_params)
    CODES = tf.gather(elecode,Zs)
    # Perform each of the contractions. 
    SHRAD = tf.einsum('mijk,mijl->mijkl',SH,RAD)
    SHRADCODE = tf.einsum('mijkl,mjn->mikln',SHRAD,CODES)
    return SHRADCODE

xyzs_tf = tf.Variable(xyzs)
Zs_tf = tf.Variable(Zs)
gauss_params_tf = tf.Variable(GaussParams)
elecode_tf = tf.Variable(atom_codes)
l_max_tf = 4

AwghYeahBitches = inv_gaush_element_coded(xyzs_tf, Zs_tf, gauss_params_tf, elecode_tf, l_max_tf)
ginvgaush = tf.gradients(AwghYeahBitches,xyzs_tf)

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
init = tf.global_variables_initializer()
sess.run(init)

NREPLICA = 10
t0 = time.time()
for i in range(NREPLICA): 
    tp = time.time()
    sess.run(ginvgaush)
    print time.time()-tp
    
print "Average Time per batch: ", (time.time()-t0)/NREPLICA
print "Average Time per atom: ", (time.time()-t0)/NREPLICA/batch_size/MaxNAtom

1.54851698875
1.01922893524
1.00106596947
1.04743289948
1.0281059742
1.03989005089
1.87094902992
1.99342393875
1.96746683121
1.97200798988
Average Time per batch:  1.44896669388
Average Time per atom:  0.00036224707365


In [28]:
def tf_gaush_element_channel(xyzs, Zs, elements, gauss_params, l_max):
	"""
	Encodes atoms into a gaussians * spherical harmonics embedding
	cast into element channels. Works on a batch of molecules.

	Args:
		xyzs (tf.float): NMol x MaxNAtoms x 3 coordinates tensor
		Zs (tf.int32): NMol x MaxNAtoms atomic number tensor
		element (int): element to return embedding/labels for
		gauss_params (tf.float): NGaussians x 2 tensor of gaussian parameters
		l_max (tf.int32): Scalar for the highest order spherical harmonics to use

	Returns:
		embedding (tf.float): atom embeddings for element
		molecule_indices (tf.float): mapping between atoms and molecules.
	"""
	num_elements = elements.get_shape().as_list()[0]
	num_mols = Zs.get_shape().as_list()[0]
	padding_mask = tf.where(tf.not_equal(Zs, 0))

	dxyzs = tf.expand_dims(xyzs, axis=2) - tf.expand_dims(xyzs, axis=1)
	dxyzs = tf.gather_nd(dxyzs, padding_mask)
	dist_tensor = tf.norm(dxyzs+1.e-16,axis=-1)
	gauss = tf_gauss(dist_tensor, gauss_params)
	# dxyzs has dimension NNZ X MaxNAtoms X 3
	harmonics = tf_spherical_harmonics(dxyzs, dist_tensor, l_max)
	channel_scatter = tf.gather(tf.equal(tf.expand_dims(Zs, axis=-1), elements), padding_mask[:,0])
	channel_scatter = tf.where(channel_scatter, tf.ones_like(channel_scatter, dtype=tf.float64),
					tf.zeros_like(channel_scatter, dtype=tf.float64))
	channel_gauss = tf.expand_dims(gauss, axis=-2) * tf.expand_dims(channel_scatter, axis=-1)
	channel_harmonics = tf.expand_dims(harmonics, axis=-2) * tf.expand_dims(channel_scatter, axis=-1)
	embeds = tf.reshape(tf.einsum('ijkg,ijkl->ikgl', channel_gauss, channel_harmonics),
			[tf.shape(padding_mask)[0], -1])
	partition_idx = tf.cast(tf.where(tf.equal(tf.expand_dims(tf.gather_nd(Zs, padding_mask), axis=-1),
						tf.expand_dims(elements, axis=0)))[:,1], tf.int32)
	with tf.device('/cpu:0'):
		embeds = tf.dynamic_partition(embeds, partition_idx, num_elements)
		mol_idx = tf.dynamic_partition(padding_mask, partition_idx, num_elements)
	return embeds, mol_idx

In [25]:
def Canonicalize(dxyzs,ChiralInv=True):
	"""
	Perform a PCA to create invariant axes.
	These axes are invariant to both rotation and reflection.
	MaxNAtom must be >= 4 otherwise this won't work.
	I have tested the rotational invariance and differentiability of this routine

	Args:
	    dxyz: a nMol X maxNatom X maxNatom X 3 tensor of atoms. (differenced from center of embedding
				ie: ... X i X i = (0.,0.,0.))
	Returns:
	    Cdxyz: canonically oriented versions of the above coordinates.
	"""
	ap = dxyzs - tf.reduce_mean(dxyzs,axis=-2,keepdims=True)
	C = tf.einsum('lmji,lmjk->lmik',ap,ap) # Covariance matrix.
	w,v = tf.self_adjoint_eig(C)
	tore = tf.matmul(dxyzs,v)
	if (not ChiralInv):
		return tore
	signc = tf.sign(tf.reduce_mean(tore,axis=-2,keepdims=True))
	# output axes only match up to a sign due to phase freedom of eigenvalues.
	# Make a convention that mean axis is positive.
	return tore*signc

def tf_gauss_overlap(gauss_params):
	r_nought = gauss_params[:,0]
	sigma = gauss_params[:,1]
	scaling_factor = tf.cast(tf.sqrt(np.pi / 2), tf.float64)
	exponential_factor = tf.exp(-tf.square(tf.expand_dims(r_nought, axis=0) - tf.expand_dims(r_nought, axis=1))
	/ (2.0 * (tf.square(tf.expand_dims(sigma, axis=0)) + tf.square(tf.expand_dims(sigma, axis=1)))))
	root_inverse_sigma_sum = tf.sqrt((1.0 / tf.expand_dims(tf.square(sigma), axis=0)) + (1.0 / tf.expand_dims(tf.square(sigma), axis=1)))
	erf_numerator = (tf.expand_dims(r_nought, axis=0) * tf.expand_dims(tf.square(sigma), axis=1)
				+ tf.expand_dims(r_nought, axis=1) * tf.expand_dims(tf.square(sigma), axis=0))
	erf_denominator = (tf.sqrt(tf.cast(2.0, tf.float64)) * tf.expand_dims(tf.square(sigma), axis=0) * tf.expand_dims(tf.square(sigma), axis=1)
				* root_inverse_sigma_sum)
	erf_factor = 1 + tf.erf(erf_numerator / erf_denominator)
	overlap_matrix = scaling_factor * exponential_factor * erf_factor / root_inverse_sigma_sum
	return overlap_matrix

def tf_sparse_gauss(dist_tensor, gauss_params):
	exponent = ((tf.square(tf.expand_dims(dist_tensor, axis=-1) - tf.expand_dims(gauss_params[:,0], axis=0)))
				/ (-2.0 * (gauss_params[:,1] ** 2)))
	gaussian_embed = tf.where(tf.greater(exponent, -25.0), tf.exp(exponent), tf.zeros_like(exponent))
	xi = (dist_tensor - 6.0) / (7.0 - 6.0)
	cutoff_factor = 1 - 3 * tf.square(xi) + 2 * tf.pow(xi, 3.0)
	cutoff_factor = tf.where(tf.greater(dist_tensor, 7.0), tf.zeros_like(cutoff_factor), cutoff_factor)
	cutoff_factor = tf.where(tf.less(dist_tensor, 6.0), tf.ones_like(cutoff_factor), cutoff_factor)
	return gaussian_embed * tf.expand_dims(cutoff_factor, axis=-1)

def tf_gauss(dist_tensor, gauss_params):
	exponent = (tf.square(tf.expand_dims(dist_tensor, axis=-1) - tf.expand_dims(tf.expand_dims(gauss_params[:,0], axis=0), axis=1))) \
				/ (-2.0 * (gauss_params[:,1] ** 2))
	gaussian_embed = tf.where(tf.greater(exponent, -25.0), tf.exp(exponent), tf.zeros_like(exponent))
	gaussian_embed *= tf.expand_dims(tf.where(tf.less(dist_tensor, 1.e-15), tf.zeros_like(dist_tensor),
					tf.ones_like(dist_tensor)), axis=-1)
	xi = (dist_tensor - 6.0) / (7.0 - 6.0)
	cutoff_factor = 1 - 3 * tf.square(xi) + 2 * tf.pow(xi, 3.0)
	cutoff_factor = tf.where(tf.greater(dist_tensor, 7.0), tf.zeros_like(cutoff_factor), cutoff_factor)
	cutoff_factor = tf.where(tf.less(dist_tensor, 6.0), tf.ones_like(cutoff_factor), cutoff_factor)
	return gaussian_embed * tf.expand_dims(cutoff_factor, axis=-1)

def tf_spherical_harmonics_0(inv_dist_tensor):
	return tf.fill(tf.shape(inv_dist_tensor), tf.constant(0.28209479177387814, dtype=tf.float64))

def tf_spherical_harmonics_1(dxyzs, inv_dist_tensor, invariant=False):
	lower_order_harmonics = tf_spherical_harmonics_0(tf.expand_dims(inv_dist_tensor, axis=-1))
	l1_harmonics = 0.4886025119029199 * tf.stack([dxyzs[...,1], dxyzs[...,2], dxyzs[...,0]],
										axis=-1) * tf.expand_dims(inv_dist_tensor, axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l1_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l1_harmonics], axis=-1)

def tf_spherical_harmonics_2(dxyzs, inv_dist_tensor, invariant=False):
	if invariant:
		lower_order_harmonics = tf_spherical_harmonics_1(dxyzs, inv_dist_tensor, True)
	else:
		lower_order_harmonics = tf_spherical_harmonics_1(dxyzs, inv_dist_tensor)
	l2_harmonics = tf.stack([(-1.0925484305920792 * dxyzs[...,0] * dxyzs[...,1]),
			(1.0925484305920792 * dxyzs[...,1] * dxyzs[...,2]),
			(-0.31539156525252005 * (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 2. * tf.square(dxyzs[...,2]))),
			(1.0925484305920792 * dxyzs[...,0] * dxyzs[...,2]),
			(0.5462742152960396 * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])))], axis=-1) \
			* tf.expand_dims(tf.square(inv_dist_tensor),axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l2_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l2_harmonics], axis=-1)

def tf_spherical_harmonics_3(dxyzs, inv_dist_tensor, invariant=False):
	if invariant:
		lower_order_harmonics = tf_spherical_harmonics_2(dxyzs, inv_dist_tensor, True)
	else:
		lower_order_harmonics = tf_spherical_harmonics_2(dxyzs, inv_dist_tensor)
	l3_harmonics = tf.stack([(-0.5900435899266435 * dxyzs[...,1] * (-3. * tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]))),
			(-2.890611442640554 * dxyzs[...,0] * dxyzs[...,1] * dxyzs[...,2]),
			(-0.4570457994644658 * dxyzs[...,1] * (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 4. \
				* tf.square(dxyzs[...,2]))),
			(0.3731763325901154 * dxyzs[...,2] * (-3. * tf.square(dxyzs[...,0]) - 3. * tf.square(dxyzs[...,1]) \
				+ 2. * tf.square(dxyzs[...,2]))),
			(-0.4570457994644658 * dxyzs[...,0] * (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 4. \
				* tf.square(dxyzs[...,2]))),
			(1.445305721320277 * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])) * dxyzs[...,2]),
			(0.5900435899266435 * dxyzs[...,0] * (tf.square(dxyzs[...,0]) - 3. * tf.square(dxyzs[...,1])))], axis=-1) \
				* tf.expand_dims(tf.pow(inv_dist_tensor,3),axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l3_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l3_harmonics], axis=-1)

def tf_spherical_harmonics_4(dxyzs, inv_dist_tensor, invariant=False):
	if invariant:
		lower_order_harmonics = tf_spherical_harmonics_3(dxyzs, inv_dist_tensor, True)
	else:
		lower_order_harmonics = tf_spherical_harmonics_3(dxyzs, inv_dist_tensor)
	l4_harmonics = tf.stack([(2.5033429417967046 * dxyzs[...,0] * dxyzs[...,1] * (-1. * tf.square(dxyzs[...,0]) \
				+ tf.square(dxyzs[...,1]))),
			(-1.7701307697799304 * dxyzs[...,1] * (-3. * tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1])) * dxyzs[...,2]),
			(0.9461746957575601 * dxyzs[...,0] * dxyzs[...,1] * (tf.square(dxyzs[...,0]) \
				+ tf.square(dxyzs[...,1]) - 6. * tf.square(dxyzs[...,2]))),
			(-0.6690465435572892 * dxyzs[...,1] * dxyzs[...,2] * (3. * tf.square(dxyzs[...,0]) + 3. \
				* tf.square(dxyzs[...,1]) - 4. * tf.square(dxyzs[...,2]))),
			(0.10578554691520431 * (3. * tf.pow(dxyzs[...,0], 4) + 3. * tf.pow(dxyzs[...,1], 4) - 24. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 8. * tf.pow(dxyzs[...,2], 4) + 6. \
				* tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 4. * tf.square(dxyzs[...,2])))),
			(-0.6690465435572892 * dxyzs[...,0] * dxyzs[...,2] * (3. * tf.square(dxyzs[...,0]) + 3.
				* tf.square(dxyzs[...,1]) - 4. * tf.square(dxyzs[...,2]))),
			(-0.47308734787878004 * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])) \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 6. * tf.square(dxyzs[...,2]))),
			(1.7701307697799304 * dxyzs[...,0] * (tf.square(dxyzs[...,0]) - 3. * tf.square(dxyzs[...,1])) * dxyzs[...,2]),
			(0.6258357354491761 * (tf.pow(dxyzs[...,0], 4) - 6. * tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) \
				+ tf.pow(dxyzs[...,1], 4)))], axis=-1) \
			* tf.expand_dims(tf.pow(inv_dist_tensor,4),axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l4_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l4_harmonics], axis=-1)

def tf_spherical_harmonics_5(dxyzs, inv_dist_tensor, invariant=False):
	if invariant:
		lower_order_harmonics = tf_spherical_harmonics_4(dxyzs, inv_dist_tensor, True)
	else:
		lower_order_harmonics = tf_spherical_harmonics_4(dxyzs, inv_dist_tensor)
	l5_harmonics = tf.stack([(0.6563820568401701 * dxyzs[...,1] * (5. * tf.pow(dxyzs[...,0], 4) - 10. \
				* tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) + tf.pow(dxyzs[...,1], 4))),
			(8.302649259524166 * dxyzs[...,0] * dxyzs[...,1] * (-1. * tf.square(dxyzs[...,0]) \
				+ tf.square(dxyzs[...,1])) * dxyzs[...,2]),
			(0.4892382994352504 * dxyzs[...,1] * (-3. * tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1])) \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2]))),
			(4.793536784973324 * dxyzs[...,0] * dxyzs[...,1] * dxyzs[...,2] \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 2. * tf.square(dxyzs[...,2]))),
			(0.45294665119569694 * dxyzs[...,1] * (tf.pow(dxyzs[...,0], 4) + tf.pow(dxyzs[...,1], 4) - 12. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 8. * tf.pow(dxyzs[...,2], 4) + 2. \
				* tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 6. * tf.square(dxyzs[...,2])))),
			(0.1169503224534236 * dxyzs[...,2] * (15. * tf.pow(dxyzs[...,0], 4) + 15. * tf.pow(dxyzs[...,1], 4) \
				- 40. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 8. * tf.pow(dxyzs[...,2], 4) + 10. \
				* tf.square(dxyzs[...,0]) * (3. * tf.square(dxyzs[...,1]) - 4. * tf.square(dxyzs[...,2])))),
			(0.45294665119569694 * dxyzs[...,0] * (tf.pow(dxyzs[...,0], 4) + tf.pow(dxyzs[...,1], 4) - 12. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 8. * tf.pow(dxyzs[...,2], 4) + 2. \
				* tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 6. * tf.square(dxyzs[...,2])))),
			(-2.396768392486662 * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])) * dxyzs[...,2] \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 2. * tf.square(dxyzs[...,2]))),
			(-0.4892382994352504 * dxyzs[...,0] * (tf.square(dxyzs[...,0]) - 3. * tf.square(dxyzs[...,1])) \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2]))),
			(2.0756623148810416 * (tf.pow(dxyzs[...,0], 4) - 6. * tf.square(dxyzs[...,0]) \
				* tf.square(dxyzs[...,1]) + tf.pow(dxyzs[...,1], 4)) * dxyzs[...,2]),
			(0.6563820568401701 * dxyzs[...,0] * (tf.pow(dxyzs[...,0], 4) - 10. \
				* tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) + 5. * tf.pow(dxyzs[...,1], 4)))], axis=-1) \
			* tf.expand_dims(tf.pow(inv_dist_tensor,5),axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l5_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l5_harmonics], axis=-1)

def tf_spherical_harmonics_6(dxyzs, inv_dist_tensor, invariant=False):
	if invariant:
		lower_order_harmonics = tf_spherical_harmonics_5(dxyzs, inv_dist_tensor, True)
	else:
		lower_order_harmonics = tf_spherical_harmonics_5(dxyzs, inv_dist_tensor)
	l6_harmonics = tf.stack([(-1.3663682103838286 * dxyzs[...,0] * dxyzs[...,1] * (3. * tf.pow(dxyzs[...,0], 4) \
				- 10. * tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) + 3. * tf.pow(dxyzs[...,1], 4))),
			(2.366619162231752 * dxyzs[...,1] * (5. * tf.pow(dxyzs[...,0], 4) - 10. * tf.square(dxyzs[...,0]) \
				* tf.square(dxyzs[...,1]) + tf.pow(dxyzs[...,1], 4)) * dxyzs[...,2]),
			(2.0182596029148967 * dxyzs[...,0] * dxyzs[...,1] * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])) \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2]))),
			(0.9212052595149236 * dxyzs[...,1] * (-3. * tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1])) \
				* dxyzs[...,2] * (3. * tf.square(dxyzs[...,0]) + 3. * tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2]))),
			(-0.9212052595149236 * dxyzs[...,0] * dxyzs[...,1] * (tf.pow(dxyzs[...,0], 4) + tf.pow(dxyzs[...,1], 4) \
				- 16. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 16. * tf.pow(dxyzs[...,2], 4) \
				+ 2. * tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2])))),
			(0.5826213625187314 * dxyzs[...,1] * dxyzs[...,2] * (5. * tf.pow(dxyzs[...,0], 4) + 5. * tf.pow(dxyzs[...,1], 4) \
				- 20. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 8. * tf.pow(dxyzs[...,2], 4) \
				+ 10. * tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 2. * tf.square(dxyzs[...,2])))),
			(-0.06356920226762842 * (5. * tf.pow(dxyzs[...,0], 6) + 5. * tf.pow(dxyzs[...,1], 6) - 90. \
				* tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) + 120. * tf.square(dxyzs[...,1]) \
				* tf.pow(dxyzs[...,2], 4) - 16. * tf.pow(dxyzs[...,2], 6) + 15. * tf.pow(dxyzs[...,0], 4) \
				* (tf.square(dxyzs[...,1]) - 6. * tf.square(dxyzs[...,2])) + 15. * tf.square(dxyzs[...,0]) \
				* (tf.pow(dxyzs[...,1], 4) - 12. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 8. * tf.pow(dxyzs[...,2], 4)))),
			(0.5826213625187314 * dxyzs[...,0] * dxyzs[...,2] * (5. * tf.pow(dxyzs[...,0], 4) + 5. \
				* tf.pow(dxyzs[...,1], 4) - 20. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 8. \
				* tf.pow(dxyzs[...,2], 4) + 10. * tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 2. \
				* tf.square(dxyzs[...,2])))),
			(0.4606026297574618 * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])) * (tf.pow(dxyzs[...,0], 4) \
				+ tf.pow(dxyzs[...,1], 4) - 16. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 16. \
				* tf.pow(dxyzs[...,2], 4) + 2. * tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 8. \
				* tf.square(dxyzs[...,2])))),
			(-0.9212052595149236 * dxyzs[...,0] * (tf.square(dxyzs[...,0]) - 3. * tf.square(dxyzs[...,1])) * dxyzs[...,2] \
				* (3. * tf.square(dxyzs[...,0]) + 3. * tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2]))),
			(-0.5045649007287242 * (tf.pow(dxyzs[...,0], 4) - 6. * tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) \
				+ tf.pow(dxyzs[...,1], 4)) * (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2]))),
			(2.366619162231752 * dxyzs[...,0] * (tf.pow(dxyzs[...,0], 4) - 10. * tf.square(dxyzs[...,0]) \
				* tf.square(dxyzs[...,1]) + 5. * tf.pow(dxyzs[...,1], 4)) * dxyzs[...,2]),
			(0.6831841051919143 * (tf.pow(dxyzs[...,0], 6) - 15. * tf.pow(dxyzs[...,0], 4) * tf.square(dxyzs[...,1]) \
				+ 15. * tf.square(dxyzs[...,0]) * tf.pow(dxyzs[...,1], 4) - 1. * tf.pow(dxyzs[...,1], 6)))], axis=-1) \
			* tf.expand_dims(tf.pow(inv_dist_tensor,6),axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l6_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l6_harmonics], axis=-1)

def tf_spherical_harmonics_7(dxyzs, inv_dist_tensor, invariant=False):
	if invariant:
		lower_order_harmonics = tf_spherical_harmonics_6(dxyzs, inv_dist_tensor, True)
	else:
		lower_order_harmonics = tf_spherical_harmonics_6(dxyzs, inv_dist_tensor)
	l7_harmonics = tf.stack([(-0.7071627325245962 * dxyzs[...,1] * (-7. * tf.pow(dxyzs[...,0], 6) + 35. \
				* tf.pow(dxyzs[...,0], 4) * tf.square(dxyzs[...,1]) - 21. * tf.square(dxyzs[...,0]) \
				* tf.pow(dxyzs[...,1], 4) + tf.pow(dxyzs[...,1], 6))),
			(-5.291921323603801 * dxyzs[...,0] * dxyzs[...,1] * (3. * tf.pow(dxyzs[...,0], 4) - 10. \
				* tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) + 3. * tf.pow(dxyzs[...,1], 4)) * dxyzs[...,2]),
			(-0.5189155787202604 * dxyzs[...,1] * (5. * tf.pow(dxyzs[...,0], 4) - 10. \
				* tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) + tf.pow(dxyzs[...,1], 4)) \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 12. * tf.square(dxyzs[...,2]))),
			(4.151324629762083 * dxyzs[...,0] * dxyzs[...,1] * (tf.square(dxyzs[...,0]) - 1. \
				* tf.square(dxyzs[...,1])) * dxyzs[...,2] * (3. * tf.square(dxyzs[...,0]) + 3. \
				* tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2]))),
			(-0.15645893386229404 * dxyzs[...,1] * (-3. * tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1])) \
				* (3. * tf.pow(dxyzs[...,0], 4) + 3. * tf.pow(dxyzs[...,1], 4) - 60. * tf.square(dxyzs[...,1]) \
				* tf.square(dxyzs[...,2]) + 80. * tf.pow(dxyzs[...,2], 4) + 6. * tf.square(dxyzs[...,0]) \
				* (tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2])))),
			(-0.4425326924449826 * dxyzs[...,0] * dxyzs[...,1] * dxyzs[...,2] * (15. * tf.pow(dxyzs[...,0], 4) \
				+ 15. * tf.pow(dxyzs[...,1], 4) - 80. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 48. \
				* tf.pow(dxyzs[...,2], 4) + 10. * tf.square(dxyzs[...,0]) * (3. * tf.square(dxyzs[...,1]) - 8. \
				* tf.square(dxyzs[...,2])))),
			(-0.0903316075825173 * dxyzs[...,1] * (5. * tf.pow(dxyzs[...,0], 6) + 5. * tf.pow(dxyzs[...,1], 6) - 120. \
				* tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) + 240. * tf.square(dxyzs[...,1]) \
				* tf.pow(dxyzs[...,2], 4) - 64. * tf.pow(dxyzs[...,2], 6) + 15. * tf.pow(dxyzs[...,0], 4) \
				* (tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2])) + 15. * tf.square(dxyzs[...,0]) \
				* (tf.pow(dxyzs[...,1], 4) - 16. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 16. \
				* tf.pow(dxyzs[...,2], 4)))),
			(0.06828427691200495 * dxyzs[...,2] * (-35. * tf.pow(dxyzs[...,0], 6) - 35. * tf.pow(dxyzs[...,1], 6) \
				+ 210. * tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) - 168. * tf.square(dxyzs[...,1]) \
				* tf.pow(dxyzs[...,2], 4) + 16. * tf.pow(dxyzs[...,2], 6) - 105. * tf.pow(dxyzs[...,0], 4) \
				* (tf.square(dxyzs[...,1]) - 2. * tf.square(dxyzs[...,2])) - 21. * tf.square(dxyzs[...,0]) \
				* (5. * tf.pow(dxyzs[...,1], 4) - 20. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) \
				+ 8. * tf.pow(dxyzs[...,2], 4)))),
			(-0.0903316075825173 * dxyzs[...,0] * (5. * tf.pow(dxyzs[...,0], 6) + 5. * tf.pow(dxyzs[...,1], 6) \
				- 120. * tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) + 240. * tf.square(dxyzs[...,1]) \
				* tf.pow(dxyzs[...,2], 4) - 64. * tf.pow(dxyzs[...,2], 6) + 15. * tf.pow(dxyzs[...,0], 4) \
				* (tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2])) + 15. * tf.square(dxyzs[...,0]) \
				* (tf.pow(dxyzs[...,1], 4) - 16. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 16. \
				* tf.pow(dxyzs[...,2], 4)))),
			(0.2212663462224913 * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])) * dxyzs[...,2] \
				* (15. * tf.pow(dxyzs[...,0], 4) + 15. * tf.pow(dxyzs[...,1], 4) - 80. * tf.square(dxyzs[...,1]) \
				* tf.square(dxyzs[...,2]) + 48. * tf.pow(dxyzs[...,2], 4) + 10. * tf.square(dxyzs[...,0]) \
				* (3. * tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2])))),
			(0.15645893386229404 * dxyzs[...,0] * (tf.square(dxyzs[...,0]) - 3. * tf.square(dxyzs[...,1])) \
				* (3. * tf.pow(dxyzs[...,0], 4) + 3. * tf.pow(dxyzs[...,1], 4) - 60. * tf.square(dxyzs[...,1]) \
				* tf.square(dxyzs[...,2]) + 80. * tf.pow(dxyzs[...,2], 4) + 6. * tf.square(dxyzs[...,0]) \
				* (tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2])))),
			(-1.0378311574405208 * (tf.pow(dxyzs[...,0], 4) - 6. * tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) \
				+ tf.pow(dxyzs[...,1], 4)) * dxyzs[...,2] * (3. * tf.square(dxyzs[...,0]) \
				+ 3. * tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2]))),
			(-0.5189155787202604 * dxyzs[...,0] * (tf.pow(dxyzs[...,0], 4) - 10. * tf.square(dxyzs[...,0]) \
				* tf.square(dxyzs[...,1]) + 5. * tf.pow(dxyzs[...,1], 4)) * (tf.square(dxyzs[...,0]) \
				+ tf.square(dxyzs[...,1]) - 12. * tf.square(dxyzs[...,2]))),
			(2.6459606618019005 * (tf.pow(dxyzs[...,0], 6) - 15. * tf.pow(dxyzs[...,0], 4) * tf.square(dxyzs[...,1]) \
				+ 15. * tf.square(dxyzs[...,0]) * tf.pow(dxyzs[...,1], 4) - 1. * tf.pow(dxyzs[...,1], 6)) * dxyzs[...,2]),
			(0.7071627325245962 * dxyzs[...,0] * (tf.pow(dxyzs[...,0], 6) - 21. * tf.pow(dxyzs[...,0], 4) \
				* tf.square(dxyzs[...,1]) + 35. * tf.square(dxyzs[...,0]) * tf.pow(dxyzs[...,1], 4) - 7. \
				* tf.pow(dxyzs[...,1], 6)))], axis=-1) \
			* tf.expand_dims(tf.pow(inv_dist_tensor,7),axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l7_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l7_harmonics], axis=-1)

def tf_spherical_harmonics_8(dxyzs, inv_dist_tensor, invariant=False):
	if invariant:
		lower_order_harmonics = tf_spherical_harmonics_7(dxyzs, inv_dist_tensor, True)
	else:
		lower_order_harmonics = tf_spherical_harmonics_7(dxyzs, inv_dist_tensor)
	l8_harmonics = tf.stack([(-5.831413281398639 * dxyzs[...,0] * dxyzs[...,1] * (tf.pow(dxyzs[...,0], 6) \
				- 7. * tf.pow(dxyzs[...,0], 4) * tf.square(dxyzs[...,1]) + 7. * tf.square(dxyzs[...,0]) \
				* tf.pow(dxyzs[...,1], 4) - 1. * tf.pow(dxyzs[...,1], 6))),
			(-2.9157066406993195 * dxyzs[...,1] * (-7. * tf.pow(dxyzs[...,0], 6) + 35. * tf.pow(dxyzs[...,0], 4) \
				* tf.square(dxyzs[...,1]) - 21. * tf.square(dxyzs[...,0]) * tf.pow(dxyzs[...,1], 4) \
				+ tf.pow(dxyzs[...,1], 6)) * dxyzs[...,2]),
			(1.0646655321190852 * dxyzs[...,0] * dxyzs[...,1] * (3. * tf.pow(dxyzs[...,0], 4) - 10. \
				* tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) + 3. * tf.pow(dxyzs[...,1], 4)) \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 14. * tf.square(dxyzs[...,2]))),
			(-3.449910622098108 * dxyzs[...,1] * (5. * tf.pow(dxyzs[...,0], 4) - 10. * tf.square(dxyzs[...,0]) \
				* tf.square(dxyzs[...,1]) + tf.pow(dxyzs[...,1], 4)) * dxyzs[...,2] * (tf.square(dxyzs[...,0]) \
				+ tf.square(dxyzs[...,1]) - 4. * tf.square(dxyzs[...,2]))),
			(-1.9136660990373227 * dxyzs[...,0] * dxyzs[...,1] * (tf.square(dxyzs[...,0]) - 1. \
				* tf.square(dxyzs[...,1])) * (tf.pow(dxyzs[...,0], 4) + tf.pow(dxyzs[...,1], 4) - 24. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 40. * tf.pow(dxyzs[...,2], 4) + 2. \
				* tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 12. * tf.square(dxyzs[...,2])))),
			(-1.2352661552955442 * dxyzs[...,1] * (-3. * tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1])) \
				* dxyzs[...,2] * (3. * tf.pow(dxyzs[...,0], 4) + 3. * tf.pow(dxyzs[...,1], 4) - 20. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 16. * tf.pow(dxyzs[...,2], 4) \
				+ tf.square(dxyzs[...,0]) * (6. * tf.square(dxyzs[...,1]) - 20. * tf.square(dxyzs[...,2])))),
			(0.912304516869819 * dxyzs[...,0] * dxyzs[...,1] * (tf.pow(dxyzs[...,0], 6) + tf.pow(dxyzs[...,1], 6) \
				- 30. * tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) + 80. * tf.square(dxyzs[...,1]) \
				* tf.pow(dxyzs[...,2], 4) - 32. * tf.pow(dxyzs[...,2], 6) + 3. * tf.pow(dxyzs[...,0], 4) \
				* (tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2])) + tf.square(dxyzs[...,0]) \
				* (3. * tf.pow(dxyzs[...,1], 4) - 60. * tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 80. \
				* tf.pow(dxyzs[...,2], 4)))),
			(-0.10904124589877995 * dxyzs[...,1] * dxyzs[...,2] * (35. * tf.pow(dxyzs[...,0], 6) + 35. \
				* tf.pow(dxyzs[...,1], 6) - 280. * tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) + 336. \
				* tf.square(dxyzs[...,1]) * tf.pow(dxyzs[...,2], 4) - 64. * tf.pow(dxyzs[...,2], 6) + 35. \
				* tf.pow(dxyzs[...,0], 4) * (3. * tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2])) + 7. \
				* tf.square(dxyzs[...,0]) * (15. * tf.pow(dxyzs[...,1], 4) - 80. * tf.square(dxyzs[...,1]) \
				* tf.square(dxyzs[...,2]) + 48. * tf.pow(dxyzs[...,2], 4)))),
			(0.009086770491564996 * (35. * tf.pow(dxyzs[...,0], 8) + 35. * tf.pow(dxyzs[...,1], 8) - 1120. \
				* tf.pow(dxyzs[...,1], 6) * tf.square(dxyzs[...,2]) + 3360. * tf.pow(dxyzs[...,1], 4) \
				* tf.pow(dxyzs[...,2], 4) - 1792. * tf.square(dxyzs[...,1]) * tf.pow(dxyzs[...,2], 6) + 128. \
				* tf.pow(dxyzs[...,2], 8) + 140. * tf.pow(dxyzs[...,0], 6) * (tf.square(dxyzs[...,1]) - 8. \
				* tf.square(dxyzs[...,2])) + 210. * tf.pow(dxyzs[...,0], 4) * (tf.pow(dxyzs[...,1], 4) - 16. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 16. * tf.pow(dxyzs[...,2], 4)) + 28. \
				* tf.square(dxyzs[...,0]) * (5. * tf.pow(dxyzs[...,1], 6) - 120. * tf.pow(dxyzs[...,1], 4) \
				* tf.square(dxyzs[...,2]) + 240. * tf.square(dxyzs[...,1]) * tf.pow(dxyzs[...,2], 4) - 64. \
				* tf.pow(dxyzs[...,2], 6)))),
			(-0.10904124589877995 * dxyzs[...,0] * dxyzs[...,2] * (35. * tf.pow(dxyzs[...,0], 6) + 35. \
				* tf.pow(dxyzs[...,1], 6) - 280. * tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) + 336. \
				* tf.square(dxyzs[...,1]) * tf.pow(dxyzs[...,2], 4) - 64. * tf.pow(dxyzs[...,2], 6) + 35. \
				* tf.pow(dxyzs[...,0], 4) * (3. * tf.square(dxyzs[...,1]) - 8. * tf.square(dxyzs[...,2])) + 7. \
				* tf.square(dxyzs[...,0]) * (15. * tf.pow(dxyzs[...,1], 4) - 80. * tf.square(dxyzs[...,1]) \
				* tf.square(dxyzs[...,2]) + 48. * tf.pow(dxyzs[...,2], 4)))),
			(-0.4561522584349095 * (tf.square(dxyzs[...,0]) - 1. * tf.square(dxyzs[...,1])) * (tf.pow(dxyzs[...,0], 6) \
				+ tf.pow(dxyzs[...,1], 6) - 30. * tf.pow(dxyzs[...,1], 4) * tf.square(dxyzs[...,2]) + 80. \
				* tf.square(dxyzs[...,1]) * tf.pow(dxyzs[...,2], 4) - 32. * tf.pow(dxyzs[...,2], 6) + 3. \
				* tf.pow(dxyzs[...,0], 4) * (tf.square(dxyzs[...,1]) - 10. * tf.square(dxyzs[...,2])) \
				+ tf.square(dxyzs[...,0]) * (3. * tf.pow(dxyzs[...,1], 4) - 60. * tf.square(dxyzs[...,1]) \
				* tf.square(dxyzs[...,2]) + 80. * tf.pow(dxyzs[...,2], 4)))),
			(1.2352661552955442 * dxyzs[...,0] * (tf.square(dxyzs[...,0]) - 3. * tf.square(dxyzs[...,1])) \
				* dxyzs[...,2] * (3. * tf.pow(dxyzs[...,0], 4) + 3. * tf.pow(dxyzs[...,1], 4) - 20. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 16. * tf.pow(dxyzs[...,2], 4) \
				+ tf.square(dxyzs[...,0]) * (6. * tf.square(dxyzs[...,1]) - 20. * tf.square(dxyzs[...,2])))),
			(0.47841652475933066 * (tf.pow(dxyzs[...,0], 4) - 6. * tf.square(dxyzs[...,0]) * tf.square(dxyzs[...,1]) \
				+ tf.pow(dxyzs[...,1], 4)) * (tf.pow(dxyzs[...,0], 4) + tf.pow(dxyzs[...,1], 4) - 24. \
				* tf.square(dxyzs[...,1]) * tf.square(dxyzs[...,2]) + 40. * tf.pow(dxyzs[...,2], 4) + 2. \
				* tf.square(dxyzs[...,0]) * (tf.square(dxyzs[...,1]) - 12. * tf.square(dxyzs[...,2])))),
			(-3.449910622098108 * dxyzs[...,0] * (tf.pow(dxyzs[...,0], 4) - 10. * tf.square(dxyzs[...,0]) \
				* tf.square(dxyzs[...,1]) + 5. * tf.pow(dxyzs[...,1], 4)) * dxyzs[...,2] * (tf.square(dxyzs[...,0]) \
				+ tf.square(dxyzs[...,1]) - 4. * tf.square(dxyzs[...,2]))),
			(-0.5323327660595426 * (tf.pow(dxyzs[...,0], 6) - 15. * tf.pow(dxyzs[...,0], 4) * tf.square(dxyzs[...,1]) \
				+ 15. * tf.square(dxyzs[...,0]) * tf.pow(dxyzs[...,1], 4) - 1. * tf.pow(dxyzs[...,1], 6)) \
				* (tf.square(dxyzs[...,0]) + tf.square(dxyzs[...,1]) - 14. * tf.square(dxyzs[...,2]))),
			(2.9157066406993195 * dxyzs[...,0] * (tf.pow(dxyzs[...,0], 6) - 21. * tf.pow(dxyzs[...,0], 4) \
				* tf.square(dxyzs[...,1]) + 35. * tf.square(dxyzs[...,0]) * tf.pow(dxyzs[...,1], 4) - 7. \
				* tf.pow(dxyzs[...,1], 6)) * dxyzs[...,2]),
			(0.7289266601748299 * (tf.pow(dxyzs[...,0], 8) - 28. * tf.pow(dxyzs[...,0], 6) * tf.square(dxyzs[...,1]) \
				+ 70. * tf.pow(dxyzs[...,0], 4) * tf.pow(dxyzs[...,1], 4) - 28. * tf.square(dxyzs[...,0]) \
				* tf.pow(dxyzs[...,1], 6) + tf.pow(dxyzs[...,1], 8)))], axis=-1) \
			* tf.expand_dims(tf.pow(inv_dist_tensor,8),axis=-1)
	if invariant:
		return tf.concat([lower_order_harmonics, tf.norm(l8_harmonics+1.e-16, axis=-1, keep_dims=True)], axis=-1)
	else:
		return tf.concat([lower_order_harmonics, l8_harmonics], axis=-1)

def tf_spherical_harmonics(dxyzs, dist_tensor, max_l, invariant=False):
	"""
	Args:
		dxyzs: (...) X MaxNAtom X MaxNAtom X 3 (differenced from center of embedding
				ie: ... X i X i = (0.,0.,0.))
		dist_tensor: just tf.norm of the above.
		max_l : integer, maximum angular momentum.
		invariant: whether to return just total angular momentum of a given l.
	Returns:
		(...) X MaxNAtom X MaxNAtom X {NSH = (max_l+1)^2}
	"""
	inv_dist_tensor = tf.where(tf.greater(dist_tensor, 1.e-9), tf.reciprocal(dist_tensor), tf.zeros_like(dist_tensor))
	if max_l == 8:
		harmonics = tf_spherical_harmonics_8(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 7:
		harmonics = tf_spherical_harmonics_7(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 6:
		harmonics = tf_spherical_harmonics_6(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 5:
		harmonics = tf_spherical_harmonics_5(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 4:
		harmonics = tf_spherical_harmonics_4(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 3:
		harmonics = tf_spherical_harmonics_3(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 2:
		harmonics = tf_spherical_harmonics_2(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 1:
		harmonics = tf_spherical_harmonics_1(dxyzs, inv_dist_tensor, invariant)
	elif max_l == 0:
		harmonics = tf_spherical_harmonics_0(inv_dist_tensor)
	else:
		raise Exception("Spherical Harmonics only implemented up to l=8. Choose a lower order")
	return harmonics