In [None]:
from colabfit import SHORT_ID_STRING_NAME

from colabfit.tools.database import MongoDatabase, load_data
from colabfit.tools.property_settings import PropertySettings
from colabfit.tools.configuration import AtomicConfiguration

In [None]:
client = MongoDatabase('colabfit', configuration_type=AtomicConfiguration, nprocs=1)

In [None]:
name = 'CuPd_CMS2019'

configurations = load_data(
    file_path='../../../data/gubaev/CuPd/train.cfg',
    file_format='cfg',
    name_field=None,
    elements=['Cu', 'Pd'],
    default_name=name,
    verbose=True,
)

In [None]:
property_map = {
    'potential-energy': [{
        'energy':   {'field': 'energy',  'units': 'eV'},
        'per-atom': {'value': False, 'units': None},
        
        '_settings': {
            'method': 'VASP',
            'description': 'static calculation',
            'files': None,
            'labels': None
        }
    }],
    
    'atomic-forces': [{
        'forces':   {'field': 'forces',  'units': 'eV/Ang'},
        
        '_settings': {
            'method': 'VASP',
            'description': 'Static calculation',
            'files': None,
            'labels': None
        }
    }],
    
    'cauchy-stress': [{
        'stress':   {'field': 'virial',  'units': 'GPa'},
        
        '_settings': {
            'method': 'VASP',
            'description': 'Static calculation',
            'files': None,
            'labels': None
        }
    }],
}

In [None]:
ids = list(client.insert_data(
    configurations,
    property_map=property_map,
    generator=False,
    verbose=True
))

all_co_ids, all_pr_ids = list(zip(*ids))

In [None]:
len(set(all_co_ids))

In [None]:
len(set(all_pr_ids))

In [None]:
cs_regexes = {
    '.*':
        'Configurations generated using active learning by iteratively '\
        'fitting a MTP model, identifying configurations that required the MTP to '\
        'extrapolate, re-computing the energies/forces/structures of those '\
        'configurations with DFT, then retraining the MTP model.',
}

cs_ids = []

for i, (regex, desc) in enumerate(cs_regexes.items()):
    co_ids = client.get_data(
        'configurations',
        fields=SHORT_ID_STRING_NAME,
        query={SHORT_ID_STRING_NAME: {'$in': all_co_ids}, 'names': {'$regex': regex}},
        ravel=True
    ).tolist()

    print(f'Configuration set {i}', f'({regex}):'.rjust(22), f'{len(co_ids)}'.rjust(7))

    cs_id = client.insert_configuration_set(co_ids, description=desc)

    cs_ids.append(cs_id)

In [None]:
ds_id = client.insert_dataset(
    cs_ids=cs_ids,
    pr_ids=all_pr_ids,
    name='CuPd_CMS2019',
    authors=[
        'K. Gubaev', 'E. V. Podryabinkin', 'G. L. W. Hart', 'A. V. Shapeev'
    ],
    links=[
        'https://www.sciencedirect.com/science/article/pii/S0927025618306372?via%3Dihub',
        'https://gitlab.com/kgubaev/accelerating-high-throughput-searches-for-new-alloys-with-active-learning-data',
    ],
    description =  'This dataset was generated using the following active '\
    'learning scheme: 1) candidate structures relaxed by a partially-trained '\
    'MTP model, 2) structures for which the MTP had to perform extrapolation '\
    'are passed to DFT to be re-computed, 3) the MTP is retrained included '\
    'the structures that were re-computed with DFT, 4) steps 1-3 are repeated '\
    'until the MTP does not extrapolate on any of the original candidate '\
    'structures. The original candidate structures for this dataset included '\
    '40,000 unrelaxed configurations with BCC, FCC, and HCP lattices.',
    resync=True,
    verbose=True,
)
ds_id

In [None]:
client.apply_labels(
    dataset_id=ds_id,
    collection_name='configurations',
    query={SHORT_ID_STRING_NAME: {'$in': all_co_ids}},
    labels='active_learning',
    verbose=True
)

In [None]:
# ds_id = 'DS_124291675899_000'
dataset = client.get_dataset(ds_id, resync=True, verbose=True)['dataset']

for k,v in dataset.aggregated_info.items():
    print(k,v)

In [None]:
client.plot_histograms(dataset.aggregated_info['property_fields'], ids=dataset.property_ids, method='matplotlib')