In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from functools import partial

import numpy as np
import h5py

import torch
torch.set_default_dtype(torch.float64)

import torch_geometric as tg

import e3nn
import e3nn.point
import e3nn.radial
import e3nn.kernel
from e3nn.point.message_passing import Convolution
from e3nn.networks import GatedConvParityNetwork
from e3nn.non_linearities import rescaled_act
import e3nn.point.data_helpers as dh 

import matplotlib
import matplotlib.pyplot as plt

torch.set_default_dtype(torch.float64)
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

## Import data and create torch.Tensors

In [2]:
# read perturbed data
with h5py.File("../acetone/acetone-b3lyp_d3bj-631gd-gas-NMR-pcSseg_1.hdf5", "r") as h5:
    geoms_and_shieldings = np.array(h5.get("data"))

In [4]:
### Graphic to get order of atom types
# import plotly.graph_objects as go
# import numpy as np

# annotations = [
#     dict(showarrow=False, 
#          x=geometry[0, i, 0], 
#          y=geometry[0, i, 1],
#          z=geometry[0, i, 2], text=str(i), xanchor="left") for i in range(len(geometry[0]))
# ]

# fig = go.Figure(data=[go.Scatter3d(x=geometry[0, :, 0], y=geometry[0, :, 1], z=geometry[0, :, 2], 
#                                    text=[str(i) for i in range(10)],
#                                    mode='markers')])
# fig.update_layout(scene={'annotations': annotations})
# fig.show()

In [5]:
# geometry = torch.tensor(geoms_and_shieldings[:, :, :3], dtype=torch.float64)
# B, N, _ = geometry.shape
# shielding = torch.tensor(geoms_and_shieldings[:, :, 3], dtype=torch.float64).unsqueeze(-1)
# shielding = (shielding - shielding.mean()) / shielding.std()
# atom_types = ['C', 'O', 'C', 'H', 'H', 'H', 'C', 'H', 'H', 'H']
# map_dict = {'H': 0, 'C': 1, 'O': 2}
# features = torch.zeros(B, N, 3)
# for i, a in enumerate(atom_types):
#     features[:, i, map_dict[a]] = 1.
# print(geometry.shape, shielding.shape, features.shape)

In [6]:
max_radius = 2.

In [7]:
#### This is not efficient we'll need a 
# dataset = []
# for i, (geo, shield, feat) in enumerate(zip(geometry, shielding, features)):
#     data = dh.DataNeighbors(x=feat, Rs_in=[(3, 0, 1)], 
#                             pos=geo, r_max=max_radius, 
#                             y=shield, Rs_out=[(1, 0, 1)])
#     dataset.append(data)
#     if i % 2500 == 2499:
#         print(i)
#         torch.save(dataset, 'acetone_geo/acetone_geometric_dataset_{}.torch'.format(i))
#         dataset = []

In [3]:
dataset = torch.load('acetone_geo/acetone_geometric_dataset_2499.torch')
batch_size = 64
dataloader = tg.data.DataListLoader(dataset, batch_size=batch_size)

FileNotFoundError: [Errno 2] No such file or directory: 'acetone_geo/acetone_geometric_dataset_2499.torch'

In [9]:
test_dataset = torch.load('acetone_geo/acetone_geometric_dataset_4999.torch')
batch_size = 64
test_dataloader = tg.data.DataListLoader(test_dataset, batch_size=batch_size)

## Define model, optimizer, and loss function

In [10]:
Rs_in = [(3, 0, 1)]
Rs_out = [(1, 0, 1)]
mul = 8
model = GatedConvParityNetwork(
    Rs_in=Rs_in, Rs_out=Rs_out,
    mul=mul, lmax=2, layers=5,
    number_of_basis=10, 
    max_radius=max_radius,
    convolution=Convolution
)
opt = torch.optim.Adam(model.parameters(), 3e-3)
loss_fn = lambda x,y: ((x - y)**2).mean()  # MSE
# loss_fn = lambda x,y: ((x - y)**2).sqrt().mean()  # MAE

## Train!

In [11]:
model.to(device)
max_iter = 100
n_norm = 5
for i in range(max_iter):
    for data in dataloader:
        data = tg.data.Batch.from_data_list(data)
        data.to(device)
        output = model(data.x, data.edge_index, data.edge_attr, n_norm=n_norm)
        loss = loss_fn(output, data.y)
    if i % 5 == 0:
        print(loss)
    opt.zero_grad()
    loss.backward()
    opt.step()

tensor(1.0578, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.7954, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.4901, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0753, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0814, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0201, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0223, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0068, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0066, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0045, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0022, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0018, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0008, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0004, device='cuda:1', grad_fn=<MeanBackward0>)


KeyboardInterrupt: 

In [13]:
max_iter = 100
n_norm = 5
for data in test_dataloader:
    data = tg.data.Batch.from_data_list(data)
    data.to(device)
    output = model(data.x, data.edge_index, data.edge_attr, n_norm=n_norm)
    loss = loss_fn(output, data.y)
    print(loss)

tensor(0.0145, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0155, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0173, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0178, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0175, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0166, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0177, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0163, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0196, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0168, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0188, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0150, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0194, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0164, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0141, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0134, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0135, device='cuda:1', grad_fn=<MeanBackward0>)
tensor(0.0157, device='cuda:1',