# **Imports**

In [37]:
from mp_api.client import MPRester
import csv
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import re
from collections import Counter
from IPython.display import display
import ast

In [82]:
pd.set_option('display.max_colwidth', None)

# **Data Download**

In [2]:
MP_API_KEY = "" # Give your key here. I have deleted it on purpose.

In [3]:
mpr = MPRester(MP_API_KEY)

In [4]:
results = mpr.materials.insertion_electrodes.search()

Retrieving InsertionElectrodeDoc documents:   0%|          | 0/5604 [00:00<?, ?it/s]

In [5]:
results_list = [result.dict() for result in results]
if results_list:
        keys = results_list[0].keys()

        # Save to a CSV file
        with open('insertion_electrodes_data_hash.csv', 'w', newline='') as output_file:
            dict_writer = csv.DictWriter(output_file, delimiter='#', fieldnames=keys)
            dict_writer.writeheader()
            dict_writer.writerows(results_list)

# **Dataset exploration**

In [161]:
df = pd.read_csv('insertion_electrodes_data_hash.csv',delimiter='#')

In [7]:
df.shape

(5604, 35)

In [8]:
# Displaying the first row in the specified format
first_row = df.iloc[0]  # Selects the first row
formatted_string = '\n'.join([f"{index}: {value}" for index, value in first_row.items()])
print(formatted_string)

