In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from ase.io import read
from ase.visualize import view
import numpy as np 
import torch 
import metatensor 
import matplotlib.pyplot as plt
torch.set_default_dtype(torch.float64)

In [3]:
from metatensor import Labels, TensorBlock, TensorMap
from mlelec.data.dataset import PySCFPeriodicDataset
from mlelec.utils.twocenter_utils import fix_orbital_order



In [4]:
orbitals = {'sto-3g': {6: [[1,0,0],[2,0,0],[2,1,1], [2,1,-1],[2,1,0]]}, 
            'def2svp': {6: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]]}}

In [7]:
workdir = '/home/pegolo/Software/my_mlelec/'
root = f'{workdir}/examples/data/periodic/c2/'
ORBS = 'sto-3g'
START = 0
STOP = 1
frames = read(f'{root}/C2_174.extxyz', slice(START, STOP))
for f in frames: 
    f.pbc = True

kmesh = [8,8,1]
kfock = [np.load(f"{root}/fock_{i}_881.npy") for i in range(START, STOP)]
kover = [np.load(f"{root}/over_{i}_881.npy") for i in range(START, STOP)]
for ifr in range(len(frames)):
    for ik, k in enumerate(kfock[ifr]):
        kfock[ifr][ik] = fix_orbital_order(k, frames[ifr], orbitals[ORBS]) #### TODO <<
        kover[ifr][ik] = fix_orbital_order(kover[ifr][ik], frames[ifr], orbitals[ORBS]) #### TODO <<

dataset = PySCFPeriodicDataset(frames = frames, kmesh = kmesh, fock_kspace = kfock, overlap_kspace = kover, device = "cpu", orbs = orbitals[ORBS], orbs_name = ORBS)

# Targets

In [8]:
from metatensor import load 
from mlelec.utils.twocenter_utils import _to_coupled_basis
from mlelec.utils.pbc_utils import matrix_to_blocks 

def get_targets(dataset, device ="cpu", cutoff = None, target='fock'):
    if target.lower() == 'fock':
        matrices_negative = dataset._fock_realspace_negative_translations
    elif target.lower() == 'overlap':
        matrices_negative = dataset._overlap_realspace_negative_translations
    else: 
        raise ValueError('target must be fock or overlap')
    blocks = matrix_to_blocks(dataset, matrices_negative , device = 'cpu', cutoff = cutoff, all_pairs = True, target= target)
    coupled_blocks = _to_coupled_basis(blocks, skip_symmetry = True, device = device, translations = True)

    blocks = blocks.keys_to_samples('cell_shift_a')
    blocks = blocks.keys_to_samples('cell_shift_b')
    blocks = blocks.keys_to_samples('cell_shift_c')

    coupled_blocks = coupled_blocks.keys_to_samples('cell_shift_a')
    coupled_blocks = coupled_blocks.keys_to_samples('cell_shift_b')
    coupled_blocks = coupled_blocks.keys_to_samples('cell_shift_c')
    return blocks , coupled_blocks

In [106]:
cutoff = 6
target_blocks, target_coupled_blocks = get_targets(dataset, cutoff = cutoff)

In [17]:
from mlelec.utils.pbc_utils import blocks_to_matrix, move_cell_shifts_to_keys

In [144]:
from mlelec.utils.twocenter_utils import (
    _components_idx,
    ISQRT_2,
    _orbs_offsets,
    _atom_blocks_idx,
)

def moveit(blocks):
    if "cell_shift_a" not in blocks.keys.names:
        assert "cell_shift_b" not in blocks.keys.names, "Weird! keys contain 'cell_shift_b' but not 'cell_shift_a'."
        assert "cell_shift_c" not in blocks.keys.names, "Weird! keys contain 'cell_shift_c' but not 'cell_shift_a'."

        assert "cell_shift_a" in blocks.sample_names, "Cell shifts must be in samples."
        assert "cell_shift_b" in blocks.sample_names, "Cell shifts must be in samples."
        assert "cell_shift_c" in blocks.sample_names, "Cell shifts must be in samples."

        if "L" in blocks.keys.names:
            from mlelec.utils.twocenter_utils import _to_uncoupled_basis
            blocks = _to_uncoupled_basis(blocks)
        blocks = move_cell_shifts_to_keys(blocks) 
    return blocks

