# Understanding `ase.spacegroup.crystal`

**Content of this notebook**:
1. Parsing ase's spacegroup.dat and storage as json for easier parsing in the future
2. Reproducing the `ase` logic in a stripped down version for better grokking (of course this skips a lot of covenience aspects contained in `ase` - a look at the original code is highly recommended if you are interested in special cases)
3. Sanity checking created crystals for correctness 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from ase import spacegroup, Atoms
import matplotlib.pyplot as plt
from ase.visualize.plot import plot_atoms
import json
import itertools
import numpy as np

from playing_with_ase import dat2json, equivalent_sites, get_symop

## The ASE experience

ase makes it super easy to generate relatively complex structures. Below a few examples we'll reproduce in this notebook.

In [None]:
crystal_specs = dict(
    NaCl = dict(
        # NaCl structure
        a = 5.64,
        symbols = ['Na', 'Cl'],
        basis = [(0, 0, 0), (0.5, 0.5, 0.5)], # scaled coordinates
        nr = 225,
        cellpar = [5.64, 5.64, 5.64, 90, 90, 90]
    ),
    Al_fcc = dict(
        # Al fcc structure
        a = 4.05,
        symbols = ['Al'],
        basis = [(0, 0, 0),], # scaled coordinates
        nr = 225,
        cellpar = [4.05, 4.05, 4.05, 90, 90, 90]
    ),
    Fe_bcc = dict(
        # Fe bcc structure
        a = 2.87,
        symbols = ['Fe'],
        basis = [(0, 0, 0),], # scaled coordinates
        nr = 229,
        cellpar = [2.87, 2.87, 2.87, 90, 90, 90]
    ),
    Mg_hcp = dict(
        # Mg hcp structure
        a = 3.21,
        c = 5.21,
        symbols = ['Mg'],
        basis = [(1./3., 2./3., 3./4.),], # scaled coordinates
        nr = 194,
        cellpar = [3.21, 3.21, 5.21, 90, 90, 120]
    ),
    Diamond = dict(
        # Diamond structure
        a = 3.57,
        symbols = ['C'],
        basis = [(0,0,0),], # scaled coordinates
        nr = 227,
        cellpar = [3.57, 3.57, 3.57, 90, 90, 90]
    ),
    Rutile = dict(
        # Rutile structure
        a = 4.6,
        c = 2.95,
        symbols = ['Ti', 'O'],
        basis = [(0,0,0), (.3,.3,0)], # scaled coordinates
        nr = 136,
        cellpar = [4.6, 4.6, 2.95, 90, 90, 90]
    ),
    Skudderudite = dict(
        # CoSb3 skudderudite
        a = 9.04,
        symbols = ['Co', 'Sb'],
        basis = [(0.25, 0.25, 0.25), (0.0, 0.335, 0.158)], # scaled coordinates
        nr = 204,
        cellpar = [9.04, 9.04, 9.04, 90, 90, 90]
    )
)

Generating an `ase.Atoms` object using the above specs

In [None]:
ab_normal = (0,0,1) # normal to the plane spanned by cell vectors a and b
a_direction = (1,0,0) # direction of the "x" vector of the cellbox
name = "Skudderudite"
nr = crystal_specs[name]['nr']
basis = crystal_specs[name]['basis']
cellpar = crystal_specs[name]['cellpar']
symbols = crystal_specs[name]['symbols']

crystal = spacegroup.crystal(symbols, 
                             basis, 
                             spacegroup=nr,
                             cellpar=cellpar, 
                             a_direction=a_direction,
                             ab_normal=ab_normal)

Visualizing `crystal`

In [None]:
crystal.todict()

In [None]:
fig, ax = plt.subplots()
plot_atoms(crystal, ax, radii=0.3, rotation=('0x,30y,0z'))
plt.show()

When using `spacegroup.crystal`, a lot of things happen in the background. But the basic ase crystal algorithm is this:

1. collect symmetry operations from `spacegroup.dat` for the `spacegroup` and setting (`1` or `2`)

2. compute equivalent `sites` to the scaled positions
    - iterate positions:
        - apply rotations and translations and fold back into $[0,1)$ to generate equivalent sites
            - iterate relevant parities (-1, 1) and sub-translations: `newtrans = (trans + subtrans) % 1`, `newrot = parity * rotation` 
        - drop new site if duplicate

3. compute the `cell` vectors

4. compute real space positions by multiplying `sites` and `cell`

## Re-implementing the ase crystal algorithm

### `spacegroup.dat` $\Rightarrow$ `spacegroup.json`

`spacegroup.dat` contains the operations defining all spacegroups. However it is written in a format that requires some verbose loading functions. So let's simplify this by using ase's specialised functions once and store the output as a json.

This can be done using `dat2json`

In [None]:
%%time
dat2json()

