In [None]:
import torch
from torch import nn

# Model

In [None]:
class EnergyModel(nn.Module):
    def __init__(self):
        super().__init__()
        layer = nn.Linear(256, 1)
        return None
    def energy(self, x):
        return layer(x)
    def forward(self, x: torch.Tensor):
        x.requires_grad_(True)
        energy = energy(x)
        y = energy.sum()
        grad = torch.autograd.grad([y], [x], create_graph=False)[0] # for inference only
        if not grad is None: # grad might be None for the JIT compiler
            grad = - grad
        return energy, grad
model = EnergyModel()
input = torch.randn((4, 256))
model = torch.jit.script(model)
model(input)

In [None]:
model = EnergyModel()
input = torch.randn((4, 256))
model = torch.jit.script(model)
model(input)

# Scripting Example
https://github.com/pytorch/pytorch/issues/46483

In [None]:
import torch

x = torch.randn(16, 30, 10, requires_grad=True).cuda()
y = torch.randn(16, 10, 1, requires_grad=True).cuda()
z = torch.randn(16, 30, 1, requires_grad=True).cuda()


@torch.jit.script
def sample_function(x: torch.Tensor, y:torch.Tensor, z:torch.Tensor):
    xy = torch.bmm(x, y)
    diff = z - xy
    grad = torch.autograd.grad([diff.mean()], [y])[0]
    if grad is not None:
        y = y - grad
    return y

# Compile

## Region3D

In [None]:
import torch
class Region3D(object):

    def __init__(self, boxt):
        '''Construct a simulation box.'''
        boxt = boxt.reshape([3, 3])
        boxt = boxt.permute(1, 0)  # 用于世界坐标转内部坐标
        rec_boxt = torch.linalg.inv(boxt)  # 用于内部坐标转世界坐标

        # 计算空间属性
        volume = torch.linalg.det(boxt)  # 平行六面体空间的体积
        c_yz = torch.cross(boxt[1], boxt[2])
        _h2yz = volume / torch.linalg.norm(c_yz)
        c_zx = torch.cross(boxt[2], boxt[0])
        _h2zx = volume / torch.linalg.norm(c_zx)
        c_xy = torch.cross(boxt[0], boxt[1])
        _h2xy = volume / torch.linalg.norm(c_xy)

    def phys2inter(self, coord):
        '''Convert physical coordinates to internal ones.'''
        return coord@rec_boxt

    def inter2phys(self, coord):
        '''Convert internal coordinates to physical ones.'''
        return coord@boxt

    def get_face_distance(self):
        '''Return face distinces to each surface of YZ, ZX, XY.'''
        return torch.stack([_h2yz, _h2zx, _h2xy])

def func(box):
    region = Region3D(box)
    return region.phys2inter(box)

In [None]:
opt_func = torch.compile(func, dynamic=True)
func(2*torch.eye(3))

## Full Model

In [None]:
import logging
import torch

from typing import Any, Dict

from deepmd_pt import my_random
from deepmd_pt.dataset import DeepmdDataSet
from deepmd_pt.learning_rate import LearningRateExp
from deepmd_pt.loss import EnergyStdLoss
from deepmd_pt.model import EnergyModel
from deepmd_pt.env import DEVICE, JIT
import json

with open("tests/water/se_e2_a.json", 'r') as fin:
    content = fin.read()
config = json.loads(content)

model_params = config['model']
training_params = config['training']
my_random.seed(training_params['seed'])
dataset_params = training_params.pop('training_data')
training_data = DeepmdDataSet(
    systems=dataset_params['systems'],
    batch_size=dataset_params['batch_size'],
    type_map=model_params['type_map']
)
model = EnergyModel(model_params, training_data).to(DEVICE)

torch._dynamo.config.verbose = True
model = torch.compile(model, dynamic=True)

bdata = training_data.get_batch(tf=False, pt=True)

# Prepare inputs
coord = bdata['coord']
atype = bdata['type']
natoms = bdata['natoms_vec']
box = bdata['box']
l_energy = bdata['energy']
l_force = bdata['force']

# Compute prediction error
coord.requires_grad_(True)
p_energy, p_force = model(coord, atype, natoms, box)

## Descriptor

In [None]:
from deepmd_pt.descriptor import smoothDescriptor
import numpy as np
from deepmd_pt.env import *
from deepmd_pt.dataset import DeepmdDataSet
import os

rcut = 6.
rcut_smth = 0.5
sel = [46, 92]

sec = np.cumsum(sel)
ntypes = len(sel)
nnei = sum(sel)
CUR_DIR = "tests/"
ntypes=2
ds = DeepmdDataSet([
    os.path.join(CUR_DIR, 'water/data/data_0'),
    os.path.join(CUR_DIR, 'water/data/data_1'),
    os.path.join(CUR_DIR, 'water/data/data_2')
], 2, ['O', 'H'])
np_batch, pt_batch = ds.get_batch(pt=True)
pt_coord = pt_batch['coord']
pt_coord.requires_grad_(True)
smoothDescriptor = torch.compile(smoothDescriptor, dynamic=True)
avg_zero = torch.zeros([ntypes, nnei*4], dtype=GLOBAL_PT_FLOAT_PRECISION)
std_ones = torch.ones([ntypes, nnei*4], dtype=GLOBAL_PT_FLOAT_PRECISION)
my_d = smoothDescriptor(
    pt_coord.to(DEVICE),
    pt_batch['type'],
    pt_batch['natoms_vec'],
    pt_batch['box'],
    avg_zero.reshape([-1, nnei, 4]).to(DEVICE),
    std_ones.reshape([-1, nnei, 4]).to(DEVICE),
    rcut,
    rcut_smth,
    sec
)