This notebook serves as an example of how to load and manipulate the [Si GAP dataset](https://www.repository.cam.ac.uk/handle/1810/317974) using a `Dataset` object.

# Uncomment for Google Colab

# Imports

In [None]:
import os
import itertools
import numpy as np

# Initialize the database

In [None]:
from colabfit.tools.database import MongoDatabase

client = MongoDatabase('colabfit_rebuild')#, drop_database=True)  # drop_database=True overwrites existing

# Data loading

In [None]:
from colabfit.tools.database import load_data

images = load_data(
    file_path='/colabfit/data/gap_si/gp_iter6_sparse9k.xml.xyz',
    file_format='xyz',
    name_field=None,
    elements=['Si'],    # order matters for CFG files, but not others
    verbose=True
)
images

In [None]:
images = list(images)

In [None]:
images[0]

In [None]:
images[0].info.keys()

In [None]:
images[0].info['config_type']

In [None]:
images = list(load_data(
    file_path='/colabfit/data/gap_si/gp_iter6_sparse9k.xml.xyz',
    file_format='xyz',
    name_field='config_type',  # key in Configuration.info to use as the Configuration name
    elements=['Si'],    # order matters for CFG files, but not others
    default_name='Si_PRX_GAP',  # default name with `name_field` not found
    verbose=True
))

In [None]:
images[0].info['_name']

# Property definitions

In [None]:
set(itertools.chain.from_iterable(img.arrays.keys() for img in images))

In [None]:
set(itertools.chain.from_iterable(img.info.keys() for img in images))

In [None]:
base_definition = {
    'property-id': 'energy-forces-stress',
    'property-title': 'Basic outputs from a static calculation',
    'property-description':
        'Energy, forces, and stresses from a calculation of a '\
        'static configuration. Energies must be specified to be '\
        'per-atom or supercell. If a reference energy has been '\
        'used, this must be specified as well.',

    'energy': {
        'type': 'float',
        'has-unit': True,
        'extent': [],
        'required': True,
        'description':
            'The potential energy of the system.'
    },
    'forces': {
        'type': 'float',
        'has-unit': True,
        'extent': [":", 3],
        'required': False,
        'description':
            'The [x,y,z] components of the force on each particle.'
    },
    'stress': {
        'type': 'float',
        'has-unit': True,
        'extent': [3, 3],
        'required': False,
        'description':
            'The full Cauchy stress tensor of the simulation cell'
    },

    'per-atom': {
        'type': 'bool',
        'has-unit': False,
        'extent': [],
        'required': True,
        'description':
            'If True, "energy" is the total energy of the system, '\
            'and has NOT been divided by the number of atoms in the '\
            'configuration.'
    },
    'reference-energy': {
        'type': 'float',
        'has-unit': True,
        'extent': [],
        'required': False,
        'description':
            'If provided, then "energy" is the energy (either of '\
            'the whole system, or per-atom) LESS the energy of '\
            'a reference configuration (E = E_0 - E_reference). '\
            'Note that "reference-energy" is just provided for '\
            'documentation, and that "energy" should already have '\
            'this value subtracted off. The reference energy must '\
            'have the same units as "energy".'
    },
}

In [None]:
client.insert_property_definition(base_definition)

# Adding to the database using `insert_data()`

In [None]:
# Data stored on atoms needs to be cleaned
def tform(img):
    img.info['per-atom'] = False
    
    # Renaming some fields to be consistent
    info_items = list(img.info.items())
    
    for key, v in info_items:
        if key in ['_name', '_labels', '_constraints']:
            continue
            
        del img.info[key]
        img.info[key.replace('_', '-').lower()] = v

    arrays_items = list(img.arrays.items())
    for key, v in arrays_items:
        del img.arrays[key]
        img.arrays[key.replace('_', '-').lower()] = v
    
    # Converting some string values to floats
    for k in [
        'md-temperature', 'md-cell-t', 'smearing-width', 'md-delta-t',
        'md-ion-t', 'cut-off-energy', 'elec-energy-tol',
        ]:
        if k in img.info:
            try:
                img.info[k] = float(img.info[k].split(' ')[0])
            except:
                pass
    
    # Reshaping shape (9,) stress vector to (3, 3) to match definition
    if 'dft-virial' in img.info:
        img.info['dft-virial'] = img.info['dft-virial'].reshape((3,3))
        
    if 'gap-virial' in img.info:
            img.info['gap-virial'] = img.info['gap-virial'].reshape((3,3))

# Property maps

```python
property_map = {
    'property-id': [
        { # mapping configuration fields to property keys and property settings
            
            {key: {'field': ..., 'units': ...}}, # fields for definition keys
            
            '_settings': {
                key: {'field': ..., 'units': ...}, # fields for settings keys
                
                '_method':      ...,
                '_description': ...,
                '_files':       ...,
                '_labels':      ...,
            }
        }
    ]
}
```

## Data being loaded:

1. `energy-forces-stress` PIs computed using CASTEP
2. `energy-forces-stress` PIs computed using a trained GAP model
3. Additional settings information for each

In [None]:
dft_map = {
    # Property Definition field: {'field': ASE field, 'units': ASE-readable units}
    'energy': {'field': 'dft-energy', 'units': 'eV'},
    'forces': {'field': 'dft-force',  'units': 'eV/Ang'},
    'stress': {'field': 'dft-virial', 'units': 'GPa'},
    'per-atom': {'field': 'per-atom', 'units': None},
}

In [None]:
gap_map = {
    # Property Definition field: {'field': ASE field, 'units': ASE-readable units}
    'energy': {'field': 'gap-energy', 'units': 'eV'},
    'forces': {'field': 'gap-force',  'units': 'eV/Ang'},
    'stress': {'field': 'gap-virial', 'units': 'GPa'},
    'per-atom': {'field': 'per-atom', 'units': None},
}

In [None]:
settings_keys = [
    'mix-history-length',
    'castep-file-name',
    'grid-scale',
    'popn-calculate',
    'n-neighb',
    'oldpos',
    'i-step',
    'md-temperature',
    'positions',
    'task',
    'data-distribution',
    'avg-ke',
    'force-nlpot',
    'continuation',
    'castep-run-time',
    'calculate-stress',
    'minim-hydrostatic-strain',
    'avgpos',
    'frac-pos',
    'hamiltonian',
    'md-cell-t',
    'cutoff-factor',
    'momenta',
    'elec-energy-tol',
    'mixing-scheme',
    'minim-lattice-fix',
    'in-file',
    'travel',
    'thermostat-region',
    'time',
    'temperature',
    'kpoints-mp-grid',
    'cutoff',
    'xc-functional',
    'smearing-width',
    'pressure',
    'reuse',
    'fix-occupancy',
    'map-shift',
    'md-num-iter',
    'damp-mask',
    'opt-strategy',
    'spin-polarized',
    'nextra-bands',
    'fine-grid-scale',
    'masses',
    'iprint',
    'finite-basis-corr',
    'enthalpy',
    'opt-strategy-bias',
    'force-ewald',
    'num-dump-cycles',
    'velo',
    'md-delta-t',
    'md-ion-t',
    'force-locpot',
    'numbers',
    'max-scf-cycles',
    'mass',
    'minim-constant-volume',
    'cut-off-energy',
    'virial',
    'nneightol',
    'max-charge-amp',
    'md-thermostat',
    'md-ensemble',
    'acc',
]

units = {
    'energy': 'eV',
    'forces': 'eV/Ang',
    'virial': 'GPa',
    'oldpos': 'Ang',
    'md-temperature': 'K',
    'positions': 'Ang',
    'avg-ke': 'eV',
    'force-nlpot': 'eV/Ang',
    'castep-run-time': 's',
    'avgpos': 'Ang',
    'md-cell-t': 'ps',
    'time': 's',
    'temperature': 'K',
    'gap-force': 'eV/Ang',
    'gap-energy': 'eV',
    'cutoff': 'Ang',
    'smearing-width': 'eV',
    'pressure': 'GPa',
    'gap-virial': 'GPa',
    'masses': '_amu',
    'enthalpy': 'eV',
    'force-ewald': 'eV/Ang',
    'velo': 'Ang/s',
    'md-delta-t': 'fs',
    'md-ion-t': 'ps',
    'force-locpot': 'eV/Ang',
    'mass': 'g',
    'cut-off-energy': 'eV',
    'virial': 'GPa',
}

In [None]:
dft_settings_map = {
    k: {'field': k, 'units': units[k] if k in units else None} for k in settings_keys
}

dft_settings_map['_method'] = 'CASTEP'
dft_settings_map['_description'] = 'DFT calculations using the CASTEP software'
dft_settings_map['_files'] = None
dft_settings_map['_labels'] = ['Monkhorst-Pack']

In [None]:
gap_settings_map = dict(dft_settings_map)

gap_settings_map['_method'] = 'GAP'
gap_settings_map['_description'] = 'Predictions using a trained GAP potential'
gap_settings_map['_files'] = None
gap_settings_map['_labels'] = ['GAP', 'classical']

In [None]:
dft_map['_settings'] = dft_settings_map
gap_map['_settings'] = gap_settings_map

In [None]:
property_map = {
    'energy-forces-stress': [
        dft_map,
        gap_map,
    ]
}

In [None]:
ids = client.insert_data(
    images,
    property_map=property_map,
    transform=tform,
    verbose=True
)

In [None]:
client.properties.count_documents({'methods': 'CASTEP'})

In [None]:
client.properties.count_documents({'methods': 'GAP'})

In [None]:
ids[0]

# Building Configuration Sets

In [None]:
# Used for building groups of configurations for easier analysis/exploration
configuration_set_regexes = {
    'isolated_atom': 'Reference atom',
    'bt': 'Beta-tin',
    'dia': 'Diamond',
    'sh': 'Simple hexagonal',
    'hex_diamond': 'Hexagonal diamond',
    'bcc': 'Body-centered-cubic',
    'bc8': 'BC8',
    'fcc': 'Face-centered-cubic',
    'hcp': 'Hexagonal-close-packed',
    'st12': 'ST12',
    'liq': 'Liquid',
    'amorph': 'Amorphous',
    'surface_001': 'Diamond surface (001)',
    'surface_110': 'Diamond surface (110)',
    'surface_111': 'Diamond surface (111)',
    'surface_111_pandey': 'Pandey reconstruction of diamond (111) surface',
    'surface_111_3x3_das': 'Dimer-adatom-stacking-fault (DAS) reconstruction',
    '111adatom': 'Configurations with adatom on (111) surface',
    'crack_110_1-10': 'Small (110) crack tip',
    'crack_111_1-10': 'Small (111) crack tip',
    'decohesion': 'Decohesion of diamond-structure Si along various directions',
    'divacancy': 'Diamond divacancy configurations',
    'interstitial': 'Diamond interstitial configurations',
    'screw_disloc': 'Si screw dislocation core',
    'sp': 'sp bonded configurations',
    'sp2': 'sp2 bonded configurations',
    'vacancy': 'Diamond vacancy configurations'
}

In [None]:
cs_ids = []

for i, (regex, desc) in enumerate(configuration_set_regexes.items()):
    co_ids = client.get_data(
        'configurations',
        fields='_id',
        query={'names': {'$regex': regex}},
        ravel=True
    ).tolist()
    
    print(f'Configuration set {i}', f'({regex}):'.rjust(22), f'{len(co_ids)}'.rjust(7))

    cs_id = client.insert_configuration_set(co_ids, description=desc)
    
    cs_ids.append(cs_id)

# Building the Dataset

In [None]:
all_co_ids, all_pr_ids = list(zip(*ids))  # returned by insert_data()
len(all_pr_ids)

In [None]:
len(cs_ids)

In [None]:
len(all_pr_ids)

In [None]:
ds_id = client.insert_dataset(
    cs_ids=cs_ids,
    pr_ids=all_pr_ids,
    name='Si_PRX_GAP',
    authors=[
        'Albert P. Bartók', 'James Kermode', 'Noam Bernstein', 'Gábor Csányi'
    ],
    links=[
        'https://journals.aps.org/prx/abstract/10.1103/PhysRevX.8.041048',
        'https://www.repository.cam.ac.uk/handle/1810/317974'
    ],
    description=\
        "The original DFT training data for the general-purpose silicon "\
        "interatomic potential described in the associated publication."\
        " The kinds of configuration that we include are chosen using "\
        "intuition and past experience to guide what needs to be included "\
        "to obtain good coverage pertaining to a range of properties.",
    verbose=True,
)
ds_id

In [None]:
list(client.datasets.find({}, {'name'}))

# Adding labels

In [None]:
# Used to apply metadata labels to configurations for future queries
configuration_label_regexes = {
    'isolated_atom': 'isolated_atom',
    'bt': 'a5',
    'dia': 'diamond',
    'sh': 'sh',
    'hex_diamond': 'sonsdaleite',
    'bcc': 'bcc',
    'bc8': 'bc8',
    'fcc': 'fcc',
    'hcp': 'hcp',
    'st12': 'st12',
    'liq': 'liquid',
    'amorph': 'amorphous',
    'surface_001': ['surface', '001'],
    'surface_110': ['surface', '110'],
    'surface_111': ['surface', '111'],
    'surface_111_pandey': ['surface', '111'],
    'surface_111_3x3_das': ['surface', '111', 'das'],
    '111adatom': ['surface', '111', 'adatom'],
    'crack_110_1-10': ['crack', '110'],
    'crack_111_1-10': ['crac', '111'],
    'decohesion': ['diamond', 'decohesion'],
    'divacancy': ['diamond', 'vacancy', 'divacancy'],
    'interstitial': ['diamond', 'interstitial'],
    'screw_disloc': ['screw', 'dislocation'],
    'sp': 'sp',
    'sp2': 'sp2',
    'vacancy': ['diamond', 'vacancy']
}

In [None]:
for regex, labels in configuration_label_regexes.items():
    client.apply_labels(
        dataset_id=ds_id,
        collection_name='configurations',
        query={'names': {'$regex': regex}},
        labels=labels,
        verbose=True
    )

In [None]:
client.datasets.find_one({}, {'aggregated_info.configuration_labels'})

In [None]:
client.resync_dataset(ds_id, verbose=True)

In [None]:
client.datasets.find_one({}, {'aggregated_info.configuration_labels'})

# Next up: exploring the dataset

In [None]:
from colabfit.tools.database import MongoDatabase, load_data

client = MongoDatabase('colabfit_rebuild')

In [None]:
ds_id = 'DS_238054846094_000'
dataset = client.get_dataset(ds_id, resync=True)['dataset']
dataset

In [None]:
for k,v in dataset.aggregated_info.items():
    print(k, '\n\t', v)

In [None]:
client.get_data('properties', 'energy-forces-stress.forces', concatenate=True).shape

In [None]:
client.properties.count_documents({'methods': 'CASTEP'})

In [None]:
client.properties.count_documents({'methods': 'GAP', 'energy-forces-stress.energy': {'$exists': 1}})

In [None]:
fig = client.plot_histograms(
    ['energy-forces-stress.energy', 'energy-forces-stress.forces', 'energy-forces-stress.stress'],
    query={'methods': 'CASTEP'},
    yscale='log',
    ids=dataset.property_ids,
    verbose=True,
)

In [None]:
fig = client.plot_histograms(
    ['energy-forces-stress.energy', 'energy-forces-stress.forces', 'energy-forces-stress.stress'],
    query={'methods': 'GAP'},
    yscale='log',
    ids=dataset.property_ids,
    verbose=True,
)

In [None]:
client.get_statistics(
    ['energy-forces-stress.energy', 'energy-forces-stress.forces', 'energy-forces-stress.stress'],
    ids=dataset.property_ids,
    query={'methods': 'CASTEP'},
    verbose=True
)

In [None]:
client.get_statistics(
    ['energy-forces-stress.energy', 'energy-forces-stress.forces', 'energy-forces-stress.stress'],
    ids=dataset.property_ids,
    query={'methods': 'GAP'},
    verbose=True
)

In [None]:
client.configurations.find_one({'names': 'st12'}, {'_id', 'names'})

In [None]:
conf = client.get_configuration('CO_204525967030_000')

from ase.visualize import view

# Creates a Jupyter Widget; may require `pip install nglview` first
view([conf], viewer='nglview')