In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np 
import torch 
import metatensor.torch as mts
from metatensor.torch import TensorMap, Labels, TensorBlock
import ase 
from mlelec.data.dataset import QMDataset
from mlelec.utils.target_utils import get_targets
from mlelec.utils.twocenter_utils import _to_coupled_basis,_to_uncoupled_basis_old
torch.set_default_dtype(torch.float64)
# print(torch.cuda.is_available())



In [3]:
# from mlelec.data.pyscf_calculator import _instantiate_pyscf_mol
import pyscf.pbc.tools.pyscf_ase as pyscf_ase
import torch
from collections import defaultdict
import pyscf

In [4]:
def compute_xhat(frame, basis='sto-3g', fix_xyz=False, device='cpu'):
    mol = pyscf.gto.Mole()
    mol.atom = pyscf_ase.ase_atoms_to_pyscf(frame)
    mol.basis = basis
    mol.symmetry = False    
    mol.verbose = 2
    mol.build() 
    with mol.with_common_orig((0,0,0)):
       x= torch.from_numpy(mol.intor('int1e_r', comp=3)).to(device = device)
    if fix_xyz:
        x = torch.roll(x, shifts=-1, dims=0)

    return x.moveaxis(0,-1)

# other integrals to try mol.intor('cint1e_kin_sph') #+ mol.intor('cint1e_nuc_sph')

In [5]:
frames = [ase.Atoms('H2O', positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], pbc = False)]
H = torch.randn(1,7,7) 
H = H + H.transpose(-1,-2)   
qmdata = QMDataset(frames = frames, 
                   fock_realspace= H,
                   dimension = 0, 
                   orbs = {8:[[1,0,0],[2,0,0],[2,1,0], [2,1,1], [2,1,-1]], 1:[[1,0,0]]},
                   orbs_name = 'sto-3g',    
                   device = 'cpu'
)   
blocks, coupled_blocks  = get_targets(qmdata, cutoff = 4, device = 'cpu', all_pairs = False, sort_orbs =True, return_uncoupled=True)
# just needed them for keys 



In [6]:
xhat_sto3g = []
xhat_def2 = []
for f in frames:
    # _instantiate_pyscf_mol(frames[0], basis="sto-3g"
    xhat_sto3g.append(compute_xhat(f, basis='sto-3g'))
    xhat_def2.append(compute_xhat(f, basis='def2-svp'))

In [7]:
xhat_sto3g[0][...,0] - xhat_sto3g[0][...,0].T

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1102e-16,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1102e-16,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])

# tensor - to -blocks

go from matrix to blocks of the shape(nsample, nmi, nmj,3, nprop)

In [8]:
from mlelec.utils.pbc_utils import blocks_to_matrix, matrix_to_blocks

blocks_to_matrix(matrix_to_blocks(qmdata, cutoff = 10), qmdata)[0][0,0,0] - qmdata.fock_realspace[0]

