In [12]:
from torch_geometric.datasets import MD17
import torch_geometric.transforms as tg_transforms


import e3nn
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import torch

In [13]:
import plotly.io as pio
pio.renderers.default = "notebook"

In [14]:
radius = 3 # TODO check this is a reasonable value

class OneHotTransform(tg_transforms.BaseTransform):
    def __call__(self, graph):
        graph.z = F.one_hot(graph.z).float()
        return graph


# This adds edges between all nodes that are leq 2 away - otherwise there are no edges
radius_transform = tg_transforms.RadiusGraph(r=radius)

# The one 


transforms = tg_transforms.Compose([radius_transform,
                                    OneHotTransform(),
                                    tg_transforms.Distance()])

asp_train = MD17('./../real_datasets/MD17', name='aspirin CCSD', train=True)
asp_test = MD17('./../real_datasets/MD17', name='aspirin CCSD', train=False)

In [15]:
graph = asp_train[0] # pull a graph object out of the dataset

#Each graph has a set of positions (in R3)
print(f"positions: {graph.pos.shape}")

# Some associated scalar energies
print(f"energies: {graph.energy.shape}")

# And associated forces (vectors in R3)
print(f"forces: {graph.force.shape}")

# It also has a categorical variable associated with it 'z', which I think is the elemental identity? Looks right considering
# the chemical structure is C_9 H_8 O_4
# Should be able to check this in a visualisation.
print(f"z: {graph.z}")

positions: torch.Size([21, 3])
energies: torch.Size([1])
forces: torch.Size([21, 3])
z: tensor([6, 6, 6, 6, 6, 6, 6, 8, 8, 8, 6, 6, 8, 1, 1, 1, 1, 1, 1, 1, 1])


In [16]:
import plotly.graph_objects as go

In [17]:
# Irreps objects create a subset of the possible irreps of the group. For example:
# The following corresponds to 3 copies (in the sense of direct sums) of each of l=0,1,2 spherical harmonics,
# Each of even parity (invariant to reflections)
type(e3nn.o3.Irreps("3x0e+3x1e+3x2e"))

e3nn.o3._irreps.Irreps

In [18]:
# Let's start by visualising a graph.
# We will initially do this just for the molecules themselves
# But eventually I think we want to use this to plot the functions defined on the manifold 

class Visualiser:

    def __init__(self) -> None:
        self.group_structure = e3nn.o3
        self.s2_grid = self.create_s2_grid()

    def create_s2_grid(self, N=100):
        """N corresponds to a discretization parameter"""
        betas = torch.linspace(0, np.pi, N//2) # angle in vertical plane
        alphas = torch.linspace(0, 2 * np.pi, N) # angle in horizontal plane

        alphas, betas = torch.meshgrid(alphas, betas)

        return e3nn.o3.angles_to_xyz(alphas, betas)

    def show_molecule_structure(self, graph):
        
        fig = go.Figure()
        
        positions = graph.pos.numpy()
        marker = go.Marker(color=graph.z.numpy())
        

        trace = go.Scatter3d(x = positions[:, 0],
                             y= positions[:, 1],
                             z = positions[:, 2],
                             marker=marker,
                             mode='markers')
        fig.add_trace(trace)

        fig.show()

    def show_S2_function(self, graph, s2_function, cmap='RdBu', position_scaler=10):


        figure_layout = go.Layout(paper_bgcolor="rgba(0,0,0,0)",
                                 plot_bgcolor="rgba(0,0,0,0)",
                                 margin=dict(l=0, r=0, t=0, b=0),
                                 showlegend=False)
        fig = go.Figure(layout=figure_layout)
        
        # normalise the positions
        positions = graph.pos.numpy()
        positions -= positions.min()
        positions /= positions.max()
        positions *= position_scaler

        num_nodes = positions.shape[0]

        # We plot one surface for each node, essential because each node has its own surface color map
        for node_id in range(num_nodes):
            x = (s2_function[node_id].abs()*self.s2_grid)[..., 0]+ positions[node_id,  0]
            y = (s2_function[node_id].abs()*self.s2_grid)[..., 1]+ positions[node_id,  1]
            z = (s2_function[node_id].abs()*self.s2_grid)[..., 2]+ positions[node_id,  2]
            surface = go.Surface(x=x,
                                 y=y,
                                 z=z,
                                 surfacecolor=s2_function,
                                #  colorscale=cmap,
                                #  cmin=0,
                                #  cmax=1,
                                 )
            fig.add_trace(surface)
            break

        fig.show()




In [25]:
vis = Visualiser()
vis.show_molecule_structure(graph)


plotly.graph_objs.Marker is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Marker
  - plotly.graph_objs.histogram.selected.Marker
  - etc.




In [20]:
N = 100
betas = torch.linspace(0, np.pi, N//2) # angle in vertical plane
alphas = torch.linspace(0, 2 * np.pi, N) # angle in horizontal plane

alphas, betas = torch.meshgrid(alphas, betas)
f = torch.cos(alphas*2) + torch.sin(betas*2)
f = [f for i in range(graph.pos.shape[0])]

In [21]:
s2_domain = vis.create_s2_grid()

### Creating a transformer network
