In [1]:
import torch 
import numpy as np
import pandas as pd
from pdos_gnn.models.crystal_model import ProDosNet
from pdos_gnn.utilities.preprocess import CrystalGraphPDOS
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Predict PDOS fingerprints for Materials Space 
We will use the pretrained ProDosNet model to predict projected Density of States (PDOS) for all materials in the [Materials Project Database](https://next-gen.materialsproject.org/) and use the predicted PDOS as materials fingerprint to create a structured materials space where compounds are grouped by the similarity of their electronic properties. The ProDosNet model outputs electronic PDOS for all orbitals of all atoms in the unit cell of provided material (node level predictions). To avoid different size of the fingerprint for different materials, we will aggregate PDOS across atoms and individual orbitals. The result of this procedure will be a fixed-size PDOS fingerprint that contains s, p, and d orbital electronic densities for entire material (graph level prediction). 

In [2]:
class Predictor():
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = ProDosNet(orig_atom_fea_len=98, nbr_fea_len=4, n_conv=3, use_cdf=True)

        self._load_model_state()

    def _load_model_state(self):
        pretrained_model = torch.load(self.model_path, map_location=torch.device('cpu'))
        self.model.load_state_dict(pretrained_model['state_dict'])
    

    def get_prediction(self, graph, include_target=False):
        self.model.eval()
        graph.edge_attr = torch.squeeze(graph.edge_attr, 1)
        output_pdos, _, _ = self.model(graph.x, graph.edge_index, graph.edge_attr, graph.batch, graph.atoms_batch)
        out_pdos_data = pd.DataFrame(output_pdos.detach().numpy())
        elements = np.array(graph.elements)
        sites = np.array(graph.sites)
        orbital_types = graph.orbital_types
        id = np.array([graph.material_id]*len(orbital_types))

        orbital_types = np.array(orbital_types)
        if include_target:
            target_pdos_data = pd.DataFrame(graph.target_pdos.detach().numpy())
            output_and_id = pd.concat([pd.DataFrame(id), pd.DataFrame(elements), pd.DataFrame(sites), pd.DataFrame(orbital_types), out_pdos_data, target_pdos_data], axis = 1, ignore_index=True, sort=False)
            output_and_id = output_and_id.rename({0: 'id', 1: 'element', 2: 'atom_number', 3: 'orbital_type'}, axis='columns')
        else:
            output_and_id = pd.concat([pd.DataFrame(id), pd.DataFrame(elements), pd.DataFrame(sites), pd.DataFrame(orbital_types), out_pdos_data], axis = 1, ignore_index=True, sort=False)
            output_and_id = output_and_id.rename({0: 'id', 1: 'element', 2: 'atom_number', 3: 'orbital_type'}, axis='columns')
        return output_and_id
    

### Function to combine orbital PDOS into spd PDOS

In [3]:
def get_spd_dos(predicted_data):
    p_orbitals = ["px", "py", "pz"]
    d_orbitals = ["dxy", "dyz", "dz2", "dxz", "dx2"]
    data_s = predicted_data[predicted_data["orbital_type"]=="s"]
    data_p = predicted_data[predicted_data["orbital_type"].isin(p_orbitals)]
    data_d = predicted_data[predicted_data["orbital_type"].isin(d_orbitals)]
    data_s_total = data_s.sum(numeric_only=True)#.drop(index=('atom_number'))
    data_p_total = data_p.sum(numeric_only=True)#.drop(index=('atom_number'))
    data_d_total = data_d.sum(numeric_only=True)#.drop(index=('atom_number'))
    prediction_s = np.array(data_s_total)
    prediction_p = np.array(data_p_total)
    prediction_d = np.array(data_d_total)
    prediction_spd = np.concatenate((prediction_s, prediction_p, prediction_d))
    return prediction_spd

### Get predicted PDOS
**Before running make sure to download materials structures (cif files) using `download_data.ipynb`**

The predicted spd PDOS will be used to create a structured materials space and visualize it with UMAP it the `materials_space_umap.ipynb` and search for similar compounds to a specific target material in the `search_for_similar_materials.ipynb`.  

In [10]:
CIF_DIR = "../data/cif_dir"

id_list = pd.read_csv("../data/materials_prop.csv")["material_id"].tolist()

graph_generator = CrystalGraphPDOS(cif_dir=CIF_DIR, dos_dir=None)
predictor = Predictor(model_path="../pdos_gnn/pretrained/pretrained_model.pth.tar")

predicted_id_list = []
spd_dos_list = []
for id in tqdm(id_list):
    try: 
        graph = graph_generator.get_crystal_pdos_graph_pred(CIF_DIR+f"/{id}.cif")
    except Exception as e:
        print(id)
        print(e)
        continue
    if graph is not None:
        predicted_data = predictor.get_prediction(graph)
        spd_array = get_spd_dos(predicted_data)
        spd_array = spd_array/len(predicted_data["atom_number"].unique())

        spd_dos_list.append(spd_array)
        predicted_id_list.append(id)

spd_dos_array = np.array(spd_dos_list)
spd_dos_df = pd.DataFrame(spd_dos_array, columns=range(len(spd_dos_array[0])))
spd_dos_df.insert(0, "id", predicted_id_list)

spd_dos_df.to_csv(f'spd_pdos_{len(predicted_id_list)}_materials.csv')
    

	------------------------------------------------
        |        Data Preprocessing Parameters         |
        ------------------------------------------------
            - dos_dir:      None
            - cif_dir:      ../data/cif_dir
            - radius:       8
            - max_num_nbr:  12
            - sigma:        0.3
            - bound_low:    -20.0
            - bound_high:   10.0
            - grid:         256
            - max_element:  83
            - n_orbitals:   9
            - norm_pdos:    False
        ------------------------------------------------
        


100%|██████████| 3/3 [00:00<00:00, 108.01it/s]