tensor([[0.0000e+00, 4.4409e-16, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [4.4409e-16, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00]])

In [9]:
xhat_blocks= matrix_to_blocks(qmdata, matrix= xhat_sto3g, cutoff = 10, high_rank = True, )
# xhat_blocks = mts.remove_dimension(xhat_blocks, 'samples', 'cell_shift_a') 
# xhat_blocks = mts.remove_dimension(xhat_blocks, 'samples', 'cell_shift_b') 
# xhat_blocks= mts.remove_dimension(xhat_blocks, 'samples', 'cell_shift_c') 




# couple blocks

For a 3D tensor, we should have the same block structure, except, we have an additional components dimension 

In [10]:
blocks_3D = []
position_components = Labels(['m_3'], values = torch.tensor([-1,0,1]).reshape(3,-1))
for block in blocks:
    nsample, nmi, nmj, nprop = block.values.shape
    blocks_3D.append(
        TensorBlock( values = torch.randn(nsample, nmi, nmj,3, nprop), 
                    components = [block.components[0],  block.components[1], position_components],
                    properties = block.properties,
                    samples = block.samples,
        )
    )
key_names = blocks.keys.names + ["l_3"]
key_value = torch.nn.functional.pad(blocks.keys.values, (0,1), mode='constant', value=1) 

uncoupled_blocks_3D = TensorMap( Labels(key_names, key_value) , blocks_3D)

uncoupled_blocks_3D = mts.remove_dimension(uncoupled_blocks_3D, 'samples', 'cell_shift_a') 
uncoupled_blocks_3D = mts.remove_dimension(uncoupled_blocks_3D, 'samples', 'cell_shift_b') 
uncoupled_blocks_3D = mts.remove_dimension(uncoupled_blocks_3D, 'samples', 'cell_shift_c') 
uncoupled_blocks_3D

TensorMap with 12 blocks
keys: block_type  species_i  n_i  l_i  species_j  n_j  l_j  l_3
          -1          1       1    0       1       1    0    1
          0           1       1    0       1       1    0    1
          0           8       1    0       8       1    0    1
          0           8       1    0       8       2    0    1
          0           8       1    0       8       2    1    1
          0           8       2    0       8       2    0    1
          0           8       2    0       8       2    1    1
          0           8       2    1       8       2    1    1
          1           1       1    0       1       1    0    1
          2           1       1    0       8       1    0    1
          2           1       1    0       8       2    0    1
          2           1       1    0       8       2    1    1

In [11]:
coupled_blocks = _to_coupled_basis(xhat_blocks, skip_symmetry = False, device = 'cpu', translations = True, high_rank = True)

# decouple blocks 

In [12]:
unc = _to_uncoupled_basis_old(coupled_blocks,device = 'cpu', translations = None, high_rank = True)

In [13]:
import metatensor.torch as mts

for k,b in unc.items():
    b1 = xhat_blocks.block(k)
    print(k.values, torch.norm(b.values - b1.values))

tensor([0, 1, 1, 0, 1, 1, 0, 1], dtype=torch.int32) tensor(1.3323e-15)
tensor([1, 1, 1, 0, 1, 1, 0, 1], dtype=torch.int32) tensor(4.4409e-16)
tensor([-1,  1,  1,  0,  1,  1,  0,  1], dtype=torch.int32) tensor(0.)
tensor([2, 1, 1, 0, 8, 1, 0, 1], dtype=torch.int32) tensor(3.1038e-17)
tensor([2, 1, 1, 0, 8, 2, 0, 1], dtype=torch.int32) tensor(1.9230e-16)
tensor([2, 1, 1, 0, 8, 2, 1, 1], dtype=torch.int32) tensor(6.1448e-16)
tensor([0, 8, 1, 0, 8, 1, 0, 1], dtype=torch.int32) tensor(4.4409e-16)
tensor([0, 8, 1, 0, 8, 2, 0, 1], dtype=torch.int32) tensor(1.1102e-16)
tensor([0, 8, 1, 0, 8, 2, 1, 1], dtype=torch.int32) tensor(8.1218e-17)
tensor([0, 8, 2, 0, 8, 2, 0, 1], dtype=torch.int32) tensor(4.4409e-16)
tensor([0, 8, 2, 0, 8, 2, 1, 1], dtype=torch.int32) tensor(9.9920e-16)
tensor([0, 8, 2, 1, 8, 2, 1, 1], dtype=torch.int32) tensor(2.1896e-15)


In [26]:
for k,b in unc.items():

    print(b.values.shape)

torch.Size([2, 1, 1, 3, 1])
torch.Size([1, 1, 1, 3, 1])
torch.Size([1, 1, 1, 3, 1])
torch.Size([2, 1, 1, 3, 1])
torch.Size([2, 1, 1, 3, 1])
torch.Size([2, 1, 3, 3, 1])
torch.Size([1, 1, 1, 3, 1])
torch.Size([1, 1, 1, 3, 1])
torch.Size([1, 1, 3, 3, 1])
torch.Size([1, 1, 1, 3, 1])
torch.Size([1, 1, 3, 3, 1])
torch.Size([1, 3, 3, 3, 1])


# blocks to tensor 

In [35]:
from mlelec.utils.pbc_utils import inverse_bloch_sum, _orbs_offsets, _components_idx, ISQRT_2, _atom_blocks_idx
from ase.units import Bohr
import warnings



In [37]:
recon = blocks_to_matrix(xhat_blocks, qmdata)

0 0
torch.Size([1, 1, 1])
0 0
torch.Size([1, 1, 1])
0 0
0 0
0 0
0 0
0 0
0 0
0 1
0 1
0 0
torch.Size([1, 1, 1])
0 0
torch.Size([1, 1, 1])
0 1
torch.Size([1, 3, 1])
0 0
torch.Size([1, 1, 1])
0 1
torch.Size([1, 3, 1])
1 1
torch.Size([3, 3, 1])




In [40]:
recon[0][0,0,0]  - xhat_sto3g[0]

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -4.6911e-01],
         [ 0.0000e+00, -8.8543e-02,  0.0000e+00],
         [ 0.0000e+00, -4.8729e-01,  0.0000e+00],
         [ 0.0000e+00,  2.8755e-01,  2.8755e-01],
         [ 0.0000e+00,  1.1610e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -2.8755e-01]],

        [[ 0.0000e+00,  0.0000e+00, -4.6911e-01],
         [ 0.0000e+00,  0.0000e+00, -1.8897e+00],
         [ 0.0000e+00, -3.4216e-02, -9.1265e-04],
         [ 0.0000e+00, -2.6750e-01, -1.6970e-01],
         [ 0.0000e+00,  1.5190e-01,  1.5190e-01],
         [ 0.0000e+00,  2.6498e-02,  1.2224e-01],
         [ 0.0000e+00, -1.7839e-01, -2.7414e-01]],

        [[ 0.0000e+00, -8.8543e-02,  0.0000e+00],
         [ 0.0000e+00, -3.4216e-02, -9.1265e-04],
         [ 0.0000e+00, -1.8897e+00,  0.0000e+00],
         [ 0.0000e+00, -4.4731e-01,  0.0000e+00],
         [ 0.0000e+00,  5.0792e-02,  5.0792e-02],
         [ 0.0000e+00, -5.0792e-02,  0.0000e+0

In [20]:
blocks_2Dx = []; blocks_2Dy = []; blocks_2Dz = []
for block in uncoupled_blocks_3D:
    nsample, nmi, nmj, _, nprop = block.values.shape
    blocks_2Dy.append(
        TensorBlock( values = block.values[...,0,:], 
                    components = [block.components[0],  block.components[1]],
                    properties = block.properties,
                      samples = block.samples,
        ))
    blocks_2Dz.append(
        TensorBlock( values = block.values[...,1,:], 
                    components = [block.components[0],  block.components[1]],
                    properties = block.properties,
                    samples = block.samples,
        ))
    blocks_2Dx.append(
        TensorBlock( values = block.values[...,2,:], 
                    components = [block.components[0],  block.components[1]],
                    properties = block.properties,
                    samples = block.samples,
        ))
uncoupled_blocks_2Dx = TensorMap( blocks.keys , blocks_2Dx)
uncoupled_blocks_2Dy = TensorMap( blocks.keys , blocks_2Dy)
uncoupled_blocks_2Dz = TensorMap( blocks.keys , blocks_2Dz)

# uncoupled_blocks_2Dx = mts.remove_dimension(uncoupled_blocks_2Dx, 'samples', 'cell_shift_a') 
# uncoupled_blocks_2Dx = mts.remove_dimension(uncoupled_blocks_2Dx, 'samples', 'cell_shift_b') 
# uncoupled_blocks_2Dx = mts.remove_dimension(uncoupled_blocks_2Dx, 'samples', 'cell_shift_c') 

# uncoupled_blocks_2Dy = mts.remove_dimension(uncoupled_blocks_2Dy, 'samples', 'cell_shift_a') 
# uncoupled_blocks_2Dy = mts.remove_dimension(uncoupled_blocks_2Dy, 'samples', 'cell_shift_b') 
# uncoupled_blocks_2Dy = mts.remove_dimension(uncoupled_blocks_2Dy, 'samples', 'cell_shift_c') 

# uncoupled_blocks_2Dz = mts.remove_dimension(uncoupled_blocks_2Dz, 'samples', 'cell_shift_a') 
# uncoupled_blocks_2Dz = mts.remove_dimension(uncoupled_blocks_2Dz, 'samples', 'cell_shift_b') 
# uncoupled_blocks_2Dz = mts.remove_dimension(uncoupled_blocks_2Dz, 'samples', 'cell_shift_c') 

In [6]:
coupled_blocks_2Dx = _to_coupled_basis(uncoupled_blocks_2Dx, skip_symmetry = False, device = 'cpu', translations = None)
coupled_blocks_2Dy = _to_coupled_basis(uncoupled_blocks_2Dy, skip_symmetry = False, device = 'cpu', translations = None)
coupled_blocks_2Dz = _to_coupled_basis(uncoupled_blocks_2Dz, skip_symmetry = False, device = 'cpu', translations = None)

create a tensor map from these three individually coupled tmaps, concatenating them along the components dim

In [7]:
coupled_blocks_2Dx = coupled_blocks_2Dx.keys_to_properties(['species_i', 'n_i', 'l_i','species_j', 'n_j', 'l_j'])
coupled_blocks_2Dy = coupled_blocks_2Dy.keys_to_properties(['species_i', 'n_i', 'l_i','species_j', 'n_j', 'l_j'])
coupled_blocks_2Dz = coupled_blocks_2Dz.keys_to_properties(['species_i', 'n_i', 'l_i','species_j', 'n_j', 'l_j'])

In [8]:
uncoupled_2 = []
for (bx,by,bz) in zip(coupled_blocks_2Dy.blocks(), coupled_blocks_2Dz.blocks(),coupled_blocks_2Dx.blocks()):
    ## Assert that the blocks correspond to the same keys
    assert bx.values.shape == by.values.shape == bz.values.shape
    uncoupled_2.append( TensorBlock( values = torch.stack([bx.values, by.values,bz.values], dim=2), 
                                    components = [bx.components[0],position_components], 
                                    samples = bx.samples, 
                                    properties = bx.properties
    ) )
newkeys = Labels(coupled_blocks_2Dz.keys.names+['L2'], values = torch.nn.functional.pad(coupled_blocks_2Dz.keys.values, (0,1), mode='constant', value=1))
uncoupled_2 = TensorMap(newkeys, uncoupled_2)

In [9]:
uncoupled_2

TensorMap with 6 blocks
keys: block_type  L  L2
          0       0  1
          0       1  1
          0       2  1
          1       0  1
          2       0  1
          2       1  1

In [None]:
# _to_coupled_basis()

In [82]:
coupled_blocks[0].values

tensor([[[-0.8360],
         [ 1.1375],
         [ 0.0626]]])

In [57]:
coupled_blocks

{(0, 1, 0, 1): {1: tensor([[[-0.8360,  1.1375,  0.0626]]])}}

In [21]:
from mlelec.utils.symmetry import ClebschGordanReal
CG = ClebschGordanReal(10, 'cpu')

In [17]:
def couple(decoupled, iterate = 0, cg=None, selfdevice= 'cpu', lmax=10):
       
        coupled = {}

        # when called on a matrix, turns it into a dict form to which we can
        # apply the generic algorithm
        if not isinstance(decoupled, dict):
            l2 = (decoupled.shape[-1] - 1) // 2
            decoupled = {(): {l2: decoupled}}

        # runs over the tuple of (partly) decoupled terms
        for ltuple, lcomponents in decoupled.items():
            # each is a list of L terms
            for lc in lcomponents.keys():
                # this is the actual matrix-valued coupled term,
                # of shape (..., 2l1+1, 2l2+1), transforming as Y^m1_l1 Y^m2_l2
                dec_term = lcomponents[lc]
                l1 = (dec_term.shape[-2] - 1) // 2
                l2 = (dec_term.shape[-1] - 1) // 2

                # there is a certain redundance: the L value is also the last entry
                # in ltuple
                if lc != l2:
                    raise ValueError(
                        "Inconsistent shape for coupled angular momentum block."
                    )

                # in the new coupled term, prepend (l1,l2) to the existing label
                device = dec_term.device
                if device != selfdevice:
                    dec_term = dec_term.to(selfdevice)
                
                coupled[(l1, l2) + ltuple] = {}
                for L in range(
                    max(l1, l2) - min(l1, l2), min(lmax, (l1 + l2)) + 1
                ):
                    # Lterm = torch.einsum('spmn,mnM->spM', dec_term, self._cg[(l1, l2, L)])
                    coupled[(l1, l2) + ltuple][L] = torch.tensordot(dec_term, cg._cg[(l1, l2, L)].to(dec_term), dims=2)

        # repeat if required
        if iterate > 0:
            coupled = couple(coupled, iterate - 1, cg= cg, selfdevice=selfdevice,lmax=lmax)
        return coupled

def decouple(coupled, iterate: int = 0, cg=None, selfdevice= 'cpu', lmax=10 ):
        decoupled = {}
        # applies the decoupling to each entry in the dictionary
        for ltuple, lcomponents in coupled.items():
            # the initial pair in the key indicates the decoupled terms that generated
            # the L entries
            l1, l2 = ltuple[:2]
            # shape of the coupled matrix (last entry is the 2L+1 M terms)
            # if lcomponents == {}:
            #     print(f'here,{ltuple}')
            #     continue
            shape = next(iter(lcomponents.values())).shape[:-1]
            dtype_ = next(iter(lcomponents.values())).dtype

            dec_term = torch.zeros(shape+ ( 2 * l1 + 1, 2 * l2 + 1),device=selfdevice, dtype = dtype_)
            for L in range(max(l1, l2) - min(l1, l2), min(lmax, (l1 + l2)) + 1):
                # supports missing L components, e.g. if they are zero because of symmetry
                if L not in lcomponents:
                    continue
                # decouples the L term into m1, m2 components
                # a = torch.einsum('spM,mnM->spmn', lcomponents[L], self._cg[(l1, l2, L)])
                # dec_term+=torch.tensordot(lcomponents[L], self._cg[(l1, l2, L)].to(dtype_), dims=([2],[2]))
                dec_term+=torch.tensordot(lcomponents[L], cg._cg[(l1, l2, L)].to(dtype_), dims=([-1],[-1])) #CHECK<<<<<< 
            if not ltuple[2:] in decoupled:
                decoupled[ltuple[2:]] = {}
            decoupled[ltuple[2:]][l2] = dec_term

        # rinse, repeat
        if iterate > 0:
            decoupled = decouple(decoupled, iterate - 1,cg= cg, selfdevice=selfdevice,lmax=lmax)
        # if we got a fully decoupled state, just return an array
        if ltuple[2:] == ():
            decoupled = next(iter(decoupled[()].values()))
        return decoupled




In [92]:
a = torch.randn(1,5,7,3)


In [93]:
b = couple(a, 1,cg = CG,)

In [94]:
c = decouple(b, 1, cg = CG)

In [95]:
c.shape

torch.Size([1, 5, 7, 3])

In [35]:
torch.norm(c - a)

tensor(1.1898e-06)

In [53]:
bb = b.copy()
bb[(2, 2, 3, 1)] = {}

In [64]:
c = decouple(bb, 1, cg = CG)

In [23]:
coupled = a
for coupledkey in coupled:
    k = coupledkey[1]
    print(coupledkey)
    for L in coupled[coupledkey]:
        print(coupledkey, L)
                    # block_idx = tuple(idx) + (k, L)
                    # skip blocks that are zero because of symmetry - TBD 
                    # if ai == aj and ni == nj and li == lj:
                    #     parity = (-1) ** (li + lj + L)
                    #     if ((parity == -1 and block_type in (0, 1)) or (parity == 1 and block_type == -1)) and not skip_symmetry:
                    #         continue
                    
        print(block.samples.values.shape, coupled[coupledkey][L].shape, torch.moveaxis(coupled[coupledkey][L], -1, -2).shape)


(2, 2, 3, 1)
(2, 2, 3, 1) 0
torch.Size([2, 6]) torch.Size([1, 1]) torch.Size([1, 1])
(2, 2, 3, 1) 1
torch.Size([2, 6]) torch.Size([1, 3]) torch.Size([3, 1])
(2, 2, 3, 1) 2
torch.Size([2, 6]) torch.Size([1, 5]) torch.Size([5, 1])
(2, 2, 3, 1) 3
torch.Size([2, 6]) torch.Size([1, 7]) torch.Size([7, 1])
(2, 2, 3, 1) 4
torch.Size([2, 6]) torch.Size([1, 9]) torch.Size([9, 1])
(2, 3, 3, 1)
(2, 3, 3, 1) 1
torch.Size([2, 6]) torch.Size([1, 3]) torch.Size([3, 1])
(2, 3, 3, 1) 2
torch.Size([2, 6]) torch.Size([1, 5]) torch.Size([5, 1])
(2, 3, 3, 1) 3
torch.Size([2, 6]) torch.Size([1, 7]) torch.Size([7, 1])
(2, 3, 3, 1) 4
torch.Size([2, 6]) torch.Size([1, 9]) torch.Size([9, 1])
(2, 3, 3, 1) 5
torch.Size([2, 6]) torch.Size([1, 11]) torch.Size([11, 1])
(2, 4, 3, 1)
(2, 4, 3, 1) 2
torch.Size([2, 6]) torch.Size([1, 5]) torch.Size([5, 1])
(2, 4, 3, 1) 3
torch.Size([2, 6]) torch.Size([1, 7]) torch.Size([7, 1])
(2, 4, 3, 1) 4
torch.Size([2, 6]) torch.Size([1, 9]) torch.Size([9, 1])
(2, 4, 3, 1) 5
torch.Si