In [None]:
%%time
spgs = json.load(open('../src/spacegroup.json','r'))

Let's check the available `spgs` entries for `nr`

In [None]:
[k for k in spgs.keys() if str(nr) in k]

In [None]:
setting = 1 # some spacegroups have 2 settings (1 and 2)
spg = spgs[f'{nr}: {setting}']
spg

### Computing equivalent sites

Equivalent sites are generated by applying spacegroup specific translations and rotations. The ones for our current spacegroup `nr` are 

Depending on wether or not our spacegroup is centrosymmetric, we also apply a parity operation. So let's loop over our parities and collect rotation and translation operations in `symops`

In [None]:
parities = [1] if not spg['centrosymmetric'] else [1,-1]
symops = []

assert len(spg['rotations']) == len(spg['translations'])

for parity, trans_sub in itertools.product(parities,spg['subtrans']):
    for rot, trans in zip(spg['rotations'], spg['translations']):
        
        symops.append((
            parity * np.array(rot), # rotation op
            np.mod(np.array(trans) + np.array(trans_sub), 1) # translation op
        ))

In [None]:
symops[:3]

Now we apply all operations in `symops` to all the sites in our `basis`

In [None]:
%%time
symprec = 1e-3
kinds, sites = [], []

for kind, pos in enumerate(np.array(basis)):
    for rot, trans in symops:
        
        site = np.mod(np.dot(rot, pos) + trans, 1.) # applying symmetry op to site
        
        # storing site
        if len(sites) == 0: # if we have no sites yet just store `kind` and generated `site`
            sites.append(site)
            kinds.append(kind)
        else: # sanity check if site already exists
            t = site - sites
            isclose = np.isclose(t, np.zeros_like(t), atol=symprec).all(axis=1)
            if not np.any(isclose):
                sites.append(site)
                kinds.append(kind)
                
print(f'sites:\n{sites[:5]}, \n\nkinds:\n{kinds[:5]}')

`sites` contains all equivalent positions of our initial `basis` (which contains scaled positions). `kinds` contains integers related to the `symbols` for the `basis` entries. 

### Computing cell vectors

To compute our `cell` box we need a coordinate system `xyz` and our cell vectors `abc`. First, let's generate `xyz`

In [None]:
ab_normal, a_direction

In [None]:
def norm(x): return np.array(x) / np.linalg.norm(x)

assert np.isclose(np.dot(ab_normal, a_direction),0)
_x = norm(a_direction)
z = norm(ab_normal)
x = _x - np.dot(_x, ab_normal) * z

xyz = np.array([x, np.cross(z,x), z])
xyz

Second, compute components required for `abc`

In [None]:
cellpar # a, b, c, alpha, gamma, beta

In [None]:
a, b, c, alpha, beta, gamma = cellpar
assert all([a>0, b>0, c>0, 0<alpha<180, 0<beta<180, -180.<=gamma<=180.])

In [None]:
def deg2rad(x): return x*np.pi/180.

cos_alpha = 0. if np.isclose(alpha,90) else np.cos(deg2rad(alpha))
cos_beta = 0. if np.isclose(beta,90) else np.cos(deg2rad(beta))
cos_gamma = 0. if np.isclose(abs(gamma),90) else np.cos(deg2rad(gamma))
sin_gamma = np.sign(gamma) if np.isclose(abs(gamma),90) else np.sin(deg2rad(gamma))
cos_alpha, cos_beta, cos_gamma, sin_gamma

And with `abc`

In [None]:
cy = (cos_alpha - cos_beta * cos_gamma) / sin_gamma
abc = np.array([
    [a, 0, 0],
    [b*cos_gamma, b*sin_gamma, 0],
    [c*cos_beta, 
     c*cy, 
     c*np.sqrt(1. - cos_beta*cos_beta - cy*cy)]
])

abc

we can compute `cell`

In [None]:
cell = np.dot(abc, xyz)
cell

### Computing atom positions 

Having computed `sites` and `cell` we can easily compute our real space / no scaled atom `positions`

In [None]:
positions = np.dot(sites, cell)
positions[:5]

In [None]:
assert np.allclose(positions, crystal.positions), f'crystal.positions != positions:\n{crystal.positions}\n!=\n{positions}'

Using `kinds`, `positions` and `cell` we can generate an `ase.Atoms` object and compare the result to the one from `ase.spacegroup.crystal`

In [None]:
atoms = Atoms(symbols=[symbols[v] for v in kinds],
              positions=positions,
              cell=cell)

rotation = ('5x,30y,0z') # defines the view angle

fig, axs = plt.subplots(ncols=2, figsize=(10,9))
plot_atoms(atoms, axs[0], radii=0.3, rotation=rotation)
axs[0].set(title='nb code')
plot_atoms(crystal, axs[1], radii=0.3, rotation=rotation)
axs[1].set(title='original ase')
plt.show()

