In [None]:
from dptb.nnops.trainer import Trainer
from dptb.nn.build import build_model
from dptb.data.build import build_dataset
from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer
from dptb.plugins.train_logger import Logger
from dptb.utils.argcheck import normalize
from dptb.plugins.saver import Saver
from typing import Dict, List, Optional, Any
from dptb.utils.tools import j_loader, setup_seed, j_must_have
from dptb.utils.constants import dtype_dict
from dptb.utils.loggers import set_log_handles
import heapq
import logging
import torch
import random
import numpy as np
from pathlib import Path
import json
import os
import time
import matplotlib.pyplot as plt

In [None]:
jdata = j_loader('./input_dftb_env.json')
jdata = normalize(jdata)
checkpoint = 'dftbenv2/checkpoint/mix.best.pth'
f = torch.load(checkpoint)

if jdata.get("model_options", None) is None:
    jdata["model_options"] = f["config"]["model_options"]

# update basis
basis = f["config"]["common_options"]["basis"]
# nnsk
if len(f["config"]["model_options"])==1 and f["config"]["model_options"].get("nnsk") != None:
    for asym, orb in jdata["common_options"]["basis"].items():
        assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
        if orb != basis[asym]:
            log.info(f"Initializing Orbital {orb} of Atom {asym} from {basis[asym]}")
    # we have the orbitals in jdata basis correct, now we need to make sure all atom in basis are also contained in jdata basis
    for asym, orb in basis.items():
        if asym not in jdata["common_options"]["basis"].keys():
            jdata["common_options"]["basis"][asym] = orb # add the atomtype in the checkpoint but not in the jdata basis, because it will be used to build the orbital mapper for dataset
else: # not nnsk
    for asym, orb in jdata["common_options"]["basis"].items():
        assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
        assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis, which is only allowed in nnsk training"

    jdata["common_options"]["basis"] = basis
model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"])

In [None]:
from dptb.nnops.loss import Loss
train_lossfunc = Loss(**jdata['train_options']["loss_options"]["train"], **jdata["common_options"], idp=model.hamiltonian.idp)

In [None]:
from dptb.data import AtomicDataset, DataLoader, AtomicData, AtomicDataDict
train_datasets = build_dataset(**jdata["data_options"]["train"], **jdata["common_options"])
train_loader = DataLoader(dataset=train_datasets, batch_size=jdata['train_options']["batch_size"], shuffle=True)
batch =  next(iter(train_loader))
batch = AtomicData.to_AtomicDataDict(batch)
batch[AtomicDataDict.KPOINT_KEY] = batch[AtomicDataDict.KPOINT_KEY][0]

#batch = model(batch)

#batch = train_lossfunc.eigenvalue(batch)


In [5]:
batch = model.nnenv.embedding(batch)


In [6]:
from ase.io import read
from dptb.data import AtomicData, AtomicDataDict

stru_data = "./data/struct.vasp"
AtomicData_options={"r_max": 5.0,
        "er_max": 3.5,
        "oer_max":1.6,
        "pbc": True
        }
#AtomicData_options = {"r_max": 5.0,"er_max": 3.5, "oer_max":1.6, "pbc": True}
#structase = read(stru_data)
#from ase.io.trajectory  import Trajectory
#traj = Trajectory('./data/set.0/xdat.traj','r')
#structase=traj[0]

structase = read('./asestruct2.vasp')
structase.positions = structase.positions + np.random.normal(0, 0.01, [3,3])
data = AtomicData.from_ase(structase, **AtomicData_options)
data = AtomicData.to_AtomicDataDict(data)
data = model.idp(data)
data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(np.load('./data/set.0/kpoints.npy'),dtype=torch.float32)
#data = model(data)
#data = bcal.eigv(data)

In [None]:
#data = model.nnenv(data)
data = model.nnenv.embedding(data)


In [8]:
env_vectors = batch['env_vectors']
env_index = batch['env_index']
atom_attr = batch['node_attrs']
edge_index = batch['edge_index']
edge_length = batch['edge_lengths']
n_env = env_index.shape[1]
env_attr = atom_attr[env_index].transpose(1,0).reshape(n_env,-1)


In [None]:
env_vectors = data['env_vectors']
env_index   = data['env_index']
atom_attr   = data['node_attrs']
edge_index  = data['edge_index']
edge_length = data['edge_lengths']


n_env = env_index.shape[1]
env_attr = atom_attr[env_index].transpose(1,0).reshape(n_env,-1)

In [9]:
se2 = model.nnenv.embedding.descriptor

In [10]:
size = se2._check_input(edge_index, size=None)
decomposed_layers = 1 if se2.explain else se2.decomposed_layers


In [12]:
kwargs={'env_vectors':env_vectors, 'env_attr':env_attr}
coll_dict = se2._collect(se2._user_args, env_index, size,kwargs)

In [13]:
msg_kwargs = se2.inspector.distribute('message', coll_dict)

In [15]:
out = se2.message(**msg_kwargs)
#out2 = se2.message(**msg_kwargs)

In [17]:
rij = env_vectors.norm(dim=-1, keepdim=True)
snorm = se2.smooth(rij, 2.5, 3.5)

In [18]:
aggr_kwargs = se2.inspector.distribute('aggregate', coll_dict)

In [20]:
out = se2.aggregate(out, **aggr_kwargs)
#out2 = se2.aggregate(out2, **aggr_kwargs)

In [21]:
update_kwargs = se2.inspector.distribute('update', coll_dict)

In [22]:
reout2 = torch.bmm(out, out.transpose(1, 2))[:,:,:se2.n_axis].flatten(start_dim=1, end_dim=2)

In [23]:
reout2.norm(dim=1, keepdim=True)

tensor([[0.1425],
        [0.1841],
        [0.1841]], grad_fn=<LinalgVectorNormBackward0>)

In [37]:
iind = env_index[0]
jjind = env_index[1]

In [39]:
env_vectors[iind==0]

tensor([[-3.1841e+00,  0.0000e+00,  0.0000e+00],
        [ 1.5920e+00, -2.7575e+00,  0.0000e+00],
        [ 1.5497e-06, -1.8383e+00,  1.5636e+00],
        [-1.5920e+00, -2.7575e+00,  0.0000e+00],
        [-1.5920e+00,  9.1916e-01,  1.5636e+00],
        [ 1.5920e+00,  9.1916e-01,  1.5636e+00],
        [ 1.5497e-06, -1.8383e+00, -1.5636e+00],
        [ 1.5920e+00,  9.1916e-01, -1.5636e+00],
        [-1.5920e+00,  9.1916e-01, -1.5636e+00],
        [ 3.1841e+00,  0.0000e+00,  0.0000e+00],
        [-1.5920e+00,  2.7575e+00,  0.0000e+00],
        [ 1.5920e+00,  2.7575e+00,  0.0000e+00]])

In [41]:
env_vectors[(jjind==0)*(iind==0)]

tensor([[-3.1841,  0.0000,  0.0000],
        [ 1.5920, -2.7575,  0.0000],
        [-1.5920, -2.7575,  0.0000],
        [ 3.1841,  0.0000,  0.0000],
        [-1.5920,  2.7575,  0.0000],
        [ 1.5920,  2.7575,  0.0000]])