In [2]:
from pymatgen.analysis.diffraction.tem import TEMCalculator
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from mp_api.client import MPRester

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math as m
from decimal import Decimal, ROUND_HALF_UP

import time
import warnings
import os
from joblib import dump, load

import preprocessing as pp

In [14]:
def get_crystals_from_file(filename: str, api_key: str) -> tuple:
    """
    Retrieves crystal structure properties, crystal system classifications, and material IDs from a file using the Materials Project API.

    Args:
        filename (str): Path to the .txt file containing Material Project (MPR) material IDs, one per line.
        api_key (str): API key for accessing the Materials Project database.

    Returns:
        crystal_list (list): A list of crystal structures.
        space_group_list (list): A list of spacegroup numbers.
        bravais_list (list): A list of Bravais Lattice types, which can be one of the following:
                'Simple (P)', 'Body centered (I)', 'Face centered (F)'
        system_list (list): A list of crystal system classifications, which can be one of the following:
                'Triclinic', 'Trigonal', 'Orthorhombic', 'Cubic', 'Monoclinic', 'Tetragonal', 'Hexagonal'.
        material_id_list (list): A list of tuples, where each tuple contains a material ID and its formula, e.g., [(id_1, formula_1), ... , (id_n, formula_n)].
    """
    with open(filename, 'r') as file:
        materials_random_order = [material_id.strip() for material_id in file]

    with MPRester(api_key=api_key) as mpr:
        crystals = mpr.materials.search(
            material_ids=materials_random_order,
            fields= ['structure', 'symmetry', 'material_id', 'formula_pretty', 'chemsys']
        )
        # print(crystals)

    crystal_list = [crystal.structure for crystal in crystals]
    space_group_list = [crystal.symmetry.number for crystal in crystals]
    bravais_list = [str(crystal.symmetry.symbol)[0] for crystal in crystals]
    system_list = [str(crystal.symmetry.crystal_system)
                   for crystal in crystals]
    material_id_list = [(crystal.material_id, crystal.formula_pretty)
                        for crystal in crystals]

    return (
            crystal_list, 
            space_group_list, 
            bravais_list, 
            system_list, 
            material_id_list
            )

In [15]:
# key = 'MKc7ImqWWraesSOgZw5qy1pwY5pi3Djr'
# upper = 2
# local_filename = 'test.txt'
# path = '/Users/jonathanchoi/Desktop/GitHub Projects/crystal_sim/raw_data'


# filename = os.path.join(path, local_filename) 

# crystal_list, space_group_list, bravais_list, system_list, material_id_list = get_crystals_from_file(filename=filename, api_key=key)
# directions = pp.get_cartesian_beam_directions(upper)
# directions = np.delete(directions, 3, axis=0)


# features, labels_regression, labels_classification_space, labels_classification_bravais, labels_classification_system, materials = pp.get_preprocessed_data(
#     crystal_list, space_group_list, bravais_list, system_list, material_id_list, directions, plot=False, vectors=True)

In [16]:
# for i in range(0, 1):
#     print(        
#         features[i],
#         labels_regression[i],
#         labels_classification_space[i],
#         labels_classification_bravais[i],
#         labels_classification_system[i],
#         material_id_list[i])

In [6]:
key = 'MKc7ImqWWraesSOgZw5qy1pwY5pi3Djr'

In [18]:
# def retrieve_crystals(api_key: str, crystal_system: str, base_dir: str, write: bool) -> tuple:
#     """
#     Retrieves ALL mp-ids that correspond to the crystal system.
    
#     Args:
#         api_key (str): API key for accessing the Materials Project database.
#         crystal_system (str): Crystal system of interest.
#         base_dir (str): Base directory for the project.
#         write (bool): If True, all the ids are written to mp-ids/crystal_system.

#     Returns:
#         bravais_dict: Dictionary with keys as Bravais lattice types and values as corresponding mp-ids.
#         least_datapoints: Least number of datapoints of all the values in the dictionary.
#         api_key: API key.
#     """
    
#     bravais_dict = {
#         'P': [],  # primitive
#         'I': [],  # body centered
#         'A': [],  # face centered
#         'F': [],  # centered on A
#         'C': [],  # centered on C
#         'R': [],  # rhombohedral
#     }

#     conversion_dict = {
#         'P': ['Primitive'],
#         'I': ['Body Centered'],
#         'A': ['Face Centered'],
#         'F': ['Centered on A'],
#         'C': ['Centered on C'],
#         'R': ['Rhombohedral']
#     }

#     with MPRester(api_key=api_key) as mpr:
#         crystals = mpr.materials.search(
#             elements=["Si", "O"],
#             crystal_system=[crystal_system.capitalize()],
#             fields =['material_id', 'symmetry']
#         )
    
#     for crystal in crystals:
#         bravais_dict[str(crystal.symmetry.symbol)[0]].append(crystal.material_id)
    
#     non_empty_keys = [key for key in list(bravais_dict.keys()) if len(bravais_dict[key]) > 0]    
    
#     sizes = []
    
#     if write:
#         new_folder = os.path.join(base_dir, "mp_ids")
        
