In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [2]:
import numpy as np
import json
from equistore import Labels, TensorBlock, TensorMap
from utils.builder import TensorBuilder
import ase.io
from itertools import product
from utils.acdc_mini import acdc_standardize_keys, cg_increment, cg_combine, _remove_suffix
from utils.clebsh_gordan import ClebschGordanReal
from utils.hamiltonians import hamiltonian_features
import matplotlib.pyplot as plt

from utils.librascal import  RascalSphericalExpansion, RascalPairExpansion
from rascal.representations import SphericalExpansion
import copy
from utils.model_hamiltonian import *
from itertools import product

In [3]:
from utils.mp_utils import * 

In [4]:
# frames1 = ase.io.read("./data/hamiltonian/water-hamiltonian/water_coords_1000.xyz", ":2")
# frames2 = ase.io.read("./data/hamiltonian/ethanol-hamiltonian/ethanol_4500.xyz",":1")
# frames3= [ase.build.molecule('NH3')]
# frames = frames1+frames2#+frames3#+frames2

frames = ase.io.read("./data/random-methane-10k.extxyz",":100")
energy =[]
for f in frames:
    energy.append(f.info['energy'])

energy=np.array(energy)

# frames = frames1

for f in frames:
    f.cell = [100,100,100]
    f.positions += 50

In [5]:
en_block = TensorBlock(values = np.asarray(energy).reshape(len(energy),1 ,-1),
                    samples = Labels(['structure'], np.asarray(range(len(energy)), np.int32).reshape(-1,1)), 
                    components = [Labels(['spherical_component_m'], np.asarray([[0]], np.int32))],
                    properties= Labels(['energy'],np.asarray([0], np.int32).reshape(-1,1))
           )
energy_tensor = TensorMap(Labels(['energy'], np.asarray([[0]], np.int32)), [en_block])

In [6]:
rascal_hypers = {
    "interaction_cutoff": 3.5,
    "cutoff_smooth_width": 0.3,
    "max_radial": 2,
    "max_angular": 2,
    "gaussian_sigma_type": "Constant",
    "compute_gradients":  False,
#     "expansion_by_species_method": "user defined",
#     "global_species": [1,6,8,7]
    
}

spex = RascalSphericalExpansion(rascal_hypers)
rhoi = spex.compute(frames)

In [7]:
lmax = rascal_hypers["max_angular"]+2
cg = ClebschGordanReal(lmax)

In [8]:
pairs = RascalPairExpansion(rascal_hypers)
gij = pairs.compute(frames)

In [9]:
rho1i = acdc_standardize_keys(rhoi)
rho1i.keys_to_properties(['species_neighbor'])
gij =  acdc_standardize_keys(gij)

In [10]:
rho2i =  cg_combine(rho1i, rho1i, clebsch_gordan=cg, lcut=lmax, other_keys_match = ['species_center'])
rho2i_L0 = []#rho2i.blocks(spherical_harmonics_l=0)
rho2i_L0_keys = rho2i.blocks_matching(spherical_harmonics_l=0)
for idx in rho2i_L0_keys: 
    rho2i_L0.append(rho2i.block(idx).copy())
rho2i_L0 = TensorMap(rho2i.keys[rho2i_L0_keys], rho2i_L0)

In [11]:
rhoii1i2_nu0 = cg_combine(gij, gij, clebsch_gordan=cg,lcut=lmax, other_keys_match=['species_center'])

In [12]:
rhoii1i2_nu0

TensorMap with 64 blocks
keys: ['order_nu' 'inversion_sigma' 'spherical_harmonics_l' 'species_center' 'species_neighbor_a' 'species_neighbor_b']
           2             1                    0                   1                 1                   1
           2             1                    0                   1                 1                   6
           2             1                    1                   1                 1                   1
        ...
           2             1                    4                   6                 6                   1
           2            -1                    3                   6                 6                   6
           2             1                    4                   6                 6                   6

In [41]:
rhoii1i2_nu1 =  cg_combine(rho1i, rhoii1i2_nu0, clebsch_gordan=cg, lcut=0, other_keys_match = ['species_center'])

In [42]:
rhoii1i2_nu1

TensorMap with 16 blocks
keys: ['order_nu' 'inversion_sigma' 'spherical_harmonics_l' 'species_center' 'species_neighbor_a' 'species_neighbor_b']
           3             1                    0                   1                 1                   1
           3             1                    0                   1                 1                   6
           3             1                    0                   1                 6                   1
        ...
           3            -1                    0                   6                 1                   6
           3            -1                    0                   6                 6                   1
           3            -1                    0                   6                 6                   6

