In [1]:
import torch

import sys
sys.path.append('../code/')
from code_pytorch import *
from utilities import *
from miscellaneous import ClebschGordan
import ase.io
import numpy as np
import copy

In [2]:
LAMBDA_MAX = 3
HYPERS = {
    'interaction_cutoff': 6.3,
    'max_radial': 5,
    'max_angular': LAMBDA_MAX,
    'gaussian_sigma_type': 'Constant',
    'gaussian_sigma_constant': 0.05,
    'cutoff_smooth_width': 0.3,
    'radial_basis': 'GTO'
    
}
subset = '0:100'

In [3]:
structures = ase.io.read('../methane.extxyz' , index=subset)
all_species = get_all_species(structures)
coefficients = get_coefs(structures, HYPERS, all_species)

In [4]:
for key in coefficients.keys():
    print(key, coefficients[key].shape)

0 torch.Size([500, 10, 1])
1 torch.Size([500, 10, 3])
2 torch.Size([500, 10, 5])
3 torch.Size([500, 10, 7])


In [5]:
all_features = copy.deepcopy(coefficients)
all_features['bla bla bla'] = torch.randn(500, 137)
all_features['one another'] = torch.randn(500, 42, 42, 42)

In [6]:
for key in all_features.keys():
    print(key, all_features[key].shape)

0 torch.Size([500, 10, 1])
1 torch.Size([500, 10, 3])
2 torch.Size([500, 10, 5])
3 torch.Size([500, 10, 7])
bla bla bla torch.Size([500, 137])
one another torch.Size([500, 42, 42, 42])


In [7]:
block = CentralSplitter()
splitted = block(all_features, get_central_species(structures))

In [8]:
for key1 in splitted.keys():
    for key2 in splitted[key1].keys():
        print(key1, key2, splitted[key1][key2].shape)

1 0 torch.Size([400, 10, 1])
1 1 torch.Size([400, 10, 3])
1 2 torch.Size([400, 10, 5])
1 3 torch.Size([400, 10, 7])
1 bla bla bla torch.Size([400, 137])
1 one another torch.Size([400, 42, 42, 42])
6 0 torch.Size([100, 10, 1])
6 1 torch.Size([100, 10, 3])
6 2 torch.Size([100, 10, 5])
6 3 torch.Size([100, 10, 7])
6 bla bla bla torch.Size([100, 137])
6 one another torch.Size([100, 42, 42, 42])


In [9]:
block = CentralUniter()
back = block(splitted, get_central_species(structures))
for key in back.keys():
    print(key, back[key].shape)

0 torch.Size([500, 10, 1])
1 torch.Size([500, 10, 3])
2 torch.Size([500, 10, 5])
3 torch.Size([500, 10, 7])
bla bla bla torch.Size([500, 137])
one another torch.Size([500, 42, 42, 42])


In [10]:
block = Accumulator()
summed = block(back, get_structural_indices(structures))
for key in summed.keys():
    print(key, summed[key].shape)

0 torch.Size([100, 10, 1])
1 torch.Size([100, 10, 3])
2 torch.Size([100, 10, 5])
3 torch.Size([100, 10, 7])
bla bla bla torch.Size([100, 137])
one another torch.Size([100, 42, 42, 42])


In [11]:
block = ClebschCombining(ClebschGordan(2 * LAMBDA_MAX).precomputed_, 2 * LAMBDA_MAX)
ps_covariants = block(coefficients, coefficients)
for key in ps_covariants.keys():
    print(key, ps_covariants[key].shape)

0 torch.Size([500, 400, 1])
1 torch.Size([500, 900, 3])
2 torch.Size([500, 1100, 5])
3 torch.Size([500, 1000, 7])
4 torch.Size([500, 600, 9])
5 torch.Size([500, 300, 11])
6 torch.Size([500, 100, 13])


In [12]:
block = ClebschCombining(ClebschGordan(LAMBDA_MAX).precomputed_, 2)
ps_covariants = block(coefficients, coefficients)
for key in ps_covariants.keys():
    print(key, ps_covariants[key].shape)

0 torch.Size([500, 400, 1])
1 torch.Size([500, 900, 3])
2 torch.Size([500, 1100, 5])


## main block

In [13]:
class Atomistic(torch.nn.Module):
    def __init__(self, models, accumulate = True):
        super(Atomistic, self).__init__()
        self.accumulate = accumulate
        if type(models) == dict:
            self.central_specific = True
            self.splitter = CentralSplitter()
            self.uniter = CentralUniter()
            self.models = nn.ModuleDict(models)
        else:
            self.central_specific = False
            self.model = models
        
        
        if self.accumulate:
            self.accumulator = Accumulator()
        
    def forward(self, X, central_species = None, structural_indices = None):
        if self.central_specific:
            if central_species is None:
                raise ValueError("central species should be provided for central specie specific model")
                      

            splitted = self.splitter(X, central_species)
            result = {}
            for key in splitted.keys():            
                result[key] = self.models[str(key)](splitted[key])
            result = self.uniter(result, central_species)
        else:
            result = self.model(X)
            
        if self.accumulate:
            if structural_indices is None:
                raise ValueError("structural indices should be provided to accumulate structural targets")
            result = self.accumulator(result, structural_indices)
        return result