In [None]:
# Useful standard and scientific ML libraries
import ase.io
import matplotlib.pyplot as plt
import numpy as np
import py3Dmol
import torch

# M-Stack packages

import metatensor   # storage format for atomistic ML
import chemiscope  # interactive molecular visualization
import rascaline   # generating structural representations

from metatensor import Labels, TensorBlock, TensorMap
from rascaline.utils import clebsch_gordan

# Torch-based density leaning
from rholearn import io, data, loss, models, predictor

## Check the equivariance condition

### a) $\lambda$-SOAP descriptor

In [None]:
# Check equivariance of lambda-SOAP descriptor

# Generate Wigner-D matrices, initialized with random angles
wig = clebsch_gordan.WignerDReal(lmax=2*rascal_settings["hypers"]["max_angular"])
print("Random rotation angles (rad):", wig.angles)

# Apply an O(3) transformation to each frame 
frames_o3 = [clebsch_gordan.transform_frame_o3(frame, wig.angles) for frame in frames]
assert not np.allclose(frames[0].positions, frames_o3[0].positions)

# Generate lambda-SOAP for the transformed frames
lsoap = predictor.descriptor_builder(frames, **descriptor_kwargs)
lsoap_o3 = predictor.descriptor_builder(frames_o3, **descriptor_kwargs)

# Apply the O(3) transformation to the TensorMap
lsoap_transformed = wig.transform_tensormap_o3(lsoap)

# Check for equivariance!
assert metatensor.equal_metadata(lsoap_transformed, lsoap_o3)
assert metatensor.allclose(lsoap_transformed, lsoap_o3)
print("O(3) EQUIVARIANT!")

Random rotation angles (rad): [2.35138515 4.55346788 5.98524295]
O(3) EQUIVARIANT!


### b) torch model (untrained)

In [None]:
# Make a prediction on both the original and O3-transformed lambda-SOAP
# descriptors
with torch.no_grad():
    pred = model(metatensor.to(lsoap, "torch", **torch_settings))
    pred_o3 = model(metatensor.to(lsoap_o3, "torch", **torch_settings))

# Transform the prediction made on the original (untransfored) lambda-SOAP
pred_transformed = wig.transform_tensormap_o3(pred)

# Check for equivariance!
assert metatensor.equal_metadata(pred_transformed, pred_o3)
assert metatensor.allclose(pred_transformed, pred_o3)
print("O(3) EQUIVARIANT!")



O(3) EQUIVARIANT!
