In [2]:
%load_ext autoreload
%autoreload 2

In [8]:
from ase.io import read
from ase.visualize import view
import numpy as np 
import torch 
import metatensor 
import matplotlib.pyplot as plt
torch.set_default_dtype(torch.float64)

from metatensor import Labels, TensorBlock, TensorMap
from mlelec.data.dataset import PySCFPeriodicDataset

from mlelec.utils.plot_utils import print_matrix, matrix_norm, block_matrix_norm

from metatensor import load, sort
from mlelec.utils.twocenter_utils import _to_coupled_basis
from mlelec.utils.pbc_utils import matrix_to_blocks, kmatrix_to_blocks, TMap_bloch_sums, precompute_phase

from mlelec.features.acdc import pair_features, single_center_features, twocenter_features_periodic_NH, twocenter_hermitian_features
import rascaline
from mlelec.utils.pbc_utils import kblocks_to_matrix, kmatrix_to_blocks, blocks_to_matrix, matrix_to_blocks
from mlelec.utils.plot_utils import plot_block_errors

import rascaline
from mlelec.models.linear import LinearModelPeriodic

In [13]:
def get_targets(dataset, device ="cpu", cutoff = None, target='fock', all_pairs = False, sort_orbs = True):
    
    blocks = matrix_to_blocks(dataset, device = device, cutoff = cutoff, all_pairs = all_pairs, target = target, sort_orbs = sort_orbs)
    coupled_blocks = _to_coupled_basis(blocks, skip_symmetry = True, device = device, translations = True)

    blocks = blocks.to(arrays='numpy')
    blocks = sort(blocks)
    blocks = blocks.to(arrays='torch')
    
    coupled_blocks = coupled_blocks.to(arrays='numpy')
    coupled_blocks = sort(coupled_blocks)
    coupled_blocks = coupled_blocks.to(arrays='torch')
    
    return blocks, coupled_blocks

In [7]:
def compute_features(dataset, all_pairs=False):

    rhoij = pair_features(dataset.structures, hypers_atom, hypers_pair, order_nu = 1, all_pairs = all_pairs, both_centers = both_centers,
                          kmesh = dataset.kmesh, device = device, lcut = LCUT, return_rho0ij = return_rho0ij)  
    
    if both_centers and not return_rho0ij:
        NU = 3
    else:
        NU = 2
    rhonui = single_center_features(dataset.structures, hypers_atom, order_nu = NU, lcut = LCUT, device = device,
                                    feature_names = rhoij.property_names)
    
    hfeat = twocenter_features_periodic_NH(single_center = rhonui, pair = rhoij, all_pairs = all_pairs)

    return hfeat

In [11]:
device = 'cpu'

