In [None]:
from se3cnn.SO3 import basis_transformation_Q_J
from sph_projection_utils import *
# from equivariant_point_autoencoder.utils import *
import torch
import numpy as np
from lie_vae.lie_tools import block_wigner_matrix_multiply
%matplotlib inline
import matplotlib.pyplot as plt
import sys
import math

In [None]:
import plotly
from plotly.graph_objs import *
from plotly import tools
plotly.offline.init_notebook_mode(connected=False)
import plotly.plotly as py

In [None]:
import scipy
def random_rotation_matrix(numpy_random_state, eps=1e-8):
    """
    Generates a random 3D rotation matrix from axis and angle.
    Args:
        numpy_random_state: numpy random state object
    Returns:
        Random rotation matrix.
    """
    rng = numpy_random_state
    axis = rng.randn(3)
    axis /= np.linalg.norm(axis) + eps
    theta = 2 * np.pi * rng.uniform(0.0, 1.0)
    return rotation_matrix(axis, theta)

def rotation_matrix(axis, theta):
    return scipy.linalg.expm(np.cross(np.eye(3), axis * theta))

In [None]:
tetrahedron = np.array([
    [0., 0., 0.], [1., 1., 0], [1., 0., 1.], [0., 1., 1.]
]) - np.array([0.5, 0.5, 0.5]) 

rnd = np.random.RandomState(42)
rnd_rot = random_rotation_matrix(rnd)
rot_tetrahedron = np.einsum('xy,bx->by', rnd_rot, tetrahedron)

octahedron = np.array([
    [0.5, 0.5, 0.], [0.5, 0.5, 1.], 
    [0., 0.5, 0.5], [1., 0.5, 0.5],
    [0.5, 0., 0.5], [0.5, 1., 0.5],
]) - np.array([0.5, 0.5, 0.5]) 

tetrahedron /= np.linalg.norm(tetrahedron, keepdims=True, axis=-1)
octahedron /= np.linalg.norm(octahedron, keepdims=True, axis=-1)

In [None]:
vector_1 = np.array([1.,1.,1.])
vector_1 /= np.linalg.norm(vector_1)
vector_2 = np.array([0.,0.,1.])
axis = np.cross(vector_1, vector_2)
norm_axis = np.linalg.norm(axis)
theta = np.arcsin(norm_axis) #  / (2 * np.pi)
# print(theta / (2 * np.pi))

In [None]:
rot_oct_matrix = rotation_matrix(axis / norm_axis, theta)

In [None]:
rotated_octahedron = np.einsum('ix,xy->iy', octahedron, rot_oct_matrix)

In [None]:
upper = rotated_octahedron[np.where(rotated_octahedron[:,2] > 0)]
lower = rotated_octahedron[np.where(rotated_octahedron[:,2] <= 0)]

In [None]:
twist = rotation_matrix(np.array([0., 0., 1.]), np.pi / 2.)
negative_twist = rotation_matrix(np.array([0., 0., 1.]), -np.pi / 2.)

In [None]:
intermediate_twist = rotation_matrix(np.array([0., 0., 1.]), np.pi / 4.)
negative_intermediate_twist = rotation_matrix(np.array([0., 0., 1.]), -np.pi / 4.)

In [None]:
trig_upper = np.einsum('ix,xy->iy', upper, twist)
trig_lower = np.einsum('ix,xy->iy', lower, negative_twist)

In [None]:
inter_upper = np.einsum('ix,xy->iy', upper, intermediate_twist)
inter_lower = np.einsum('ix,xy->iy', lower, negative_intermediate_twist)

In [None]:
trigonal_prism = np.concatenate((trig_upper, trig_lower), axis=0)
trigonal_prism.shape

In [None]:
inter = np.concatenate((inter_upper, inter_lower), axis=0)

In [None]:
L_max = 8
oct_angles = xyz_to_phi_theta(octahedron)
oct_coeffs = get_Ylm_coeffs(*oct_angles, L_max=L_max)
rot_oct_angles = xyz_to_phi_theta(rotated_octahedron)
rot_oct_coeffs = get_Ylm_coeffs(*rot_oct_angles, L_max=L_max)
trig_prism_angles = xyz_to_phi_theta(trigonal_prism)
trig_prism_coeffs = get_Ylm_coeffs(*trig_prism_angles, L_max=L_max)

