/
base_grid.py
127 lines (104 loc) · 3.32 KB
/
base_grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from deepchem.utils.differentiation_utils import EditableModule
import torch
from abc import abstractmethod, abstractproperty
from typing import List
class BaseGrid(EditableModule):
"""
BaseGrid is a class that regulates the integration points over the spatial
dimensions.
Examples
--------
>>> import torch
>>> from deepchem.utils.dft_utils import BaseGrid
>>> class Grid(BaseGrid):
... def __init__(self):
... super(Grid, self).__init__()
... self.ngrid = 10
... self.ndim = 3
... self.dvolume = torch.ones(self.ngrid, dtype=self.dtype, device=self.device)
... self.rgrid = torch.ones((self.ngrid, self.ndim), dtype=self.dtype, device=self.device)
... def get_dvolume(self):
... return self.dvolume
... def get_rgrid(self):
... return self.rgrid
>>> grid = Grid()
>>> grid.get_dvolume()
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
>>> grid.get_rgrid()
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
References
----------
Kasim, Muhammad F., and Sam M. Vinko. "Learning the exchange-correlation
functional from nature with fully differentiable density functional theory."
Physical Review Letters 127.12 (2021): 126403.
https://github.com/diffqc/dqc/blob/0fe821fc92cb3457fb14f6dff0c223641c514ddb/dqc/grid/base_grid.py
"""
@abstractproperty
def dtype(self) -> torch.dtype:
"""dtype of the grid points.
Returns
-------
torch.dtype
dtype of the grid points
"""
pass
@abstractproperty
def device(self) -> torch.device:
"""device of the grid points
Returns
-------
torch.device
device of the grid points
"""
pass
@abstractproperty
def coord_type(self) -> str:
"""type of the coordinate returned in get_rgrid. It can be 'cartesian'
or 'spherical'.
Returns
-------
str
type of the coordinate returned in get_rgrid. It can be 'cartesian'
or 'spherical'.
"""
pass
@abstractmethod
def get_dvolume(self) -> torch.Tensor:
"""Obtain the torch.tensor containing the dV elements for the integration.
Returns
-------
torch.tensor (*BG, ngrid)
The dV elements for the integration. *BG is the length of the BaseGrid.
"""
pass
@abstractmethod
def get_rgrid(self) -> torch.Tensor:
"""
Returns the grid points position in the specified coordinate in
self.coord_type.
Returns
-------
torch.tensor (*BG, ngrid, ndim)
The grid points position. *BG is the length of the BaseGrid.
"""
pass
@abstractmethod
def getparamnames(self, methodname: str, prefix: str = "") -> List[str]:
"""
Return a list with the parameter names corresponding to the given method
(methodname)
Returns
-------
List[str]
List of parameter names of methodname
"""
pass