In [1]:
export_configs = {
    "WyFomer generated datasets": {
        "mp_20": (
            ("WyckoffTransformer", ),
            ("WyckoffTransformer", "DiffCSP++10k"),
            ("WyckoffTransformer", "DiffCSP++10k", "CHGNet_free", "DFT"),
            ("WyckoffTransformer", "DiffCSP++10k", "CHGNet_free", "DFT-GGA-relax-1"),
            ("WyckoffTransformer", "CrySPR", "CHGNet_fix"),
            ("WyckoffTransformer", "CrySPR", "CHGNet_fix", "DFT"),
            ("WyckoffTransformer", "DiffCSP++"),
            ("WyckoffTransformer", "DiffCSP++", "DFT"),
        ),
        "mpts_52": (
            ("WyckoffTransformer", ),
            ("WyckoffTransformer", "CrySPR", "CHGNet_fix"))
    }}

In [2]:
import sys
sys.path.append("../..")
from evaluation.generated_dataset import GeneratedDataset, DATA_KEYS

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [3]:
from pathlib import Path
from omegaconf import OmegaConf
all_data = OmegaConf.load("../../generated/datasets.yaml")
del all_data['mp_20']["FlowMM"]
# These data are not used and might contain errors
del all_data["carbon_24"]
del all_data["perov_5"]
from collections import defaultdict
export_configs["generated_public"] = defaultdict(set)
def flatten_config(dataset, config, prefix=[]):
    for key, value in config.items():
        if key in DATA_KEYS:
            export_configs["generated_public"][dataset].add(tuple(prefix))
            if "path" in value:
                value["path"] = str(Path(value["path"]).parent / "data.csv.gz")
                value["storage_type"] = "monty"
            if "cache_key" in value:
                del value["cache_key"]
            if "storage_key" in value:
                del value["storage_key"]
        else:
            flatten_config(dataset, value, prefix + [key])
    if "structures" in config and "wyckoffs" in config:
        # No need to export two times
        del config["wyckoffs"]
    
for dataset_name, dataset_config in all_data.items():
    flatten_config(dataset_name, dataset_config)

In [4]:
from pathlib import Path
from monty.json import MontyEncoder
encoder = MontyEncoder()
def to_json(obj):
    if isinstance(obj, str):
        return obj
    if isinstance(obj, frozenset):
        obj = tuple(obj)
    return encoder.encode(obj)

In [None]:
from tqdm.auto import tqdm
from pickle import UnpicklingError
from scripts.cache_generated_datasets import compute_fields_and_cache
def export_data(export_path, export_config):
    export_path = Path(export_path)
    export_path.mkdir(parents=True, exist_ok=True)
    for dataset, transformation_tuples in tqdm(export_config.items()):
        for these_transformations in tqdm(transformation_tuples):
            print(f"Exporting {dataset_path}")
            dataset_path = export_path.joinpath(dataset).joinpath(*these_transformations) / "data.csv.gz"
            dataset_path.parent.mkdir(parents=True, exist_ok=True)
            try:
                dataset_processed = GeneratedDataset.from_cache(
                    transformations=these_transformations,
                    dataset=dataset)
            except (FileNotFoundError, UnpicklingError):
                dataset_raw = GeneratedDataset.from_transformations(
                    transformations=these_transformations,
                    dataset=dataset)
                dataset_processed = compute_fields_and_cache(dataset_raw)
            if "CHGNet" in these_transformations[-1]:
                dataset_processed.data.rename(columns={
                    "energy_per_atom": "chgnet_energy_per_atom",
                    "corrected_chgnet_ehull": "chgnet_e_above_hull_corrected",
                }, inplace=True)
            elif "DFT" in these_transformations[-1]:
                dataset_processed.data.rename(columns={
                    "e_above_hull_corrected": "dft_e_above_hull_corrected",
                    "e_uncorrected": "dft_e_uncorrected",
                    "e_corrected": "dft_e_corrected",
                }, inplace=True)
            export_filter = dataset_processed.data.filter(
                ["cdvae_crystal", "fingerprint", "composition", "naive_validity",
                "spacegroup_number", "density"], axis=1)
            dataset_processed.data.drop(export_filter, axis=1).map(to_json).to_csv(
                dataset_path, index_label="material_id")            

In [None]:
OmegaConf.save(all_data, "generated_public/datasets.yaml")
export_data("generated_public", export_configs["generated_public"])

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/57 [00:00<?, ?it/s]