#         if not os.path.exists(new_folder):
#             os.makedirs(new_folder)
        
#         file_dir = os.path.join(new_folder, crystal_system+".txt")
#         file = open(file_dir, 'w')
            
#         for non_empty_key in non_empty_keys:
#             length_space_group = len(bravais_dict[non_empty_key])
#             sizes.append(length_space_group)
#             file.write(f"~{non_empty_key}-{length_space_group}\n")
            
#             for value in list(bravais_dict[non_empty_key]):
#                 file.write(f"{value}\n")
#         file.close()
    
#     bravais_dict = {k: v for k, v in bravais_dict.items() if k in non_empty_keys}
    
#     print(f"_________________Summary: {crystal_system} System_________________\n")
#     for bravais in non_empty_keys:
#         print(f"{bravais}: {conversion_dict[bravais][0]}")
#     for x, y in zip(non_empty_keys, sizes):
#         print(f"{x}: number of datapoints - {y}")
#     least_datapoints = min(sizes)
#     print(f"\nSmallest number of sample points: {least_datapoints}")
#     print(f"\nTotal number of samples: {sum(sizes)}\n")
    
#     return bravais_dict, least_datapoints, api_key


#     ## TODO 
#     # - implement the download crystal function with the read crystals functions
#     # - figure out how to sample equal numbers of the bravais lattice types 
#     # - consider how to deal with multiple lattice systems rather than just one

In [19]:
# def get_crystal_info(bravais_dict: str, min_size: int, api_key: str) -> tuple:
#     """
#     Retrieves crystal structure properties, crystal system classifications, 
#     and material IDs from a dictionary of mp-ids

#     Args:
#         bravais_dict (dict): See above documentation.
#         min_size (int): See above documentation.
#         api_key (str): See above documentation.

#     Returns:
#         crystal_list (list): A list of crystal structures.
#         space_group_list (list): A list of spacegroup numbers.
#         bravais_list (list): A list of Bravais Lattice types, which can be one of the following:
#                 'Simple (P)', 'Body centered (I)', 'Face centered (F)'
#         system_list (list): A list of crystal system classifications, which can be one of the following:
#                 'Triclinic', 'Trigonal', 'Orthorhombic', 'Cubic', 'Monoclinic', 'Tetragonal', 'Hexagonal'.
#         material_id_list (list): A list of tuples, where each tuple contains a material ID and its formula, e.g., [(id_1, formula_1), ... , (id_n, formula_n)].
#     """
#     crystal_list, space_group_list, bravais_list, system_list, material_id_list = [], [], [], [], []
    
#     for k in bravais_dict.keys():
#         shuffled_array = np.array(bravais_dict[k])
#         np.random.shuffle(shuffled_array)
#         shuffle_list = shuffled_array.tolist()
        
#         with MPRester(api_key=api_key) as mpr:
#             crystals = mpr.materials.search(
#                 material_ids=shuffle_list[:min_size],
#                 fields=['structure', 'symmetry',
#                         'material_id', 'formula_pretty', 'chemsys']
#             )
#             # print(crystals)

#         crystal_list += [crystal.structure for crystal in crystals]
#         space_group_list += [crystal.symmetry.number for crystal in crystals]
#         bravais_list += [str(crystal.symmetry.symbol)[0] for crystal in crystals]
#         system_list += [str(crystal.symmetry.crystal_system)
#                     for crystal in crystals]
#         material_id_list += [(crystal.material_id, crystal.formula_pretty)
#                             for crystal in crystals]

#     return (
#         crystal_list,
#         space_group_list,
#         bravais_list,
#         system_list,
#         material_id_list
#     )


In [7]:
bravais_dict, least_datapoints, api_key= pp.retrieve_crystals(key, "Cubic", "/Users/jonathanchoi/Desktop/GitHub Projects/crystal_sim", True)
crystal_list, space_group_list, bravais_list, system_list, material_id_list = pp.get_crystal_info(bravais_dict, least_datapoints, api_key)

Retrieving MaterialsDoc documents: 100%|██████████| 260/260 [00:00<00:00, 1857783.71it/s]
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


_________________Summary:           Cubic System_________________

P: Primitive
I: Body Centered
F: Centered on A
P: number of datapoints - 100
I: number of datapoints - 128
F: number of datapoints - 32

Smallest number of sample points: 32

Total number of samples: 260



Retrieving MaterialsDoc documents: 100%|██████████| 32/32 [00:00<00:00, 745654.04it/s]
Retrieving MaterialsDoc documents: 100%|██████████| 32/32 [00:00<00:00, 741534.41it/s]
Retrieving MaterialsDoc documents: 100%|██████████| 32/32 [00:00<00:00, 741534.41it/s]


In [8]:
print(space_group_list[0:10])
print(system_list[0:10])
print(bravais_list.count("P"))

[221, 221, 218, 218, 221, 221, 212, 221, 221, 221]
['Cubic', 'Cubic', 'Cubic', 'Cubic', 'Cubic', 'Cubic', 'Cubic', 'Cubic', 'Cubic', 'Cubic']
32
