/
base_grid.py
53 lines (45 loc) · 1.29 KB
/
base_grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from abc import abstractmethod, abstractproperty
from typing import List
import torch
import xitorch as xt
class BaseGrid(xt.EditableModule):
"""
Grid is a class that regulates the integration points over the spatial
dimensions.
"""
@abstractproperty
def dtype(self) -> torch.dtype:
pass
@abstractproperty
def device(self) -> torch.device:
pass
@abstractproperty
def coord_type(self) -> str:
"""
Returns the type of the coordinate returned in get_rgrid
"""
pass
@abstractmethod
def get_dvolume(self) -> torch.Tensor:
"""
Obtain the torch.tensor containing the dV elements for the integration.
Returns
-------
torch.tensor (*BG, ngrid)
The dV elements for the integration
"""
pass
@abstractmethod
def get_rgrid(self) -> torch.Tensor:
"""
Returns the grid points position in the specified coordinate in
self.coord_type.
Returns
-------
torch.tensor (*BG, ngrid, ndim)
The grid points position.
"""
pass
@abstractmethod
def getparamnames(self, methodname: str, prefix: str = "") -> List[str]:
pass