battery_type: insertion
battery_id: mp-1022721_Al
thermo_type: nan
battery_formula: Al1-3Cu
working_ion: Al
num_steps: 1
max_voltage_step: 0.0
last_updated: 2025-02-05 18:56:15.307000+00:00
framework: {'Cu': 1.0}
framework_formula: Cu
elements: ['Cu']
nelements: 1
chemsys: Cu
formula_anonymous: A
formula_charge: AlCu
formula_discharge: Al3Cu
max_delta_volume: 1.243652762037662
average_voltage: -0.013976348999999
capacity_grav: 1112.936545960678
capacity_vol: 4418.979753458019
energy_grav: -15.554789581200644
energy_vol: -61.76120325826192
fracA_charge: 0.5
fracA_discharge: 0.75
stability_charge: 0.074061238
stability_discharge: 0.096245834999999
id_charge: mp-1022721
id_discharge: mp-1183161
host_structure: {'@module': 'pymatgen.core.structure', '@class': 'Structure', 'charge': 0, 'lattice': {'matrix': [[2.997509, 0.0, 0.0], [0.0, 2.997509, 0.0], [0.0, 0.0, 2.997509]], 'pbc': [True, True, True], 'a': 2.997509, 'b': 2.997509, 'c': 2.997509, 'alpha': 90.0, 'beta': 90.0, 'gamma': 90.0, 'v

In [49]:
# Displaying the first row in the specified format
first_row = df.iloc[5603]  # Selects the first row
formatted_string = '\n'.join([f"{index}: {value}" for index, value in first_row.items()])
print(formatted_string)

battery_type: insertion
battery_id: mp-997084_Li
thermo_type: nan
battery_formula: Li0-1RbAgO2
working_ion: Li
num_steps: 1
max_voltage_step: 0.0
last_updated: 2025-02-05 19:25:13.214000+00:00
framework: {'Rb': 1.0, 'Ag': 1.0, 'O': 2.0}
framework_formula: RbAgO2
elements: ['Rb', 'Ag', 'O']
nelements: 3
chemsys: Ag-O-Rb
formula_anonymous: ABC2
formula_charge: RbAgO2
formula_discharge: RbLiAgO2
max_delta_volume: 0.072649154894966
average_voltage: 1.255989316666666
capacity_grav: 115.38645500462766
capacity_vol: 558.5858568598221
energy_grav: 144.92415477385134
energy_vol: 701.5778686570324
fracA_charge: 0.0
fracA_discharge: 0.2
stability_charge: 0.048474679999999
stability_discharge: 0.206693272749999
id_charge: mp-997084
id_discharge: mp-1236284
host_structure: {'@module': 'pymatgen.core.structure', '@class': 'Structure', 'charge': 0, 'lattice': {'matrix': [[3.10698, 0.0, 0.0], [0.0, 3.94405, 0.0], [0.0, 0.0, 6.06149]], 'pbc': [True, True, True], 'a': 3.10698, 'b': 3.94405, 'c': 6.06149

In [60]:
df['host_structure'][344]

"{'@module': 'pymatgen.core.structure', '@class': 'Structure', 'charge': 0, 'lattice': {'matrix': [[-6.278e-05, 4.8132646, 4.8132909], [4.81326783, -6.604000000000001e-05, 4.81329416], [4.81319534, 4.8131964, 5.42e-06]], 'pbc': [True, True, True], 'a': 6.807002673833565, 'b': 6.80700726298748, 'c': 6.806886877714367, 'alpha': 60.000510736661774, 'beta': 60.00048094410407, 'gamma': 60.00070447368375, 'volume': 223.02381350935656}, 'properties': {}, 'sites': [{'species': [{'element': 'Mn', 'occu': 1}], 'abc': [0.625, 0.625, 0.125], 'properties': {}, 'label': 'Mn', 'xyz': [3.6099025737500003, 3.60989865, 6.01661634]}, {'species': [{'element': 'Mn', 'occu': 1}], 'abc': [0.625, 0.125, 0.625], 'properties': {}, 'label': 'Mn', 'xyz': [3.60986632875, 6.016529869999999, 3.6099719699999997]}, {'species': [{'element': 'Mn', 'occu': 1}], 'abc': [0.125, 0.625, 0.625], 'properties': {}, 'label': 'Mn', 'xyz': [6.0165316337500006, 3.6098645499999997, 3.6099735999999996]}, {'species': [{'element': 'Mn'

In [9]:
df.columns

Index(['battery_type', 'battery_id', 'thermo_type', 'battery_formula',
       'working_ion', 'num_steps', 'max_voltage_step', 'last_updated',
       'framework', 'framework_formula', 'elements', 'nelements', 'chemsys',
       'max_delta_volume', 'average_voltage', 'capacity_grav', 'capacity_vol',
       'energy_grav', 'energy_vol', 'fracA_charge', 'fracA_discharge',
       'stability_charge', 'stability_discharge', 'id_charge', 'id_discharge',
       'host_structure', 'adj_pairs', 'material_ids',
       'entries_composition_summary', 'electrode_object',
       'fields_not_requested'],
      dtype='object')

In [14]:
unique_working_ions = df["working_ion"].unique()

In [15]:
print(f"Unique working ions: {unique_working_ions}")
print(f"Unique working ions count: {len(unique_working_ions)}")

Unique working ions: ['Al' 'Li' 'Cs' 'Na' 'K' 'Ca' 'Rb' 'Y' 'Zn' 'Mg']
Unique working ions count: 10


# **Small dataset - creation and preprocessing**

## Creation and exploration

In [54]:
df.head(200).to_csv('first_20_rows_hash.csv', index=False, sep='#')

In [4]:
df_small = pd.read_csv('first_20_rows_hash.csv',delimiter='#')

In [56]:
df_small.shape

(200, 35)

In [57]:
# Displaying the first row in the specified format
first_row = df_small.iloc[0]  # Selects the first row
formatted_string = '\n'.join([f"{index}: {value}" for index, value in first_row.items()])
print(formatted_string)

battery_type: insertion
battery_id: mp-1022721_Al
thermo_type: nan
battery_formula: Al1-3Cu
working_ion: Al
num_steps: 1
max_voltage_step: 0.0
last_updated: 2025-02-05 18:56:15.307000+00:00
framework: {'Cu': 1.0}
framework_formula: Cu
elements: ['Cu']
nelements: 1
chemsys: Cu
formula_anonymous: A
formula_charge: AlCu
formula_discharge: Al3Cu
max_delta_volume: 1.243652762037662
average_voltage: -0.013976348999999
capacity_grav: 1112.936545960678
capacity_vol: 4418.979753458019
energy_grav: -15.554789581200644
energy_vol: -61.76120325826192
fracA_charge: 0.5
fracA_discharge: 0.75
stability_charge: 0.074061238
stability_discharge: 0.096245834999999
id_charge: mp-1022721
id_discharge: mp-1183161
host_structure: {'@module': 'pymatgen.core.structure', '@class': 'Structure', 'charge': 0, 'lattice': {'matrix': [[2.997509, 0.0, 0.0], [0.0, 2.997509, 0.0], [0.0, 0.0, 2.997509]], 'pbc': [True, True, True], 'a': 2.997509, 'b': 2.997509, 'c': 2.997509, 'alpha': 90.0, 'beta': 90.0, 'gamma': 90.0, 'v

In [None]:
# Displaying the first row in the specified format
first_row = df.iloc[0]  # Selects the first row
formatted_string = '\n'.join([f"{index}: {value}" for index, value in first_row.items()])
print(formatted_string)

## Processing

### working_ion

In [107]:
df_wi = pd.get_dummies(df_small['working_ion'], prefix='working_ion')

In [108]:
df_wi.shape

(200, 10)

In [109]:
df_wi.columns

Index(['working_ion_Al', 'working_ion_Ca', 'working_ion_Cs', 'working_ion_K',
       'working_ion_Li', 'working_ion_Mg', 'working_ion_Na', 'working_ion_Rb',
       'working_ion_Y', 'working_ion_Zn'],
      dtype='object')

In [110]:
df_wi

Unnamed: 0,working_ion_Al,working_ion_Ca,working_ion_Cs,working_ion_K,working_ion_Li,working_ion_Mg,working_ion_Na,working_ion_Rb,working_ion_Y,working_ion_Zn
0,True,False,False,False,False,False,False,False,False,False
1,True,False,False,False,False,False,False,False,False,False
2,True,False,False,False,False,False,False,False,False,False
3,True,False,False,False,False,False,False,False,False,False
4,True,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...
195,False,False,False,False,False,True,False,False,False,False
196,False,False,False,False,False,True,False,False,False,False
197,False,False,False,False,False,True,False,False,False,False
198,False,False,False,False,False,True,False,False,False,False


### formula_anonymous

In [111]:
df_fa = df_small[['formula_anonymous']]

In [112]:
df_fa.shape

(200, 1)

In [113]:
df_fa.iloc[198] = "A3B8C16"
df_fa.iloc[199] = "A2B3C12"


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_fa.iloc[198] = "A3B8C16"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_fa.iloc[199] = "A2B3C12"


In [114]:
df_fa

Unnamed: 0,formula_anonymous
0,A
1,A
2,A
3,A
4,A
...,...
195,A
196,A
197,A
198,A3B8C16


In [115]:
df_fa['formula_anonymous'].unique()

array(['A', 'A3B8C16', 'A2B3C12'], dtype=object)

In [116]:
unique_formulas = df_fa['formula_anonymous'].unique()

In [117]:
unique_formulas

array(['A', 'A3B8C16', 'A2B3C12'], dtype=object)

In [118]:
# Function to extract elements from a formula
def extract_elements(formula):
    elements = re.findall(r'[A-Z][a-z]*', formula)
    return elements

In [119]:
# Extract elements from all formulas
all_elements = set()  # Using a set to avoid duplicates
for formula in unique_formulas:
    all_elements.update(extract_elements(formula))

In [120]:
# Convert set to a sorted list 
predetermined_elements = sorted(all_elements)

In [121]:
print("Predetermined Elements:", predetermined_elements)

Predetermined Elements: ['A', 'B', 'C']


In [122]:
# Function to parse formula and extract counts
def parse_formula(formula):
    element_counts = {el: 0 for el in predetermined_elements}  # Initialize with zeroes
    matches = re.findall(r'([A-Z][a-z]*)(\d*)', formula)  # Extract elements and counts
    for element, count in matches:
        if element in element_counts:
            element_counts[element] = int(count) if count else 1  # Default to 1 if no count
    return list(element_counts.values())

In [123]:
# Apply function to each formula
encoded_features = np.array([parse_formula(f) for f in df_fa['formula_anonymous']])


In [124]:
# Convert to DataFrame
df_fa_2 = pd.DataFrame(encoded_features, columns=[f'FA_{el}' for el in predetermined_elements])

In [125]:
df_fa_2

Unnamed: 0,FA_A,FA_B,FA_C
0,1,0,0
1,1,0,0
2,1,0,0
3,1,0,0
4,1,0,0
...,...,...,...
195,1,0,0
196,1,0,0
197,1,0,0
198,3,8,16


### formula_charge & formula discharge

In [126]:
df_small['formula_charge']

0       AlCu
1         Mn
2         Mo
3         Re
4      Al10V
       ...  
195       Cr
196    CsMg7
197     MgCd
198     MgSn
199        C
Name: formula_charge, Length: 200, dtype: object

In [127]:
def enhanced_parse_chemical_formula(formula):
    """
    Enhanced parser for chemical formulas to handle elements, parentheses, and multipliers.
    Correctly accounts for cases like 'TiBi(PO4)3', identifying 3 Ps and 12 Os.
    """
    # Split formula into parts by parentheses, considering potential multipliers
    parts = re.split(r'(\([A-Za-z0-9]*\)\d*)', formula)
    element_counts = Counter()

    for part in parts:
        if part.startswith('('):
            # Extract content inside parentheses and the multiplier
            content, multiplier = re.match(r'\(([A-Za-z0-9]*)\)(\d+)', part).groups()
            multiplier = int(multiplier)
            # Parse the content inside parentheses
            sub_counts = Counter({element: int(count) * multiplier if count else multiplier 
                                  for element, count in re.findall(r'([A-Z][a-z]*)(\d*)', content)})
            element_counts.update(sub_counts)
        else:
            # Handle parts of the formula outside parentheses
            sub_counts = Counter({element: int(count) if count else 1 
                                  for element, count in re.findall(r'([A-Z][a-z]*)(\d*)', part)})
            element_counts.update(sub_counts)

    return element_counts

In [128]:
# Test the enhanced parser with a complex formula example
test_formula = "TiBi(PO4)3"
enhanced_parse_chemical_formula(test_formula)

Counter({'O': 12, 'P': 3, 'Ti': 1, 'Bi': 1})

In [129]:
# Test the enhanced parser with a complex formula example
test_formula = "Al10V"
enhanced_parse_chemical_formula(test_formula)

Counter({'Al': 10, 'V': 1})

In [130]:
# Test the enhanced parser with a complex formula example
test_formula = "NaCl"
enhanced_parse_chemical_formula(test_formula)

Counter({'Na': 1, 'Cl': 1})

In [131]:
# Test the enhanced parser with a complex formula example
test_formula = "NO"
enhanced_parse_chemical_formula(test_formula)

Counter({'N': 1, 'O': 1})

In [132]:
# Apply the function to both columns
df_charge_parsed = df_small['formula_charge'].apply(enhanced_parse_chemical_formula)
df_dicharge_parsed = df_small['formula_discharge'].apply(enhanced_parse_chemical_formula)


In [133]:
# Convert dictionaries to DataFrames
df_charge_df = pd.DataFrame(df_charge_parsed.tolist()).fillna(0)  # Convert to DF, fill NaN with 0
df_dicharge_df = pd.DataFrame(df_dicharge_parsed.tolist()).fillna(0)

In [134]:
# Rename columns to indicate charge/discharge
df_charge_df.columns = [f'charge_{col}' for col in df_charge_df.columns]
df_dicharge_df.columns = [f'dicharge_{col}' for col in df_dicharge_df.columns]

In [135]:
# Merge the processed features into a single DataFrame
df_fc_dc = pd.concat([df_small[['formula_charge', 'formula_discharge']], df_charge_df, df_dicharge_df], axis=1)


In [136]:
!pip install ace_tools



In [138]:
display(df_fc_dc)

Unnamed: 0,formula_charge,formula_discharge,charge_Al,charge_Cu,charge_Mn,charge_Mo,charge_Re,charge_V,charge_Sb,charge_Fe,...,dicharge_Sm,dicharge_Si,dicharge_O,dicharge_Ba,dicharge_Sr,dicharge_Ga,dicharge_Rh,dicharge_In,dicharge_Os,dicharge_Y
0,AlCu,Al3Cu,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Mn,MnAl12,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,Mo,Al12Mo,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Re,Al12Re,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Al10V,Al41V4,10.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,Cr,Mg3Cr,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
196,CsMg7,CsMg149,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
197,MgCd,Mg3Cd,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
198,MgSn,Mg3Sn,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [139]:
df_fc_dc.columns

Index(['formula_charge', 'formula_discharge', 'charge_Al', 'charge_Cu',
       'charge_Mn', 'charge_Mo', 'charge_Re', 'charge_V', 'charge_Sb',
       'charge_Fe',
       ...
       'dicharge_Sm', 'dicharge_Si', 'dicharge_O', 'dicharge_Ba',
       'dicharge_Sr', 'dicharge_Ga', 'dicharge_Rh', 'dicharge_In',
       'dicharge_Os', 'dicharge_Y'],
      dtype='object', length=108)

In [140]:
df_fc_dc = df_fc_dc.drop(['formula_charge', 'formula_discharge'], axis=1)

In [141]:
display(df_fc_dc)

Unnamed: 0,charge_Al,charge_Cu,charge_Mn,charge_Mo,charge_Re,charge_V,charge_Sb,charge_Fe,charge_W,charge_Sn,...,dicharge_Sm,dicharge_Si,dicharge_O,dicharge_Ba,dicharge_Sr,dicharge_Ga,dicharge_Rh,dicharge_In,dicharge_Os,dicharge_Y
0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,10.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
196,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
197,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
198,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [142]:
df_fc_dc.columns

Index(['charge_Al', 'charge_Cu', 'charge_Mn', 'charge_Mo', 'charge_Re',
       'charge_V', 'charge_Sb', 'charge_Fe', 'charge_W', 'charge_Sn',
       ...
       'dicharge_Sm', 'dicharge_Si', 'dicharge_O', 'dicharge_Ba',
       'dicharge_Sr', 'dicharge_Ga', 'dicharge_Rh', 'dicharge_In',
       'dicharge_Os', 'dicharge_Y'],
      dtype='object', length=106)

### host structure

#### Lattice

In [143]:
def extract_structure_features(struct_dict):
    try:
        lattice = struct_dict['lattice']
        # Scalar lattice parameters
        features = {
            'a': lattice.get('a', 0),
            'b': lattice.get('b', 0),
            'c': lattice.get('c', 0),
            'alpha': lattice.get('alpha', 0),
            'beta': lattice.get('beta', 0),
            'gamma': lattice.get('gamma', 0),
            'volume': lattice.get('volume', 0),
        }
        # Flatten lattice matrix
        matrix_flat = [val for row in lattice['matrix'] for val in row]
        matrix_features = {f'm{i+1}': val for i, val in enumerate(matrix_flat)}
        features.update(matrix_features)
        return features
    except Exception as e:
        # Return NaNs if something goes wrong
        return {key: float('nan') for key in ['a', 'b', 'c', 'alpha', 'beta', 'gamma', 'volume'] + [f'm{i+1}' for i in range(9)]}


In [144]:
# Parse the structure strings into dicts on-the-fly
structure_dicts = df_small['host_structure'].apply(
    lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)

In [145]:
# Now apply your existing feature extractor on this parsed series
structure_features_list = structure_dicts.apply(extract_structure_features)

In [146]:

# Convert list of dicts to DataFrame
df_lattice = pd.DataFrame(structure_features_list.tolist())

In [147]:
df_lattice

Unnamed: 0,a,b,c,alpha,beta,gamma,volume,m1,m2,m3,m4,m5,m6,m7,m8,m9
0,2.997509,2.997509,2.997509,90.000000,90.000000,90.000000,26.932799,2.997509,0.000000,0.000000,0.000000,2.997509,0.000000,0.000000,0.000000,2.997509
1,2.413221,2.413221,2.413221,109.471221,109.471221,109.471221,10.818557,-1.393274,-1.393274,1.393274,-1.393274,1.393274,-1.393274,1.393274,-1.393274,-1.393274
2,2.743238,2.743238,2.743238,109.471229,109.471231,109.471218,15.891628,2.586349,0.000000,-0.914413,-1.293175,2.239844,-0.914413,0.000000,0.000000,2.743238
3,2.703704,2.703704,2.703704,109.471221,109.471221,109.471221,15.214418,-1.560984,1.560984,1.560984,1.560984,-1.560984,1.560984,1.560984,1.560984,-1.560984
4,10.129421,10.129421,10.129420,59.999906,59.999908,59.999913,734.916418,8.772332,0.000005,5.064717,2.924116,8.270632,5.064717,0.000008,0.000006,10.129420
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,2.274707,2.700294,4.236838,90.000000,90.000000,114.910144,23.603206,0.000000,-2.274707,0.000000,-2.449084,1.137354,0.000000,0.000000,0.000000,-4.236838
196,7.015440,7.015440,5.445853,90.000000,90.000000,119.999998,232.116684,3.507720,-6.075549,0.000000,3.507720,6.075549,0.000000,0.000000,0.000000,5.445853
197,3.314288,3.314288,3.835538,90.000000,90.000000,90.000000,42.131485,3.314288,-0.000000,-0.000000,0.000000,3.314288,0.000000,0.000000,0.000000,3.835538
198,3.623148,3.623148,3.623148,90.000000,90.000000,90.000000,47.561787,3.623148,0.000000,-0.000000,0.000000,3.623148,-0.000000,0.000000,0.000000,3.623148


### sites

In [148]:
def extract_site_attributes(struct_str, max_sites=5):
    try:
        # Parse if the input is a string
        struct = ast.literal_eval(struct_str) if isinstance(struct_str, str) else struct_str
        sites = struct.get('sites', [])
        
        features = {}

        for i, site in enumerate(sites[:max_sites]):  # Only process up to max_sites
            prefix = f"site{i+1}_"
            species = site.get('species', [])
            
            # Get element and occupancy (handle multi-species by taking first only)
            if species:
                features[prefix + 'element'] = species[0].get('element', '')
                features[prefix + 'occu'] = species[0].get('occu', 0.0)
            else:
                features[prefix + 'element'] = ''
                features[prefix + 'occu'] = 0.0

            # Cartesian and fractional coordinates
            xyz = site.get('xyz', [float('nan')] * 3)
            abc = site.get('abc', [float('nan')] * 3)

            features[prefix + 'x'] = xyz[0]
            features[prefix + 'y'] = xyz[1]
            features[prefix + 'z'] = xyz[2]

            features[prefix + 'a'] = abc[0]
            features[prefix + 'b'] = abc[1]
            features[prefix + 'c'] = abc[2]

        # Fill remaining site slots with NaN/empty if fewer than max_sites
        for j in range(i+1, max_sites):
            prefix = f"site{j+1}_"
            features.update({
                prefix + 'element': '',
                prefix + 'occu': 0.0,
                prefix + 'x': float('nan'),
                prefix + 'y': float('nan'),
                prefix + 'z': float('nan'),
                prefix + 'a': float('nan'),
                prefix + 'b': float('nan'),
                prefix + 'c': float('nan'),
            })

        return features
    
    except Exception as e:
        print(f"Error: {e}")
        # Return empty features on failure
        base_keys = ['element', 'occu', 'x', 'y', 'z', 'a', 'b', 'c']
        return {
            f"site{i+1}_{k}": ('' if k == 'element' else float('nan')) 
            for i in range(max_sites) for k in base_keys
        }


In [149]:
max_sites = df_small['host_structure'].apply(
    lambda x: len(ast.literal_eval(x)['sites']) if isinstance(x, str) else len(x['sites'])
).max()

In [150]:
print(f"Maximum number of sites: {max_sites}")


Maximum number of sites: 120


In [151]:
site_features_list = df_small['host_structure'].apply(
    lambda x: extract_site_attributes(x, max_sites=max_sites)
)
df_sites = pd.DataFrame(site_features_list.tolist())

In [152]:
df_sites

Unnamed: 0,site1_element,site1_occu,site1_x,site1_y,site1_z,site1_a,site1_b,site1_c,site2_element,site2_occu,...,site119_b,site119_c,site120_element,site120_occu,site120_x,site120_y,site120_z,site120_a,site120_b,site120_c
0,Cu,1,1.498754,1.498754,1.498754,0.500000,0.500000,0.50,,0.0,...,,,,0.0,,,,,,
1,Mn,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.00,,0.0,...,,,,0.0,,,,,,
2,Mo,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.00,,0.0,...,,,,0.0,,,,,,
3,Re,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.00,,0.0,...,,,,0.0,,,,,,
4,V,1,0.000000,0.000000,0.000000,0.000000,-0.000000,-0.00,V,1.0,...,,,,0.0,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,Cr,1,-1.766762,-1.137354,-3.177628,0.860699,0.721397,0.75,Cr,1.0,...,,,,0.0,,,,,,
196,Cs,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.00,,0.0,...,,,,0.0,,,,,,
197,Cd,1,0.000000,0.000000,0.000000,0.000000,0.000000,-0.00,,0.0,...,,,,0.0,,,,,,
198,Sn,1,1.811574,1.811574,1.811574,0.500000,0.500000,0.50,,0.0,...,,,,0.0,,,,,,


In [153]:
df_sites.columns

Index(['site1_element', 'site1_occu', 'site1_x', 'site1_y', 'site1_z',
       'site1_a', 'site1_b', 'site1_c', 'site2_element', 'site2_occu',
       ...
       'site119_b', 'site119_c', 'site120_element', 'site120_occu',
       'site120_x', 'site120_y', 'site120_z', 'site120_a', 'site120_b',
       'site120_c'],
      dtype='object', length=960)

In [154]:
# Replace '' with 'None' in all site#_element columns
element_cols = [col for col in df_sites.columns if col.endswith('_element')]
df_sites[element_cols] = df_sites[element_cols].replace('', 'None')

In [155]:
encoded_parts = [
    pd.get_dummies(df_sites[col], prefix=col)
    for col in element_cols
]

In [156]:
df_sites_encoded = pd.concat([df_sites] + encoded_parts, axis=1)
df_sites_encoded.drop(columns=element_cols, inplace=True)

In [157]:
for cl in df_sites_encoded:
    print(cl)

site1_occu
site1_x
site1_y
site1_z
site1_a
site1_b
site1_c
site2_occu
site2_x
site2_y
site2_z
site2_a
site2_b
site2_c
site3_occu
site3_x
site3_y
site3_z
site3_a
site3_b
site3_c
site4_occu
site4_x
site4_y
site4_z
site4_a
site4_b
site4_c
site5_occu
site5_x
site5_y
site5_z
site5_a
site5_b
site5_c
site6_occu
site6_x
site6_y
site6_z
site6_a
site6_b
site6_c
site7_occu
site7_x
site7_y
site7_z
site7_a
site7_b
site7_c
site8_occu
site8_x
site8_y
site8_z
site8_a
site8_b
site8_c
site9_occu
site9_x
site9_y
site9_z
site9_a
site9_b
site9_c
site10_occu
site10_x
site10_y
site10_z
site10_a
site10_b
site10_c
site11_occu
site11_x
site11_y
site11_z
site11_a
site11_b
site11_c
site12_occu
site12_x
site12_y
site12_z
site12_a
site12_b
site12_c
site13_occu
site13_x
site13_y
site13_z
site13_a
site13_b
site13_c
site14_occu
site14_x
site14_y
site14_z
site14_a
site14_b
site14_c
site15_occu
site15_x
site15_y
site15_z
site15_a
site15_b
site15_c
site16_occu
site16_x
site16_y
site16_z
site16_a
site16_b
site16_c
site17_

In [158]:
df_sites[[f'site{i}_element' for i in range(1, 121)]].nunique()

site1_element      52
site2_element      24
site3_element       9
site4_element       9
site5_element       4
                   ..
site116_element     2
site117_element     2
site118_element     2
site119_element     2
site120_element     2
Length: 120, dtype: int64

### Final small dataset

In [159]:
# Concatenate all parts along columns
df_final = pd.concat([
    df_wi,
    df_fa_2,
    df_fc_dc,
    df_lattice,
    df_sites_encoded,
    df_small['average_voltage'],
    df_small['capacity_grav'],
    df_small['energy_grav'],
    df_small['max_delta_volume']
], axis=1)

In [160]:
# Save to CSV
df_final.to_csv('final_dataset_small.csv', index=False)

# **Full dataset creation**

### working_ion

In [162]:
df_wi = pd.get_dummies(df['working_ion'], prefix='working_ion')

In [163]:
df_wi.shape

(5604, 10)

In [164]:
df_wi.columns

Index(['working_ion_Al', 'working_ion_Ca', 'working_ion_Cs', 'working_ion_K',
       'working_ion_Li', 'working_ion_Mg', 'working_ion_Na', 'working_ion_Rb',
       'working_ion_Y', 'working_ion_Zn'],
      dtype='object')

In [165]:
df_wi

Unnamed: 0,working_ion_Al,working_ion_Ca,working_ion_Cs,working_ion_K,working_ion_Li,working_ion_Mg,working_ion_Na,working_ion_Rb,working_ion_Y,working_ion_Zn
0,True,False,False,False,False,False,False,False,False,False
1,True,False,False,False,False,False,False,False,False,False
2,True,False,False,False,False,False,False,False,False,False
3,True,False,False,False,False,False,False,False,False,False
4,True,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...
5599,False,False,False,False,True,False,False,False,False,False
5600,False,False,False,False,True,False,False,False,False,False
5601,False,False,False,False,True,False,False,False,False,False
5602,False,False,False,False,True,False,False,False,False,False


### formula_anonymous

In [166]:
df_fa = df[['formula_anonymous']]

In [167]:
df_fa.shape

(5604, 1)

In [168]:
df_fa

Unnamed: 0,formula_anonymous
0,A
1,A
2,A
3,A
4,A
...,...
5599,A2B3C12
5600,AB3C8
5601,ABC3
5602,A3B8C16


In [169]:
df_fa['formula_anonymous'].unique()

array(['A', 'AB3', 'AB2', 'AB', 'A2B3', 'A3B4', 'AB6', 'AB4', 'A5B8',
       'A2B9', 'A3B8', 'A3B5', 'AB8', 'A6B11', 'A3B7', 'A9B11', 'A8B17',
       'A11B17', 'A4B9', 'A2B5', 'A9B13', 'A5B7', 'AB5', 'A2B7', 'A7B16',
       'A6B13', 'A8B13', 'A15B19', 'A7B13', 'A11B24', 'A15B32', 'A12B29',
       'A9B22', 'A21B40', 'AB16', 'A7B12', 'A5B12', 'AB2C8', 'A2B3C8',
       'A3B4C12', 'A2B3C7', 'AB4C7', 'ABC4', 'AB2C12', 'A2B2C7', 'ABC5',
       'ABC6', 'AB6C14', 'AB3C4', 'AB4C12', 'AB2C7', 'AB5C15', 'AB2C6',
       'A3B5C19', 'A4B5', 'A19B30', 'AB2C4', 'AB6C15', 'ABC2', 'AB4C8',
       'AB3C9', 'A2B3C12', 'A4B7C24', 'AB2C3', 'A2B3C13', 'AB12C40',
       'A4B4C15', 'A3B3C11', 'A6B7C24', 'AB2C2', 'A3B10', 'AB7', 'A7B15',
       'A15B28', 'A13B28', 'A11B28', 'A5B16', 'A9B10', 'A9B20', 'A23B32',
       'A19B48', 'A2B2C9', 'A3B3C13', 'ABC7', 'AB4C10', 'A3B5C18',
       'A3B4C14', 'ABC', 'AB3C8', 'ABC3', 'A3B3C8', 'A3B5C5', 'A3B4C16',
       'A2B3C10', 'AB7C28', 'AB2C9', 'A3B5C16', 'A2B2C11', 'AB5C

In [170]:
unique_formulas = df_fa['formula_anonymous'].unique()

In [171]:
unique_formulas

array(['A', 'AB3', 'AB2', 'AB', 'A2B3', 'A3B4', 'AB6', 'AB4', 'A5B8',
       'A2B9', 'A3B8', 'A3B5', 'AB8', 'A6B11', 'A3B7', 'A9B11', 'A8B17',
       'A11B17', 'A4B9', 'A2B5', 'A9B13', 'A5B7', 'AB5', 'A2B7', 'A7B16',
       'A6B13', 'A8B13', 'A15B19', 'A7B13', 'A11B24', 'A15B32', 'A12B29',
       'A9B22', 'A21B40', 'AB16', 'A7B12', 'A5B12', 'AB2C8', 'A2B3C8',
       'A3B4C12', 'A2B3C7', 'AB4C7', 'ABC4', 'AB2C12', 'A2B2C7', 'ABC5',
       'ABC6', 'AB6C14', 'AB3C4', 'AB4C12', 'AB2C7', 'AB5C15', 'AB2C6',
       'A3B5C19', 'A4B5', 'A19B30', 'AB2C4', 'AB6C15', 'ABC2', 'AB4C8',
       'AB3C9', 'A2B3C12', 'A4B7C24', 'AB2C3', 'A2B3C13', 'AB12C40',
       'A4B4C15', 'A3B3C11', 'A6B7C24', 'AB2C2', 'A3B10', 'AB7', 'A7B15',
       'A15B28', 'A13B28', 'A11B28', 'A5B16', 'A9B10', 'A9B20', 'A23B32',
       'A19B48', 'A2B2C9', 'A3B3C13', 'ABC7', 'AB4C10', 'A3B5C18',
       'A3B4C14', 'ABC', 'AB3C8', 'ABC3', 'A3B3C8', 'A3B5C5', 'A3B4C16',
       'A2B3C10', 'AB7C28', 'AB2C9', 'A3B5C16', 'A2B2C11', 'AB5C

In [172]:
# Function to extract elements from a formula
def extract_elements(formula):
    elements = re.findall(r'[A-Z][a-z]*', formula)
    return elements

In [173]:
# Extract elements from all formulas
all_elements = set()  # Using a set to avoid duplicates
for formula in unique_formulas:
    all_elements.update(extract_elements(formula))

In [174]:
# Convert set to a sorted list 
predetermined_elements = sorted(all_elements)

In [175]:
print("Predetermined Elements:", predetermined_elements)

Predetermined Elements: ['A', 'B', 'C', 'D', 'E', 'F']


In [176]:
# Function to parse formula and extract counts
def parse_formula(formula):
    element_counts = {el: 0 for el in predetermined_elements}  # Initialize with zeroes
    matches = re.findall(r'([A-Z][a-z]*)(\d*)', formula)  # Extract elements and counts
    for element, count in matches:
        if element in element_counts:
            element_counts[element] = int(count) if count else 1  # Default to 1 if no count
    return list(element_counts.values())

In [177]:
# Apply function to each formula
encoded_features = np.array([parse_formula(f) for f in df_fa['formula_anonymous']])


In [178]:
# Convert to DataFrame
df_fa_2 = pd.DataFrame(encoded_features, columns=[f'FA_{el}' for el in predetermined_elements])

In [179]:
df_fa_2

Unnamed: 0,FA_A,FA_B,FA_C,FA_D,FA_E,FA_F
0,1,0,0,0,0,0
1,1,0,0,0,0,0
2,1,0,0,0,0,0
3,1,0,0,0,0,0
4,1,0,0,0,0,0
...,...,...,...,...,...,...
5599,2,3,12,0,0,0
5600,1,3,8,0,0,0
5601,1,1,3,0,0,0
5602,3,8,16,0,0,0


### formula_charge & formula discharge

In [180]:
df['formula_charge']

0               AlCu
1                 Mn
2                 Mo
3                 Re
4              Al10V
            ...     
5599    Li3Mn2(PO4)3
5600         Co3TeO8
5601    Li5Mn6(BO3)6
5602       La8Cu3O16
5603          RbAgO2
Name: formula_charge, Length: 5604, dtype: object

In [181]:
def enhanced_parse_chemical_formula(formula):
    """
    Enhanced parser for chemical formulas to handle elements, parentheses, and multipliers.
    Correctly accounts for cases like 'TiBi(PO4)3', identifying 3 Ps and 12 Os.
    """
    # Split formula into parts by parentheses, considering potential multipliers
    parts = re.split(r'(\([A-Za-z0-9]*\)\d*)', formula)
    element_counts = Counter()

    for part in parts:
        if part.startswith('('):
            # Extract content inside parentheses and the multiplier
            content, multiplier = re.match(r'\(([A-Za-z0-9]*)\)(\d+)', part).groups()
            multiplier = int(multiplier)
            # Parse the content inside parentheses
            sub_counts = Counter({element: int(count) * multiplier if count else multiplier 
                                  for element, count in re.findall(r'([A-Z][a-z]*)(\d*)', content)})
            element_counts.update(sub_counts)
        else:
            # Handle parts of the formula outside parentheses
            sub_counts = Counter({element: int(count) if count else 1 
                                  for element, count in re.findall(r'([A-Z][a-z]*)(\d*)', part)})
            element_counts.update(sub_counts)

    return element_counts

In [182]:
# Test the enhanced parser with a complex formula example
test_formula = "TiBi(PO4)3"
enhanced_parse_chemical_formula(test_formula)

Counter({'O': 12, 'P': 3, 'Ti': 1, 'Bi': 1})

In [183]:
# Test the enhanced parser with a complex formula example
test_formula = "Al10V"
enhanced_parse_chemical_formula(test_formula)

Counter({'Al': 10, 'V': 1})

In [184]:
# Test the enhanced parser with a complex formula example
test_formula = "NaCl"
enhanced_parse_chemical_formula(test_formula)

Counter({'Na': 1, 'Cl': 1})

In [185]:
# Test the enhanced parser with a complex formula example
test_formula = "NO"
enhanced_parse_chemical_formula(test_formula)

Counter({'N': 1, 'O': 1})

In [186]:
# Apply the function to both columns
df_charge_parsed = df['formula_charge'].apply(enhanced_parse_chemical_formula)
df_dicharge_parsed = df['formula_discharge'].apply(enhanced_parse_chemical_formula)


In [187]:
# Convert dictionaries to DataFrames
df_charge_df = pd.DataFrame(df_charge_parsed.tolist()).fillna(0)  # Convert to DF, fill NaN with 0
df_dicharge_df = pd.DataFrame(df_dicharge_parsed.tolist()).fillna(0)

In [188]:
# Rename columns to indicate charge/discharge
df_charge_df.columns = [f'charge_{col}' for col in df_charge_df.columns]
df_dicharge_df.columns = [f'dicharge_{col}' for col in df_dicharge_df.columns]

In [189]:
# Merge the processed features into a single DataFrame
df_fc_dc = pd.concat([df_small[['formula_charge', 'formula_discharge']], df_charge_df, df_dicharge_df], axis=1)


In [190]:
display(df_fc_dc)

Unnamed: 0,formula_charge,formula_discharge,charge_Al,charge_Cu,charge_Mn,charge_Mo,charge_Re,charge_V,charge_Sb,charge_Fe,...,dicharge_Au,dicharge_Ho,dicharge_Hf,dicharge_Er,dicharge_Ru,dicharge_Tm,dicharge_Eu,dicharge_Lu,dicharge_Dy,dicharge_Tc
0,AlCu,Al3Cu,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Mn,MnAl12,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,Mo,Al12Mo,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Re,Al12Re,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Al10V,Al41V4,10.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5599,,,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5600,,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5601,,,0.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5602,,,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [191]:
df_fc_dc.columns

Index(['formula_charge', 'formula_discharge', 'charge_Al', 'charge_Cu',
       'charge_Mn', 'charge_Mo', 'charge_Re', 'charge_V', 'charge_Sb',
       'charge_Fe',
       ...
       'dicharge_Au', 'dicharge_Ho', 'dicharge_Hf', 'dicharge_Er',
       'dicharge_Ru', 'dicharge_Tm', 'dicharge_Eu', 'dicharge_Lu',
       'dicharge_Dy', 'dicharge_Tc'],
      dtype='object', length=160)

In [192]:
df_fc_dc = df_fc_dc.drop(['formula_charge', 'formula_discharge'], axis=1)

In [193]:
display(df_fc_dc)

Unnamed: 0,charge_Al,charge_Cu,charge_Mn,charge_Mo,charge_Re,charge_V,charge_Sb,charge_Fe,charge_W,charge_Sn,...,dicharge_Au,dicharge_Ho,dicharge_Hf,dicharge_Er,dicharge_Ru,dicharge_Tm,dicharge_Eu,dicharge_Lu,dicharge_Dy,dicharge_Tc
0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,10.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5599,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5600,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5601,0.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5602,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [194]:
df_fc_dc.columns

Index(['charge_Al', 'charge_Cu', 'charge_Mn', 'charge_Mo', 'charge_Re',
       'charge_V', 'charge_Sb', 'charge_Fe', 'charge_W', 'charge_Sn',
       ...
       'dicharge_Au', 'dicharge_Ho', 'dicharge_Hf', 'dicharge_Er',
       'dicharge_Ru', 'dicharge_Tm', 'dicharge_Eu', 'dicharge_Lu',
       'dicharge_Dy', 'dicharge_Tc'],
      dtype='object', length=158)

### host structure

#### Lattice

In [195]:
def extract_structure_features(struct_dict):
    try:
        lattice = struct_dict['lattice']
        # Scalar lattice parameters
        features = {
            'a': lattice.get('a', 0),
            'b': lattice.get('b', 0),
            'c': lattice.get('c', 0),
            'alpha': lattice.get('alpha', 0),
            'beta': lattice.get('beta', 0),
            'gamma': lattice.get('gamma', 0),
            'volume': lattice.get('volume', 0),
        }
        # Flatten lattice matrix
        matrix_flat = [val for row in lattice['matrix'] for val in row]
        matrix_features = {f'm{i+1}': val for i, val in enumerate(matrix_flat)}
        features.update(matrix_features)
        return features
    except Exception as e:
        # Return NaNs if something goes wrong
        return {key: float('nan') for key in ['a', 'b', 'c', 'alpha', 'beta', 'gamma', 'volume'] + [f'm{i+1}' for i in range(9)]}


In [196]:
# Parse the structure strings into dicts on-the-fly
structure_dicts = df['host_structure'].apply(
    lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)

In [197]:
# Now apply your existing feature extractor on this parsed series
structure_features_list = structure_dicts.apply(extract_structure_features)

In [198]:

# Convert list of dicts to DataFrame
df_lattice = pd.DataFrame(structure_features_list.tolist())

In [199]:
df_lattice

Unnamed: 0,a,b,c,alpha,beta,gamma,volume,m1,m2,m3,m4,m5,m6,m7,m8,m9
0,2.997509,2.997509,2.997509,90.000000,90.000000,90.000000,26.932799,2.997509,0.000000,0.000000,0.000000,2.997509,0.000000,0.000000,0.000000,2.997509
1,2.413221,2.413221,2.413221,109.471221,109.471221,109.471221,10.818557,-1.393274,-1.393274,1.393274,-1.393274,1.393274,-1.393274,1.393274,-1.393274,-1.393274
2,2.743238,2.743238,2.743238,109.471229,109.471231,109.471218,15.891628,2.586349,0.000000,-0.914413,-1.293175,2.239844,-0.914413,0.000000,0.000000,2.743238
3,2.703704,2.703704,2.703704,109.471221,109.471221,109.471221,15.214418,-1.560984,1.560984,1.560984,1.560984,-1.560984,1.560984,1.560984,1.560984,-1.560984
4,10.129421,10.129421,10.129420,59.999906,59.999908,59.999913,734.916418,8.772332,0.000005,5.064717,2.924116,8.270632,5.064717,0.000008,0.000006,10.129420
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5599,9.156049,9.156049,16.562408,64.610198,64.610198,56.335864,1009.733957,4.322181,8.071678,0.000000,-4.322181,8.071678,0.000000,0.000000,8.055575,14.471388
5600,5.693795,5.693795,9.038161,90.000000,90.000000,119.999990,253.754893,2.846898,-4.930971,0.000000,2.846898,4.930971,0.000000,0.000000,0.000000,9.038161
5601,10.473469,8.811388,8.872691,107.886701,98.848306,98.504411,753.212469,10.473469,0.000000,0.000000,-1.303077,8.714502,0.000000,-1.364788,-2.959494,8.252478
5602,7.107274,7.540503,8.421401,116.529290,96.772879,105.506357,374.068318,7.107274,0.000000,0.000000,-2.015918,7.266035,0.000000,-0.993169,-4.179098,7.243532


### sites

In [200]:
def extract_site_attributes(struct_str, max_sites=5):
    try:
        # Parse if the input is a string
        struct = ast.literal_eval(struct_str) if isinstance(struct_str, str) else struct_str
        sites = struct.get('sites', [])
        
        features = {}

        for i, site in enumerate(sites[:max_sites]):  # Only process up to max_sites
            prefix = f"site{i+1}_"
            species = site.get('species', [])
            
            # Get element and occupancy (handle multi-species by taking first only)
            if species:
                features[prefix + 'element'] = species[0].get('element', '')
                features[prefix + 'occu'] = species[0].get('occu', 0.0)
            else:
                features[prefix + 'element'] = ''
                features[prefix + 'occu'] = 0.0

            # Cartesian and fractional coordinates
            xyz = site.get('xyz', [float('nan')] * 3)
            abc = site.get('abc', [float('nan')] * 3)

            features[prefix + 'x'] = xyz[0]
            features[prefix + 'y'] = xyz[1]
            features[prefix + 'z'] = xyz[2]

            features[prefix + 'a'] = abc[0]
            features[prefix + 'b'] = abc[1]
            features[prefix + 'c'] = abc[2]

        # Fill remaining site slots with NaN/empty if fewer than max_sites
        for j in range(i+1, max_sites):
            prefix = f"site{j+1}_"
            features.update({
                prefix + 'element': '',
                prefix + 'occu': 0.0,
                prefix + 'x': float('nan'),
                prefix + 'y': float('nan'),
                prefix + 'z': float('nan'),
                prefix + 'a': float('nan'),
                prefix + 'b': float('nan'),
                prefix + 'c': float('nan'),
            })

        return features
    
    except Exception as e:
        print(f"Error: {e}")
        # Return empty features on failure
        base_keys = ['element', 'occu', 'x', 'y', 'z', 'a', 'b', 'c']
        return {
            f"site{i+1}_{k}": ('' if k == 'element' else float('nan')) 
            for i in range(max_sites) for k in base_keys
        }


In [203]:
max_sites = df['host_structure'].apply(
    lambda x: len(ast.literal_eval(x)['sites']) if isinstance(x, str) else len(x['sites'])
).max()

In [204]:
print(f"Maximum number of sites: {max_sites}")


Maximum number of sites: 160


In [205]:
site_features_list = df['host_structure'].apply(
    lambda x: extract_site_attributes(x, max_sites=max_sites)
)
df_sites = pd.DataFrame(site_features_list.tolist())

In [206]:
df_sites

Unnamed: 0,site1_element,site1_occu,site1_x,site1_y,site1_z,site1_a,site1_b,site1_c,site2_element,site2_occu,...,site159_b,site159_c,site160_element,site160_occu,site160_x,site160_y,site160_z,site160_a,site160_b,site160_c
0,Cu,1,1.498754,1.498754,1.498754,0.500000,0.500000,0.500000,,0.0,...,,,,0.0,,,,,,
1,Mn,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,,0.0,...,,,,0.0,,,,,,
2,Mo,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,,0.0,...,,,,0.0,,,,,,
3,Re,1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,,0.0,...,,,,0.0,,,,,,
4,V,1,0.000000,0.000000,0.000000,0.000000,-0.000000,-0.000000,V,1.0,...,,,,0.0,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5599,Mn,1,-2.143763,7.702198,0.441392,0.213897,0.709888,0.030501,Mn,1.0,...,,,,0.0,,,,,,
5600,Co,1,2.846898,3.249253,1.954322,0.170526,0.829474,0.216230,Co,1.0,...,,,,0.0,,,,,,
5601,Mn,1,7.070290,-0.774553,6.767890,0.805527,0.189631,0.820104,Mn,1.0,...,,,,0.0,,,,,,
5602,La,1,1.408171,-0.697960,2.341899,0.268808,0.089895,0.323309,La,1.0,...,,,,0.0,,,,,,


In [207]:
df_sites.columns

Index(['site1_element', 'site1_occu', 'site1_x', 'site1_y', 'site1_z',
       'site1_a', 'site1_b', 'site1_c', 'site2_element', 'site2_occu',
       ...
       'site159_b', 'site159_c', 'site160_element', 'site160_occu',
       'site160_x', 'site160_y', 'site160_z', 'site160_a', 'site160_b',
       'site160_c'],
      dtype='object', length=1280)

In [208]:
# Replace '' with 'None' in all site#_element columns
element_cols = [col for col in df_sites.columns if col.endswith('_element')]
df_sites[element_cols] = df_sites[element_cols].replace('', 'None')

In [209]:
encoded_parts = [
    pd.get_dummies(df_sites[col], prefix=col)
    for col in element_cols
]

In [210]:
df_sites_encoded = pd.concat([df_sites] + encoded_parts, axis=1)
df_sites_encoded.drop(columns=element_cols, inplace=True)

In [211]:
for cl in df_sites_encoded:
    print(cl)

site1_occu
site1_x
site1_y
site1_z
site1_a
site1_b
site1_c
site2_occu
site2_x
site2_y
site2_z
site2_a
site2_b
site2_c
site3_occu
site3_x
site3_y
site3_z
site3_a
site3_b
site3_c
site4_occu
site4_x
site4_y
site4_z
site4_a
site4_b
site4_c
site5_occu
site5_x
site5_y
site5_z
site5_a
site5_b
site5_c
site6_occu
site6_x
site6_y
site6_z
site6_a
site6_b
site6_c
site7_occu
site7_x
site7_y
site7_z
site7_a
site7_b
site7_c
site8_occu
site8_x
site8_y
site8_z
site8_a
site8_b
site8_c
site9_occu
site9_x
site9_y
site9_z
site9_a
site9_b
site9_c
site10_occu
site10_x
site10_y
site10_z
site10_a
site10_b
site10_c
site11_occu
site11_x
site11_y
site11_z
site11_a
site11_b
site11_c
site12_occu
site12_x
site12_y
site12_z
site12_a
site12_b
site12_c
site13_occu
site13_x
site13_y
site13_z
site13_a
site13_b
site13_c
site14_occu
site14_x
site14_y
site14_z
site14_a
site14_b
site14_c
site15_occu
site15_x
site15_y
site15_z
site15_a
site15_b
site15_c
site16_occu
site16_x
site16_y
site16_z
site16_a
site16_b
site16_c
site17_

In [213]:
df_sites[[f'site{i}_element' for i in range(1, 161)]].nunique()

site1_element      74
site2_element      77
site3_element      74
site4_element      75
site5_element      69
                   ..
site156_element     2
site157_element     2
site158_element     2
site159_element     2
site160_element     2
Length: 160, dtype: int64

### Final dataset

In [217]:
# Concatenate all parts along columns
df_final = pd.concat([
    df_wi,
    df_fa_2,
    df_fc_dc,
    df_lattice,
    df_sites_encoded,
    df['average_voltage'],
    df['capacity_grav'],
    df['energy_grav'],
    df['max_delta_volume']
], axis=1)

In [218]:
# Save to CSV
df_final.to_csv('final_dataset.csv', index=False)

In [219]:
for column in df_final.columns:
    print(column)

working_ion_Al
working_ion_Ca
working_ion_Cs
working_ion_K
working_ion_Li
working_ion_Mg
working_ion_Na
working_ion_Rb
working_ion_Y
working_ion_Zn
FA_A
FA_B
FA_C
FA_D
FA_E
FA_F
charge_Al
charge_Cu
charge_Mn
charge_Mo
charge_Re
charge_V
charge_Sb
charge_Fe
charge_W
charge_Sn
charge_Ti
charge_Ag
charge_Bi
charge_Ce
charge_Ca
charge_C
charge_Li
charge_Cr
charge_Nb
charge_Cd
charge_Ac
charge_S
charge_Se
charge_Pt
charge_Cs
charge_Br
charge_H
charge_Te
charge_Co
charge_Na
charge_Ni
charge_K
charge_Ge
charge_U
charge_Zn
charge_Mg
charge_Hg
charge_Pm
charge_N
charge_Rb
charge_Tl
charge_Nd
charge_Pb
charge_Sm
charge_Si
charge_O
charge_Ba
charge_Sr
charge_Ga
charge_Rh
charge_In
charge_Os
charge_Y
charge_Ta
charge_Cl
charge_Be
charge_F
charge_Pd
charge_I
charge_Ir
charge_Zr
charge_B
charge_Sc
charge_Pr
charge_La
charge_P
charge_As
charge_Gd
charge_Tb
charge_Au
charge_Ho
charge_Hf
charge_Er
charge_Ru
charge_Tm
charge_Eu
charge_Lu
charge_Dy
charge_Tc
dicharge_Al
dicharge_Cu
dicharge_Mn
dicharge_M