<p style="font-size:32px; font-weight: bolder; text-align: center"> Rotations, equivariance and <br/> symmetry-adapted regression </p>
<p style="text-align: center"><i> authored by: <a href="mailto:michele.ceriotti@gmail.com"> Michele Ceriotti </a></i></p>

This notebook discusses the concept of equivariance, with a specific focus on the rotation group. We will learn about Cartesian rotations, spherical harmonics, Wigner matrices.
We will see how these concepts apply to some of the equivariant descriptors used in machine learning, and how it is possible to build simple regression models that yield rotationally equivariant predictions for vectorial or tensorial properties.

### Packages and dependencies

This module uses some utility functions from `scipy` and `spherical` to handle rotations, and `rascaline` to compute descriptors. 

In [None]:
%matplotlib widget
# scwidgets import
import matplotlib as mpl
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as mplot3d
import chemiscope

import ipywidgets
from ipywidgets import FloatSlider, IntSlider, Checkbox, Dropdown, HBox, Layout, HTML, Text

from markdown import markdown as mdwn

import scwidgets
from scwidgets.check import (
    Check,
    CheckRegistry,
    assert_numpy_allclose,
    assert_numpy_floating_sub_dtype,
    assert_shape,
    assert_type,
)
from scwidgets.code import ParameterPanel, CodeInput
from scwidgets.cue import CueObject, CueFigure
from scwidgets.exercise import CodeExercise, TextExercise, ExerciseRegistry

In [None]:
import numpy as np
import ase, ase.io
import itertools
from copy import deepcopy
from tqdm.notebook import tqdm

import rascaline
from metatensor import mean_over_samples, Labels, TensorMap, TensorBlock, slice_block

from sklearn.decomposition import PCA
from sklearn.linear_model import RidgeCV

from scipy.spatial.transform import Rotation
from spherical import wigner_D
import quaternionic

In [None]:
def _real2complex(L):
    """
    Computes a matrix that can be used to convert from real to complex-valued
    spherical harmonics(coefficients) of order L.

    It's meant to be applied to the left, ``real2complex @ [-L..L]``.
    """
    result = np.zeros((2 * L + 1, 2 * L + 1), dtype=np.complex128)

    I_SQRT_2 = 1.0 / np.sqrt(2)

    for m in range(-L, L + 1):
        if m < 0:
            result[L - m, L + m] = I_SQRT_2 * 1j * (-1) ** m
            result[L + m, L + m] = -I_SQRT_2 * 1j

        if m == 0:
            result[L, L] = 1.0

        if m > 0:
            result[L + m, L + m] = I_SQRT_2 * (-1) ** m
            result[L - m, L + m] = I_SQRT_2

    return result
    
def rotation_matrix(alpha, beta, gamma):
    """A Cartesian rotation matrix in the appropriate convention
    (ZYZ, implicit rotations) to be consistent with the common Wigner D definition.
    (alpha, beta, gamma) are Euler angles (radians)."""
    return Rotation.from_euler("ZYZ", [alpha, beta, gamma]).as_matrix()