In [None]:
oscillate = lambda theta: np.sin(theta) * rot_oct_coeffs + np.cos(theta) * trig_prism_coeffs

In [None]:
def oscillate_signals(signal_1, signal_2, eps=1e-11):
    inner = signal_1 * signal_2
    average_parallel = (inner / (signal_1 + eps) + inner / (signal_2 + eps)) / 2.
    signal_1_perp = signal_1 - average_parallel
    signal_2_perp = signal_2 - average_parallel
    return lambda theta: np.cos(theta) * signal_1_perp + np.sin(theta) * signal_2_perp + average_parallel

In [None]:
rot_oct_perp = rot_oct_coeffs - (rot_oct_coeffs * trig_prism_coeffs) / (trig_prism_coeffs + 1e-8)
trig_perp = trig_prism_coeffs - (rot_oct_coeffs * trig_prism_coeffs) / (rot_oct_coeffs + 1e-8)
parallel = rot_oct_coeffs * trig_prism_coeffs / (trig_prism_coeffs + 1e-8)
oscillate_perp = lambda theta: np.sin(theta) * rot_oct_perp + np.cos(theta) * trig_perp + parallel

In [None]:
plot_coeffs = [oscillate(theta) for theta in np.linspace(0, 2 * np.pi, 10)]
# plot_perp_coeffs = [oscillate_perp(theta) for theta in np.linspace(0, 2 * np.pi, 10)]
plot_perp_coeffs = [oscillate_signals(rot_oct_coeffs, trig_prism_coeffs)(theta) \
                    for theta in np.linspace(0, np.pi, 20)]
plot_rot_perp_coeffs = [oscillate_signals(oct_coeffs, trig_prism_coeffs)(theta) \
                        for theta in np.linspace(0, np.pi, 20)]
# plot_rot_perp_coeffs = [oscillate_rot_perp(theta) for theta in np.linspace(0, np.pi, 20)]

In [None]:
fig = visualize_coeff_series(plot_coeffs, L_max=L_max, num_angular_points=50,
                             cmin=-5, cmax=5)

In [None]:
fig_perp = visualize_coeff_series(plot_perp_coeffs, L_max=L_max, num_angular_points=50,
                                  cmin=-5, cmax=5)

In [None]:
fig_rot_perp = visualize_coeff_series(plot_rot_perp_coeffs, L_max=L_max,
                                      num_angular_points=50,
                                      cmin=-5, cmax=5)

In [None]:
plotly.offline.iplot(fig_rot_perp)

In [None]:
plotly.offline.iplot(fig_perp)

In [None]:
# Plot coefficients
array = np.array(plot_perp_coeffs)
fig, axes = plt.subplots(1, L_max + 1, figsize=(10,5))
for L in range(L_max + 1):
    print((L + 1)**2 - (2 * L + 1),(L + 1)**2)
    for i,M in enumerate(range(-L, L+1, 1)):
        axes[L].plot(array[:,(L + 1)**2 - (2 * L + 1) + i:(L + 1)**2 - (2 * L + 1) + i + 1], label="M={}".format(M))
    axes[L].set_ylim([-3,3])
    axes[L].set_title("L={}".format(L))
    axes[L].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
                   fancybox=True, shadow=True, ncol=1)

In [None]:
# Plot coefficients
array = np.array(plot_rot_perp_coeffs)
fig, axes = plt.subplots(1, L_max + 1, figsize=(10,5))
for L in range(L_max + 1):
    print((L + 1)**2 - (2 * L + 1),(L + 1)**2)
    for i,M in enumerate(range(-L, L+1, 1)):
        axes[L].plot(array[:,(L + 1)**2 - (2 * L + 1) + i:(L + 1)**2 - (2 * L + 1) + i + 1], label="M={}".format(M))
    axes[L].set_ylim([-3,3])
    axes[L].set_title("L={}".format(L))
    axes[L].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
                   fancybox=True, shadow=True, ncol=1)