orbitals = {
    'sto-3g': {5: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               6: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               7: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]]}, 
    
    'def2svp': {6: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]]},
    'benzene': {6: [[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 1:[[1,0,0]]},
    'gthszvmolopt': {
        6: [[2, 0, 0], [2, 1, -1], [2, 1, 0], [2, 1, 1]],
        
        16: [[3,0,0], 
             [3,1,-1], [3,1,0], [3,1,1]],

        42: [[4,0,0], 
             [5,0,0], 
             [4,1,-1], [4,1,0], [4,1,1], 
             [4, 2, -2], [4, 2, -1], [4, 2, 0], [4, 2, 1], [4, 2, 2]]}
}

# QC dataset

In [48]:
workdir = './'
START = 0 
STOP = 5
ORBS = 'sto-3g'
root = f'{workdir}/examples/data/periodic/deepH_graphene/wrap/'
data_dir = root
frames = read(f'{data_dir}/wrapped_deepH_graphene.xyz', slice(START, STOP))
rfock = [np.load(f"{data_dir}/realfock_{i}.npy", allow_pickle = True).item() for i in range(START, STOP)]
rover = [np.load(f"{data_dir}/realoverlap_{i}.npy", allow_pickle = True).item() for i in range(START, STOP)]
kmesh = [1,1,1]
dataset = PySCFPeriodicDataset(frames = frames, 
                               kmesh = kmesh, 
                               dimension = 2,
                               fock_realspace = rfock, 
                               overlap_realspace = rover, 
                               device = device, 
                               orbs = orbitals[ORBS], 
                               orbs_name = 'sto-3g')

# Targets

In [49]:
cutoff = 6

In [50]:
target_blocks, target_coupled_blocks = get_targets(dataset, cutoff = cutoff, device = device)

# Features

In [51]:
max_radial  = 6
max_angular = 4
atomic_gaussian_width = 0.3

hypers_pair = {'cutoff': cutoff,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': atomic_gaussian_width,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.1}}}


hypers_atom = {'cutoff': 4,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': 0.3,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.1}}}


return_rho0ij = False
both_centers = False
LCUT = 3

In [52]:
features = compute_features(dataset)

cpu pair features
cpu single center features
cpu single center features


# Dataloader

In [53]:
from metatensor.learn import Dataset, DataLoader
import metatensor as mts

In [117]:
%%timeit
umd1 = mts.unique_metadata(features, axis = "samples", names = ["structure", "center", "neighbor"]).values.tolist()

80.4 ms ± 379 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [123]:
%%timeit
Aijs, invs = unique_Aij(features)
umd2 = np.unique(np.concatenate(Aijs), axis = 0).tolist()

187 ms ± 943 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [133]:
def unique_Aij(tensor):
    Aijs = []
    invs = []
    for b in tensor.blocks():
        Aij, inv = np.unique(b.samples.values[:, :3].tolist(), axis = 0, return_inverse = True)
        Aijs.append(ifrij)
        invs.append(inv)
    return Aijs, invs

In [65]:
split_by_axis = "samples"
split_by_dimension = ["structure", "center", "neighbor"]
grouped_labels = [
    mts.Labels(names=split_by_dimension, values=np.array([A]))
    for A in mts.unique_metadata(features, axis = split_by_axis, names = split_by_dimension)
]

In [185]:
other_samples = features[10].samples.values[:,:3].tolist()

In [192]:
Labels(features[10].samples.names, np.array(features[10].samples.values[:10].tolist()))

Labels(
    structure  center  neighbor  cell_shift_a  cell_shift_b  cell_shift_c
        0        0        1           0             0             0
        0        0        2           0             0             0
        0        0        3           0             0             0
        0        0        4           0             0             0
        0        0        8           0             0             0
        0        0        9           0             0             0
        0        0        10          0             0             0
        0        0        11          0             0             0
        0        0        12          0             0             0
        0        0        13          0             0             0
)

In [230]:
from mlelec.utils.pbc_utils import unique_Aij_block

def split_block_by_Aij(block):
    
    Aij, where_inv = unique_Aij_block(block)

    values = {}
    b_values = block.values
    for I, (A, i, j) in enumerate(Aij):
        idx = np.where(where_inv == I)[0]
        values[A, i, j] = b_values[idx]

    return values

def split_block_by_Aij_mts(block):

    from metatensor import Labels, TensorBlock
    
    Aij, where_inv = unique_Aij_block(block)

    # samples = {}
    # components = {}
    # properties = {}
    # values = {}
    new_blocks = {}
    b_values = block.values
    b_samples = block.samples
    b_components = block.components
    b_properties = block.properties
    for I, (A, i, j) in enumerate(Aij):
        idx = np.where(where_inv == I)[0]
        # samples[A, i, j] = Labels(b_samples.names, np.array(b_samples.values[idx].tolist()))
        # components[A, i, j] = Labels(b_components.names, np.array(b_components.values[idx].tolist()))
        # properties[A, i, j] = Labels(b_properties.names, np.array(b_properties.values[idx].tolist()))
        # values[A, i, j] = b_values[idx]
        new_blocks[A, i, j] = TensorBlock(samples = Labels(b_samples.names, np.array(b_samples.values[idx].tolist())),
                                          components = b_components, #Labels(b_components.names, np.array(b_components.values[idx].tolist())), 
                                          properties = b_properties, #Labels(b_properties.names, np.array(b_properties.values[idx].tolist())),
                                          values = b_values[idx])

    return new_blocks

def split_by_Aij(tensor, features = None):

    if features is None:
        values = {}
        for k, b in tensor.items():
            kl = tuple(k.values.tolist())
            values[kl] = split_block_by_Aij(b)

        return values

    else:
        from mlelec.utils.twocenter_utils import map_targetkeys_to_featkeys
        
        target_values = {}
        feature_values = {}
        
        for k, target in tensor.items():
            
            feature = map_targetkeys_to_featkeys(features, k)
            
            kl = tuple(k.values.tolist())
            
            target_values[kl] = split_block_by_Aij(target)
            feature_values[kl] = split_block_by_Aij(feature)
    
        return feature_values, target_values

def split_by_Aij_mts(tensor, features = None):

    from metatensor import TensorMap
    
    if features is None:

        blocks = {}
        keys = {}
        for k, b in tensor.items():
            block = split_block_by_Aij_mts(b)

            for Aij in block:
                if Aij not in blocks:
                    blocks[Aij] = []
                    keys[Aij] = []
                keys[Aij].append(k.values.tolist())
                blocks[Aij].append(block[Aij])


        tmaps = {}
        for Aij in blocks:
            tmap_keys = Labels(tensor.keys.names, np.array(keys[Aij]))
            tmap_blocks = blocks[Aij]
            tmaps[Aij] = TensorMap(tmap_keys, tmap_blocks)
        
        return tmaps

    else:
        from mlelec.utils.twocenter_utils import map_targetkeys_to_featkeys

        feature_blocks = {}
        target_blocks = {}
        keys = {}
        for k, b in tensor.items():
            feature = map_targetkeys_to_featkeys(features, k)
            
            target_block = split_block_by_Aij_mts(b)
            feature_block = split_block_by_Aij_mts(feature)

            for Aij in target_block:
                if Aij not in target_blocks:
                    feature_blocks[Aij] = []
                    target_blocks[Aij] = []
                    keys[Aij] = []
                kval = k.values.tolist()
                keys[Aij].append(kval)
                feature_blocks[Aij].append(feature_block[Aij])
                target_blocks[Aij].append(target_block[Aij])

        tmaps_feature = {}
        tmaps_target = {}
        for Aij in feature_blocks:
            tmap_keys = Labels(tensor.keys.names, np.array(keys[Aij]))
            tmaps_feature[Aij] = TensorMap(tmap_keys, feature_blocks[Aij])
            tmaps_target[Aij] = TensorMap(tmap_keys, target_blocks[Aij])
        
        return tmaps_feature, tmaps_target

In [200]:
from mlelec.utils.twocenter_utils import map_targetkeys_to_featkeys

In [173]:
split_features_, split_target_ = split_by_Aij(target_coupled_blocks, features = features)

In [231]:
%%timeit
split_target_mts = split_by_Aij_mts(target_coupled_blocks, features=features)

55.8 s ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [232]:
%%timeit
split_features_, split_target_ = split_by_Aij(target_coupled_blocks, features = features)

2.94 s ± 5.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [222]:
%%timeit
split_target_mts = split_by_Aij_mts(target_coupled_blocks)

23.2 s ± 175 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [224]:
%%timeit
split_target_ = split_by_Aij(target_coupled_blocks)

1.43 s ± 2.84 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [182]:
for k in split_features_:
    for f, t in zip(split_features_[k].values(), split_target_[k].values()):
        assert (f.shape[:2] == t.shape[:2])

In [167]:
split_target = split_by_Aij(target_coupled_blocks)
split_features = split_by_Aij(features)

In [221]:
for (A,i,j), v in split_target_mts.items():
    for k,b in v.items():
        assert A == np.unique(b.samples.values[:,0]), (A, np.unique(b.samples.values[:,0]))
        assert i == np.unique(b.samples.values[:,1]), (i, np.unique(b.samples.values[:,1]))
        assert j == np.unique(b.samples.values[:,2]), (j, np.unique(b.samples.values[:,2]))

In [162]:
split_features.keys()

dict_keys([(2, 1, 0, 6, 6, 0), (2, 1, 1, 6, 6, 0), (2, 1, 2, 6, 6, 0), (2, 1, 3, 6, 6, 0), (2, -1, 1, 6, 6, 0), (2, -1, 2, 6, 6, 0), (2, -1, 3, 6, 6, 0), (2, 1, 0, 6, 6, 1), (2, 1, 0, 6, 6, -1), (2, 1, 1, 6, 6, 1), (2, 1, 1, 6, 6, -1), (2, 1, 2, 6, 6, 1), (2, 1, 2, 6, 6, -1), (2, 1, 3, 6, 6, 1), (2, 1, 3, 6, 6, -1), (2, -1, 1, 6, 6, 1), (2, -1, 1, 6, 6, -1), (2, -1, 2, 6, 6, 1), (2, -1, 2, 6, 6, -1), (2, -1, 3, 6, 6, 1), (2, -1, 3, 6, 6, -1)])

In [126]:
np.random.seed(1)
randidx = np.random.choice(382, 40)
np.random.seed(1)
randidx1 = np.random.choice(382, 40, replace=False)

In [23]:
ml_data = Dataset(descriptor = [features], target = [target_coupled_blocks])

In [31]:
batch_size = 2

In [32]:
dataloader = DataLoader(ml_data, batch_size = batch_size)

In [33]:
batch = next(iter(dataloader))

In [35]:
print("batch.descriptor =", batch.descriptor)

# `scalar` data are float objects, so are just grouped and returned in a tuple
print("batch.target =", batch.target)
assert len(batch.target) == batch_size, len(batch.target)

batch.descriptor = TensorMap with 21 blocks
keys: order_nu  inversion_sigma  spherical_harmonics_l  species_center  species_neighbor  block_type
         2             1                   0                  6                6              0
         2             1                   1                  6                6              0
         2             1                   2                  6                6              0
         2             1                   3                  6                6              0
         2            -1                   1                  6                6              0
         2            -1                   2                  6                6              0
         2            -1                   3                  6                6              0
         2             1                   0                  6                6              1
         2             1                   0                  6                6       

AssertionError: 24