In [2]:
import numpy as np
import pandas as pd
import networkx as nx
import torch
import copy
import itertools

from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.analysis import local_env

from networkx.algorithms.components import is_connected

from sklearn.metrics import accuracy_score, recall_score, precision_score

from torch_scatter import scatter

from p_tqdm import p_umap

In [3]:
def build_crystal(crystal_str, niggli=True, primitive=False):
    """Build crystal from cif string."""
    crystal = Structure.from_str(crystal_str, fmt='cif')

    if primitive:
        crystal = crystal.get_primitive_structure()

    if niggli:
        crystal = crystal.get_reduced_structure()

    canonical_crystal = Structure(
        lattice=Lattice.from_parameters(*crystal.lattice.parameters),
        species=crystal.species,
        coords=crystal.frac_coords,
        coords_are_cartesian=False,
    )
    # match is gaurantteed because cif only uses lattice params & frac_coords
    # assert canonical_crystal.matches(crystal)
    return canonical_crystal

In [4]:
def build_crystal_graph(crystal, graph_method='crystalnn'):
    """
    """

    if graph_method == 'crystalnn':
        crystal_graph = StructureGraph.with_local_env_strategy(
            crystal, CrystalNN)
    elif graph_method == 'none':
        pass
    else:
        raise NotImplementedError

    frac_coords = crystal.frac_coords
    atom_types = crystal.atomic_numbers
    lattice_parameters = crystal.lattice.parameters
    lengths = lattice_parameters[:3]
    angles = lattice_parameters[3:]

    assert np.allclose(crystal.lattice.matrix,
                       lattice_params_to_matrix(*lengths, *angles))

    edge_indices, to_jimages = [], []
    if graph_method != 'none':
        for i, j, to_jimage in crystal_graph.graph.edges(data='to_jimage'):
            edge_indices.append([j, i])
            to_jimages.append(to_jimage)
            edge_indices.append([i, j])
            to_jimages.append(tuple(-tj for tj in to_jimage))

    atom_types = np.array(atom_types)
    lengths, angles = np.array(lengths), np.array(angles)
    edge_indices = np.array(edge_indices)
    to_jimages = np.array(to_jimages)
    num_atoms = atom_types.shape[0]

    return frac_coords, atom_types, lengths, angles, edge_indices, to_jimages, num_atoms

In [5]:
def preprocess(input_file, num_workers, niggli, primitive, graph_method,
               prop_list):
    df = pd.read_csv(input_file)

    def process_one(row, niggli, primitive, graph_method, prop_list):
        crystal_str = row['cif']
        crystal = build_crystal(
            crystal_str, niggli=niggli, primitive=primitive)
        graph_arrays = build_crystal_graph(crystal, graph_method)
        properties = {k: row[k] for k in prop_list if k in row.keys()}
        result_dict = {
            'mp_id': row['material_id'],
            'cif': crystal_str,
            'graph_arrays': graph_arrays,
        }
        result_dict.update(properties)
        return result_dict

    unordered_results = p_umap(
        process_one,
        [df.iloc[idx] for idx in range(len(df))],
        [niggli] * len(df),
        [primitive] * len(df),
        [graph_method] * len(df),
        [prop_list] * len(df),
        num_cpus=num_workers)

    mpid_to_results = {result['mp_id']: result for result in unordered_results}
    ordered_results = [mpid_to_results[df.iloc[idx]['material_id']]
                       for idx in range(len(df))]

    return ordered_results

In [6]:
def preprocess_space_group(input_file):
    df = pd.read_csv(input_file)

    def process_one(row):
        crystal_str = row['cif']
        crystal = build_crystal(
            crystal_str)
        result_dict = {
            'mp_id': row['material_id'],
            'cif': crystal_str,
            'crystal': crystal,
        }
        return result_dict

    unordered_results = p_umap(
        process_one,
        [df.iloc[idx] for idx in range(len(df))])

    mpid_to_results = {result['mp_id']: result for result in unordered_results}
    ordered_results = [mpid_to_results[df.iloc[idx]['material_id']]
                       for idx in range(len(df))]
    return ordered_results

In [5]:
df = pd.read_csv("../data/mp_20/val.csv")

In [6]:
df.iloc[0]

Unnamed: 0                                                                  65
material_id                                                          mp-865981
formation_energy_per_atom                                            -0.436368
band_gap                                                                   0.0
pretty_formula                                                         TmMgHg2
e_above_hull                                                               0.0
elements                                                    ['Hg', 'Mg', 'Tm']
cif                          # generated using pymatgen\ndata_TmMgHg2\n_sym...
spacegroup.number                                                          225
Name: 0, dtype: object

In [7]:
ddd = df.iloc[0].cif

In [8]:
ddd

"# generated using pymatgen\ndata_TmMgHg2\n_symmetry_space_group_name_H-M   'P 1'\n_cell_length_a   5.04880040\n_cell_length_b   5.04880040\n_cell_length_c   5.04880040\n_cell_angle_alpha   60.00000000\n_cell_angle_beta   60.00000000\n_cell_angle_gamma   60.00000000\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   TmMgHg2\n_chemical_formula_sum   'Tm1 Mg1 Hg2'\n_cell_volume   91.00172128\n_cell_formula_units_Z   1\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  Tm  Tm0  1  0.00000000  0.00000000  0.00000000  1\n  Mg  Mg1  1  0.50000000  0.50000000  0.50000000  1\n  Hg  Hg2  1  0.25000000  0.25000000  0.25000000  1\n  Hg  Hg3  1  0.75000000  0.75000000  0.75000000  1\n"

In [9]:
cry = build_crystal(ddd,niggli=True)

In [10]:
cry

