In [None]:
#| hide
#| default_exp util
from nbdev import *

# Utility functions

> Helper functions

In [None]:
#| exporti
from numpy import dot
from spglib import find_primitive, get_symmetry_dataset
import itertools

In [None]:
#| hide
from ase.build import bulk
from numpy import allclose

In [None]:
#| exporti
flatten = itertools.chain.from_iterable

In [None]:
#| exporti
def select_asap_model(comp='SiC'):
    '''
    This simple function selects the latest *working* OpenKIM model
    containing `comp` in the name. Required since some models are 
    not loding properly and the names are not stable.
    OUTPUT
    ------
    Name of the model. If nothing is found returns None.
    '''
    import asap3
    models = []
    for pot in [pot for pot in asap3.OpenKIMavailable() if comp in pot]:
        try :
            calc = asap3.OpenKIMcalculator(pot)
            models.append(pot)
        except asap3.AsapError :
            pass
    if models :
        model = sorted(models, key=lambda m: m.split('_')[3])[-1]
    else :
        model = None
    return model

In [None]:
#| exporti
def create_asap_calculator(model):    
    import asap3
    return asap3.OpenKIMcalculator(model)

In [None]:
#| export
def normalize_conf(c, base):
    '''
    Normalize the configuration `c` relative to the basic structure `base`.
    Normalization is performed by "nuwrapping" the displacements of atoms
    when they cross the periodic boundary conditions in such a way that the
    atoms are not "jumping" from one side of the cell to the other. 
    
    E.g. if the atom at r=(0,0,0) goes to the relative position (-0.01, 0, 0)
    it is "wrapped" by PBC to the r=(0.99, 0, 0). Thus if we naiively calculate
    the displacement we will get a large positive displacement (0.99 of the cell 
    vector) instead of a small negative one. 
    
    This function reverses that process making the positions suitable for 
    differentiation. The positions may be part of a continous trajectory or
    just independent configurations. This makes it impossible for described 
    procedure to work if the displacements are above 1/2 of the unit cell.
    For sefety this implementation is limited to displacements < 1/3 of the 
    unit cell. If any coordinate changes by more then 1/3 the function
    will rise an AssertionError exception.
    
    This implementation is not suitable for tracking positions in the system
    with systematic drift (e.g. long MD trajectory with non-perfect momentum
    conservation). For stronger implementation suitable for such cases look
    at dxutils package.
    '''
    cell = base.get_cell()
    spos = c.get_scaled_positions()
    bspos = base.get_scaled_positions()

    # Unwrap the displacement relative to base
    sdx = spos - bspos
    sht = (sdx < -0.5)*1 - (sdx > 0.5)*1
    sdx += sht

    # Check if fractional displacements are below 1/3
    assert (abs(sdx) < 1/3).all()

    # Calculate unwrapped spos
    spos = bspos + sdx

    # Return carthesian positions, fractional positions
    return dot(spos,cell), spos

In [None]:
#| export
def write_dfset(fn, c):
    '''
    Append displacement-force data from the conf to the fn file.
    The format is suitable for use as ALAMODE DFSET file.
    Optionaly you can provide configuration number in n.
    File need not exist prior to first call. 
    If it does not it will be created.
    '''
    n, i, x, f, e = c
    with open(fn, 'at') as dfset:
        print(f'# set: {n:04d} config: {i:04d}  energy: {e:8e} eV/at', file=dfset)
        for ui, fi in zip(x,f):
            print((3*'%15.7f ' + '     ' + 3*'%15.8e ') % 
                        (tuple(ui/un.Bohr) + tuple(fi*un.Bohr/un.Ry)), 
                        file=dfset)

In [None]:
#| export
def calc_init_xscale(cryst, xsl, skip=None):
    '''
    Calculate initial xscale amplitude correction coefficients 
    from the history exported from the previous calculation 
    (with `xscale_list` argument). 
    
    INPUT
    -----
    cryst : ASE structure 
    xsl   : List of amplitude correction coefficients. The shape of 
            each element of the list must be `cryst.get_positions().shape`
    skip  : Number of samples to skip at the start of the xsl list
    
    OUTPUT
    ------
    Array amplitude correction coefficients with shape the same as
    `cryst.get_positions().shape`. May be directly plugged into 
    `xscale_init` argument of `HECSS_Sampler` or `HECSS`.
    '''
    from numpy import array, ones
    elmap = cryst.get_atomic_numbers()
    if skip is not None:
        skip = min(skip, len(xsl)//2)
    xs = array(xsl)[skip:]
    xscale = ones(xs[0].shape)
    for i, el in enumerate(set(elmap)):
        xscale[elmap==el] = xs[:,elmap==el,:].mean()
    return xscale

In [None]:
#| hide

# Use structure with most of the atoms 
# on the surface of the unit cell: rocksalt
b = bulk('NaCl', 'rocksalt', a=3.6, cubic=True).repeat((3,3,3))
c = bulk('NaCl', 'rocksalt', a=3.6, cubic=True).repeat((3,3,3))

# Displace the atoms
c.rattle(stdev=0.5, seed=42)

# Store unwrapped positions
unwrapped = c.get_scaled_positions(wrap=False)

# Wrap the positions
c.set_scaled_positions(c.get_scaled_positions(wrap=True))

# Check if the reversal worked
assert allclose(unwrapped, normalize_conf(c,b)[1])