In [2]:
import pickle
import numpy as np

# Assert that each record in the list has "angles" property
with open('combined_train.pickle', 'rb') as file:
    crystal_dict_list = pickle.load(file)
    assert all('angles' in record for record in crystal_dict_list)
    assert all(len(record['angles']) == 3 for record in crystal_dict_list)
    
    # Extract HOA and zeolite codes
    hoa = np.array([entry['hoa'] for entry in crystal_dict_list])
    zeo_code = np.array([entry['zeolite_code'] for entry in crystal_dict_list])

    # Find unique zeolite codes
    unique_zeo_codes = np.unique(zeo_code)

    # Find the maximum HOA for each zeolite type
    mean_hoa_per_zeo_code = {}
    std_hoa_per_zeo_code = {}
    for code in unique_zeo_codes:
        # Get the HOA values corresponding to the current zeolite code
        hoa_values_for_code = np.array([entry['hoa'] for entry in crystal_dict_list if entry['zeolite_code'] == code])
        mean_hoa_per_zeo_code[code] = np.mean(hoa_values_for_code)
        std_hoa_per_zeo_code[code] = np.std(hoa_values_for_code)

        print(code, mean_hoa_per_zeo_code[code], std_hoa_per_zeo_code[code])

    # Make sure for each crystal we have at least 60% Si atoms
    # Anything else should indicate potential issues
    for entry in crystal_dict_list:
        atoms = np.array(entry['atom_types'])
        atoms = atoms - 13
        ratio = atoms.mean()
        print(ratio)
        # assert ratio > 0.6



    # Add normalized HOA
    for entry in crystal_dict_list:
        mean_hoa = mean_hoa_per_zeo_code[entry['zeolite_code']]
        std_hoa = std_hoa_per_zeo_code[entry['zeolite_code']]
        entry['norm_hoa'] = (entry['hoa'] - mean_hoa) / std_hoa

    # Get all BEC zeolites from the dataset
    # bec_zeolites = [entry for entry in crystal_dict_list if entry['zeolite_code'] == 'BEC']
    # print(bec_zeolites)

BEC 36.12035214723927 5.207546641542933
CHA 35.43490182926829 4.046374741553484
DDRch1 37.89276666666667 4.306715302743948
DDRch2 40.20668737864078 4.25027433708935
ERI 37.02220609756098 4.253100323231868
FAU 33.67515588235294 3.4331279807629267
FAUch 32.401106875 4.187938209297057
FER 45.878861585365854 4.099577084886531
HEU 43.33665975609756 7.555871471825489
ITW 46.42235283018868 4.629900153004801
LTA 33.60020059523809 3.6691072456307574
LTL 31.714803048780485 2.7195185525630468
MEL 48.42311804511278 5.110353257742805
MELch 39.92625342465754 5.4316184147370095
MER 37.98797423312883 4.294771470335011
MFI 45.58538013605443 10.22711065082431
MOR 40.5782435 3.0957076050812278
MTW 44.28428242424242 6.362011254512175
NAT 45.265499999999996 3.2903534517723716
RHO 29.31394279835391 3.278250677822803
TON 50.30659677419355 5.876923297730619
TON2 47.05439310344827 5.466002888824894
TON3 48.96967741935484 5.753205118955906
TON4 51.114217741935484 5.642312667335579
TONch 49.81748518518518 8.3985

AssertionError: 

In [1]:
import pickle

# Find all unique zeolite codes
with open('combined_train.pickle', 'rb') as file:
    data = pickle.load(file)
    unique_zeolite_codes = {}
    for record in data:
        unique_zeolite_codes.update({record["zeolite_code"]: record['lengths']})
    print(len(unique_zeolite_codes))


print(unique_zeolite_codes)
print(len(unique_zeolite_codes))

unique_zeolite_codes_mapping = {code: i for i, code in enumerate(unique_zeolite_codes)}
print(unique_zeolite_codes_mapping)

26
{'DDRch1': [13.795, 13.795, 40.75], 'DDRch2': [13.795, 13.795, 40.75], 'FAU': [24.345, 24.345, 24.345], 'FAUch': [24.345, 24.345, 24.345], 'ITW': [10.45, 8.954, 8.954], 'MEL': [20.27, 20.27, 13.459], 'MELch': [20.27, 20.27, 13.459], 'MFI': [20.09, 19.738, 13.142], 'MOR': [18.256, 20.534, 7.542], 'RHO': [15.031, 15.031, 15.031], 'TON': [14.1, 17.84, 5.25], 'TON2': [14.105, 17.842, 5.256], 'TON3': [14.105, 17.842, 5.256], 'TON4': [14.105, 17.842, 5.256], 'TONch': [14.105, 17.842, 5.256], 'BEC': [12.77, 12.77, 12.977], 'CHA': [13.675, 13.675, 14.767], 'ERI': [13.054, 13.054, 15.175], 'FER': [19.018, 14.303, 7.541], 'HEU': [17.523, 17.644, 7.401], 'LTA': [11.919, 11.919, 11.919], 'LTL': [18.126, 18.126, 7.567], 'MER': [14.012, 14.012, 9.954], 'MTW': [25.552, 5.256, 12.117], 'NAT': [13.85, 13.85, 6.42], 'YFI': [18.181, 31.841, 12.641]}
26
{'DDRch1': 0, 'DDRch2': 1, 'FAU': 2, 'FAUch': 3, 'ITW': 4, 'MEL': 5, 'MELch': 6, 'MFI': 7, 'MOR': 8, 'RHO': 9, 'TON': 10, 'TON2': 11, 'TON3': 12, 'TON4