def opt_blocks_to_matrix(blocks, dataset, device=None, return_negative=False):
    if device is None:
        device = dataset.device

    orbs_tot, orbs_offset = _orbs_offsets(dataset.basis)
    atom_blocks_idx = _atom_blocks_idx(dataset.structures, orbs_tot)
    orbs_mult = {
        species: 
                {tuple(k): v
            for k, v in zip(
                *np.unique(
                    np.asarray(dataset.basis[species])[:, :2],
                    axis=0,
                    return_counts=True,
                )
            )
        }
        for species in dataset.basis
    }

    reconstructed_matrices_plus = []
    reconstructed_matrices_minus = []

    # Loop over frames
    for A, shifts in enumerate(dataset.realspace_translations):
        norbs = np.sum([orbs_tot[ai] for ai in dataset.structures[A].numbers])

        reconstructed_matrices_plus.append({T: torch.zeros(norbs, norbs, device = device) for T in shifts})
        reconstructed_matrices_minus.append({T: torch.zeros(norbs, norbs, device = device) for T in shifts})

    # loops over block types
    for key, block in blocks.items():
        block_type = key["block_type"]
        ai, ni, li = key["species_i"], key["n_i"], key["l_i"]
        aj, nj, lj = key["species_j"], key["n_j"], key["l_j"]
        Tx, Ty, Tz = key["cell_shift_a"], key["cell_shift_b"], key["cell_shift_c"]
        
        # What's the multiplicity of the orbital type, ex. 2p_x, 2p_y, 2p_z makes the multiplicity 
        # of a p block = 3
        orbs_i = orbs_mult[ai]
        orbs_j = orbs_mult[aj]
        
        # The shape of the block corresponding to the orbital pair
        shapes = {
            (k1 + k2): (orbs_i[tuple(k1)], orbs_j[tuple(k2)])
            for k1 in orbs_i
            for k2 in orbs_j
        }
        # offset of the orbital (ni, li) within a block of atom i
        ioffset = orbs_offset[(ai, ni, li)] 
        # offset of the orbital (nj,lj) within a block of atom j
        joffset = orbs_offset[(aj, nj, lj)]

        i_end, j_end = shapes[(ni, li, nj, lj)]

        # loops over samples (structure, i, j)
        for sample, blockval in zip(block.samples, block.values):
            
            A = sample["structure"]
            i = sample["center"]
            j = sample["neighbor"]


            matrix_T_plus  = reconstructed_matrices_plus[A][Tx, Ty, Tz]
            matrix_T_minus = reconstructed_matrices_minus[A][Tx, Ty, Tz]

            i_start, j_start = atom_blocks_idx[(A, i, j)]

            # values = blockval[:, :, 0].clone().reshape(2 * li + 1, 2 * lj + 1)
            values = blockval[:, :, 0].clone()

            if block_type == 0 or block_type == 2:

                matrix_T_plus[
                    i_start + ioffset : i_start + ioffset + i_end,
                    j_start + joffset : j_start + joffset + j_end,
                             ] = values

            if abs(block_type) == 1:
                values *= ISQRT_2
                if block_type == 1:
                    matrix_T_plus[
                        i_start + ioffset : i_start + ioffset + i_end,
                        j_start + joffset : j_start + joffset + j_end,
                    ] += values

                    matrix_T_minus[
                        j_start + ioffset : j_start + ioffset + i_end,
                        i_start + joffset : i_start + joffset + j_end,
                    ] += values
                    
                else:
                    
                    matrix_T_plus[
                        i_start + ioffset : i_start + ioffset + i_end,
                        j_start + joffset : j_start + joffset + j_end,
                    ] += values

                    matrix_T_minus[
                        j_start + ioffset : j_start + ioffset + i_end,
                        i_start + joffset : i_start + joffset + j_end,
                    ] -= values

    if return_negative:
        return reconstructed_matrices_plus, reconstructed_matrices_minus
    return reconstructed_matrices_plus

def opt_move_cell_shifts_to_keys(blocks):
    """ Move cell shifts when present in samples, to keys"""
    from metatensor import Labels, TensorBlock, TensorMap

    out_blocks = []
    out_block_keys = []

    sample_names = blocks.sample_names[:-3]
    
    for key, block in blocks.items():        
        translations = np.unique(block.samples.values[:, -3:], axis = 0)
        block_view = block.samples.view(["cell_shift_a", "cell_shift_b", "cell_shift_c"]).values
        for T in translations:
            idx = np.where(np.all(np.isclose(np.array(block_view),np.array([T[0], T[1], T[2]])), axis = 1))[0]

            if len(idx):
                out_block_keys.append(list(key.values)+[T[0], T[1], T[2]])
                out_blocks.append(TensorBlock(
                        samples = Labels(
                            sample_names,
                            values = np.asarray(block.samples.values[idx])[:, :-3],
                        ),
                        values = block.values[idx],
                        components = block.components,
                        properties = block.properties,
                    ))
                
    return TensorMap(Labels(blocks.keys.names + ["cell_shift_a", "cell_shift_b", "cell_shift_c"], np.asarray(out_block_keys)), out_blocks)

In [126]:
%load_ext line_profiler

In [120]:
%prun bl = moveit(target_blocks)

 

         407673 function calls (407004 primitive calls) in 0.293 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.029    0.029    0.290    0.290 pbc_utils.py:329(move_cell_shifts_to_keys)
     5181    0.020    0.000    0.079    0.000 labels.py:363(_from_mts_labels_t)
    13491    0.013    0.000    0.013    0.000 __init__.py:511(cast)
     5821    0.012    0.000    0.042    0.000 utils.py:42(_ptr_to_ndarray)
      639    0.009    0.000    0.015    0.000 array.py:163(__init__)
     5821    0.009    0.000    0.029    0.000 ctypeslib.py:506(as_array)
     6391    0.008    0.000    0.011    0.000 extract.py:99(data_origin)
      639    0.008    0.000    0.027    0.000 numeric.py:2330(within_tol)
      639    0.008    0.000    0.039    0.000 block.py:68(__init__)
     3864    0.007    0.000    0.063    0.000 block.py:282(_labels)
     6391    0.007    0.000    0.009    0.000 block.py:226(_raw_values)
      640    0.