In [15]:
# contracted_rhoii1i2 = contract_three_center(rhoii1i2_nu0)
# compare_with_rho2i(contracted_rhoii1i2, rho2i)

In [13]:
rhoii1i2_nu0_L0 = []#rhoii1i2_nu0.blocks(spherical_harmonics_l=0)
rhoii1i2_nu0_L0_keys = rhoii1i2_nu0.blocks_matching(spherical_harmonics_l=0)
for idx in rhoii1i2_nu0_L0_keys: 
    rhoii1i2_nu0_L0.append(rhoii1i2_nu0.block(idx).copy())
rhoii1i2_nu0_L0 = TensorMap(rhoii1i2_nu0.keys[rhoii1i2_nu0_L0_keys], rhoii1i2_nu0_L0)

In [38]:
compare_with_rho2i(contract_three_center_property(rhoii1i2_nu0_L0), rho2i_L0)

contract 3 center (400, 1, 48)
contract 3 center (100, 1, 48)
(2, 1, 0, 1) 3.5230068197095586e-19
(2, 1, 0, 6) 1.4410325788004896e-19


# Linear Model 

In [46]:
def scale(feat, val):
    block_list = []
    for _,b in feat:
        block = TensorBlock(properties = b.properties, 
                             components = b.components, 
                             values = b.values*val, 
                             samples = b.samples)
        block_list.append(block)
    return TensorMap(feat.keys, block_list)

In [16]:
def get_rmse(first, second):
#     print(first.shape, second.shape)
    return torch.sqrt(torch.mean((first - torch.from_numpy(second)) ** 2))
    
def get_mae(first, second):
        return torch.mean(torch.abs(first - second))
    
def get_spread(values):
    return np.sqrt(np.mean((values - np.mean(values)) ** 2))


In [48]:
class LinearModel(torch.nn.Module):
    def __init__(self, target,features, weights=None, intercepts=None):
        super().__init__()
        self.target = target
        self.features = features
        self.weights = {}
        if weights==None:
            for key, block in self.features:
                size = block.values.shape[-1]
                self.weights[key] = torch.nn.Parameter(torch.ones(size, dtype=torch.float64)*1e0)
            
        else: 
            self.weights = weights
            
        if intercepts is not None:
            self.intercepts = intercepts 
        else:
            self.intercepts = None
         
    def forward(self, features):
        k = []
        pred_blocks = []
        for (idx, wts) in self.weights.items():
            if len(features.keys.dtype) ==6:
                order_nu, inversion_sigma, L, ai, aj,ak = idx
                X = features.block(species_center=ai, species_neighbor_a = aj, species_neighbor_b = ak)
            elif len(features.keys.dtype) ==4:
                order_nu, inversion_sigma, L, ai = idx
                X = features.block(species_center=ai)
            else: 
                raiseValueError('features are neither three center nor atom centered')
            assert L==0 #invariant only
            
            X_new = torch.from_numpy(X.values.reshape(-1, X.values.shape[-1]))
            #print(idx, wts.shape, X.values.shape, X_new.shape)
            if self.intercepts is not None:
                Y = X_new @ wts + self.intercepts[idx]
            else:
                Y = X_new @ wts
                
            newblock = TensorBlock(
                        values=Y.reshape((-1, 2 * L + 1, 1)),
                        samples=X.samples,
                        components=[Labels(
                            ["mu"], np.asarray(range(-L, L + 1), dtype=np.int32).reshape(-1, 1)
                        )],
                        properties= Labels(["dummy"], np.asarray([[0]], dtype=np.int32))
                    )
            pred_blocks.append(newblock) 
        
        keys = self.features.keys
        pred_target = TensorMap(keys, pred_blocks)
        return(pred_target)
    
    def parameters(self):
        for idx, wts in self.weights.items():
            yield wts        

In [49]:
def loss_fn(pred, target, three_center=True):
    # pred and target are tensormaps 
    if three_center: 
        pred = contract_three_center_property(pred, numpy=False)
    pred = atom_to_structure(pred, numpy =False)
    return get_rmse(pred.block().values, target.block().values)

In [50]:
np.linalg.norm(atom_to_structure(rho2i_L0).block().values - rho2i_L0.block(1).values - np.add.reduceat(rho2i_L0.block(0).values, np.arange(0,400,4)))

3.476593498823191e-19

In [51]:
contract_three_center_property(rhoii1i2_nu0_L0)