In [None]:
# plot coefficients as a function of twist
coeffs = []
upper_coeffs = []
lower_coeffs = []
L_max = 6
for angle in np.linspace(0, np.pi / 3., 50):
    rot = rotation_matrix(np.array([0., 0., 1.]), angle)
    neg_rot = rotation_matrix(np.array([0., 0., 1.]), -angle)
    new_upper = np.einsum('ix,xy->iy', upper, rot)
    new_lower = np.einsum('ix,xy->iy', lower, neg_rot)
    
    sph_angles = xyz_to_phi_theta(new_upper)
    sph_coeffs = get_Ylm_coeffs(*sph_angles, L_max=L_max)
    upper_coeffs.append(sph_coeffs)
    
    sph_angles = xyz_to_phi_theta(new_lower)
    sph_coeffs = get_Ylm_coeffs(*sph_angles, L_max=L_max)
    lower_coeffs.append(sph_coeffs)
    
    new_shape = np.concatenate((new_upper, new_lower), axis=0)
    sph_angles = xyz_to_phi_theta(new_shape)
    sph_coeffs = get_Ylm_coeffs(*sph_angles, L_max=L_max)
    coeffs.append(sph_coeffs)

In [None]:
# plot coefficients as a function of twist
rot_coeffs = []
L_max = 6
num_snaps = 50
for i, angle in enumerate(np.linspace(0, np.pi / 6., num_snaps)):
    rot = rotation_matrix(np.array([0., 0., 1.]), angle)
    neg_rot = rotation_matrix(np.array([0., 0., 1.]), -angle)
    new_upper = np.einsum('ix,xy->iy', upper, rot)
    new_lower = np.einsum('ix,xy->iy', lower, neg_rot)
    new_shape = np.concatenate((new_upper, new_lower), axis=0)
    rot_mat = rotation_matrix(axis / norm_axis, - (num_snaps - 1 - i)/(num_snaps - 1) * theta)
    new_shape = np.einsum('ix,xy->iy', new_shape, rot_mat)
    
    sph_angles = xyz_to_phi_theta(new_shape)
    sph_coeffs = get_Ylm_coeffs(*sph_angles, L_max=L_max)
    rot_coeffs.append(sph_coeffs)

In [None]:
# Plot coefficients along octahedral to trigonal distortion
np_rot_coeffs = np.array(rot_coeffs)
fig, axes = plt.subplots(1, L_max + 1, figsize=(10,5))
for L in range(L_max + 1):
    print((L + 1)**2 - (2 * L + 1),(L + 1)**2)
    for i,M in enumerate(range(-L, L+1, 1)):
        axes[L].plot(np_rot_coeffs[:,(L + 1)**2 - (2 * L + 1) + i:(L + 1)**2 - (2 * L + 1) + i + 1], label="M={}".format(M))
    axes[L].set_ylim([-3,3])
    axes[L].set_title("L={}".format(L))
    axes[L].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
                   fancybox=True, shadow=True, ncol=1)

In [None]:
fig_rot_perp_analytic = visualize_coeff_series(rot_coeffs, L_max=6,
                                      num_angular_points=50,
                                      cmin=-5, cmax=5)

In [None]:
plotly.offline.iplot(fig_rot_perp_analytic)

In [None]:
# Plot coefficients along octahedral to trigonal distortion
np_coeffs = np.array(coeffs)
fig, axes = plt.subplots(1, L_max + 1, figsize=(10,5))
for L in range(L_max + 1):
    print((L + 1)**2 - (2 * L + 1),(L + 1)**2)
    for i,M in enumerate(range(-L, L+1, 1)):
        axes[L].plot(np_coeffs[:,(L + 1)**2 - (2 * L + 1) + i:(L + 1)**2 - (2 * L + 1) + i + 1], label="M={}".format(M))
    axes[L].set_ylim([-3,3])
    axes[L].set_title("L={}".format(L))
    axes[L].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
                   fancybox=True, shadow=True, ncol=1)