In [1]:
from tqdm import tqdm
from jarvis.db.figshare import data as jdata
from jarvis.core.atoms import Atoms,  pmg_to_atoms
import pandas as pd
from pymatgen.core import Structure
from pymatgen.io.vasp.inputs import Poscar
import sys
import ast
import json

In [2]:
dft_3d = jdata("dft_3d")
prop = "optb88vdw_bandgap"

Obtaining 3D dataset 76k ...
Reference:https://www.nature.com/articles/s41524-020-00440-1
Other versions:https://doi.org/10.6084/m9.figshare.6815699
Loading the zipfile...
Loading completed.


In [3]:
len(dft_3d)

75993

In [4]:
type(dft_3d)

list

In [None]:
dft_3d[1]

In [None]:
max_samples = 1
f = open("id_prop.csv", "w")
count = 0
for i in dft_3d:
    atoms = Atoms.from_dict(i["atoms"])
    print(type(atoms))
    print(atoms)
    jid = i["jid"]
    poscar_name = "POSCAR-" + jid + ".vasp"
    target = i[prop]
    if target != "na":
        atoms.write_poscar(poscar_name)
        f.write("%s,%6f\n" % (poscar_name, target))
        count += 1
        if count == max_samples:
            break
f.close()

In [4]:
df = pd.read_csv('/home/mudaliar.k/github/comformer_uv/data/surface_prop_data_set_top_bottom.csv')

In [5]:
len(df)

36852

In [None]:
row = df.iloc[0]
df.head(1)

In [None]:
df.dtypes

In [None]:
row['slab']

In [None]:
struc = Structure.from_dict(ast.literal_eval(row['slab']))
type(struc)

In [None]:
jarvis_atoms = pmg_to_atoms(struc)  
print(type(jarvis_atoms))
print(jarvis_atoms)

In [None]:
sys.path.append('/home/mudaliar.k/github/comformer_uv')

import imp
import random
from pathlib import Path
from typing import Optional

import os
import torch
import numpy as np
import pandas as pd
from jarvis.core.atoms import Atoms
from comformer.graphs import PygGraph, PygStructureDataset
from jarvis.db.figshare import data as jdata
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
from jarvis.db.jsonutils import dumpjson
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True)
# from sklearn.pipeline import Pipeline
import pickle as pk
from sklearn.preprocessing import StandardScaler
# use pandas progress_apply



def load_dataset(
    name: str = "dft_3d",
    target=None,
    limit: Optional[int] = None,
    classification_threshold: Optional[float] = None,
):
    """Load jarvis data."""
    d = jdata(name)
    data = []
    for i in d:
        if i[target] != "na" and not math.isnan(i[target]):
            if classification_threshold is not None:
                if i[target] <= classification_threshold:
                    i[target] = 0
                elif i[target] > classification_threshold:
                    i[target] = 1
                else:
                    raise ValueError(
                        "Check classification data type.",
                        i[target],
                        type(i[target]),
                    )
            data.append(i)
    d = data
    if limit is not None:
        d = d[:limit]
    d = pd.DataFrame(d)
    return d

def load_dataset_D2R2(
    name: str = "D2R2_surface_data", 
    data_path: str = "/home/mudaliar.k/github/comformer_uv/data/surface_database_for_GNN.csv",
    target=None,
    limit: Optional[int] = None,
    classification_threshold: Optional[float] = None,
):
    """Load jarvis data."""
    df = pd.read_csv(data_path, on_bad_lines="skip")
    if limit is not None:
        df = df[:limit]
        
    df["jid"] = df["mpid"].astype(str) + "_" + df["miller_index"].astype(str) + "_" + df["term"].astype(str)

    return df


def mean_absolute_deviation(data, axis=None):
    """Get Mean absolute deviation."""
    return np.mean(np.absolute(data - np.mean(data, axis)), axis)

def load_pyg_graphs(
    df: pd.DataFrame,
    name: str = "dft_3d",
    neighbor_strategy: str = "k-nearest",
    cutoff: float = 8,
    max_neighbors: int = 12,
    cachedir: Optional[Path] = None,
    use_canonize: bool = False,
    use_lattice: bool = False,
    use_angle: bool = False,
):
    """Construct crystal graphs.

    Load only atomic number node features
    and bond displacement vector edge features.

    Resulting graphs have scheme e.g.
    ```
    Graph(num_nodes=12, num_edges=156,
          ndata_schemes={'atom_features': Scheme(shape=(1,)}
          edata_schemes={'r': Scheme(shape=(3,)})
    ```
    """
    def atoms_to_graph(atoms):
        """Convert structure dict to DGLGraph."""
        # structure = Atoms.from_dict(atoms)
        structure = pmg_to_atoms(Structure.from_dict(ast.literal_eval(atoms)))
        return PygGraph.atom_dgl_multigraph(
            structure,
            neighbor_strategy=neighbor_strategy,
            cutoff=cutoff,
            atom_features="atomic_number",
            max_neighbors=max_neighbors,
            compute_line_graph=False,
            use_canonize=use_canonize,
            use_lattice=use_lattice,
            use_angle=use_angle,
        )
    
    graphs = df["slab"].parallel_apply(atoms_to_graph).values 
    # graphs = df["atoms"].apply(atoms_to_graph).values

    return graphs

In [None]:
df = load_dataset(target="optb88vdw_bandgap", limit = 10)
df.head()

In [None]:
df_d2r2 = load_dataset_D2R2()
df_d2r2.head()

In [None]:
len(df_d2r2)

In [None]:
graphs = load_pyg_graphs(df_d2r2)

In [None]:
graphs