## Exporting ase crystal properties for tests in julia

In [None]:
crystal_specs = dict(
    NaCl = dict(
        # NaCl structure
        a = 5.64,
        symbols = ['Na', 'Cl'],
        basis = [(0, 0, 0), (0.5, 0.5, 0.5)], # scaled coordinates
        nr = 225,
        cellpar = [5.64, 5.64, 5.64, 90, 90, 90]
    ),
    Al_fcc = dict(
        # Al fcc structure
        a = 4.05,
        symbols = ['Al'],
        basis = [(0, 0, 0),], # scaled coordinates
        nr = 225,
        cellpar = [4.05, 4.05, 4.05, 90, 90, 90]
    ),
    Fe_bcc = dict(
        # Fe bcc structure
        a = 2.87,
        symbols = ['Fe'],
        basis = [(0, 0, 0),], # scaled coordinates
        nr = 229,
        cellpar = [2.87, 2.87, 2.87, 90, 90, 90]
    ),
    Mg_hcp = dict(
        # Mg hcp structure
        a = 3.21,
        c = 5.21,
        symbols = ['Mg'],
        basis = [(1./3., 2./3., 3./4.),], # scaled coordinates
        nr = 194,
        cellpar = [3.21, 3.21, 5.21, 90, 90, 120]
    ),
    Diamond = dict(
        # Diamond structure
        a = 3.57,
        symbols = ['C'],
        basis = [(0,0,0),], # scaled coordinates
        nr = 227,
        cellpar = [3.57, 3.57, 3.57, 90, 90, 90]
    ),
    Rutile = dict(
        # Rutile structure
        a = 4.6,
        c = 2.95,
        symbols = ['Ti', 'O'],
        basis = [(0,0,0), (.3,.3,0)], # scaled coordinates
        nr = 136,
        cellpar = [4.6, 4.6, 2.95, 90, 90, 90]
    ),
    Skudderudite = dict(
        # CoSb3 skudderudite
        a = 9.04,
        symbols = ['Co', 'Sb'],
        basis = [(0.25, 0.25, 0.25), (0.0, 0.335, 0.158)], # scaled coordinates
        nr = 204,
        cellpar = [9.04, 9.04, 9.04, 90, 90, 90]
    )
)

Generating an `ase.Atoms` object using the above specs

In [None]:
ab_normal = (0,0,1) # normal to the plane spanned by cell vectors a and b
a_direction = (1,0,0) # direction of the "x" vector of the cellbox

ase_crystals = {k: spacegroup.crystal(d['symbols'], d['basis'], 
                             spacegroup=d['nr'],
                             cellpar=d['cellpar'], 
                             a_direction=a_direction,
                             ab_normal=ab_normal)
                for k, d in crystal_specs.items()}

In [None]:
def atoms2dict(atoms):
    d = atoms.todict().copy()
    for k,v in d.items():
        if isinstance(v,np.ndarray):
            d[k] = v.tolist()
    d['info']['spacegroup'] = {
        'nr': d['info']['spacegroup'].no, 
        'setting': d['info']['spacegroup'].setting
    }
    d['symbols'] = atoms.get_chemical_symbols()
    return d

def ase_crystals2json(crystals:list, json_fname='../src/ase-atoms.json'):
    ds = {k: atoms2dict(v) for k,v in crystals.items()}
    print(f'Storing ase crystals as dictionaries: {json_fname}')
    with open(json_fname, "w") as f:
        json.dump(ds, f)

In [None]:
%%time
ase_crystals2json(ase_crystals, json_fname='../src/ase-atoms.json')

In [None]:
def json2ase_crystals(json_fname='../src/ase-atoms.json'):
    with open(json_fname, "r") as f:
        ds = json.load(f)
    print(ds['Skudderudite']['positions'])
    ase_crystals = {}
    for k, d in ds.items():
        ase_crystals[k] = {
            v: np.array(d[v]) for v in ['numbers', 'positions', 'spacegroup_kinds', 'cell', 'pbc']
        }
        ase_crystals[k]['info'] = {'spacegroup': spacegroup.Spacegroup(d['info']['spacegroup']['nr'], 
                                                                     setting=d['info']['spacegroup']['setting'])}
        print(ase_crystals)
        ase_crystals[k] = Atoms.fromdict(ase_crystals[k])
        
    return ase_crystals

In [None]:
%%time
# json_fname='../src/ase-atoms.json'
json_fname='julia-atoms.json'
atoms_from_disk = json2ase_crystals(json_fname=json_fname)

In [None]:
fig, ax = plt.subplots()
plot_atoms(atoms_from_disk['Mg_hcp'], ax, radii=0.3, rotation=('0x,30y,0z'))
plt.show()