Skip to content

Commit

Permalink
Support CIF output
Browse files Browse the repository at this point in the history
- support CIF output for Atoms and AtomCell
- fix symmetry parsing in CIF input
- more CIF tests
  • Loading branch information
hexane360 committed Jan 24, 2024
1 parent cd89634 commit e2a368a
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 116 deletions.
7 changes: 1 addition & 6 deletions atomlib/atomcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,7 @@ def crop_to_box(self: HasAtomCellT, eps: float = 1e-5) -> HasAtomCellT:

def wrap(self: HasAtomCellT, eps: float = 1e-5) -> HasAtomCellT:
"""Wrap atoms around the cell boundaries."""
def transform(atoms):
coords = atoms.coords()
coords = (coords + eps) % 1. - eps
return atoms.with_coords(coords)

return self.with_atoms(self._transform_atoms_in_frame('cell_box', transform))
return self.with_atoms(self._transform_atoms_in_frame('cell_box', lambda a: a._wrap(eps)))

"""
def explode(self: HasAtomCellT) -> HasAtomCellT:
Expand Down
9 changes: 6 additions & 3 deletions atomlib/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,12 @@ def crop(self: HasAtomsT, x_min: float = -numpy.inf, x_max: float = numpy.inf,

crop_atoms = crop

def _wrap(self: HasAtomsT, eps: float = 1e-5) -> HasAtomsT:
coords = (self.coords() + eps) % 1. - eps
return self.with_coords(coords)

def deduplicate(self: HasAtomsT, tol: float = 1e-3, subset: t.Iterable[str] = ('x', 'y', 'z', 'symbol'),
keep: UniqueKeepStrategy = 'first') -> HasAtomsT:
keep: UniqueKeepStrategy = 'first', maintain_order: bool = True) -> HasAtomsT:
"""
De-duplicate atoms in `self`. Atoms of the same `symbol` that are closer than `tolerance`
to each other (by Euclidian distance) will be removed, leaving only the atom specified by
Expand All @@ -346,7 +350,6 @@ def deduplicate(self: HasAtomsT, tol: float = 1e-3, subset: t.Iterable[str] = ('
cols -= spatial_cols
if len(spatial_cols) > 0:
coords = self.select(list(spatial_cols)).to_numpy()
print(coords.shape)
tree = scipy.spatial.KDTree(coords)

# TODO This is a bad algorithm
Expand All @@ -364,7 +367,7 @@ def deduplicate(self: HasAtomsT, tol: float = 1e-3, subset: t.Iterable[str] = ('
self = self.with_column(polars.Series('_unique_pts', indices))
cols.add('_unique_pts')

frame = self._get_frame().unique(subset=list(cols), keep=keep)
frame = self._get_frame().unique(subset=list(cols), keep=keep, maintain_order=maintain_order)
if len(spatial_cols) > 0:
frame = frame.drop('_unique_pts')

Expand Down
9 changes: 9 additions & 0 deletions atomlib/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,9 @@ def __pow__(self: SupportsNumSelf, other: SupportsNumSelf) -> SupportsNumSelf:
def __neg__(self: SupportsNumSelf) -> SupportsNumSelf:
...

def __pos__(self: SupportsNumSelf) -> SupportsNumSelf:
...


def parse_numeric(s: str) -> t.Union[int, float]:
try:
Expand All @@ -564,6 +567,12 @@ def sub(lhs: SupportsNum, rhs: t.Optional[SupportsNum] = None):
return lhs-rhs


def add(lhs: SupportsNum, rhs: t.Optional[SupportsNum] = None):
if rhs is None:
return +lhs
return lhs+rhs


def parse_boolean(s: str) -> bool:
if s.lower() in ("0", "false", "f"):
return False
Expand Down
20 changes: 17 additions & 3 deletions atomlib/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def read_cif(f: t.Union[FileOrPath, CIF, CIFDataBlock], block: t.Union[int, str,

logging.debug("cif data: %r", cif.data_dict)

# TODO: support atom_site_Cartn_[xyz]
# TODO: support atom_site_B_iso_or_equiv
df = cif.stack_tags('atom_site_fract_x', 'atom_site_fract_y', 'atom_site_fract_z',
'atom_site_type_symbol', 'atom_site_label', 'atom_site_occupancy', 'atom_site_U_iso_or_equiv',
rename=('x', 'y', 'z', 'symbol', 'label', 'frac_occupancy', 'wobble'),
Expand All @@ -68,17 +70,29 @@ def read_cif(f: t.Union[FileOrPath, CIF, CIFDataBlock], block: t.Union[int, str,
sym_atoms.append(atoms.transform(sym))

if len(sym_atoms) > 0:
atoms = AtomCell.from_ortho(Atoms.concat(sym_atoms), LinearTransform3D()) \
.wrap().get_atoms().deduplicate()
atoms = Atoms.concat(sym_atoms)._wrap().deduplicate()

if (cell_size := cif.cell_size()) is not None:
cell_size = to_vec3(cell_size)
if (cell_angle := cif.cell_angle()) is not None:
# degrees to radians
cell_angle = to_vec3(cell_angle) * numpy.pi/180.
return AtomCell.from_unit_cell(atoms, cell_size, cell_angle, frame='cell_frac')
return Atoms(atoms)


def write_cif(atoms: t.Union[HasAtoms, CIF, CIFDataBlock], f: FileOrPath):
"""Write a structure to an XSF file."""
if isinstance(atoms, (CIF, CIFDataBlock)):
cif = atoms
elif isinstance(atoms, AtomCell):
cif = CIF((CIFDataBlock.from_atomcell(atoms),))
else:
cif = CIF((CIFDataBlock.from_atoms(atoms),))

cif.write(f)


def read_xyz(f: t.Union[FileOrPath, XYZ]) -> HasAtoms:
"""Read a structure from an XYZ file."""
if isinstance(f, XYZ):
Expand Down Expand Up @@ -172,7 +186,7 @@ def write_cfg(atoms: t.Union[HasAtoms, CFG], f: FileOrPath):

WriteFunc = t.Callable[[HasAtoms, FileOrPath], None]
_WRITE_TABLE: t.Mapping[FileType, t.Optional[WriteFunc]] = {
'cif': None,
'cif': write_cif,
'xyz': write_xyz,
'xsf': write_xsf,
'cfg': write_cfg,
Expand Down
196 changes: 96 additions & 100 deletions atomlib/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,21 @@
from numpy.typing import NDArray

from ..transform import AffineTransform3D
from ..expr import Parser, BinaryOp, BinaryOrUnaryOp, sub
from ..expr import Parser, BinaryOp, BinaryOrUnaryOp, sub, add
from ..util import open_file, FileOrPath
from ..cell import ortho_to_cell

if t.TYPE_CHECKING:
from ..atoms import HasAtoms
from ..atomcell import HasAtomCell


Value = t.Union[int, float, str, None]
_INT_RE = re.compile(r'[-+]?\d+')
# float regex with uncertainty
# float regex with uncertainty (e.g. '3.14159(3)')
_FLOAT_RE = re.compile(r'([-+]?\d*(\.\d*)?(e[-+]?\d+)?)(\(\d+\))?', re.I)


def _format_val(val: Value) -> str:
if val is None:
return '.'
if isinstance(val, float):
return f"{val:.8f}"
if isinstance(val, str):
if len(val.splitlines()) > 1 or re.search(r'"\'', val) is not None:
# multi-line string
return f";\n{val}\n;"

if re.search(r'\s', val) is not None:
# whitespace, quote string
val = f"'{val}'"

return f"{val:<3}" # auto pad to 3 chars (for element symbols)

return str(val)


@dataclass
class CIF:
data_blocks: t.Tuple[CIFDataBlock, ...]
Expand Down Expand Up @@ -113,6 +99,72 @@ def from_file(file: FileOrPath) -> t.Iterator[CIFDataBlock]:
with open_file(file) as f:
yield from CifReader(f).parse()

@staticmethod
def from_atoms(atoms: HasAtoms) -> CIFDataBlock:
data: t.List[t.Union[t.Tuple[str, Value], CIFTable]] = []

data.append(('audit_creation_method', 'Generated by atomlib'))

keys: t.Sequence[t.Tuple[str, t.Union[str, polars.Expr], t.Union[str, bool]]] = (
# col, expr, predicate (column or boolean)
('atom_site_type_symbol', 'symbol', True),
('atom_site_label', 'label', 'label'),
('atom_site_occupancy', 'frac_occupancy', 'frac_occupancy'),
('atom_site_Cartn_x', 'x', True),
('atom_site_Cartn_y', 'y', True),
('atom_site_Cartn_z', 'z', True),
('atom_site_U_iso_or_equiv', 'wobble', 'wobble'),
)
data.append(CIFTable({
key: atoms.select(expr).to_series().to_list() for (key, expr, pred) in keys
if (atoms.try_get_column(pred) is not None if isinstance(pred, str) else pred)
}))

return CIFDataBlock("", tuple(data))

@staticmethod
def from_atomcell(atomcell: HasAtomCell) -> CIFDataBlock:
atoms = atomcell.get_atoms('cell_box')
ortho = atomcell.get_transform('local', 'cell_box').to_linear()
(cell_size, cell_angle) = ortho_to_cell(ortho)
cell_angle *= 180./numpy.pi # convert to degrees

data: t.List[t.Union[t.Tuple[str, Value], CIFTable]] = []

data.append(('audit_creation_method', 'Generated by atomlib'))

# symmetry information
data.append(CIFTable({
'space_group_symop_id': [1],
'space_group_symop_operation_xyz': ['x,y,z'],
}))

# cell information
data.append(('cell_length_a', cell_size[0]))
data.append(('cell_length_b', cell_size[1]))
data.append(('cell_length_c', cell_size[2]))
data.append(('cell_angle_alpha', cell_angle[0]))
data.append(('cell_angle_beta', cell_angle[1]))
data.append(('cell_angle_gamma', cell_angle[2]))
data.append(('cell_volume', ortho.det()))

keys: t.Sequence[t.Tuple[str, t.Union[str, polars.Expr], t.Union[str, bool]]] = (
# col, expr, predicate (column or boolean)
('atom_site_type_symbol', 'symbol', True),
('atom_site_label', 'label', 'label'),
('atom_site_occupancy', 'frac_occupancy', 'frac_occupancy'),
('atom_site_fract_x', 'x', True),
('atom_site_fract_y', 'y', True),
('atom_site_fract_z', 'z', True),
('atom_site_U_iso_or_equiv', 'wobble', 'wobble'),
)
data.append(CIFTable({
key: atoms.select(expr).to_series().to_list() for (key, expr, pred) in keys
if (atoms.try_get_column(pred) is not None if isinstance(pred, str) else pred)
}))

return CIFDataBlock("", tuple(data))

def write(self, file: FileOrPath):
with open_file(file, 'w') as f:
self._write(f)
Expand Down Expand Up @@ -193,7 +245,10 @@ def cell_angle(self) -> t.Optional[t.Tuple[float, float, float]]:
return None

def get_symmetry(self) -> t.Iterator[AffineTransform3D]:
syms = self.data_dict.get('symmetry_equiv_pos_as_xyz', None)
syms = self.data_dict.get('space_group_symop_operation_xyz')
if syms is None:
# old name for symmetry
syms = self.data_dict.get('symmetry_equiv_pos_as_xyz')
if syms is None:
syms = ()
if not hasattr(syms, '__iter__'):
Expand All @@ -218,87 +273,25 @@ def _write(self, f: TextIOBase):

print(file=f)

"""
@dataclass
class CIF:
name: t.Optional[str]
data: t.Dict[str, t.Union[t.List[Value], Value]]
@staticmethod
def from_file(file: FileOrPath) -> t.Iterator[CIF]:
with open_file(file) as f:
yield from CifReader(f).parse()
def stack_tags(self, *tags: str, dtype: t.Union[str, numpy.dtype, t.Iterable[t.Union[str, numpy.dtype]], None] = None,
rename: t.Optional[t.Iterable[t.Optional[str]]] = None, required: t.Union[bool, t.Iterable[bool]] = True) -> polars.DataFrame:
dtypes: t.Iterable[t.Optional[numpy.dtype]]
if dtype is None:
dtypes = repeat(None)
elif isinstance(dtype, (numpy.dtype, str)):
dtypes = (numpy.dtype(dtype),) * len(tags)
else:
dtypes = tuple(map(lambda ty: numpy.dtype(ty), dtype))
if len(dtypes) != len(tags):
raise ValueError(f"dtype list of invalid length")
if isinstance(required, bool):
required = repeat(required)
if rename is None:
rename = repeat(None)
d = {}
for (tag, ty, req, name) in zip(tags, dtypes, required, rename):
if tag not in self.data:
if req:
raise ValueError(f"Tag '{tag}' missing from CIF file")
continue
try:
arr = numpy.array(self.data[tag], dtype=ty)
d[name or tag] = arr
except TypeError:
raise TypeError(f"Tag '{tag}' of invalid or heterogeneous type.")
if len(d) == 0:
return polars.DataFrame({})
l = len(next(iter(d.values())))
if any(len(arr) != l for arr in d.values()):
raise ValueError(f"Tags of mismatching lengths: {tuple(map(len, d.values()))}")
return polars.DataFrame(d)

def cell_size(self) -> t.Optional[t.Tuple[float, float, float]]:
\"""Return cell size (in angstroms).\"""
try:
a = float(self['cell_length_a']) # type: ignore
b = float(self['cell_length_b']) # type: ignore
c = float(self['cell_length_c']) # type: ignore
return (a, b, c)
except (ValueError, TypeError, KeyError):
return None
def _format_val(val: Value) -> str:
if val is None:
# None -> '.' (or '?')
return '.'
if isinstance(val, float):
return f"{val:.8f}"
if isinstance(val, str):
if len(val.splitlines()) > 1 or re.search(r'"\'', val) is not None:
# multi-line string, use semicolon syntax
return f";\n{val}\n;"

def cell_angle(self) -> t.Optional[t.Tuple[float, float, float]]:
\"""Return cell angle (in degrees).\"""
try:
a = float(self['cell_angle_alpha']) # type: ignore
b = float(self['cell_angle_beta']) # type: ignore
g = float(self['cell_angle_gamma']) # type: ignore
return (a, b, g)
except (ValueError, TypeError, KeyError):
return None
if val.startswith('_') or re.search(r'\s', val) is not None:
# string needs to be quoted
val = f"'{val}'"

def get_symmetry(self) -> t.Iterator[AffineTransform3D]:
syms = self.data.get('symmetry_equiv_pos_as_xyz', None)
if syms is None:
syms = ()
if not hasattr(syms, '__iter__'):
syms = (syms,)
return map(parse_symmetry, map(str, syms)) # type: ignore
return f"{val:<3}" # auto pad to 3 chars (for element symbols)

def __getitem__(self, key: str) -> t.Union[Value, t.List[Value]]:
return self.data.__getitem__(key)
"""
return str(val)


class SymmetryVec:
Expand Down Expand Up @@ -331,6 +324,9 @@ def __add__(self, rhs: SymmetryVec) -> SymmetryVec:
def __neg__(self) -> SymmetryVec:
return SymmetryVec(-self.inner)

def __pos__(self) -> SymmetryVec:
return self

def __sub__(self, rhs: SymmetryVec) -> SymmetryVec:
if self.is_scalar() and rhs.is_scalar():
return SymmetryVec(self.inner - rhs.inner)
Expand All @@ -349,7 +345,7 @@ def __truediv__(self, rhs: SymmetryVec) -> SymmetryVec:

SYMMETRY_PARSER: Parser[SymmetryVec, SymmetryVec] = Parser([
BinaryOrUnaryOp(['-'], sub, False, 5),
BinaryOp(['+'], operator.add, 5),
BinaryOrUnaryOp(['+'], add, False, 5),
BinaryOp(['*'], operator.mul, 6),
BinaryOp(['/'], operator.truediv, 6),
], SymmetryVec.parse)
Expand Down
Loading

0 comments on commit e2a368a

Please sign in to comment.