# Setup

In [None]:
import json
import numpy as np

from pymatgen import Structure
from pymatgen.io.ase import AseAtomsAdaptor

In [None]:
from colabfit.tools.configuration import Configuration
from colabfit.tools.database import MongoDatabase, load_data
from colabfit.tools.property_settings import PropertySettings

client = MongoDatabase('colabfit_rebuild', nprocs=1)#, drop_database=True)

# Data loading

In [None]:
def reader(path):
    adaptor = AseAtomsAdaptor()
    
    with open(path, 'r') as f:
        data = json.load(f)

        group_counts = {}

        for entry in data:
            struct = Structure.from_dict(entry['structure'])
            atoms  = adaptor.get_atoms(struct)

            # Adding labels
            i = group_counts.get(entry['group'], 0)
            group_counts[entry['group']] = i + 1
            
            atoms.info['_labels'] = entry['group'].lower()
            
            # Generating names
            clean_name = '_'.join(entry['description'].split(' '))
            clean_name = clean_name.replace('/', '_')
            clean_name = clean_name.replace('(', '_')
            clean_name = clean_name.replace(')', '_')
            clean_name = clean_name.replace(',', '')
            
            clean_name = f'{entry["tag"]}_{clean_name}'
            
            # Loading DFT-computed values
            atoms.info['_name'] = [clean_name]
            
            atoms.info['per-atom'] = True
            atoms.info['energy'] = entry['outputs']['energy'] / entry['num_atoms']
            
            atoms.arrays['forces'] = np.array(entry['outputs']['forces'])
            
            stress = np.zeros((3,3))
            stress[0,0] = entry['outputs']['virial_stress'][0]
            stress[1,1] = entry['outputs']['virial_stress'][1]
            stress[2,2] = entry['outputs']['virial_stress'][2]
            stress[1,2] = entry['outputs']['virial_stress'][3]
            stress[0,2] = entry['outputs']['virial_stress'][4]
            stress[0,1] = entry['outputs']['virial_stress'][5]
            
            atoms.info['stress'] = stress
            
            # Add DFT settings
            atoms.info['ke_cutoff'] = 520 # eV
            
            if 'Li' in atoms.get_chemical_symbols():
                atoms.info['k-point-mesh'] = '3x3x3'
            else:
                atoms.info['k-point-mesh'] = '4x4x4'
                
            atoms.info['energy-convergence'] = 1e-5 # eV
            atoms.info['forces-convergence'] = 0.02 # eV/Ang
                
            yield Configuration.from_ase(atoms)

In [None]:
from colabfit.tools.database import load_data

images = list(load_data(
    file_path='/colabfit/data/mlearn',
    file_format='folder',
    name_field='_name',  # key in Configuration.info to use as the Configuration name
    elements=['Cu', 'Ge', 'Li', 'Mo', 'Ni', 'Si'],    # order matters for CFG files, but not others
    default_name='mlearn',  # default name with `name_field` not found
    reader=reader,
    glob_string='*test.json',
#     glob_string='*training.json',
    verbose=True
))
len(images)

In [None]:
import itertools
set(itertools.chain.from_iterable([a.info['_labels'] for a in images]))

# Properties

In [None]:
base_definition = {
    'property-id': 'energy-forces-stress',
    'property-title': 'Basic outputs from a static calculation',
    'property-description':
        'Energy, forces, and stresses from a calculation of a '\
        'static configuration. Energies must be specified to be '\
        'per-atom or supercell. If a reference energy has been '\
        'used, this must be specified as well.',

    'energy': {
        'type': 'float',
        'has-unit': True,
        'extent': [],
        'required': True,
        'description':
            'The potential energy of the system.'
    },
    'forces': {
        'type': 'float',
        'has-unit': True,
        'extent': [":", 3],
        'required': False,
        'description':
            'The [x,y,z] components of the force on each particle.'
    },
    'stress': {
        'type': 'float',
        'has-unit': True,
        'extent': [3, 3],
        'required': False,
        'description':
            'The full Cauchy stress tensor of the simulation cell'
    },

    'per-atom': {
        'type': 'bool',
        'has-unit': False,
        'extent': [],
        'required': True,
        'description':
            'If True, "energy" is the total energy of the system, '\
            'and has NOT been divided by the number of atoms in the '\
            'configuration.'
    },
    'reference-energy': {
        'type': 'float',
        'has-unit': True,
        'extent': [],
        'required': False,
        'description':
            'If provided, then "energy" is the energy (either of '\
            'the whole system, or per-atom) LESS the energy of '\
            'a reference configuration (E = E_0 - E_reference). '\
            'Note that "reference-energy" is just provided for '\
            'documentation, and that "energy" should already have '\
            'this value subtracted off. The reference energy must '\
            'have the same units as "energy".'
    },
}

In [None]:
client.insert_property_definition(base_definition)

In [None]:
property_map = {
    'energy-forces-stress': [
        {
            'energy': {'field': 'energy', 'units': 'eV'},
            'forces': {'field': 'forces', 'units': 'eV/Ang'},
            'stress': {'field': 'stress', 'units': 'kilobar'},
            'per-atom': {'field': 'per-atom', 'units': None},
            
            '_settings': {
                '_method': 'VASP',
                '_description': 'VASP 5.4.1 calculations using the projector augmented wave approach',
                '_files': None,
                '_labels': ['PBE', 'GGA'],
                
                'kinetic-energy-cutoff': {'field': 'ke_cutoff',    'units': 'eV'} ,
                'k-point-mesh':          {'field': 'k-point-mesh', 'units': None},
                'energy-convergence':    {'field': 'energy-convergence', 'units': 'eV'},
                'forces-convergence':    {'field': 'forces-convergence', 'units': 'eV/Ang'},
            }
        }
    ]
}