def wigner_d_real(l, alpha, beta, gamma):
    """Computes a real-valued Wigner D matrix
     D^l_{mm'}(alpha, beta, gamma)
    (alpha, beta, gamma) are Euler angles (radians, ZYZ convention) and l the irrep.
    Rotates real spherical harmonics by application from the left.
    """

    R_euler = quaternionic.array.from_euler_angles(alpha, beta, gamma)
    wD = wigner_D(R_euler,0,l)[((4*l*l-1)*l)//3:].reshape(2*l+1,-1).conj()
    r2c = _real2complex(l)
    
    return np.real(np.conjugate(r2c.T@wD)@r2c)

In [None]:
# set CSS style for code-hide
scwidgets.get_css_style()

In [None]:
exercise_registry = ExerciseRegistry(filename_prefix="module_03")
exercise_registry

In [None]:
check_registry = CheckRegistry()
check_registry

In [None]:
module_summary = TextExercise(
    exercise_description="""You can use this box to make general considerations, 
    or keep track of your doubts and questions about this notebook.""",
    exercise_registry=exercise_registry,
    exercise_title="Module comments",
    exercise_key="00"
)
display(module_summary)

# The rotation group 

Rotations describe the changes in the orientation in space of a rigid body relative to a fixed coordinate system. The mathematical description of rotations is notoriously tedious, with a plethora of different conventions that are often applied inconsistently in different works. 
If you are the kind of person who enjoys this stuff, you this [wikipedia article](https://en.wikipedia.org/wiki/Rotation_formalisms_in_three_dimensions) provides a comprehensive overview. 

In this exercise we are going to define rotations in terms of [Euler angles](https://en.wikipedia.org/wiki/Euler_angles) in the so-called ZYZ convention, in which the rotation is identified by three angles $(\alpha, \beta, \gamma)$ where $\alpha$ and $\gamma$ are periodic and can be chosen in the interval $[-\pi,\pi]$, and $\beta$ in the interval $[0,\pi]$.

To get a grasp of what Euler angles do, and why you need to define three angles to properly characterize the orientation of a structure, you can play around with the following visualization.

In [None]:
ex01_pb =  ParameterPanel(
    alpha = FloatSlider(value=0,min=-np.pi,max=np.pi,step=0.01,description=r'$\alpha$'),
    beta = FloatSlider(value=0,min=0,max=np.pi,step=0.01,description=r'$\beta$'),
    gamma = FloatSlider(value=0,min=-np.pi,max=np.pi,step=0.01,description=r'$\gamma$'))

In [None]:
ex01_fig = plt.figure(tight_layout=True)
ax01 = ex01_fig.add_subplot(111, projection='3d')
ex01_cuefig = CueFigure(ex01_fig) 

theta = np.linspace(0, 2 * np.pi, 20)
w = np.linspace(-0.5, 0.5, 10)
theta, w = np.meshgrid(theta, w)
R = 1
x = (R + w * np.cos(theta / 2)) * np.cos(theta)
y = (R + w * np.cos(theta / 2)) * np.sin(theta)
z = w * np.sin(theta / 2)

ex01_xyz =  np.array([x,y,z]).T

def update_01(code_exercise):    
    alpha, beta, gamma = code_exercise.parameters.values()
    cue_figure = code_exercise.cue_outputs[0]
    ax = cue_figure.figure.get_axes()[0]
    rot = rotation_matrix(alpha,beta,gamma)
    (x,y,z) = (ex01_xyz@rot.T).T
    ax.set_xlim([-2,2])
    ax.set_ylim([-2,2])
    ax.set_zlim([-2,2])
    dax = 2*np.eye(3)@rot.T    
    ax.quiver(0,0,0,*(dax[0]),color='r')
    ax.quiver(0,0,0,*(dax[1]),color='g')
    ax.quiver(0,0,0,*(dax[2]),color='b')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.plot_surface(x, y, z, color='gray')
    
    ax.set_aspect('auto')
    cue_figure.figure.subplots_adjust(left=0.0, right=1, top=1, bottom=0.0)

ce01 = CodeExercise(
            parameters=ex01_pb,
            cue_outputs = [ex01_cuefig],
            update_func = update_01,
            update_mode="continuous")

display(ce01)
ce01.run_update()

## Cartesian rotations

In practical terms, a rotation operator $\hat{R}$ parameterized by the Euler angles acts on a 3D object by applying the corresponding rotation matrix $\mathbf{R}$ to the Cartesian coordinates of all its points: if a structure $A$ has atomic positions $\mathbf{r}_i$ (each of these being a 3-vector corresponding to the Cartesian coordinates $(x,y,z)$) then the rotated structure $\hat{R}A$ has atomic coordinates $\mathbf{R}\mathbf{r}_i$. The same transformation is applied to all the properties $\mathbf{y}$ of $A$ that have a vectorial character, e.g. the dipole moment - so that the dipole of $\hat{R}A$ is $\mathbf{R}\mathbf{y}$. 

Let's consider a dataset that contains a few organic molecules, and for each of them the computed dipole noment $\boldsymbol{\mu}$ and polarizability $\boldsymbol{\alpha}$. These are structures from the "showcase" dataset from ([Yang et al. (2019)](http://doi.org/10.1038/s41597-019-0157-8)).

In [None]:
frames_alphamu = ase.io.read('data/showcase.xyz', ":")

In [None]:
dipoles_show = chemiscope.ase_vectors_to_arrows(frames_alphamu, "dipole_ccsd", scale=0.5)
dipoles_show["parameters"]["global"]["color"]="0xff8000"

alphas_show = chemiscope.ase_tensors_to_ellipsoids(frames_alphamu, "ccsd_pol", scale=0.2)
alphas_show["parameters"]["global"]["color"]="0xff0080"

In [None]:
chemiscope.show(frames=frames_alphamu, 
                shapes={
                    "mu": dipoles_show,
                    "alpha": alphas_show
                       },
                mode="structure",
               settings=chemiscope.quick_settings(structure_settings={"shape":"mu"})
               )

In [None]:
ex02_wci = CodeInput(
        function_name="rotate_atoms", 
        function_parameters="positions, dipole, rotm",
        docstring="""takes the positions and dipole of a structure and transforms them
        according to the given rotation matrix 
        
        :param positions: a (n_atoms,3) array containing the atomic positions
        :param dipole: a (3) array containing the dipole components
        :param rotm: a (3,3) array containing the rotation matrix
        
        :returns: (positions, dipole) - a tuple containing the transformed positions and dipole
""",
        function_body="""

# NB: be careful with how you can apply the rotations to the positions array

new_positions = positions.copy()
new_dipole = dipole.copy()

# Apply the rotation here. Be careful with the shape and layout of the arrays

return new_positions, new_dipole
"""
        )

In [None]:
def update_02(code_exercise):
    output = code_exercise.cue_outputs[0]
    output.clear_output()
    rots = []
    f = frames_alphamu[0]
    for r in np.pi*np.array([[0,0,0],[0,0.125,0],[0,0.250,0],[0,0.375,0],[0,0.5,0],
                             [0.125,0.5,0],[0.25,0.5,0],[0.375,0.5,0],[0.5,0.5,0],
                             [0.5,0.5,0.125],[0.5,0.5,0.25],[0.5,0.5,0.375],[0.5,0.5,0.5]]):
        nf = deepcopy(f)
        nf.positions, nf.info["dipole_ccsd"] = ex02_wci.get_function_object()(
            f.positions, f.info["dipole_ccsd"], rotation_matrix(*r) )
        rots.append(nf)
    with output:
        dipoles_show = chemiscope.ase_vectors_to_arrows(rots, "dipole_ccsd", scale=0.5)
        dipoles_show["parameters"]["global"]["color"]="0xff8000"
        cs=chemiscope.show(rots, shapes={
                    "mu": dipoles_show,
                },
                mode="structure",
               settings=chemiscope.quick_settings(structure_settings={"shape":"mu", 
                                                                      "keepOrientation":True})
                          )
        cs.save("module_02-dipole_rotations.chemiscope.json.gz")
        display(cs)

ex02_reference_input = [{'positions':np.array([[0.,0,1],[1,2,0],[3,2,-1]]), 
                         'dipole':np.array([5.,6,7]),
                         'rotm': np.eye(3)},
                       {'positions':np.array([[0.,0,1],[1,2,0],[3,2,-1]]), 
                         'dipole':np.array([5.,6,7]),
                         'rotm': rotation_matrix(0,np.pi/2,0)}]
ex02_reference_output = [(np.array([[0.,0,1],[1,2,0],[3,2,-1]]),np.array([5.,6,7])),
                         (np.array([[ 1.00000000e+00,  0.00000000e+00,  2.22044605e-16],
         [ 2.22044605e-16,  2.00000000e+00, -1.00000000e+00],
         [-1.00000000e+00,  2.00000000e+00, -3.00000000e+00]]),
  np.array([ 7.,  6., -5.]))]
ex02_code_demo = CodeExercise(
    code= ex02_wci,
    check_registry=check_registry,
    cue_outputs = [CueObject()],
    update_func = update_02,
    exercise_key="02",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 02: Moelcular rotations",
    exercise_description=mdwn("""
Implement a function that gets positions and dipole of a molecule and rotate them according to
the provided rotation matrix.
""")
)

check_registry.add_check(ex02_code_demo,
    asserts= [
        assert_type,
        assert_shape,
        assert_numpy_allclose,
    ],
     inputs_parameters=ex02_reference_input,
     outputs_references =ex02_reference_output)

In [None]:
display(ex02_code_demo)

[Download chemiscope datafile](./module_02-dipole_rotations.chemiscope.json.gz)

There are however more complicated properties than those transforming as vectors. Go back to the dataset viewer, and change the visualizer settings to display the polarizability `alpha`. The polarizability describes the second order response of the energy of a molecule to an applied electric field, i.e.

$$
\alpha_{ab} = \frac{\partial^2 U}{\partial E_a \partial E_b}
$$

It is therefore a _tensor_ labeled by two Cartesian indices. In order to see how it transforms under rotations, you should consider that a rotation would affect the relation of the reference frame of the molecule to that of _both_ electric field vectors, so one needs to apply _two_ rotation matrices,

$$
\boldsymbol{\alpha}(\hat{R}A) = \mathbf{R}\boldsymbol{\alpha}(A)\mathbf{R}^T
$$

In [None]:
ex03_wci = CodeInput(
        function_name="rotate_atoms_pol", 
        function_parameters="positions, alpha, rotm",
        docstring="""takes the positions and polarizability of a structure and transforms 
        them according to the given rotation matrix 
        
        :param positions: a (n_atoms,3) array containing the atomic positions
        :param alpha: a (3,3) matrix containing the polarizability
        :param rotm: a (3,3) array containing the rotation matrix
        
        :returns: (positions, alpha) - a tuple containing the transformed positions and polarizability
""",
        function_body="""

# NB: be careful with how you can apply the rotations to the positions array

new_positions = positions.copy()
new_alpha = alpha.copy()

# Apply the rotation here. Be careful with the shape and layout of the arrays

return new_positions, new_alpha
"""
        )

In [None]:
def update_03(code_exercise):
    output = code_exercise.cue_outputs[0]
    output.clear_output()
    rots = []
    f = frames_alphamu[0]
    for r in np.pi*np.array([[0,0,0],[0,0.125,0],[0,0.250,0],[0,0.375,0],[0,0.5,0],
                             [0.125,0.5,0],[0.25,0.5,0],[0.375,0.5,0],[0.5,0.5,0],
                             [0.5,0.5,0.125],[0.5,0.5,0.25],[0.5,0.5,0.375],[0.5,0.5,0.5]]):
        nf = deepcopy(f)
        pol = nf.info["ccsd_pol"]
        pol = np.array([[pol[0], pol[3], pol[4]],[pol[3], pol[1], pol[5]],[pol[4], pol[5], pol[2]]])
        nf.positions, pol = ex03_wci.get_function_object()(
            nf.positions, pol, rotation_matrix(*r) )
        nf.info["ccsd_pol"][:] = [ pol[0,0], pol[1,1], pol[2,2], pol[0,1], pol[0,2], pol[1,2]] 
        rots.append(nf)
    with output:
        alphas_show = chemiscope.ase_tensors_to_ellipsoids(rots, "ccsd_pol", scale=0.2)
        alphas_show["parameters"]["global"]["color"]="0xff0080"
        cs=chemiscope.show(rots, shapes={
                    "alpha": alphas_show,
                },
                mode="structure",
               settings=chemiscope.quick_settings(structure_settings={"shape":"alpha", 
                                                                      "keepOrientation":True})
                          )
        cs.save("module_02-alpha_rotations.chemiscope.json.gz")
        display(cs)

ex03_reference_input = [{'positions':np.array([[0.,0,1],[1,2,0],[3,2,-1]]), 
                         'alpha':np.array([[5.,1,1],[1,3,0],[1,0,4]]),
                         'rotm': np.eye(3)},
                       {'positions':np.array([[0.,0,1],[1,2,0],[3,2,-1]]), 
                         'alpha':np.array([[8.,-1,1],[-1,9,0],[1,0,4]]),
                         'rotm': rotation_matrix(0,np.pi/2,0)}]
ex03_reference_output = [(np.array([[ 0.,  0.,  1.],
         [ 1.,  2.,  0.],
         [ 3.,  2., -1.]]),
  np.array([[5., 1., 1.],
         [1., 3., 0.],
         [1., 0., 4.]])),
 (np.array([[ 1.00000000e+00,  0.00000000e+00,  2.22044605e-16],
         [ 2.22044605e-16,  2.00000000e+00, -1.00000000e+00],
         [-1.00000000e+00,  2.00000000e+00, -3.00000000e+00]]),
  np.array([[ 4.00000000e+00, -2.22044605e-16, -1.00000000e+00],
         [-2.22044605e-16,  9.00000000e+00,  1.00000000e+00],
         [-1.00000000e+00,  1.00000000e+00,  8.00000000e+00]]))]

ex03_code_demo = CodeExercise(
    code= ex03_wci,
    check_registry=check_registry,
    cue_outputs = [CueObject()],
    update_func = update_03,
    exercise_key="03",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 03: Tensor rotations",
    exercise_description=mdwn("""
Implement a function that gets positions and polarizability of a molecule and rotates 
them according to the provided rotation matrix.
""")
)

check_registry.add_check(ex03_code_demo,
    asserts= [
        assert_type,
        assert_shape,
        assert_numpy_allclose,
    ],
     inputs_parameters=ex03_reference_input,
     outputs_references =ex03_reference_output)

In [None]:
display(ex03_code_demo)

[Download chemiscope datafile](./module_02-alpha_rotations.chemiscope.json.gz)

## Rotating tensors

The action of rotations on a Cartesian tensorial quantity can always be formulated as a matrix-vector multiplication, by "unrolling" the tensor and combining multiple rotation matrices together, e.g.

$$
\alpha_{(ab)} = \sum_{(a'b')} R_{(ab)(a'b')} \alpha_{(a'b')}
$$

where $R_{(ab)(a'b')}=R_{aa'}R_{bb'}$.  

In the visualization below you can see how the elements of the combined rotation matrix change with the Euler angles. 

In [None]:
ex04_pb =  ParameterPanel(
    alpha = FloatSlider(value=0,min=-np.pi,max=np.pi,step=0.01,description=r'$\alpha$'),
    beta = FloatSlider(value=0,min=0,max=np.pi,step=0.01,description=r'$\beta$'),
    gamma = FloatSlider(value=0,min=-np.pi,max=np.pi,step=0.01,description=r'$\gamma$'))

ex04_fig = plt.figure(tight_layout=True)
ax04 = ex04_fig.add_subplot(111)
ex04_cuefig = CueFigure(ex04_fig) 

ex04_cbar = None
def update_04(code_exercise):
    global ex04_cbar
    alpha, beta, gamma = code_exercise.parameters.values()
    cue_figure = code_exercise.cue_outputs[0]
    ax = cue_figure.figure.get_axes()[0]
    rot = rotation_matrix(alpha,beta,gamma)
    ROT = np.einsum("ab,cd->acbd",rot, rot).reshape(9,9)
    
    fig = code_exercise.cue_outputs[0].figure
    ax = fig.get_axes()[0]

    cax=ax.matshow(ROT, cmap='seismic', vmin=-1, vmax=1)
    if ex04_cbar is None:
        ex04_cbar = fig.colorbar(cax, ax=ax, orientation='vertical' )
    else:
        ex04_cbar.update_normal(cax)

    ax.set_xlabel("(ab)")
    ax.set_ylabel("(a'b')")    

In [None]:
ex04_code_demo = CodeExercise(
            parameters= ex04_pb,            
            cue_outputs = [ex04_cuefig],
            update_func = update_04,
    update_mode="continuous",
    #exercise_key="04",
    #exercise_registry=exercise_registry,
    exercise_title="Exercise 04",
    exercise_description=mdwn("")
)

display(ex04_code_demo)
ex04_code_demo.run_update()

In [None]:
ex04_txt = TextExercise(
    exercise_description="""
Play around with the parameters. Does the matrix have any obvious structure? 
How many multiplications would you have to perform to rotate the polarizability 
using the tensorial (two rotations) notation? And how about the combined (single matrix)
case?""",
    exercise_registry=exercise_registry,
    exercise_key="04",
    exercise_title=""
)
display(ex04_txt)

In [None]:
ex04b_txt = TextExercise(
    exercise_description=mdwn(r"""
Given we have seen there are different ways to transform a tensor such as the 
polarizability, we can wonder if there are _better_ ways to apply a rotation to a tensor.
For example, consider the _trace_ of the polarizability, $\alpha_{xx}+\alpha_{yy}+\alpha_{zz}$.
How does it transform under rotations?"""),
    exercise_registry=exercise_registry,
    exercise_key="04b",
    exercise_title="Polarizability trace"
)
display(ex04b_txt)

## Irreducible representations and spherical harmonics

[Spherical harmonics]() are special functions of the polar angles $(\theta, \phi)$ that can be obtained as the orthogonal solutions of the Laplacian eigenvalue problem on the sphere $\nabla^2 Y^m_l(\theta, \phi) = \epsilon_l Y^m_l(\theta, \phi)$. 
Much as for rotations, there are many different 

In [None]:
from sphericart import SphericalHarmonics
sph = SphericalHarmonics(l_max=8)

In [None]:
# Define the resolution of the sphere
num_points = 40

# Create the angles
theta = np.linspace(0, np.pi, num_points)
phi =   np.linspace(0, 2*np.pi, 2*num_points)
theta, phi = np.meshgrid(theta, phi)

# Convert to Cartesian coordinates
x = np.sin(theta) * np.cos(phi)
y = np.sin(theta) * np.sin(phi)
z = np.cos(theta)

ex05_xyz =  np.array([x,y,z])

# Define a function on the sphere
# Example function: cos(theta) + sin(phi)
ex05_sph = sph.compute(ex05_xyz.T.reshape(-1,3)).reshape(theta.shape[1], theta.shape[0], -1).T

In [None]:
ex05_mslider = IntSlider(value=0,min=0,max=8,step=1,description=r'$m$')
ex05_pb =  ParameterPanel(
    l = IntSlider(value=1,min=0,max=8,step=1,description=r'$l$'),
    m = ex05_mslider
)

ex05_fig = plt.figure(tight_layout=True)
ax05 = ex05_fig.add_subplot(111, projection='3d')
ex05_cuefig = CueFigure(ex05_fig) 

def update_05(code_exercise):    
    l, m = code_exercise.parameters.values()
    # updates the range of the m slider
    ex05_mslider.min = -l 
    ex05_mslider.max = l
    if m>l: 
        m=l
    if m<-l:
        m=-l
    cue_figure = code_exercise.cue_outputs[0]
    ax = cue_figure.figure.get_axes()[0]
    ax.set_xlim([-1.2,1.2])
    ax.set_ylim([-1.2,1.2])
    ax.set_zlim([-1.2,1.2])
    
    color_map = lambda x:  mpl.colormaps['seismic']((x-x.min())/(1e-15+x.max()-x.min()))
    x,y,z = ex05_xyz
    lm=l*l+l+m
    # Plot the sphere with colors
    ax.plot_surface(x, y, z, rstride=1, antialiased=True,
                    cstride=1, shade=True, facecolors=color_map(ex05_sph[lm]) )
    ax.set_axis_off()
    ax.set_aspect('auto')
    cue_figure.figure.subplots_adjust(left=0.0, right=1, top=1, bottom=0.0)
    
ce05 = CodeExercise(
            parameters=ex05_pb,
            cue_outputs = [ex05_cuefig],
            update_func = update_05,
            update_mode="continuous")

display(ce05)
ce05.run_update()

In [None]:

    rot = rotation_matrix(1,1,1)
    ROT = np.einsum("ab,cd->acbd",rot, rot).reshape(9,9)

In [None]:
ex05_xyz.T.shape

In [None]:
ex05_sph.shape

In [None]:
len(phi)

In [None]:
from rascaline.utils import cartesian_to_spherical
from metatensor import TensorMap, TensorBlock, block_from_array
import metatensor as mtt

In [None]:
mtr = block_from_array(ROT.reshape(1,3,3,3,3,1))

In [None]:
tm = TensorMap(blocks=[mtr], keys=mtt.Labels.single())

In [None]:
tm.block(0)

In [None]:
tms=cartesian_to_spherical(tm, components=['component_1', 'component_2'])

In [None]:
tms.blocks()

In [None]:
tms.keys.values

In [None]:
keyLabels(["o3_lambda1", "o3_sigma1"], tms.keys.values[:,:2])

In [None]:
tms.block(0).components

In [None]:
tms = TensorMap(blocks=[
    TensorBlock(values=b.values, samples=b.samples, components=[Labels(
        ["o3_mu1"], b.components[0].values), b.components[1], b.components[2]],
               properties=b.properties)
    for b in tms.blocks()
], 
keys=Labels(["o3_lambda1", "o3_sigma1"], tms.keys.values[:,:2]))

In [None]:
tms.block(1)

In [None]:
tmss=cartesian_to_spherical(tms, components=['component_3', 'component_4'])

In [None]:
tmss

In [None]:
tmss.block(2).values

In [None]:
tms.block(2).values

In [None]:
tms.block(0).values

In [None]:
wigner_d_real(0, 0.1, 0.5, 0.3)

## Equivariance: a primer

_Equivariance_ indicates the property of a function for which the inputs and outputs are subject to the action of the same symmetries, and which commutes with the application of the symmetries, that is: $f(\hat{S}A) = \hat{S} f(A)$. _Invariance_ can be seen as a special case, in which  $f(\hat{S}A) = f(A)$.
This section focuses in particular on the case of 3D rotations and inversion - in technical terms the $O(3)$ group symmetries - and their combination with translations - the three-dimensional Euclidean group $E(3)$. 

In [None]:
from IPython.display import display, Javascript

script = """
document.querySelectorAll('.jp-Cell-inputWrapper').forEach(function(element) {
    element.style.display = ''; // 'none' to hide
});
"""
display(Javascript(script))

In [None]:
# set CSS style for code-hide
scwidgets.get_css_style()

In [None]:
exercise_registry = ExerciseRegistry(filename_prefix="module_02")
exercise_registry

In [None]:
check_registry = CheckRegistry()
check_registry

In [None]:
show_script = """
document.querySelectorAll('.jp-Cell-inputWrapper').forEach(function(element) {
    element.style.display = '';
});
"""
display(Javascript(show_script))

In [None]:
module_summary = TextExercise(
    exercise_description="""You can use this box to make general considerations, 
    or keep track of your doubts and questions about this notebook.""",
    exercise_registry=exercise_registry,
    exercise_title="Module comments",
    exercise_key="00"
)
display(module_summary)

Let's consider this dataset, which contains a collection of configurations for a single water molecule. The configurations are generated by distorting an equilibrium configuration along the bending mode, and the asymmetric stretching coordinate. Each frame contains also the energy and dipole moment, computed with the Partridge-Schwenke monomer potential ([Partridge, Schwenke, J. Chem. Phys. (1997)](http://doi.org/10.1063/1.473987)). 

In [None]:
h2o_frames = ase.io.read("data/water_energy-dipole.xyz", ":")

h2o_energy = np.zeros(len(h2o_frames))
h2o_dipole = np.zeros((len(h2o_frames),3))
h2o_force = np.zeros((len(h2o_frames),3,3))
for fi, f in enumerate(h2o_frames):
    h2o_energy[fi] = f.info['energy']
    h2o_dipole[fi] = f.info['dipole']
    h2o_force[fi] = f.arrays['force']

<a id="data-driven"> </a>

# Descriptors of atomic environments in supercooled iron

As a first example we consider a structure which is cut ouf of a simulation of freezing iron ([Shibuta et al., Acta Mater. (2016)](https://www.sciencedirect.com/science/article/abs/pii/S1359645415301397)).
The snapshot contains a few solid nuclei embedded in a supercoled liquid.

We will use this structure to define atom-centered descriptors, and perform principal component analysis to color atoms based on whether they are in liquid or solid regions. 

Let's start by taking a look at the structure. Note that, to make the notebook fast enough, this is carved out of a larger structure, and so it is not periodic in the $x,y$ directions. 

In [None]:
frame_iron = ase.io.read("data/iron-snapshot.xyz", 0)

# requires running in a jupyter notebook, and takes a while to load - it's > 100k atoms.
cs = chemiscope.show(frames=[frame_iron], mode="structure", 
                     settings={"structure": [ {"bonds": False, "unitCell": True, 
                             } ] },)
display(cs)

## Atom-centered environments

A first important consideration is that we are looking at an individual configuration, and that we want to identify atomic structures _within_ this configuration - distinguishing liquid regions, crystalline nuclei, and ideally the interfacial regions.

<center><img src="figures/environments.png" width="500"/></center>

One way to do this is to look at atomic _environments_ i.e. spherical atom-centered regions that we can describe in terms of the collection of interatomic distance vectors around each atom. You can look at the environments for the frame 

In [None]:
# requires running in a jupyter notebook, and takes a while to load - it's > 100k atoms.
sel_env_idx = np.array([29030, 55650, 99980, 97370, 19570, 125940])
sel_env_idx.sort()
cs = chemiscope.show(frames=[frame_iron], mode="structure", 
                     settings={"structure": [ 
                         {"bonds": False, "unitCell": False,
                          "keepOrientation": True,
                     'environments': {'activated': False, 'center': False}}] 
                              },  
                     environments=[[0,s,5.0] for s in sel_env_idx ] ,                     
                    )

In [None]:
def update_co(code_exercise):
    cutoff = code_exercise.parameters["cutoff"]
    showenv = code_exercise.parameters["showenv"]
    cs.settings={"structure": [{"environments": {'activated': showenv, 
                                                 'center':showenv,
                                                 "cutoff":cutoff}}]}
cs_wp = ParameterPanel(
    showenv=Checkbox(value=False, description="show environments"),
    cutoff=FloatSlider(value=5.,min=2,max=8,step=0.25, description=r"cutoff / Å"),    
)
cue_cs = CueObject()
with cue_cs:
    display(cs)
    
cs_demo = CodeExercise(
            parameters=cs_wp,
            cue_outputs = cue_cs,
            update_func = update_co,
            update_mode="release")
display(cs_demo)
cs_demo.run_update()

In [None]:
ex01_txt = TextExercise(
    exercise_description="""
It is always a good idea to take a good look at the data you are working with. 
Just play around with the viewer, look at the structure. 
What kind of features can you note by just observing the arrangement of the atoms?
Now switch on the environment view and use the atom slider to highlight a few select ones.
Can you easily recognize individual environments as liquid-like or solid-like?
    """,
    exercise_registry=exercise_registry,
    exercise_key="01",
    exercise_title="Exercise 01: What am I looking at?"
)
display(ex01_txt)

## Representations

Having taken the decision of focusing on atomic _environments_ for a structure $A$, that we will indicate as $A_i$, we need to come up with an appropriate way to encode information on the positions and types of _neighbors_ within the environment, $\{(a_j, \mathbf{r}_{ji})\}$.

<center><img src="figures/requirements.png" width="400"/></center>

In practice, we want to map to a vector of descriptors, or features $A_i\rightarrow\boldsymbol\xi(A_i)$. It is desirable to use a a mapping that fulfills a number of basic mathematical requirements: 

1. **Locality** (that is already satisfied by the use of atom-centered environments with a finite cutoff)
2. **Completeness** (two environments that are inequivalent should have different feature vectors)
3. **Smoothness** (the mapping between Cartesian coordinates and features should be differentiable, and "regular")
4. **Symmetry** (the mapping should be independent of rigid translations, rotations and permutation of atom indices

It is clear that $\{\mathbf{r}_{ij}\}$ fulfills (1) and (2), but is not smooth (the number of vectors change when atoms enter or leave the cutoff) and is only symmetric to translations. Using interatomic _distances_ $r_{ij}=|\mathbf{r}_{ij}|$ easily makes the features invariant to rotations, but are still dependent on the ordering of the atoms. 

Let's now try to build an invariant descriptor: a _histogram_ of the distances, discretized on a real-space grid. We use a kernel-density estimation, and include a _cutoff function_ to smoothly send contributions to zero as atoms approach the cutoff distance:

$$
\xi_k(A_i) = \sum_{j\in A_i} g(k - r_{ij}/\Delta_r) f_\mathrm{cut}(r_{ij})
$$

where $g(\cdot)$ is a Gaussian with zero mean and unit variance, and $\Delta_r=r_\mathrm{cut}/n_\mathrm{grid}$ is the resolution of the real-space grid, and 
$f_\mathrm{cut}(r_{ij})=1+\cos \pi r_{ij}/r_\mathrm{cut}$.

In the following exercise you will be asked to implement this radial distribution fingerprint, and the exercise will compute and display it for the six environments visualized in the viewer for exercise 1. 

In [None]:
ex02_wci = CodeInput(
        function_name="radial_fp", 
        function_parameters="rij_list, rcut, ngrid",
        docstring="""
        compute a radial distribution fingerprint using a kernel density estimation in real-space
        
        
        :param rij_list: a list of interatomic distances for an environment
        :param rcut: cutoff distance
        :param ngrid: number of grid points and size of the feature vector
        
        :returns: a vector with the radial fistribution features computed for the given environment
""",
        function_body="""

import numpy as np
rgrid = np.linspace(0, rcut, ngrid)
feats = np.zeros(shape=rgrid.shape)

### ADD THE CALCULATION OF THE FEATURES HERE ###

return feats
"""
        )


In [None]:
# makes neighbor list for the six selected environments (ASE is too slow to be usable for this box)
max_cutoff = 8
px = frame_iron.positions
az = frame_iron.cell[2,2]
nl_idx = []
nl_dx = []
nl_dr = []
for isel in sel_env_idx:
    dx = px - px[isel]
    dx[:,2] /= az  # pbc along z
    dx[:,2] -= np.round(dx[:,2])
    dx[:,2] *= az
    dr = np.sqrt((dx**2).sum(axis=1))
    iw = np.where(dr<max_cutoff)[0]
    nl_idx.append(iw)
    nl_dx.append(dx[iw])
    nl_dr.append(dr[iw])

In [None]:
ex02_img = mpl.image.imread('figures/selected-env.jpg')
def update_02(code_exercise):
    rcut, ngrid = code_exercise.parameters.values()
    ax, aximg = code_exercise.cue_outputs[0].figure.get_axes()
    aximg.imshow(ex02_img)
    aximg.axis('off') 
    rgrid = np.linspace(0, rcut, ngrid)
    for dr, l in zip(nl_dr, ["A", "B", "C", "D", "E", "F"]):
        ygrid = ex02_wci.get_function_object()(dr, rcut, ngrid)    
        ax.plot(rgrid, ygrid,label=l)
    # ax.text(-4,8,f'$\ell = ${l:.3f}')
    ax.set_xlabel(r'$r$ / Å')
    ax.set_ylabel(r'$\xi$')
    ax.legend()

ex02_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    ngrid = IntSlider(value=10,min=5,max=20,description=r'$n_{grid}$') )

In [None]:
ex02_figure, ex02_ax = plt.subplots(1, 2, figsize=(8,5), tight_layout=True)
ex02_output = CueFigure(ex02_figure)
ex02_ax[1].imshow(ex02_img)
ex02_ax[1].axis('off') 

ex02_code_demo = CodeExercise(
            code= ex02_wci,
            parameters= ex02_pb,
            check_registry=check_registry,
            cue_outputs = [ex02_output],
            update_func = update_02,
    exercise_key="02",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 02: Radial distribution fingerprints",
    exercise_description="""
Implement a function that computes a radial distribution fingerprint given the list
of distances for an environment, a cutoff and the number of grid points. 
You should implement the exact functional form given above, if you want checks to pass, 
but of course you're also encouraged to try something different!
"""
)

ex02_ref_input = [{"rij_list": np.array([1,3,4]), "rcut": 5, "ngrid": 10},
                 {"rij_list": np.array([5,7,8]), "rcut": 6, "ngrid": 4}
                 ]
ex02_ref_output = [(np.array([0.0331333 , 0.82091159, 1.72185318, 0.30631179, 0.06605532,
       0.56761871, 0.47532342, 0.21108182, 0.08683   , 0.00349805])/2,),
                  (np.array([2.00234259e-06, 2.45588889e-03, 8.87637079e-02, 2.56310425e-01])/2,)
                  ]

check_registry.add_check(ex02_code_demo,
    asserts=[
        assert_type,
        assert_shape,
        assert_numpy_allclose,
    ],
    inputs_parameters=ex02_ref_input,
    outputs_references=ex02_ref_output
)
                         
#                         inputs_parameters=ex_08_ref_input,
#                         reference_outputs = ex_08_ref_output,
#                         equal=ex08_chk,
#                        fingerprint=identity)

display(ex02_code_demo)

In [None]:
ex02b_txt = TextExercise(
    exercise_description="""
Experiment with different grid resolutions, cutoff radius, etc. 
Can you recognize clear-cut differences between liquid-like and solid-like environments?
    """,
    exercise_registry=exercise_registry,
    exercise_key="02b",
    exercise_title="Exercise 02b: Resolving power of radial fingerprints."
)
display(ex02b_txt)

## Atom-centered symmetry functions

This set of radial features can be seen as a special case of so-called _atom-centered symmetry functions_ (ACSFs), one of the first types of representations used e.g. by [Behler and Parrinello](http://doi.org/10.1103/PhysRevLett.98.146401). 


<center><img src="figures/radial-acsf.png" width="500"/><br/>
<i> Representative examples of radial symmetry functions.</i><br/><br/>
</center>


ACSFs are designed as bespoke functions $\phi_k$ of the internal coordinates of the environment, accumulated over neighbors to achieve invariance to atom index permutations.
They can be generalized to also include functions of distances and angles (3-body symmetry functions) and can be tuned to focus on the structural features that are most discriminating, or most straightforwardly related to the structure-property relations one is trying to learn.
Radial (two-body) symmetry functions take the form

$$
\xi_k(A_i) = \sum_{j\in A_i} \phi_k(r_{ij}) f_\mathrm{cut}(r_{ij})
$$

where $\phi_k$ has typically a parametric form, or enumerates a set of orthogonal basis functions. 

In [None]:
ex03_wci = CodeInput(
        function_name="radial_acsf", 
        function_parameters="rij_list, rcut, delta, rs",
        docstring="""
        compute a radial distribution fingerprint using a kernel density estimation in real-space
        
        
        :param rij_list: a list of interatomic distances for an environment
        :param rcut: cutoff distance
        :param delta: the smearing of the Gaussian ACSF
        :param rs: the center of the Gaussian ACSF
        
        :returns: a float containing the value of the ACSF for the environment
""",
        function_body="""

import numpy as np

acsf = 0.0
### ADD THE CALCULATION OF THE ACSF VALUE HERE ###

return acsf
"""
        )

In [None]:
ex03_img = mpl.image.imread('figures/selected-env.jpg')
def update_03(code_exercise):
    rcut, delta, rs = code_exercise.parameters.values()
    ax, aximg = code_exercise.cue_outputs[0].figure.get_axes()
    aximg.imshow(ex02_img)
    aximg.axis('off') 
    rgrid = np.linspace(0, rcut, 100)
    ygrid = np.zeros_like(rgrid)
    for ir, r in enumerate(rgrid):
        ygrid[ir] = ex03_wci.get_function_object()([r], rcut, delta, rs)    
    ax.plot(rgrid, ygrid, 'r-')
    # ax.text(-4,8,f'$\ell = ${l:.3f}')
    ax.set_xlabel(r'$r$ / Å')
    ax.set_ylabel(r'$\phi_k(r)$')

    labels = []
    for dr, l in zip(nl_dr, ["A", "B", "C", "D", "E", "F"]):
        acf = ex03_wci.get_function_object()(dr, rcut, delta, rs) 
        labels.append(f"{l}: {acf:9.4f}")
    aximg.legend(handles=[mpl.patches.Patch(color="w", )]*6, labels=labels,
                 handlelength=0.1, loc='lower left')

ex03_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    delta = FloatSlider(value=0.5,min=0.1,max=2,step=0.1,description=r'$\Delta$ / Å'),
    rs = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_s$ / Å'),
    )

In [None]:
ex03_figure, ex03_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex03_output = CueFigure(ex03_figure)
ex03_ax[1].imshow(ex03_img)
ex03_ax[1].axis('off') 

ex03_code_demo = CodeExercise(
            code= ex03_wci,
            parameters= ex03_pb,
            check_registry=check_registry,
            cue_outputs = [ex03_output],
            update_func = update_03,
    update_mode="manual",
    exercise_key="03",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 03: Radial ACSF",
    exercise_description=mdwn("""
Implement a function that computes a Behler-Parrinello atom-center symmetry function
of the form

$$
\phi_k(r) = \exp\[-(r-r_s)^2/\delta^2\]  f_c(r_s)
$$

using a cosine cutoff function.
""")
)

ex03_ref_input = [{"rij_list": np.array([1,3,4]), "rcut": 5, "delta": 1, "rs":5},
                 {"rij_list": np.array([5,3,8]), "rcut": 6, "delta": 0.5, "rs":4}
                 ]
ex03_ref_output = [(0.04145736008495549,),
                  (0.010384734606641187,)
                  ]

check_registry.add_check(ex03_code_demo,
    asserts=[
        assert_type,
        assert_shape,
        assert_numpy_allclose,
    ],
    inputs_parameters=ex03_ref_input,
    outputs_references=ex03_ref_output
)
                         

display(ex03_code_demo)

In [None]:
ex03_img = mpl.image.imread('figures/selected-env.jpg')
def update_03(code_exercise):
    rcut, delta, rs = code_exercise.parameters.values()
    ax, aximg = code_exercise.cue_outputs[0].figure.get_axes()
    aximg.imshow(ex02_img)
    aximg.axis('off') 
    rgrid = np.linspace(0, rcut, 100)
    ygrid = np.zeros_like(rgrid)
    for ir, r in enumerate(rgrid):
        ygrid[ir] = ex03_wci.get_function_object()([r], rcut, delta, rs)    
    ax.plot(rgrid, ygrid, 'r-')
    # ax.text(-4,8,f'$\ell = ${l:.3f}')
    ax.set_xlabel(r'$r$ / Å')
    ax.set_ylabel(r'$\phi_k(r)$')

    labels = []
    for dr, l in zip(nl_dr, ["A", "B", "C", "D", "E", "F"]):
        acf = ex03_wci.get_function_object()(dr, rcut, delta, rs) 
        labels.append(f"{l}: {acf:9.4f}")
    aximg.legend(handles=[mpl.patches.Patch(color="w", )]*6, labels=labels,
                 handlelength=0.1, loc='lower left')

ex03_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    delta = FloatSlider(value=0.5,min=0.1,max=2,step=0.1,description=r'$\Delta$ / Å'),
    rs = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_s$ / Å'),
    )

In [None]:
ex03_figure, ex03_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex03_output = CueFigure(ex03_figure)
ex03_ax[1].imshow(ex03_img)
ex03_ax[1].axis('off') 

ex03_code_demo = CodeExercise(
            code= ex03_wci,
            parameters= ex03_pb,
            check_registry=check_registry,
            cue_outputs = [ex03_output],
            update_func = update_03,
    update_mode="manual",
    exercise_key="03",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 03: Radial ACSF",
    exercise_description=mdwn("""
Implement a function that computes a Behler-Parrinello atom-center symmetry function
of the form

$$
\phi_k(r) = \exp\[-(r-r_s)^2/\delta^2\]  f_c(r_s)
$$

using a cosine cutoff function.
""")
)

ex03_ref_input = [{"rij_list": np.array([1,3,4]), "rcut": 5, "delta": 1, "rs":5},
                 {"rij_list": np.array([5,3,8]), "rcut": 6, "delta": 0.5, "rs":4}
                 ]
ex03_ref_output = [(0.04145736008495549,),
                  (0.010384734606641187,)
                  ]

check_registry.add_check(ex03_code_demo,
    asserts=[
        assert_type,
        assert_shape,
        assert_numpy_allclose,
    ],
    inputs_parameters=ex03_ref_input,
    outputs_references=ex03_ref_output
)
                         

display(ex03_code_demo)

In [None]:
ex03b_txt = TextExercise(
    exercise_description="""
Observe how the shape of the symmetry function, and its value for the 
various environments, change with its parameters. Try to find values that maximise the difference
between solid-like and liquid-like environments. 
    """,
    exercise_registry=exercise_registry,
    exercise_key="03b",
    exercise_title="Exercise 03b: ACSF."
)
display(ex03b_txt)

## Discretized density expansion

The neighbor density provides a way to characterize the position of atoms in the vicinity of a tagged center. It is instructive to see it built up starting from a structure-level smooth atom density, in which a structure  $A$ is represented in terms of localized functions (e.g. Gaussians) centered on each atom $i$, "labelled" by their chemical nature $a$

$$
\rho_A^a(\mathbf{x}) = \langle a \mathbf{x} | A; \rho\rangle = \sum_{i \in A} \delta_{a a_i} \langle \mathbf{x} | \mathbf{r}_i \rangle.
$$

Note how, by summing over $i$, the identity of atoms of the same species is lost, making the representation invariant to atom labeling. 

We use  the notation $\langle \mathbf{x} | \mathbf{r}_i \rangle = g(\mathbf{x}-\mathbf{r}_i)$ to emphasize how the full structure is built as a sum of terms that describe individual atoms, and how this description can be implemented in any continuous or discrete basis. 

In general terms, in analogy with the Dirac notation used to describe a quantum state, we use  $\langle q | A\rangle$ to indicate a descriptor $| A\rangle$ for an entity $A$, discretized in a basis that is enumerated by the index $q$. 
See Section 3.1 of [this review](https://doi.org/10.1021/acs.chemrev.1c00021) for a gentler introduction. 

This density is then symmetrized with respect to translations (reflecting the fact that atomic properties are invariant to rigid translations of a molecule) which leads to expressing the structure descriptors as a sum of descriptors of _atom centered environments_ $A_i$,


$$
\langle a \mathbf{x} | A; \rho\rangle = \sum_i \langle a \mathbf{x} | \rho_i\rangle
$$

$$
\langle a \mathbf{x} | \rho_i\rangle = \sum_{j \in A_i} \delta_{a a_j} \langle \mathbf{x} | \mathbf{r}_{ji} \rangle.
$$

where the Gaussians are evaluated at the interatomic distance vectors $\mathbf{r}_{ji}=\mathbf{r}_j-\mathbf{r}_i$.

To manipulate this atom-centered density, it is more convenient to express it on a discrete basis. Guided by symmetry considerations, and in analogy with what is done routinely in quantum chemistry for the electron wafefunction (or density) we use a basis of radial functions $R_{nl}(x) \equiv \langle x|nl\rangle$ and [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonics) $Y^m_{l}(\hat{\mathbf{x}}) \equiv \langle \hat{\mathbf{x}}|lm\rangle$

$$
\langle a nlm | \rho_i\rangle = \int \mathrm{d}\mathbf{x} 
 \langle nl| x\rangle  \langle lm| \hat{\mathbf{x}} \rangle
\langle a \mathbf{x} | \rho_i\rangle  
$$

If you are confused by this rather abstract notation, you can also define the neighbor density as an explicit sum over neigbor positions $\{\mathbf{r}_{ij}\}$,

$$
\langle a nlm | \rho_i\rangle \equiv \rho_{nlm}^a(A_i) = \sum_{j\in A_i} Y^m_l(\hat{\mathbf{r}}_{ij}) \tilde{R}_{nl}(r_{ij}),
$$

where the Gaussian smearing of the density has been implemented as a transformation of the radial basis. See [Goscinski et al.](http://doi.org/10.1063/5.0057229) for a derivation of the equivalence of the two expressions. 

In [None]:
ex04_wci = CodeInput(
        function_name="density_expansion", 
        function_parameters="frame, rcut, nmax, lmax, sigma, select_idx",
        docstring="""
        compute a discretization of the neighbor density in terms 
        of radial functions and spherical harmonics. a Gaussian smearing
        is applied to the neighbor density
        
        :param frame: ase.Atoms frame to compute
        :param rcut: cutoff distance
        :param nmax: number of radial functions
        :param lmax: maximum angular momentum
        :param sigma: Gaussian smearing of the density
        :param select_idx: indices for the atoms to use as centers
        
        :returns: a TensorMap containing the density expansion coefficients
""",
        function_body="""
from rascaline import SphericalExpansion
from metatensor import Labels

# parameters of the density expansion
hypers = {
    "cutoff": rcut,
    "max_radial": nmax,
    "max_angular": lmax,
    "atomic_gaussian_width": sigma,
    "cutoff_function": {"ShiftedCosine": {"width": 0.5}}, # type of cutoff and parameters
    "center_atom_weight": 1.0, # weight to include the central atom in the expansion
    "radial_basis": { "Gto": {}, }, # choice of radial basis
}

calculator = SphericalExpansion(**hypers)

if select_idx is None:
    rhoi = calculator.compute(frame)
else:
    rhoi = calculator.compute(frame,
            selected_samples=Labels(names="atom", values=select_idx.reshape(-1,1))
    )

# Uncomment to print some of the metadata associated with the density coefficients
# print(rhoi)
# print(rhoi.block(1))

return rhoi
"""
)

In [None]:
ex04_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    nmax = IntSlider(value=4,min=1,max=8,description=r'$n_{max}$'),
    lmax = IntSlider(value=2,min=1,max=6,description=r'$l_{max}$'),
    sigma = FloatSlider(value=0.5,min=0.3,max=1,step=0.1,description=r'$\sigma$ / Å'),
    environment = Dropdown(value="A", options=["A", "B", "C", "D", "E", "F"],
                          description="environment")
    )

ex04_img = mpl.image.imread('figures/selected-env.jpg')
def combine_l(tmap, i_env):
    feats=[]
    for b in tmap:
        feats.append(b.values[i_env])
    return np.vstack(feats)

ex04_cache_pars = (-1, -1, -1, -1, "")
ex04_cache_value = None
ex04_cbar = None
def update_04(code_exercise):
    global ex04_cache_pars, ex04_cache_value, ex04_cbar
    rcut, nmax, lmax, sigma, env = code_exercise.parameters.values()
    fig = code_exercise.cue_outputs[0].figure
    ax, aximg = fig.get_axes()[:2]
    aximg.imshow(ex04_img)
    aximg.axis('off') 

    if ex04_cache_pars == (rcut, nmax, lmax, sigma, ex04_wci.function_body):
        rhoi = ex04_cache_value
    else:
        rhoi = ex04_wci.get_function_object()(frame_iron, rcut, nmax, lmax, sigma, sel_env_idx)
        ex04_cache_value = rhoi
        ex04_cache_pars = (rcut, nmax, lmax, sigma, ex04_wci.function_body)
    envidx={"A":0, "B":1, "C":2, "D":3, "E":4, "F":5}
    feats = combine_l(rhoi, envidx[env])
    frange = np.max(np.abs(feats))
    norm = mpl.colors.SymLogNorm(vmin=-frange, vmax=frange, linthresh=1e-1)
    
    cax=ax.matshow(feats.T, cmap='seismic', norm=norm)
    if ex04_cbar is None:
        ex04_cbar = fig.colorbar(cax, ax=ax, orientation='horizontal' )
    else:
        ex04_cbar.update_normal(cax)

    ax.set_ylabel("n")
    ax.set_xlabel("(l,m)")
    xticklabels = []
    xtickpos = []
    for l in range(lmax+1):
        ax.add_patch(mpl.patches.Rectangle(
            (-0.5+l**2,-0.5), 2*l+1, nmax,
            edgecolor='black', facecolor='none', linewidth=3
        ))
        xticklabels.append(f"$l={l}$")
        xtickpos.append((l)**2+l)
    ax.set_xticks(xtickpos); ax.set_xticklabels(xticklabels)
    

In [None]:
ex04_figure, ex04_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex04_output = CueFigure(ex04_figure)
ex04_ax[1].imshow(ex04_img)
ex04_ax[1].axis('off') 

ex04_code_demo = CodeExercise(
            code= ex04_wci,
            parameters= ex04_pb,            
            cue_outputs = [ex04_output],
            update_func = update_04,
    update_mode="release",
    exercise_key="04",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 04: Density expansion",
    exercise_description=mdwn("""
This exercise is set up to compute neighbor density expansion coefficients, based
on the specified parameters. You can observe how the density coefficients change 
for different hyperparameters. You can also change other options that are not 
provided as function parameters, but make sure to revert them to the initial values
as this function is also used in what follows.

_NB:_ It takes a few seconds to update the figure when you change the hypers, but 
it should be faster to switch between environments, as the features are computed for
all environments and cached. 
""")
)

display(ex04_code_demo)
ex04_code_demo.run_update()

In [None]:
rhoi = ex04_wci.get_function_object()(frame_iron, 6.0, 4, 2, 0.5, sel_env_idx)

`rascaline` calculators return a `metatensor.TensorMap` object, that works as a container that holds blocks of data. The pattern is reminiscent of a `dict`, but with some more structure and metadata: each block is associated with a _key_, which consists in a tuple of ints. The set of keys is a `metatensor.Labels` object, that also keeps track of the _names_ that describe each index in the key. 

The expansion keys hold the `o3_lambda` and `o3_sigma` indices that correspond to the angular momentum rotational symmetry and the symmetry with respect to inversion ($1$ being associated with the "normal" symmetry of a spherical harmonic. The two following indices correspond to the atomic number of the central atom `center_type` and of the neighbors `neighbor_type` - here 26 for both as we are dealing with pure iron.  

For instance, if you compute a `TensorMap` that contains density coefficients, `rhoi`, you can print an overview of the map content

```python
print(rhoi)
```

In [None]:
print(rhoi)

<center><img src="figures/mtt-tensorblock-components.svg" width="300"/><br/>
<i> Structure of a TensorBlock</i><br/><br/>
</center>

Each entry in a `TensorMap` is a `TensorBlock` object, that contains a dense storage of properties, with three axes that identify the `samples` the properties refer to, their `components` (e.g. the $m$ index in a spherical harmonic of order $l$) and the actual `properties`, that may be associated with further indices that enumerate all possible values (e.g. the radial basis index $n$). 

If you print out one of the blocks, you will get an overview of the associated metadata, e.g.

```python
print(rhoi).block(1)
```

In [None]:
print(rhoi.block(1))

You can see the [metatensor documentation](http://docs.metatensor.org) if you want to find out how to manipulate and access the entries in a `TensorMap` or `TensorBlock`. We will see a few examples further down. 

## Three-body correlations: SOAP features

The $l=0$ part of the density-expansion coefficients $\langle an00|\rho_i\rangle$ corresponds to a discretization of the pair correlation function: using a real-space basis,

$$
    \langle ax00|\rho_i\rangle \approx \sum_{j\in A_i} \delta_{a a_j} \langle x | r_{ji} \rangle 
$$

where $ \langle x | r_{ji} \rangle $ is a localized function centered on $r_{ji}$. 

In order to obtain a richer description of the atomic environment it is possible to combine several copies of $\langle a\mathbf{x} | \rho_i \rangle$, to build $\nu$-neighbors atom-centered density correlations (ACDCs). 
The formalism we use was introduced by [Willatt et al.](https://doi.org/10.1063/1.5090481), and is explained in detail, discussing its relation with the leading frameworks for atomistic machine learning, in a [review by Musil et al.](https://doi.org/10.1021/acs.chemrev.1c00021)

Essentially, the idea is that considering tensor products of the atom density provides simultaneous information on the mutual position of several neighbors

$$
\langle \mathbf{x} |  \rho_i \rangle \langle \mathbf{x}' |  \rho_i \rangle =
\sum_{jj'\in A_i}
\langle \mathbf{x} |\mathbf{r}_{ji} \rangle \langle \mathbf{x}' |\mathbf{r}_{j'i} \rangle. 
$$

The invariant part of this two-neighbor correlation function can be extracted by taking a symmetrized product of the density coefficients,
$$
\langle aa'nn'l|\overline{\rho_i^{\otimes 2}}\rangle \propto 
\sum_m  \langle anlm|\rho_i\rangle \langle a'n'lm|\rho_i\rangle
$$

These are the [SOAP powerspectrum coefficients](http://doi.org/10.1103/PhysRevB.87.184115), that have been widely used in machine-learning models and especially in the context of Gaussian approximation potentials - kernel models based on SOAP features (see [Deringer et al. 2021](http://doi.org/10.1021/acs.chemrev.1c00022) for a review of kernel methods in the field). 

For those confused by the Dirac notation, this can be written using the vectorial notation for the density descriptors, $\rho_{nlm}^a(A_i)$: 

$$
p_{nn'l}^{aa'}(A_i) \propto \sum_m \rho_{nlm}^a(A_i)\rho_{n'lm}^{a'}(A_i)
$$

Using the [addition theorem for spherical harmonics](https://mathworld.wolfram.com/SphericalHarmonicAdditionTheorem.html) it is possible to draw a connection between the SOAP powerspectrum and three-body symmetry functions computed as a sum over pairs of neighbors. Ignoring for simplicity the element indices, 

$$
\langle nn'l|\overline{\rho_i^{\otimes 2}}\rangle \propto
\sum_{jj'\in A_i} \tilde{R}_{nl}(r_{ij}) \tilde{R}_{nl}(r_{ij'}) P_l(\hat{\mathbf{r}}_{ij}\cdot \hat{\mathbf{r}}_{ij'})
$$

One sees that computing SOAP features from the density expansion avoids the double sum over neighbors, at the cost of computing a large number of $m$ components of the spherical harmonics - an idea often referred to as the _density trick_.

`rascaline` provides the infrastructure to evaluate SOAP features using the `SoapPowerSpectrum` calculator. You can see that the syntax is very similar to that for the density expansion features, as those define the discretization level on top of which the power spectrum is computed. 

In [None]:
ex05_wci = CodeInput(
        function_name="soap_ps", 
        function_parameters="frame, rcut, nmax, lmax, sigma, select_idx",
        docstring="""
        compute the SOAP powerspectrum using the rascaline calculator
        
        :param frame: ase.Atoms frame to compute
        :param rcut: cutoff distance
        :param nmax: number of radial functions
        :param lmax: maximum angular momentum
        :param sigma: Gaussian smearing of the density
        :param select_idx: indices for the atoms to use as centers
        
        :returns: a TensorMap containing the SOAP powerspectrum coefficients
""",
        function_body="""
from rascaline import SoapPowerSpectrum
from metatensor import Labels

# parameters of the density expansion
hypers = {
    "cutoff": rcut,
    "max_radial": nmax,
    "max_angular": lmax,
    "atomic_gaussian_width": sigma,
    "cutoff_function": {"ShiftedCosine": {"width": 0.5}}, # type of cutoff and parameters
    "center_atom_weight": 1.0, # weight to include the central atom in the expansion
    "radial_basis": { "Gto": {}, }, # choice of radial basis
}

calculator = SoapPowerSpectrum(**hypers)

if select_idx is None:
    soap = calculator.compute(frame)
else:
    soap = calculator.compute(frame,
            selected_samples=Labels(names="atom", values=select_idx.reshape(-1,1))
    )

# Uncomment to print some of the metadata associated with the density coefficients
# print(soap)
# print(soap.block(0))

return soap
"""
)

In [None]:
ex05_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    nmax = IntSlider(value=4,min=1,max=8,description=r'$n_{max}$'),
    lmax = IntSlider(value=2,min=1,max=6,description=r'$l_{max}$'),
    sigma = FloatSlider(value=0.5,min=0.3,max=1,step=0.1,description=r'$\sigma$ / Å'),
    environment = Dropdown(value="A", options=["A", "B", "C", "D", "E", "F"],
                          description="environment")
    )

ex05_img = mpl.image.imread('figures/selected-env.jpg')

ex05_cache_pars = (-1, -1, -1, -1, "")
ex05_cache_value = None
ex05_cbar = None

def update_05(code_exercise):
    global ex05_cache_pars, ex05_cache_value, ex05_cbar
    rcut, nmax, lmax, sigma, env = code_exercise.parameters.values()
    fig = code_exercise.cue_outputs[0].figure
    ax, aximg = fig.get_axes()[:2]
    aximg.imshow(ex05_img)
    aximg.axis('off') 

    if ex05_cache_pars == (rcut, nmax, lmax, sigma, ex05_wci.function_body):
        soap = ex05_cache_value
    else:
        soap = ex05_wci.get_function_object()(frame_iron, rcut, nmax, lmax, sigma, sel_env_idx)
        ex05_cache_value = soap
        ex05_cache_pars = (rcut, nmax, lmax, sigma, ex05_wci.function_body)
    envidx={"A":0, "B":1, "C":2, "D":3, "E":4, "F":5}
    feats = soap.block(0).values.squeeze()[envidx[env]].reshape(-1, nmax)
    frange = np.max(np.abs(feats))
    
    norm = mpl.colors.SymLogNorm(vmin=-frange, vmax=frange, linthresh=1e-3)
    
    cax=ax.matshow(feats.T, cmap='seismic', norm=norm)
    if ex05_cbar is None:
        ex05_cbar = fig.colorbar(cax, ax=ax, orientation='horizontal' )
    else:
        ex05_cbar.update_normal(cax)
    ax.set_ylabel("n2")
    ax.set_xlabel("(l,n1)")
    xticklabels = []
    xtickpos = []
    for l in range((lmax+1)):
        ax.add_patch(mpl.patches.Rectangle(
            (-0.5+l*nmax,-0.5), (l+1)*nmax, nmax,
            edgecolor='black', facecolor='none', linewidth=3
        ))
        xticklabels.append(f"$l={l}$")
        xtickpos.append((l)*nmax)
    ax.set_xticks(xtickpos); ax.set_xticklabels(xticklabels)

In [None]:
ex05_figure, ex05_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex05_output = CueFigure(ex05_figure)
ex05_ax[1].imshow(ex05_img)
ex05_ax[1].axis('off') 

ex05_code_demo = CodeExercise(
            code= ex05_wci,
            parameters= ex05_pb,            
            cue_outputs = [ex05_output],
            update_func = update_05,
    update_mode="release",
    exercise_key="05",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 05: SOAP features",
    exercise_description=mdwn("""
This exercise is set up to compute SOAP power spectrum coefficients, based
on the specified parameters. You can observe how the values change for different
environments and hyperparameters. You can also change other options that are not 
provided as function parameters, but make sure to revert them to the initial values
as this function is also used in what follows.

_NB:_ It takes a few seconds to update the figure when you change the hypers, but 
it should be faster to switch between environments, as the features are computed for
all environments and cached. Note also the very large difference in magnitude for the
$l=0$ and $l>0$ blocks. 
""")
)

display(ex05_code_demo)
ex05_code_demo.run_update()

In [None]:
soap_values = []
soap_props = []
for key, block in rhoi.items():
    soap_values.append(np.einsum('imn,imp->inp', block.values, block.values
                         ).reshape(len(block.samples),-1)/np.sqrt(2*key['o3_lambda']+1) )
    soap_props.append(  [ [key["o3_lambda"], n, p] for n in block.properties['n'] for p in block.properties['n']] )
    print(block)
soap_values = np.hstack(soap_values)
soap_props = np.vstack(soap_props)
soap_samples = rhoi.block(0).samples
soap_manual = TensorBlock(values=soap_values, 
                          properties=Labels(["l", "n1", "n2"], soap_props), 
                          components=[],
                          samples=soap_samples) 

In [None]:
soap_manual

It is also instructive to compute the powerspectrum manually. 
This requires a bit of bookkeeping, in particular to create
`metatensor` objects. Note also the normalization factor, that is
omitted in the definitions above, but is important to match
the `rascaline`-computed values. You can check the [documentation of
numpy.einsum](XXXXXXXX) if you're not familiar with the function. 

In [None]:
ex06_wci = CodeInput(
        function_name="soap_manual", 
        function_parameters="frame, rcut, nmax, lmax, sigma, select_idx, density_function",
        docstring="""
        compute the SOAP powerspectrum manually, starting from the density 
        expansion coefficients 
        
        :param frame: ase.Atoms frame to compute
        :param rcut: cutoff distance
        :param nmax: number of radial functions
        :param lmax: maximum angular momentum
        :param sigma: Gaussian smearing of the density
        :param select_idx: indices for the atoms to use as centers
        :param density_function: function that computes the density coefficients,
            with signature density_function(frame, rcut, nmax, lmax, sigma, select_idx)
        
        :returns: a TensorBlock containing the SOAP powerspectrum
""",
        function_body="""
import numpy as np
from metatensor import TensorBlock, Labels

rhoi = density_function(frame, rcut, nmax, lmax, sigma, select_idx)

soap_values = []
soap_props = []
for key, block in rhoi.items():    
    # loops over the blocks of the density expansion, corresponding to the l components.
    # each block is indexed by (sample, angular_m, radial_n)
    soap_values.append( # these are the actual SOAP values
                         np.einsum('XXXX',  # enter the correct einsum string here
                           block.values, block.values
                         ).reshape(len(block.samples),-1)*
                         (-1)**key['o3_lambda']/np.sqrt(2*key['o3_lambda']+1) )
    # these are the corresponding l,n,n' indices that are accumulated for convenience
    soap_props.append(  [ [key["o3_lambda"], n, p] for n in block.properties['n'] for p in block.properties['n']] )

# combines the values 
soap_values = np.hstack(soap_values)
soap_props = np.vstack(soap_props)
soap_samples = rhoi.block(0).samples
soap_manual = TensorBlock(values=soap_values, 
                          properties=Labels(["l", "n1", "n2"], soap_props), 
                          components=[],
                          samples=soap_samples) 
                          
return soap_manual
"""
)

In [None]:
ex06_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    nmax = IntSlider(value=4,min=1,max=8,description=r'$n_{max}$'),
    lmax = IntSlider(value=2,min=1,max=6,description=r'$l_{max}$'),
    sigma = FloatSlider(value=0.5,min=0.3,max=1,step=0.1,description=r'$\sigma$ / Å'),
    environment = Dropdown(value="A", options=["A", "B", "C", "D", "E", "F"],
                          description="environment")
    )

ex06_img = mpl.image.imread('figures/selected-env.jpg')

ex06_cache_pars = (-1, -1, -1, -1, "")
ex06_cache_value = None
ex06_cbar = None

def update_06(code_exercise):
    global ex06_cache_pars, ex06_cache_value, ex06_cbar
    rcut, nmax, lmax, sigma, env = code_exercise.parameters.values()
    fig = code_exercise.cue_outputs[0].figure
    ax, aximg = fig.get_axes()[:2]
    aximg.imshow(ex06_img)
    aximg.axis('off') 

    if ex06_cache_pars == (rcut, nmax, lmax, sigma, ex06_wci.function_body):
        soap = ex06_cache_value
    else:
        soap = ex06_wci.get_function_object()(frame_iron, rcut, nmax, lmax, 
                                              sigma, sel_env_idx, ex04_wci.get_function_object())
        ex06_cache_value = soap
        ex06_cache_pars = (rcut, nmax, lmax, sigma, ex06_wci.function_body)
    envidx={"A":0, "B":1, "C":2, "D":3, "E":4, "F":5}
    feats = soap.values.squeeze()[envidx[env]].reshape(-1, nmax)
    frange = np.max(np.abs(feats))
    
    norm = mpl.colors.SymLogNorm(vmin=-frange, vmax=frange, linthresh=1e-3)
    
    cax=ax.matshow(feats.T, cmap='seismic', norm=norm)
    if ex06_cbar is None:
        ex06_cbar = fig.colorbar(cax, ax=ax, orientation='horizontal' )
    else:
        ex06_cbar.update_normal(cax)
    ax.set_ylabel("n2")
    ax.set_xlabel("(l,n1)")
    xticklabels = []
    xtickpos = []
    for l in range((lmax+1)):
        ax.add_patch(mpl.patches.Rectangle(
            (-0.5+l*nmax,-0.5), (l+1)*nmax, nmax,
            edgecolor='black', facecolor='none', linewidth=3
        ))
        xticklabels.append(f"$l={l}$")
        xtickpos.append((l+0.5)*nmax)
    ax.set_xticks(xtickpos); ax.set_xticklabels(xticklabels)

In [None]:
ex06_figure, ex06_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex06_output = CueFigure(ex06_figure)
ex06_ax[1].imshow(ex06_img)
ex06_ax[1].axis('off') 

ex06_code_demo = CodeExercise(
            code= ex06_wci,
            parameters= ex06_pb,            
            cue_outputs = [ex06_output],
            update_func = update_06,
    update_mode="release",
    exercise_key="06",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 06: Manual evaluation of SOAP features",
    exercise_description=mdwn("""
This exercise is set up to compute SOAP power spectrum coefficients _manually_,
starting from the density coefficients (in practice these are computed using the
function in exercise 4. All the bookkeeping and `metatensor` code is already 
implemented, you only need to write the correct `einsum` string to evaluate the 
summation (that should return for each "l" block the features in the order 
`(sample, n1, n2)`.
""")
)

display(ex06_code_demo)

# Automatic identification of environments

In this section we try to come up with a strategy to differentiate between solid-like and liquid-like regions. This is not easy, because the snapshot is taken from a finite-temperature simulation, and even in the solid parts the atomic positions are free to fluctuate. 
We will look at plots for a thin slie within the sample, to keep the cost of computing descriptors (that use the Python functions you defined previously) as low as possible. 

In [None]:
selection = np.sort(np.where((frame_iron.positions[:,0]>max_cutoff+1) & (frame_iron.positions[:,0]<199-max_cutoff) &
                     (frame_iron.positions[:,1]>max_cutoff+1) & (frame_iron.positions[:,1]<199-max_cutoff) & 
                     (frame_iron.positions[:,2]>20) & (frame_iron.positions[:,2]<24)
                    )[0])

In [None]:
nl_code = rascaline.NeighborList(cutoff=6.0, full_neighbor_list=True)

In [None]:
#%%time 
nl_all = nl_code.compute(frame_iron)
#                    selected_samples=Labels(names=["first_atom"], values=selection[:,np.newaxis]))

In [None]:
#%%time
#nl_selected = slice_block(nl_all.block(0),axis="samples", 
#            labels=Labels(names=["first_atom"], values=selection[:,np.newaxis]))
#this is way too slow

In [None]:
#%%time
# this is a better way to do this, exploiting the fact we are looking 
# for a single index match in an integer array
# extract int arrays
labs = np.sort(selection)
samp = np.asarray(nl_all.block(0).samples["first_atom"])
mask = np.zeros(len(samp), dtype=bool)

# sorting order and reverse order
sort_idx = np.argsort(samp)
sort_inv = np.argsort(sort_idx)
sort_samp = samp[sort_idx]

# now blocks with the same atom are contiguous so we can use searchsorted to find the bounds
lower = np.searchsorted(sort_samp,  labs)
upper = np.searchsorted(sort_samp,  labs+1)


for l,u in zip(lower, upper):
    mask[l:u] = True
mask = mask[sort_inv]

In [None]:
# actually we better extract a classical NL
neigh_dx = nl_all.block(0).values[sort_idx][mask].squeeze()
neigh_sz = (upper-lower)
neigh_i = np.cumsum(neigh_sz) -neigh_sz[0]
neigh_dr = np.sqrt((neigh_dx**2).sum(axis=1))

## Manually tuning a descriptor

A first way to define a structural desciptor is to come up with one based on intuition, or trial-and-error. For example, one could take an ACSF, and optimize its parameters to find some that have a high discriminating power. As we shall see, this is not necessarily an easy task!

In [None]:
ex07_figure, ex07_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex07_output = CueFigure(ex07_figure)

ex07_xy = frame_iron.positions[selection,:2]
ex07_cbar = None
def update_07(code_exercise):
    global ex07_cbar
    rcut, delta, rs = code_exercise.parameters.values()
    fig = code_exercise.cue_outputs[0].figure
        
    ax = fig.get_axes()[0]
    ax.axis('off')
    vals = np.zeros(len(selection))
    acsf=ex03_wci.get_function_object()
    for i in tqdm(range(len(vals))):
        vals[i] = acsf(neigh_dr[neigh_i[i]:neigh_i[i]+neigh_sz[i]], delta=delta, rs=rs, rcut=rcut)
    cax=ax.scatter(ex07_xy[:,0], ex07_xy[:,1], c=vals, marker='.', s=5,
                  vmin=vals.min(), vmax = vals.max() )       
    
    if ex07_cbar is None:
        ex07_cbar = fig.colorbar(cax, ax=ax )
    else:
        ex07_cbar.update_normal(cax)

    ax = fig.get_axes()[1]
    ax.hist(vals, color='red', bins=50)
    ax.set_xlabel(r"$\xi$")
    ax.set_ylabel(r"counts")

ex07_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    delta = FloatSlider(value=0.5,min=0.1,max=2,step=0.1,description=r'$\Delta$ / Å'),
    rs = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_s$ / Å'),
    )

ex07_code_demo = CodeExercise(
    parameters= ex07_pb,
    cue_outputs = [ex07_output],
    update_func = update_07,
    update_mode="manual",
    #exercise_key="07",
    exercise_title="Exercise 07: Radial ACSF as a structural fingerprint",
    exercise_description=mdwn("""
Play around with the ACSF parameters (this exercise uses the function *you* implemented
in exercise 03), and see if you can find values that differentiate clearly between liquid-like
and solid-like environments.
"""
))

In [None]:
display(ex07_code_demo)
ex07_code_demo.run_update()

In [None]:
ex07b_txt = TextExercise(
    exercise_description="""
Observe the variability in the descriptor values.
Why do you think it's hard to find good values for the ACSF to make it good at discriminating?
Think also at the radial function plots in exercise 02.
    """,
    exercise_registry=exercise_registry,
    exercise_key="07b",
    exercise_title=""
)
display(ex07b_txt)

## Principal component analysis

We could then try to compute many symmetry functions (or equivalently to use the radial descriptors that are discretized on a grid). We would however than have the problem of having a large number of fingerprints for each environment: how to then use a single value to identify the two phases? Luckily, we can use a principal component analysis to identify different modes!

In [None]:
ex08_wci = CodeInput(
        function_name="features_pca", 
        function_parameters="xi_list, n_components",
        docstring="""
        perform a princpal component analysis of a list of environment descriptors
        
        
        :param xi_list: a list of features computed for some atomic environments
        :param n_components: the number of PCA components to evaluate
        
        :returns: the list of PCA features for the given environments
""",
        function_body="""
# here you only have to call sklearn functions, see ex. 01 or the sklearn documentation
# if you forgot how to perform PCA. for simplicity (and given there are way more environments
# than descriptors and components) you should apply the transformation to the same features
# so you can also use fit_transform. 

pca = xi_list
return pca
"""
        )

In [None]:
ex08_figure, ex08_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex08_output = CueFigure(ex08_figure)

ex08_xy = frame_iron.positions[selection,:2]
ex08_cbar = None
ex08_cache_pars=(-1,-1,"")
ex08_cache_vals=None
def update_08(code_exercise):
    global ex08_cbar, ex08_cache_pars, ex08_cache_vals
    rcut, ngrid, ipca = code_exercise.parameters.values()
    fig = code_exercise.cue_outputs[0].figure

    # cache values if we only change ipca
    if (ex08_cache_pars[0] == rcut and 
        ex08_cache_pars[1] == ngrid and
        ex08_cache_pars[2] == ex02_wci.function_body):
        vals = ex08_cache_vals
    else:
        vals = np.zeros((len(selection), ngrid))
        radial=ex02_wci.get_function_object()
        for i in tqdm(range(len(vals))):
            vals[i] = radial(neigh_dr[neigh_i[i]:neigh_i[i]+neigh_sz[i]], rcut=rcut, ngrid=ngrid)
        ex08_cache_vals  = vals
        ex08_cache_pars = (rcut, ngrid, ex02_wci.function_body)
        
    ax = fig.get_axes()[0]
    ax.axis('off')
    
    pcafull = ex08_wci.get_function_object()(vals, n_components=8)
    pca = pcafull[:,ipca]
    cax=ax.scatter(ex08_xy[:,0], ex08_xy[:,1], c=pca, marker='.', s=5,
                  vmin=pca.min(), vmax = pca.max() )       
    
    if ex08_cbar is None:
        ex08_cbar = fig.colorbar(cax, ax=ax )
    else:
        ex08_cbar.update_normal(cax)

    ax = fig.get_axes()[1]
    ax.hist(pca, color='red', bins=50)
    ax.set_xlabel(r"PCA")
    ax.set_ylabel(r"counts")

ex08_pb =  ParameterPanel(
    rcut = FloatSlider(value=5,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    ngrid = IntSlider(value=8,min=4,max=32,step=1,description=r'$n_{grid}$'),
    ipca = IntSlider(value=0,min=0,max=6,step=1,description=r'$i_{PCA}$')
    )
                         
ex08_code_demo = CodeExercise(
    code=ex08_wci,
    check_registry=check_registry,
    parameters= ex08_pb,
    cue_outputs = [ex08_output],
    update_func = update_08,
    update_mode="manual",
    exercise_key="08",
    exercise_registry=exercise_registry,
    exercise_title="Exercise 08: Feature PCA",
    exercise_description=mdwn("""
Write a function that computes the principal component analysis for a set of 
descriptors. The visualizer will then compute radial features using the code you
wrote in exercise 01, with the parameters specified, and plot the color-coded
structure and a histogram of the projected features. You can select which PCA
component you display.

_NB:_ It can take up to a minute to compute all the radial descriptors, but the 
viewer will cache the values so as long as you only change the PCA index the 
plotting will be almost instantaneous. This function will also be used in 
some of the following exercises so make sure it works correctly!
"""
))

In [None]:

ex08_ref_input = [{"xi_list": np.array([[1,3,4],[2,5,7],[12,-5,7]]), "n_components":2},
                 {"xi_list": np.array([[11,32,4],[23,3,-7],[23,5,7],[12,-5,7]]), "n_components": 3}
                 ]
ex08_ref_output = [(np.array([[-4.52278769,  1.88265544],
         [-4.7507803 , -1.85204928],
         [ 9.27356799, -0.03060617]]),),
 (np.array([[ 24.04316249,  -1.42811158,  -1.03650133],
         [ -7.06128363,  10.24872457,  -2.50787124],
         [ -4.63236046,  -0.28829947,   6.60648318],
         [-12.3495184 ,  -8.53231352,  -3.0621106 ]]),)]

check_registry.add_check(ex08_code_demo,
    asserts=[
        assert_type,
        assert_shape,
        assert_numpy_allclose,
    ],
    inputs_parameters=ex08_ref_input,
    outputs_references=ex08_ref_output
)

In [None]:
display(ex08_code_demo)

In [None]:
ex08b_txt = TextExercise(
    exercise_description="""
Observe the range spanned by the different components. How does it change with increasing index?
Why? Is the first fingerprint the best at distinguishing solid and liquid?
    """,
    exercise_registry=exercise_registry,
    exercise_key="08b",
    exercise_title=""
)
display(ex08b_txt)

In [None]:
ex08c_txt = TextExercise(
    exercise_description="""
Can you get much better results than with the ACSF-based approach? 
The difference between solid and liquid regions should be
clearly reflected in the color-coded structure, and
an ideal fingerprint should show clearly a bimodal distribution.
    """,
    exercise_registry=exercise_registry,
    exercise_key="08c",
    exercise_title=""
)
display(ex08c_txt)

## Incorporating directional information with density descriptors

If the problem is that the radial descriptors do not have enough resolution to identify soldi structures, it makes sense to use features that are aware of the directional information such as the density descriptors. Given the large number of descriptors available, we also perform a PCA to select the most informative combinations. Note that this section uses the functions you have introduced above in exercises 4 and 8, to compute the density descriptors and to perform the PCA. 

In [None]:
def combine_l_all(tmap):
    feats=[]
    for b in tmap:
        feats.append(b.values.reshape((len(b.values),-1)) )
    return np.hstack(feats)

In [None]:
ex09_figure, ex09_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex09_output = CueFigure(ex09_figure)

ex09_xy = frame_iron.positions[selection,:2]
ex09_cbar = None
ex09_cache_pars=(-1,-1,"")
ex09_cache_vals=None
rhoi_pca = None
def update_09(code_exercise):
    global ex09_cbar, ex09_cache_pars, ex09_cache_vals, rhoi_pca
    rcut, nmax, lmax, sigma, ipca = code_exercise.parameters.values()
    fig = code_exercise.cue_outputs[0].figure

    # cache values if we only change ipca
    if (ex09_cache_pars == (rcut, nmax, lmax, sigma, ex04_wci.function_body) ):
        vals = ex09_cache_vals
    else:        
        dsel = len(selection)//8 +1
        vals = [] 
        for i in tqdm(range(8)): # compute in pieces so we can show progress
            rhoi = ex04_wci.get_function_object()(frame_iron, rcut, nmax, lmax, sigma, selection[i*dsel:(i+1)*dsel])
            vals.append(combine_l_all(rhoi))        
        vals = np.vstack(vals)
        ex09_cache_vals  = vals
        ex09_cache_pars = (rcut, nmax, lmax, sigma, ex04_wci.function_body)
        
    ax = fig.get_axes()[0]
    ax.axis('off')
    
    rhoi_pca = ex08_wci.get_function_object()(vals, n_components=8)
    pca = rhoi_pca[:,ipca]
    cax=ax.scatter(ex09_xy[:,0], ex09_xy[:,1], c=pca, marker='.', s=5,
                  vmin=pca.min(), vmax = pca.max() )       
    
    if ex09_cbar is None:
        ex09_cbar = fig.colorbar(cax, ax=ax )
    else:
        ex09_cbar.update_normal(cax)

    ax = fig.get_axes()[1]
    ax.hist(pca, color='red', bins=50)
    ax.set_xlabel(r"PCA")
    ax.set_ylabel(r"counts")

ex09_pb =  ParameterPanel(
    rcut = FloatSlider(value=4,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    nmax = IntSlider(value=4,min=1,max=8,description=r'$n_{max}$'),
    lmax = IntSlider(value=6,min=1,max=8,description=r'$l_{max}$'),
    sigma = FloatSlider(value=0.5,min=0.3,max=1,step=0.1,description=r'$\sigma$ / Å'),
    ipca = IntSlider(value=0,min=0,max=6,step=1,description=r'$i_{PCA}$')
    )
                         
ex09_code_demo = CodeExercise(
    parameters= ex09_pb,
    cue_outputs = [ex09_output],
    update_func = update_09,
    update_mode="manual",
    #exercise_key="09",
    #exercise_registry=exercise_registry,
    exercise_title="Exercise 09: Density expansion PCA",
    exercise_description=mdwn(""" Experiment with the parameters of the density expansion 
    to see if you can identify more clearly the solid regions. 
Note that it takes up to a minute to recompute descriptors after you change the parameters 
so be patient. 

_Hint:_ You'll probably need to include high values of $l$. Given there is lots of 
directional information, low-cutoff values are usually sufficient (and even beneficial)""")
)

In [None]:
display(ex09_code_demo)

In [None]:
ex09b_txt = TextExercise(
    exercise_key="09b",
    exercise_registry=exercise_registry,
    exercise_title="",
    exercise_description=mdwn("""
Note the values of the PCA features for the different solid regions. 
Do they differ? Why? We will investigate this further in the next exercises.
""") )
    
display(ex09b_txt)

## Invariant descriptors of the atomic structure

We can do the same thing for SOAP descriptors

In [None]:
ex10_figure, ex10_ax = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)
ex10_output = CueFigure(ex10_figure)

ex10_xy = frame_iron.positions[selection,:2]
ex10_cbar = None
ex10_cache_pars=(-1,-1,-1,-1,"")
ex10_cache_vals=None
soap_pca = None
def update_10(code_exercise):
    global ex10_cbar, ex10_cache_pars, ex10_cache_vals, soap_pca
    rcut, nmax, lmax, sigma, select_l, ipca = code_exercise.parameters.values()
    fig = code_exercise.cue_outputs[0].figure

    # cache values if we only change ipca
    if (ex10_cache_pars == (rcut, nmax, lmax, sigma, ex05_wci.function_body) ):
        vals = ex10_cache_vals
    else:        
        dsel = len(selection)//8 +1
        vals = [] 
        for i in tqdm(range(8)): # compute in pieces so we can show progress
            soap = ex05_wci.get_function_object()(frame_iron, rcut, nmax, lmax, sigma, selection[i*dsel:(i+1)*dsel])
            vals.append(soap.block(0).values.squeeze())        
        vals = np.vstack(vals)
        ex10_cache_vals  = vals
        ex10_cache_pars = (rcut, nmax, lmax, sigma, ex05_wci.function_body)
        
    ax = fig.get_axes()[0]
    ax.axis('off')
    npca = 8
    if select_l!="": # picks manually some l channels
        select_l = np.array(list(map(int, select_l.split(','))))
        if select_l.max()>lmax:
            raise ValueError("Cannot extract l channels larger than lmax")
        select_blocks = []
        nmax2=nmax**2
        for l in select_l:
            # Extend the result list by adding a range starting from 'start' and continuing for 'nmax' elements
            select_blocks.extend(range(l*nmax2, (l+1)*nmax2))
        vals=vals[:,select_blocks]
        if len(select_blocks)<npca:
            npca = len(select_blocks)

    soap_pca = ex08_wci.get_function_object()(vals, n_components=8)
    pca = soap_pca[:,ipca]
    cax=ax.scatter(ex10_xy[:,0], ex10_xy[:,1], c=pca, marker='.', s=5,
                  vmin=pca.min(), vmax = pca.max() )       
    
    if ex10_cbar is None:
        ex10_cbar = fig.colorbar(cax, ax=ax )
    else:
        ex10_cbar.update_normal(cax)

    ax = fig.get_axes()[1]
    ax.hist(pca, color='red', bins=50)
    ax.set_xlabel(r"PCA")
    ax.set_ylabel(r"counts")

ex10_pb =  ParameterPanel(
    rcut = FloatSlider(value=4,min=3,max=8,step=0.1,description=r'$r_{cut}$ / Å'),
    nmax = IntSlider(value=4,min=1,max=8,description=r'$n_{max}$'),
    lmax = IntSlider(value=6,min=1,max=8,description=r'$l_{max}$'),
    sigma = FloatSlider(value=0.5,min=0.3,max=1,step=0.1,description=r'$\sigma$ / Å'),
    select_l = Text(value="",description=r'$l$ channels'),
    ipca = IntSlider(value=0,min=0,max=6,step=1,description=r'$i_{PCA}$')    
    )
                         
ex10_code_demo = CodeExercise(
    parameters= ex10_pb,
    cue_outputs = [ex10_output],
    update_func = update_10,
    update_mode="manual",
    #exercise_key="10",
    #exercise_registry=exercise_registry,
    exercise_title="Exercise 10: SOAP PCA",
    exercise_description=mdwn("""Experiment with the hyperparameters to improve the contrast
    between solid and liquid regions. 
Note that it takes up to a minute to recompute descriptors after you change the parameters 
so be patient. 

_Hint:_ Start from values that gave good resolution for the density expansion. 
You will notice less clear-cut classification. One way to improve resolution is 
to select only specific $l$ channels that have good resolving power (because they 
are sensitive to fcc ordering). You can pick a comma-separated list of values, e.g. 
`5,7,8`.""")
)

In [None]:
display(ex10_code_demo)

In [None]:
ex10b_txt = TextExercise(
    exercise_key="10b",
    exercise_registry=exercise_registry,
    exercise_title="",
    exercise_description=mdwn("""
Compare the values of the PCA features in the various solid nuclei.
Are there qualitative differences with what we saw for density expansion coefficients?
In which modeling or analysis use cases would you employ either type of descriptors?
""") )
    
display(ex10b_txt)

## Combined visualization

Only _after_ having completed all the exercises, and having found good parameters to visualize the phase transition, click the `Update` button below. It will generate an interactive viewer that you can use to get further insights on the relation between the structure and the SOAP and density based PCA descriptors.

In [None]:
def update_cs(code_exercise):
    frame_slice = frame_iron[selection]
    chemiscope.write_input("module_02-iron_analysis.chemiscope.json.gz", frames=[frame_slice], 
                     properties={
                         "SOAP-PCA": soap_pca[:,:4],
                          "RHO-PCA": rhoi_pca[:,:4]
                     },
                     settings={
                         'map': {'x': {'property': 'RHO-PCA[1]'},  
                                 'y': {'property': 'RHO-PCA[2]'},
                                 'color': {'property': 'SOAP-PCA[1]'},
                          'palette': 'viridis'},
                         "structure": [ 
                         {"bonds": False, "unitCell": False,
                           'spaceFilling': True,
                          "keepOrientation": True,  
                          'color': {'property': 'SOAP-PCA[1]','palette': 'viridis'},                         
                     'environments': {'activated': False, 'center': False}}] 
                              },  
                     environments=chemiscope.all_atomic_environments(frame_slice) ,                     
                    )
    with cue_cs_final:
        display(chemiscope.show_input("module_02-iron_analysis.chemiscope.json.gz"))

cs_wp_final = ParameterPanel(
    showenv=Checkbox(value=False, description="show environments"),
    cutoff=FloatSlider(value=5.,min=2,max=8,step=0.25, description=r"cutoff / Å"),    
)

cue_cs_final = CueObject()
    
cs_demo_final = CodeExercise(
            #parameters=cs_wp_final,
            cue_outputs = cue_cs_final,
            update_func = update_cs,
            update_mode="manual")

display(cs_demo_final)

[Download chemiscope datafile]("module_02-iron_analysis.chemiscope.json.gz")