In [None]:
from colabfit.tools.database import MongoDatabase, load_data
from colabfit.tools.property_settings import PropertySettings

client = MongoDatabase('colabfit_rebuild', nprocs=1, drop_database=True)

In [None]:
configurations = load_data(
    file_path='/colabfit/data/gubaev/AlNiTi/train_1st_stage.cfg',
    file_format='cfg',
    name_field=None,
    elements=['Al', 'Ni', 'Ti'],
    default_name='train_1st_stage',
    verbose=True,
    generator=False
)

configurations += load_data(
    file_path='/colabfit/data/gubaev/AlNiTi/train_2nd_stage.cfg',
    file_format='cfg',
    name_field=None,
    elements=['Al', 'Ni', 'Ti'],
    default_name='train_2nd_stage',
    verbose=True,
    generator=False
)

In [None]:
client.insert_property_definition({
    'property-id': 'energy-forces-stress',
    'property-title': 'Basic outputs from a static calculation',
    'property-description':
        'Energy, forces, and stresses from a calculation of a '\
        'static configuration. Energies must be specified to be '\
        'per-atom or supercell. If a reference energy has been '\
        'used, this must be specified as well.',

    'energy': {
        'type': 'float',
        'has-unit': True,
        'extent': [],
        'required': False,
        'description':
            'The potential energy of the system.'
    },
    'forces': {
        'type': 'float',
        'has-unit': True,
        'extent': [":", 3],
        'required': False,
        'description':
            'The [x,y,z] components of the force on each particle.'
    },
    'stress': {
        'type': 'float',
        'has-unit': True,
        'extent': [3, 3],
        'required': False,
        'description':
            'The full Cauchy stress tensor of the simulation cell'
    },

    'per-atom': {
        'type': 'bool',
        'has-unit': False,
        'extent': [],
        'required': True,
        'description':
            'If True, "energy" is the total energy of the system, '\
            'and has NOT been divided by the number of atoms in the '\
            'configuration.'
    },
    'reference-energy': {
        'type': 'float',
        'has-unit': True,
        'extent': [],
        'required': False,
        'description':
            'If provided, then "energy" is the energy (either of '\
            'the whole system, or per-atom) LESS the energy of '\
            'a reference configuration (E = E_0 - E_reference). '\
            'Note that "reference-energy" is just provided for '\
            'documentation, and that "energy" should already have '\
            'this value subtracted off. The reference energy must '\
            'have the same units as "energy".'
    },
})

In [None]:
property_map = {
    'energy-forces-stress': [{
        # ColabFit name: {'field': ASE field name, 'units': str}
        'energy':   {'field': 'energy',  'units': 'eV'},
        'forces':   {'field': 'forces',  'units': 'eV/Ang'},
        'stress':   {'field': 'virial',  'units': 'GPa'},
        'per-atom': {'field': 'per-atom', 'units': None},
        
        '_settings': {
            '_method': 'VASP',
            '_description': 'energy/forces/stresses',
            '_files': None,
            '_labels': None
        }
    }]
}

In [None]:
def tform(c):
    c.info['per-atom'] = False

In [None]:
ids = list(client.insert_data(
    configurations,
    property_map=property_map,
    generator=False,
    transform=tform,
    verbose=True
))

all_co_ids, all_pr_ids = list(zip(*ids))

In [None]:
len(set(all_co_ids))

In [None]:
len(set(all_pr_ids))

In [None]:
cs_regexes = {
    '.*':
        'Configurations generated using active learning by iteratively '\
        'fitting a MTP model, identifying configurations that required the '\
        'MTP to extrapolate, re-computing the energies/forces/structures of '\
        'those configurations with DFT, then retraining the MTP model.',
    'train_1st_stage':
        'Configurations used in the first stage of training',
    'train_2nd_stage':
        'Configurations used in the second stage of training',
}

cs_ids = []

for i, (regex, desc) in enumerate(cs_regexes.items()):
    co_ids = client.get_data(
        'configurations',
        fields='_id',
        query={'_id': {'$in': all_co_ids}, 'names': {'$regex': regex}},
        ravel=True
    ).tolist()

    print(f'Configuration set {i}', f'({regex}):'.rjust(22), f'{len(co_ids)}'.rjust(7))

    cs_id = client.insert_configuration_set(co_ids, description=desc, verbose=True)

    cs_ids.append(cs_id)

In [None]:
ds_id = client.insert_dataset(
    cs_ids=cs_ids,
    pr_ids=all_pr_ids,
    name='AlNiTi_CMS2019',
    authors=[
        'K. Gubaev', 'E. V. Podryabinkin', 'G. L. W. Hart', 'A. V. Shapeev'
    ],
    links=[
        'https://www.sciencedirect.com/science/article/pii/S0927025618306372?via%3Dihub',
        'https://gitlab.com/kgubaev/accelerating-high-throughput-searches-for-new-alloys-with-active-learning-data',
    ],
    description =  'This dataset was generated using the following active '\
    'learning scheme: 1) candidate structures relaxed by a partially-trained '\
    'MTP model, 2) structures for which the MTP had to perform extrapolation '\
    'are passed to DFT to be re-computed, 3) the MTP is retrained included '\
    'the structures that were re-computed with DFT, 4) steps 1-3 are repeated '\
    'until the MTP does not extrapolate on any of the original candidate '\
    'structures. The original candidate structures for this dataset included '\
    'about 375,000 binary and ternary structures enumerating all possible '\
    'unit cells with different symmetries (BCC, FCC, and HCP) and different '\
    'number of atoms',
    resync=True,
    verbose=True,
)
ds_id

In [None]:
client.apply_labels(
    dataset_id=ds_id,
    collection_name='configurations',
    query={'_id': {'$in': all_co_ids}},
    labels='active_learning',
    verbose=True
)

In [None]:
# ds_id = 'DS_930103518389_000'

dataset = client.get_dataset(ds_id, resync=True, verbose=True)['dataset']

for k,v in dataset.aggregated_info.items():
    print(k,v)

In [None]:
dataset.aggregated_info['property_fields']

In [None]:
_ = client.plot_histograms(dataset.aggregated_info['property_fields'], ids=dataset.property_ids, method='matplotlib', yscale='log')