Attempt 0 failed to convert structure Full Formula (K16 Fe8 S32)
Reduced Formula: K2FeS4
abc   :  10.788521  10.788521  10.788521
angles:  90.000000  90.000000  90.000000
pbc   :       True       True       True
Sites (56)
  #  SP           a         b         c
---  ----  --------  --------  --------
  0  Fe    0.375     0.375     0.375
  1  Fe    0.125     0.625     0.125
  2  Fe    0.375     0.875     0.875
  3  Fe    0.125     0.125     0.625
  4  Fe    0.875     0.375     0.875
  5  Fe    0.625     0.625     0.625
  6  Fe    0.875     0.875     0.375
  7  Fe    0.625     0.125     0.125
  8  K     0         0         0
  9  K     0.75      0.25      0.5
 10  K     0.25      0.5       0.75
 11  K     0.5       0.75      0.25
 12  K     0         0.5       0.5
 13  K     0.75      0.75      0
 14  K     0.25      0         0.25
 15  K     0.5       0.25      0.75
 16  K     0.5       0         0.5
 17  K     0.25      0.25      0
 18  K     0.75      0.5       0.25
 19  K     0     



  0%|          | 0/10000 [00:00<?, ?it/s]

Exported generated_public/mp_20/DiffCSP++/data.csv.gz
Read 10000 CIFs


Occupancy 2.0 exceeded tolerance.
No structure parsed for section 1 in CIF.
Occupancy 2.0 exceeded tolerance.
No structure parsed for section 1 in CIF.
Occupancy 2.0 exceeded tolerance.
No structure parsed for section 1 in CIF.
Occupancy 2.0 exceeded tolerance.
No structure parsed for section 1 in CIF.
Occupancy 2.0 exceeded tolerance.
No structure parsed for section 1 in CIF.
Occupancy 2.0 exceeded tolerance.
No structure parsed for section 1 in CIF.
Occupancy 2.0 exceeded tolerance.


Valid records: 9580


Attempt 0 failed to convert structure Full Formula (Zr3 U6 Ti6 O24)
Reduced Formula: ZrU2Ti2O8
abc   :   3.984417   3.984417  41.079651
angles:  90.000000  90.000000 120.000000
pbc   :       True       True       True
Sites (39)
  #  SP           a         b         c
---  ----  --------  --------  --------
  0  Zr    0.666667  0.333333  0.333333
  1  Zr    0.333333  0.666667  0.666667
  2  Zr    0         1         1
  3  U     0.666667  0.333333  0.086914
  4  U     0.666667  0.333333  0.579753
  5  U     0.333333  0.666667  0.420247
  6  U     0.333333  0.666667  0.913086
  7  U     0         1         0.75358
  8  U     0         1         0.24642
  9  Ti    0.666667  0.333333  0.875719
 10  Ti    0.666667  0.333333  0.790948
 11  Ti    0.333333  0.666667  0.209052
 12  Ti    0.333333  0.666667  0.124281
 13  Ti    0         1         0.542385
 14  Ti    0         1         0.457615
 15  O     0.666667  0.333333  0.710223
 16  O     0.666667  0.333333  0.956443
 17  O     0.333333 



  0%|          | 0/9580 [00:00<?, ?it/s]

Exported generated_public/mp_20/SymmCD/CHGNet_fix/data.csv.gz
Parsing CIFs...


spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.
Attempt 0 failed to convert structure Full Formula (La5 Al12)
Reduced Formula: La5Al12
abc   :   4.886029   7.298154  10.345625
angles:  92.352150  89.993080  89.851997
pbc   :       True       True       True
Sites (17)
  #  SP           a         b         c
---  ----  --------  --------  --------
  0  La    0.084383  0.861644  0.373639
  1  La    0.571402  0.262115  0.019464
  2  La    0.580098  0.73917   0.881362
  3  La    0.57886   0.265734  0.013849
  4  La    0.071925  0.381974  0.498743
  5  Al    0.577567  0.160058  0.371349
  6  Al    0.826148  0.066954  0.685959
  7  Al    0.075005  0.397448  0.841664
  8  Al    0.586849  0.511265  0.278165
  9  Al    0.576856  0.713101  0.540676
 10  Al    0.321668  0.065745  0.681167
 11  Al    0.585998  0.414096  0.703641
 12  Al    0.073589  0.2093