In [121]:
%prun opt_blocks_to_matrix(bl, dataset)

 

         146797 function calls (146158 primitive calls) in 0.143 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.068    0.068    0.143    0.143 1159621805.py:23(opt_blocks_to_matrix)
    11304    0.007    0.000    0.009    0.000 labels.py:125(__getitem__)
     1279    0.006    0.000    0.024    0.000 labels.py:363(_from_mts_labels_t)
     1638    0.006    0.000    0.006    0.000 {method 'clone' of 'torch._C._TensorBase' objects}
     2277    0.004    0.000    0.005    0.000 labels.py:73(__init__)
     2917    0.004    0.000    0.012    0.000 labels.py:411(__iter__)
      639    0.003    0.000    0.003    0.000 {method 'unbind' of 'torch._C._TensorBase' objects}
     1279    0.003    0.000    0.011    0.000 utils.py:42(_ptr_to_ndarray)
    39747    0.002    0.000    0.002    0.000 {built-in method builtins.isinstance}
      639    0.002    0.000    0.002    0.000 1159621805.py:66(<dictcomp>)
     1279    0.002

In [110]:
%prun move_cell_shifts_to_keys(target_blocks)

 

         407065 function calls (406399 primitive calls) in 0.294 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.029    0.029    0.291    0.291 pbc_utils.py:329(move_cell_shifts_to_keys)
     5168    0.019    0.000    0.079    0.000 labels.py:363(_from_mts_labels_t)
    13478    0.013    0.000    0.013    0.000 __init__.py:511(cast)
     5808    0.012    0.000    0.043    0.000 utils.py:42(_ptr_to_ndarray)
     6391    0.009    0.000    0.013    0.000 extract.py:99(data_origin)
      639    0.009    0.000    0.016    0.000 array.py:163(__init__)
     5808    0.008    0.000    0.029    0.000 ctypeslib.py:506(as_array)
      639    0.008    0.000    0.040    0.000 block.py:68(__init__)
      639    0.008    0.000    0.024    0.000 numeric.py:2330(within_tol)
     3861    0.008    0.000    0.065    0.000 block.py:282(_labels)
     6391    0.008    0.000    0.009    0.000 block.py:226(_raw_values)
      640    0.

In [138]:
%lprun -f move_cell_shifts_to_keys move_cell_shifts_to_keys(target_blocks)

Timer unit: 1e-09 s

Total time: 0.250651 s
File: /home/pegolo/Software/my_mlelec/src/mlelec/utils/pbc_utils.py
Function: move_cell_shifts_to_keys at line 329

Line #      Hits         Time  Per Hit   % Time  Line Contents
   329                                           def move_cell_shifts_to_keys(blocks):
   330                                               """ Move cell shifts when present in samples, to keys"""
   331         1       6222.0   6222.0      0.0      from metatensor import Labels, TensorBlock, TensorMap
   332                                           
   333         1        140.0    140.0      0.0      out_blocks = []
   334         1        100.0    100.0      0.0      out_block_keys = []
   335                                           
   336        28     920317.0  32868.5      0.4      for key, block in blocks.items():        
   337        27    3475974.0 128739.8      1.4          translations = np.unique(block.samples.values[:, -3:], axis = 0)
   338       6

In [145]:
%lprun -f opt_move_cell_shifts_to_keys opt_move_cell_shifts_to_keys(target_blocks)

Timer unit: 1e-09 s

Total time: 0.198754 s
File: /tmp/ipykernel_3394842/2136903063.py
Function: opt_move_cell_shifts_to_keys at line 130

Line #      Hits         Time  Per Hit   % Time  Line Contents
   130                                           def opt_move_cell_shifts_to_keys(blocks):
   131                                               """ Move cell shifts when present in samples, to keys"""
   132         1       8276.0   8276.0      0.0      from metatensor import Labels, TensorBlock, TensorMap
   133                                           
   134         1        190.0    190.0      0.0      out_blocks = []
   135         1        160.0    160.0      0.0      out_block_keys = []
   136                                           
   137         1     186703.0 186703.0      0.1      sample_names = blocks.sample_names[:-3]
   138                                               
   139        28     983649.0  35130.3      0.5      for key, block in blocks.items():        
   140

In [137]:
opt_move_cell_shifts_to_keys(target_blocks)

TensorMap with 639 blocks
keys: block_type  species_i  n_i  l_i  species_j  n_j  l_j  cell_shift_a  cell_shift_b  cell_shift_c
          0           6       1    0       6       1    0        0             0             0
          0           6       1    0       6       2    0        0             0             0
                                                   ...
          -1          6       2    1       6       2    1        3             2             0
          -1          6       2    1       6       2    1        3             3             0