Structure Summary
Lattice
    abc : 5.0488004 5.048800399999999 5.048800399999999
 angles : 59.99999999999999 60.00000000000001 60.00000000000001
 volume : 91.00172125826344
      A : 4.372389405037036 0.0 2.5244001999999996
      B : 1.457463135012344 4.122328264386535 2.5244002
      C : 0.0 0.0 5.048800399999999
    pbc : True True True
PeriodicSite: Tm (0.0, 0.0, 0.0) [0.0, 0.0, 0.0]
PeriodicSite: Mg (2.915, 2.061, 5.049) [0.5, 0.5, 0.5]
PeriodicSite: Hg (4.372, 3.092, 7.573) [0.75, 0.75, 0.75]
PeriodicSite: Hg (1.457, 1.031, 2.524) [0.25, 0.25, 0.25]

In [11]:
space_group = df.groupby('spacegroup.number')

In [19]:
print(space_group.size())

spacegroup.number
1       147
2       211
3         4
4        50
5        77
       ... 
223      12
224       2
225    1300
227     117
229      28
Length: 157, dtype: int64


In [26]:
df_all = pd.read_csv("../data/data_all/mp_20.csv")

In [33]:
# # 1:"Triclinic",    # 1-2
# # 2:"Monoclinic",   # 3-15
# # 3:"Orthorhombic", # 16-74
# # 4:"Tetragonal",   # 75-142
# # 5:"Trigonal",     # 143-167
# # 6:"Hexagonal",    # 168-194
# # 7:"Cubic"         # 195-230
# mono = df_all[df_all["spacegroup.number"].isin(range(195,231))]
# mono.to_csv("../data/mp_20_processed/Cubic.csv")

In [9]:
print(df_all)

       Unnamed: 0.1  Unnamed: 0 material_id  formation_energy_per_atom  \
0                 0        6000    mp-10009                  -0.575092   
1                 1       37702  mp-1218989                  -0.942488   
2                 2       42245  mp-1225695                   0.064863   
3                 3         780  mp-1220884                  -1.456116   
4                 4       35749  mp-1224266                   0.024139   
...             ...         ...         ...                        ...   
45224         27131       37856   mp-568116                  -0.988502   
45225         27132       11955   mp-865529                  -0.640955   
45226         27133       26119  mp-1189241                  -0.756019   
45227         27134       30556  mp-1104538                  -0.104870   
45228         27135       32933   mp-756354                  -3.712252   

       band_gap pretty_formula  e_above_hull                elements  \
0        0.8980           GaTe      0.0

In [27]:
space_group_all = df_all.groupby('spacegroup.number')

In [28]:
print(space_group_all.size())

spacegroup.number
1       781
2      1099
3        52
4       248
5       427
       ... 
223      83
224       7
225    6714
227     499
229     152
Length: 177, dtype: int64


In [43]:
len(a[1])

152

In [44]:
idx = 1
for a in space_group_all:
    print(a[0],len(a[1]))

1 781
2 1099
3 52
4 248
5 427
6 271
7 162
8 606
9 185
10 212
11 448
12 2141
13 100
14 927
15 848
16 1
17 1
18 12
19 74
20 35
21 32
22 3
23 7
24 22
25 209
26 135
28 3
29 19
30 2
31 211
33 74
34 7
35 28
36 284
37 4
38 616
39 11
40 59
41 17
42 71
43 60
44 305
46 83
47 143
49 1
51 110
52 7
53 14
54 2
55 84
56 3
57 52
58 97
59 172
60 32
61 10
62 1979
63 1504
64 138
65 437
66 32
67 22
68 4
69 80
70 136
71 702
72 193
74 500
75 2
76 4
77 4
79 12
81 8
82 218
83 15
84 20
85 11
86 27
87 184
88 131
89 1
90 3
92 18
96 5
97 42
98 3
99 132
100 5
102 20
105 4
107 264
108 13
109 80
111 24
112 5
113 30
114 7
115 178
116 8
118 12
119 210
120 1
121 170
122 213
123 1042
124 20
125 32
127 406
128 23
129 820
130 2
131 24
132 14
135 1
136 254
137 42
138 5
139 2111
140 465
141 313
142 2
143 14
144 1
145 1
146 121
147 63
148 602
149 12
150 31
152 31
153 5
154 16
155 116
156 519
157 9
159 6
160 491
161 75
162 69
163 78
164 879
166 1787
167 137
173 14
174 26
176 114
179 1
180 31
181 9
182 39
183 9
186 295
187 551

In [30]:
Crystal_system_dict  = {
    1:"Triclinic",    # 1-2
    2:"Monoclinic",   # 3-15
    3:"Orthorhombic", # 16-74
    4:"Tetragonal",   # 75-142
    5:"Trigonal",     # 143-167
    6:"Hexagonal",    # 168-194
    7:"Cubic"         # 195-230
}
Crystal_system = [0]

for idx in range(230):
    spg_num = idx+1
    if spg_num<=2:
        cry_sys=1
    elif spg_num<=15:
        cry_sys=2
    elif spg_num<=74:
        cry_sys=3
    elif spg_num<=142:
        cry_sys=4
    elif spg_num<=167:
        cry_sys=5
    elif spg_num<=194:
        cry_sys=6
    elif spg_num<=230:
        cry_sys=7
    Crystal_system.append(cry_sys)

# mp20 - [1880, 6627, 8839, 7660, 5063, 4781, 10379]

In [31]:
cnt = [0]*7
for a in space_group_all:
    cnt[Crystal_system[a[0]]-1] += len(a[1])
print(cnt)

[1880, 6627, 8839, 7660, 5063, 4781, 10379]
