# 0. Calculate the `davgs` and `dstds` using the first 10 frames.
1. `davgs.shape = (num_types, 4)`
2. `dstds.shape = (num_types, 4)`

In [10]:
import numpy as np
from matersdk.io.pwmat.output.movement import Movement
from matersdk.data.deepmd.data_system import DpLabeledSystem

from matersdk.feature.deepmd.preprocess import TildeRNormalizer

# 1. Use the first 10 frames of `DpLabeledSystem` to calculate statistic data, and init `TildeRNormalizer`

## Step 1. Initialize the `DpLabeledSystem`

In [11]:
movement_path = "/data/home/liuhanyu/hyliu/code/mlff/test/demo2/PWdata/data1/MOVEMENT"
movement = Movement(movement_path=movement_path)

dpsys = DpLabeledSystem.from_trajectory_s(trajectory_object=movement)
print(dpsys)

****************** LabeledSystem Summary *******************
	 * Images Number           : 550           
	 * Atoms Number            : 72            
	 * Virials Information     : True          
	 * Energy Deposition       : True          
	 * Elements List           :
		 - Li: 48              
		 - Si: 24              
************************************************************



## Step 1.2. Calculate the `davgs` and `dstds`, init `TildeRNormalizer`

In [12]:
structure_indices = [*range(10)]    # PWmat-MLFF 取前10帧结构计算`davg`和`dstd`
rcut = 5
rcut_smooth = 0.5
center_atomic_numbers = [3, 14]
nbr_atomic_numbers = [3, 14]
max_num_nbrs = [100, 100]
scaling_matrix = [3, 3, 3]

In [13]:
tilde_r_normalizer = TildeRNormalizer.from_dp_labeled_system(
                dp_labeled_system=dpsys,
                structure_indices=structure_indices,
                rcut=rcut,
                rcut_smooth=rcut_smooth,
                center_atomic_numbers=center_atomic_numbers,
                nbr_atomic_numbers=nbr_atomic_numbers,
                max_num_nbrs=max_num_nbrs,
                scaling_matrix=scaling_matrix
)

davgs, dstds = tilde_r_normalizer.davgs, tilde_r_normalizer.dstds

In [14]:
print("\nStep 1. davgs = ")
print(davgs)
print("\nStep 2. dstds = ")
print(dstds)


Step 1. davgs = 
[[0.0099991  0.         0.         0.        ]
 [0.01075823 0.         0.         0.        ]]

Step 2. dstds = 
[[0.03942547 0.02348297 0.02348297 0.02348297]
 [0.04493047 0.02667387 0.02667387 0.02667387]]


## Step 1.3. Normalize $\tilde{R}$ of new `DStructure` using `davgs` nad `dstds`

In [6]:
new_structure = movement.get_frame_structure(idx_frame=100)
tildeR_dict, tildeR_derivative_dict = tilde_r_normalizer.normalize(structure=new_structure)

In [8]:
print("Step 1. The Rij:")
for tmp_pair, tmp_normed_tildeR in tildeR_dict.items():
    print('\t', tmp_pair, ": ", tmp_normed_tildeR.shape)
    
print("Step 2. The derivative of Rij with respect to x, y, z:")
for tmp_key, tmp_value in tildeR_derivative_dict.items():
    print("\t", tmp_key, ": ", tmp_value.shape)

Step 1. The Rij:
	 3_3 :  (48, 100, 4)
	 3_14 :  (48, 80, 4)
	 14_3 :  (24, 100, 4)
	 14_14 :  (24, 80, 4)
Step 2. The derivative of Rij with respect to x, y, z:
	 3_3 :  (48, 100, 4, 3)
	 3_14 :  (48, 80, 4, 3)
	 14_3 :  (24, 100, 4, 3)
	 14_14 :  (24, 80, 4, 3)


In [11]:
np.concatenate(
        [
            tildeR_derivative_dict["3_3"], 
            tildeR_derivative_dict["3_14"]], axis=1).shape

(48, 180, 4, 3)

# 2. Save the `TildeRNormalizer` to `hdf5 file`

In [8]:
hdf5_file_path = "./demo_normalizer.h5"

tilde_r_normalizer.to(hdf5_file_path=hdf5_file_path)

In [9]:
ll ./demo_normalizer.h5

-rw-rw-r-- 1 liuhanyu 5632 Jun 25 16:48 ./demo_normalizer.h5


# 3. Init the `TildeRNormalizer` from `hdf5` file

In [10]:
new_trn = TildeRNormalizer.from_file(hdf5_file_path=hdf5_file_path)
print(new_trn)

*************************** TildeRNormalizer Summary ***************************
	 * rcut                      :       6.500000
	 * rcut_smooth               :       6.000000
	 * center_atomic_numbers:    :	 [ 3 14]
	 * nbr_atomic_numbers:       :	 [ 3 14]
	 * max_num_nbrs              :	 [100  80]
	 * scaling_matrix            :	 [3 3 3]
	 * davgs                     :	
[[0.06974313 0.         0.         0.        ]
 [0.06922328 0.         0.         0.        ]]
	 * dstds                     :	
[[0.11278804 0.07656205 0.07656205 0.07656205]
 [0.1140824  0.07704253 0.07704253 0.07704253]]
********************************************************************************