TensorMap with 2 blocks
keys: ['order_nu' 'inversion_sigma' 'spherical_harmonics_l' 'species_center']
           2             1                    0                   1
           2             1                    0                   6

## model with rhoii1i2

In [None]:
# rhoii1i2_nu0_L0.keys_to_properties('species_neighbor_b')
# rhoii1i2_nu0_L0.keys_to_properties('species_neighbor_a')
model = LinearModel(energy, scale(rhoii1i2_nu0_L0, 1e0))
optimizer = torch.optim.Adam(model.parameters(), lr=10)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1)

all_losses = []
for epoch in range(100000):
    optimizer.zero_grad()
    pred = model(scale(rhoii1i2_nu0_L0, 1e2))
    loss = loss_fn(pred, energy_tensor, three_center=True)
    loss.backward()
    optimizer.step()
#     scheduler.step()
    
    all_losses.append(loss.item())
    
    if epoch % 100 == 0:
        print(epoch, loss.item()) 

0 39.95268937601291
100 4.295995747963264
200 1.2624713714037
300 0.8521195168717316
400 0.6507369632923024
500 0.567716683772019
600 0.5180635327567081


## model with rho2i

In [26]:
model = LinearModel(energy, rho2i_L0)
pred = model(rho2i_L0)

In [52]:
model = LinearModel(energy, scale(rho2i_L0,1e0))
optimizer = torch.optim.Adam(model.parameters(), lr=10)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1)

all_losses = []
for epoch in range(1000000):
    optimizer.zero_grad()
    pred = model(scale(rho2i_L0, 1e2))
    loss = loss_fn(pred, energy_tensor, three_center=False)
    loss.backward()
    optimizer.step()
#     scheduler.step()
    
    all_losses.append(loss.item())
    
    if epoch % 100 == 0:
        print(epoch, loss.item()) 

0 50.58802177239843
100 2.2498598467112556
200 1.157211877558215
300 0.8594678531674216
400 0.7282455652815527
500 0.6314600994731302
600 0.5511290853251981
700 0.5035962831901813
800 0.4554906744893913
900 0.41707392153337824
1000 0.3841635318172585
1100 0.3558735302746162
1200 0.3313600039273276
1300 0.31003367890141015
1400 0.29142678117737414
1500 0.2751578953920216
1600 0.2609066185242776
1700 0.2484022582217272
1800 0.23740649286397605
1900 0.22772036179123123


KeyboardInterrupt: 

## old

### sanity check - 
we need to compare the contraction with rho2, but the property labels of the two tensor maps are in diff orders
here's a naive solution to find the index of prop label in one tensor in another

In [19]:
def compare_with_rho2i(contracted_rhoii1i2, rho2i):
    assert len(rho2i) == len(contracted_rhoii1i2)
    for rho2_k, rho2_b in rho2i:
        contracted_block = contracted_rhoii1i2.block(order_nu = rho2_k[0], inversion_sigma = rho2_k[1], spherical_harmonics_l = rho2_k[2], species_center = rho2_k[3])
        idx = []
        cidx = []
        for i,p in enumerate(rho2_b.properties):
            find_p = (p['species_neighbor_a'], p['species_neighbor_b'], p['n_1_a'], p['k_2'], p['n_1_b'], p['l_2'])
        #     print(p, find_p)
            for ip,cp in enumerate(contracted_block.properties):
                if tuple(find_p) == tuple(cp) :
                    idx.append(i)
                    cidx.append(ip)
                    break
        print(rho2_k, np.linalg.norm(rho2_b.values[:,:,idx] - contracted_block.values[:,:,cidx]))

compare_with_rho2i(contracted_rhoii1i2, rho2i)

(2, 1, 0, 1) 2.3412456464305706e-20
(2, 1, 1, 1) 1.8707075487275463e-20
(2, 1, 2, 1) 1.9540399352564176e-20
(2, 1, 0, 8) 2.938484138683137e-20
(2, 1, 1, 8) 1.3797231381027056e-20
(2, 1, 2, 8) 2.6540990298228056e-20
(2, -1, 1, 1) 0.0
(2, -1, 2, 1) 0.0
(2, 1, 3, 1) 0.0
(2, -1, 1, 8) 4.290085066403457e-21
(2, -1, 2, 8) 1.1003280752469429e-20
(2, 1, 3, 8) 8.852469447498877e-21
(2, -1, 3, 1) 0.0
(2, 1, 4, 1) 0.0
(2, -1, 3, 8) 7.727960980828157e-21
(2, 1, 4, 8) 8.668935063650161e-21


so just summing over i1, i2 of the three center feature yields power spectrum as it should!