In [None]:
import json
import numpy as np
from mpcontribs.client import Client, Attachment
from pathlib import Path
from flatten_dict import flatten, unflatten
from pymatgen.core import Structure

In [None]:
client = Client(project="ferroelectrics")

In [None]:
# client.update_project(update={
#     "references": [
#         {"label": "SciData", "url": "https://doi.org/10.1038/s41597-020-0407-9"},
#         {"label": "PyMatGen", "url": "https://github.com/materialsproject/pymatgen/tree/master/pymatgen/analysis/ferroelectricity"},
#         {"label": "Atomate", "url": "https://github.com/hackingmaterials/atomate/blob/master/atomate/vasp/workflows/base/ferroelectric.py"},
#         {"label": "Website", "url": "https://blondegeek.github.io/ferroelectric_search_site/"},
#         {"label": "Figshare", "url": "https://dx.doi.org/10.6084/m9.figshare.6025634"}
#     ]
# })

In [None]:
datadir = Path("/Users/patrick/Downloads/6025634")
distortions_file = datadir / "distortions.json"
workflow_data_file = datadir / "workflow_data.json"

with distortions_file.open() as f:
    distortions = json.load(f)
    
with workflow_data_file.open() as f:
    workflow_data = json.load(f)

In [None]:
columns = {
    "bilbao_nonpolar_spacegroup": {"name": "bilbao.spacegroup.nonpolar", "unit": ""},
    "bilbao_polar_spacegroup": {"name": "bilbao.spacegroup.polar", "unit": ""},
    "polarization_change_norm": {"name": "polarization.norm", "unit": "µC/cm²"},
    "polarization_change": {"name": "polarization.vector", "unit": "µC/cm²", "fields": ["a", "b", "c"]},
    "polarization_quanta": {"name":"polarization.quanta", "unit":"µC/cm²", "fields": ["a", "b", "c"]},
    "energies": {"name":"energy|diff", "unit":"eV"},
    "search_id": {"name": "workflow.id|search", "unit": ""},
    "workflow_status": {"name": "workflow.status","unit":None},
    "category": {"name": "workflow.category", "unit": None}, # dynamic
    "distortion.dmax": {"name": "distortion.dmax.before", "unit": "Å"},
    "calculated_max_distance": {"name": "distortion.dmax.after", "unit": "Å"},
#     "distortion.delta": {"name": "distortion.delta", "unit": ""},
#     "distortion.dav": {"name": "distortion.dav", "unit": ""},
#     "distortion.s": {"name": "distortion.s", "unit": ""},
    "bandgaps": {"name": "bandgap", "unit": "eV"},
#     "nonpolar_band_gap": {"name": "nonpolar.bandgap", "unit": "eV"},
    "nonpolar_icsd": {"name": "nonpolar.icsd", "unit": ""},
    "nonpolar_id": {"name": "nonpolar.mpid", "unit": None},
    "nonpolar_spacegroup": {"name": "nonpolar.spacegroup", "unit": ""},
#     "polar_band_gap": {"name": "polar.bandgap", "unit": "eV"},
    "polar_icsd": {"name": "polar.icsd", "unit": ""},
    "polar_id": {"name": "polar.mpid", "unit": None},
    "polar_spacegroup": {"name": "polar.spacegroup", "unit": ""},    
    "energies_per_atom_max_spline_jumps": {"name": "energies.jumps|max", "unit": "eV/atom"},
    "energies_per_atom_smoothness": {"name": "energies.smoothness", "unit": "eV/atom"},
    "polarization_max_spline_jumps": {"name": "polarizations.jumps", "fields": {"max": "µC/cm²", "index": ""}},
    "polarization_smoothness": {"name": "polarizations.smoothness", "fields": {"max": "µC/cm²", "index": ""}},
}

In [None]:
def get_category(wf):
    if (wf['polarization_len'] == 10 and
        'polarization_max_spline_jumps' in wf and
        np.all(np.array(wf['polarization_max_spline_jumps']) <= 1) and
        wf['energies_per_atom_max_spline_jumps'] <= 1e-2):
        return "smooth"
    
    elif (wf['polarization_len'] == 10 and
          'polarization_change_norm' in wf and
          'polarization_max_spline_jumps' in wf and
          (wf['energies_per_atom_max_spline_jumps'] > 1e-2 or
           np.any(np.array(wf['polarization_max_spline_jumps']) > 1))):
        return "unsmooth"
    
    elif (wf['static_len'] == 10 and
          'polarization_change_norm' not in wf and
          wf['workflow_status'] in ("COMPLETED","DEFUSED")):
        return "static"
    
    elif ((wf['polarization_len'] < 10 or 'polarization_change_norm' not in wf) and
          ((wf['workflow_status'] == "DEFUSED" and wf['static_len'] < 10) or
           wf['workflow_status'] in ("FIZZLED","RUNNING"))):
        return "incomplete"