# Inserting data

In [None]:
ids = client.insert_data(
    images,
    property_map=property_map,
    verbose=True
)

In [None]:
all_co_ids, all_pr_ids = list(zip(*ids))

In [None]:
len(all_co_ids), len(all_pr_ids)

# Building Configuration Sets

In [None]:
# Used for building groups of configurations for easier analysis/exploration
configuration_set_regexes = {
    'Ground|relaxed':
        'Ground state structure',
    'Vacancy':
        'NVT AIMD simulations of the bulk supercells (similar to those in point 2) '\
        'with a single vacancy performed at 300 K and 2.0× of the melting point of '\
        'each element. The bulk supercells were heated from 0 K to the target temperatures '\
        'and equilibrated for 20 000 time steps. A total of 40 snapshots were obtained '\
        'from the subsequent production run of each AIMD simulation at an interval of 0.1 ps.',
    'AIMD_NVT':
        'NVT ab initio molecular dynamics (AIMD) simulations of the bulk supercells '\
        '(similar to those in point 2) performed at 300 K and 0.5×, 0.9×, 1.5×, and '\
        '2.0× of the melting point of each element with a time step of 2 fs. The bulk '\
        'supercells were heated from 0 K to the target temperatures and equilibrated for '\
        '20 000 time steps. A total of 20 snapshots were obtained from the subsequent '\
        'production run in each AIMD simulation at an interval of 0.1 ps.',
    'surface':
        'Slab structures up to a maximum Miller index of three, including (100), (110), '\
        '(111), (210), (211), (310), (311), (320), (321), (322), (331), and (332), as '\
        'obtained from the Crystalium database.',
    'strain':
        'Strained structures constructed by applying strains of −10% to 10% at 2% '\
        'intervals to the bulk supercell in six different modes, as described in the '\
        'work by de Jong et al.63 The supercells used are the 3 × 3 × 3, 3 × 3 × 3, and '\
        '2 × 2 × 2 of the conventional bcc, fcc, and diamond unit cells, respectively',
}

In [None]:
images[1]

In [None]:
images[1].info

In [None]:
cs_ids = {k: [] for k in ['Cu', 'Ge', 'Li', 'Mo', 'Ni', 'Si']}

co_ids_recheck = []
for elem in cs_ids.keys():
    print(elem)
    count = 0
    for i, (regex, desc) in enumerate(configuration_set_regexes.items()):
        co_ids = client.get_data(
            'configurations',
            fields='_id',
            query={
                'names': {'$regex': 'test.*' + regex},
                'elements': elem,
            },
            ravel=True
        ).tolist()
        
        if co_ids:
            co_ids_recheck += co_ids

            print(f'\tConfiguration set {i}', f'({regex}):'.rjust(22), f'{len(co_ids)}'.rjust(7))

            cs_id = client.insert_configuration_set(co_ids, description=desc)

            cs_ids[elem].append(cs_id)

            count += len(co_ids)
        
    print('\tTotal:', count)

In [None]:
ds_ids = []
for elem, e_cs_ids in cs_ids.items():
    ds_id = client.insert_dataset(
        cs_ids=e_cs_ids,
        pr_ids=all_pr_ids,
        name='mlearn_'+elem+'_test',
        authors=[
            'Yunxing Zuo', 'Chi Chen', 'Xiangguo Li',
            'Zhi Deng', 'Yiming Chen', 'Jörg Behler',
            'Gábor Csányi', 'Alexander V. Shapeev',
            'Aidan P. Thompson', 'Mitchell A. Wood',
            'Shyue Ping Ong'
        ],
        links=[
            'https://pubs.acs.org/doi/10.1021/acs.jpca.9b08723',
            'https://arxiv.org/abs/1906.08888',
            'https://github.com/materialsvirtuallab/mlearn'
        ],
        description=\
            'A comprehensive DFT data set was generated for six '\
            'elements - Li, Mo, Ni, Cu, Si, and Ge. These elements '\
            'were chosen to span a variety of chemistries (main group '\
            'metal, transition metal, and semiconductor), crystal '\
            'structures (bcc, fcc, and diamond) and bonding types '\
            '(metallic and covalent). This dataset comprises only the {}'\
            'configurations'.format(elem),
        verbose=True,
    )
    
    ds_ids.append(ds_id)

In [None]:
for did in ds_ids:
    dataset = client.get_dataset(did, resync=True)['dataset']
    
    agg = dataset.aggregated_info
    
    print(agg['elements'], agg['nconfigurations'], agg['nsites'], agg['property_types_counts'])

# Exploration

In [None]:
for did in ds_ids:
    dataset = client.get_dataset(did, resync=True)['dataset']
    
    fields = dataset.aggregated_info['property_fields']
    fields.remove('energy-forces-stress.per-atom')
    
    print(dataset.aggregated_info['elements'][0])
    
    client.plot_histograms(
        fields,
        ids=dataset.property_ids,
        yscale='log',
    )