In [None]:
contribs_distortions = {}

for distortion in distortions:
    k1, k2 = distortion["nonpolar_id"], distortion["polar_id"]
    key = f"{k1}_{k2}"
    contribs_distortions[key] = {"data": {}}#, "structures": [], "attachments": []}
    
    for k, v in flatten(distortion, reducer="dot", max_flatten_depth=2).items():
        if k.endswith("_pre") or k.startswith("_id"):
            continue 
        elif not isinstance(v, (dict, list)):
            conf = columns.get(k)
            if conf:
                name, unit = conf["name"], conf["unit"]
                dec = conf.get('dec', '')
                contribs_distortions[key]["data"][name] = f"{float(v):{dec}} {unit}" if unit else v
#       elif isinstance(v, dict) and "@class" in v and v["@class"] == "Structure":
#           structure = Structure.from_dict(v)
#           structure.name = k
#           contribs_distortions[key]["structures"].append(structure)
            
#     attm = Attachment.from_data("distortion", distortion)
#     contribs_distortions[key]["attachments"].append(attm)

In [None]:
# get map of identifier to contribution ID to prep update
ids = {k: v["id"] for k, v in client.get_all_ids(fmt="map")[client.project].items()}
len(ids)

In [None]:
contributions = []
structure_keys = ("orig_nonpolar_structure", "orig_polar_structure")

for wf in workflow_data:
    k1, k2 = wf["nonpolar_id"], wf["polar_id"]
    key = f"{k1}_{k2}" # NOTE could also use search_id for this
    distortion = contribs_distortions[key]
    contrib = {
        "identifier": wf["wfid"], "formula": wf["pretty_formula"],
        "data": contribs_distortions[key]["data"],
#         "structures": contribs_distortions[key]["structures"],
#         "attachments": contribs_distortions[key]["attachments"]
    }
    contrib['data']['workflow.category'] = get_category(wf)
    if ids and wf["wfid"] in ids:
        contrib["id"] = ids[wf["wfid"]]
    
#     for k in structure_keys:
#         if k in wf:
#             structure = Structure.from_dict(wf[k])
#             structure.name = k
#             contrib["structures"].append(structure)
    
    for k, v in flatten(wf, reducer="dot").items():
        conf = columns.get(k)
        if conf and k.startswith('polarization') and isinstance(v, list):
            name, fields = conf["name"], conf["fields"]
            contrib["data"].setdefault(name, {})
            if not "unit" in conf:
                vmax, unit = max(v), fields["max"]
                contrib["data"][name]['max'] = f"{round(vmax, 3)} {unit}" if unit else v
                contrib["data"][name]['index'] = v.index(vmax)
            else:
                unit = conf["unit"]
                contrib["data"][name] = {
                    i: f"{j} {unit}"
                    for i, j in zip(conf["fields"], v[0])
                }
        elif conf and k == 'energies_per_atom':
            name, unit = conf["name"], conf["unit"]
            ediff = v[0] - v[-1]
            contrib["data"][name] = f"{ediff:.3g} {unit}"              
        elif conf and k == 'bandgaps':
            name, unit = conf["name"], conf["unit"]
            contrib["data"].setdefault(name, {})
            contrib["data"][name]["nonpolar"] = f"{v[0]:.3g} {unit}"
            contrib["data"][name]["polar"] = f"{v[-1]:.3g} {unit}"
        elif k.startswith(("_id", "cid")) or isinstance(v, list) or k.startswith(structure_keys):
            continue
        elif conf:
            name, unit = conf["name"], conf["unit"]
            if name == "polarization.norm":
                contrib["data"][name] = f"{v:.1g} {unit}" if unit else v
            else:
                contrib["data"][name] = f"{v:.3g} {unit}" if unit else v
        
#     attm = Attachment.from_data("workflow", wf)
#     contrib["attachments"].append(attm)
    contrib["data"] = unflatten(contrib["data"], splitter="dot")
    contributions.append(contrib)
    
len(contributions)

In [None]:
columns_map = {}

for k, v in columns.items():
    name = v["name"]
    if "fields" in v:
        if isinstance(v["fields"], list):
            for f in v["fields"]:
                columns_map[f"{name}.{f}"] = v["unit"]
        elif isinstance(v["fields"], dict):
            for f, unit in v["fields"].items():
                columns_map[f"{name}.{f}"] = unit
    elif "unit" in v:
        columns_map[v["name"]] = v["unit"]

columns_map

In [None]:
#client.delete_contributions()
client.init_columns({})
client.init_columns(columns_map)

In [None]:
client.submit_contributions(contributions, ignore_dupes=True)
client.init_columns(columns_map)

In [None]:
# client.make_public()