## Define Chemical Space

In [None]:
import os
import sys
import json
from itertools import combinations_with_replacement
from collections import defaultdict
from math import comb
import glob

import numpy as np
import pandas as pd
from rdkit import Chem
import rdkit
rdkit.__version__

from architector import build_complex, io_core
from nglview import show_rdkit

#core_geometries = io_core.Geometries().cn_geo_dict
core_geometries = {
    1: ['single'], 
    2: ['bent_109', 'bent_120', 'linear'], 
    3: ['t_shaped', 'trigonal_planar', 'trigonal_pyramidal'], 
    4: ['seesaw', 'square_planar', 'tetrahedral'], 
    5: ['pentagonal_planar', 'square_pyramidal', 'trigonal_bipyramidal'], 
    6: ['hexagonal_planar', 'octahedral', 'pentagonal_pyramidal', 'trigonal_prismatic'], 
    7: ['capped_octahedral', 'capped_trigonal_prismatic', 'hexagonal_pyramidal', 'pentagonal_bipyramidal'], 
    8: ['axial_bicapped_trigonal_prismatic', 'bicapped_trigonal_prismatic', 'dodecahedral', 'hexagonal_bipyramidal', 'square_antiprismatic', 'square_prismatic'], 
    9: ['capped_square_antiprismatic', 'tri_tri_mer_capped', 'tricapped_trigonal_prismatic'], 
    10: ['axial_bicapped_hexagonal_planar'], 
    12: ['penta_bi_capped_pyramidal']
}

def view3D(mol):
    """Display an RDKit molecule in 3D using nglview"""
    return show_rdkit(mol)



In [2]:
#import mendeleev
#sorted([x.oxidation_state for x in mendeleev.element("Fe")._oxidation_states])
metal_ox = {
#    "Fe": [2],
    "Fe": [0, 1, 2, 3, 4, 5, 6, 7],
#    "Pd": [1, 2, 3, 4, 5],
#    "Zn": [0, 1, 2],
#    "Cu": [0, 1, 2, 3, 4],
#    "Li": [1],
#    "Mg": [0, 1, 2],
}

In [3]:
# Removed:
#     "hydride": {"smiles":"H", "coordList":[0], "ligType":"mono", "bondType": "X", "charge": 0},
ligands = {}
with open("../ligand_dictionaries/ligands.json", "r") as f:
    ligands.update(json.load(f))

for lig_name, tmp in ligands.items():
    print(lig_name, tmp["smiles"])
    
ligands_list = sorted(ligands.items())

choride Cl
methyl [C-1]([H])([H])[H]
methanediide [C-2]([H])[H]
ammonia [N][H][H][H]
amine [N-1]([H])[H]
imido [N-2][H]
water [O][H][H]
hydroxyl [O-1][H]
oxo [O-2]
phosphine P([H])([H])[H]
phosphido [P-1]([H])[H]
hydrogen sulfide [S]([H])[H]
thiol [S-1][H]
sulfido [S-2]


## Check Fringe Chemical Space

In [4]:
## Fluoride
#ligands = [x[0] for x in ligands_list]
#
#ox_state = 3
#inpDict = {
#    "core": {"metal": "Fe", 'coreCN': 6, "coreType": ['trigonal_prismatic']},
#    "ligands": [x[1] for lig in ligands for x in ligands_list if x[0] == lig],
#    'parameters':{
#        "metal_ox": ox_state,
#        'assemble_method':'GFN-FF', # Switch to GFN-FF for faster assembly, 
#        'n_symmetries': 100, # build 5 structures with different ligand arrangements, duplicates will be filtered
#        'n_conformers': 1,
#        'debug': True,
#        'save_init_geos': True,  # Save pre-optimization structure
##        'seed': 42,  # Also set this (though OpenBabel still has randomness)
#    },
#    
#}
#
#out_dict = build_complex(inpDict)
#for x in out_dict.keys():
#    print(x)

In [5]:
#out_dict

In [6]:
#for label, out in out_dict.items():
#    n_el = [out['xtb_n_unpaired_electrons'], out['calc_n_unpaired_electrons']]
#    if len(set(n_el)) > 1:
#        raise ValueError(f"N_unpaired_el should agree between keys: 'xtb_n_unpaired_electrons', 'calc_n_unpaired_electrons', 'metal_spin', resulting {n_el}")
#    chg = [out['total_charge'], out['xtb_total_charge']]
#    if len(set(chg)) > 1:
#        raise ValueError(f"Charge should agree between keys: 'total_charge', 'xtb_total_charge', resulting {chg}")
#    if ox_state != out["metal_ox"]:
#        raise ValueError("Metal oxidation state is not equal to assigned.")

In [7]:
#list(out_dict.values())[0]

In [8]:
#mol = Chem.MolFromMol2Block(out['mol2string'], removeHs=False)
#show_rdkit(mol)

In [9]:
#sys.exit()

## Generate Dataset

### Check Geometry of TMOS

In [10]:
import pickle

# Checkpoint file to save progress
CHECKPOINT_FILE = "dataset_checkpoint.pkl"
PROGRESS_LOG = "dataset_progress.txt"

def save_checkpoint(data_dict, progress_state):
    """Save current dataframe and progress state"""
    with open(CHECKPOINT_FILE, 'wb') as f:
        pickle.dump({'data': data_dict, 'state': progress_state}, f)
    print(f"✓ Checkpoint saved: {len(data_dict['metals'])} complexes", flush=True)

def load_checkpoint():
    """Load previous checkpoint if it exists"""
    try:
        with open(CHECKPOINT_FILE, 'rb') as f:
            checkpoint = pickle.load(f)
            print(f"✓ Loaded checkpoint: {len(checkpoint['data']['metals'])} complexes")
            return checkpoint['data'], checkpoint['state']
    except FileNotFoundError:
        print("No checkpoint found, starting fresh")
        return None, None

def make_state_key(metal, ox_state, cn, ligand_labels, coreType):
    """Create unique key for current state"""
    return f"{metal}_{ox_state}_{cn}_{coreType}_{'_'.join(sorted(ligand_labels))}"

### Strategy for Handling Kernel Crashes

Since the kernel can crash during `build_complex` (likely due to underlying C/Fortran code in xtb), we need to:
1. Save progress incrementally to disk
2. Restart from the last saved checkpoint if the kernel crashes

In [None]:

coordination_numbers = list(range(2, 13)) # to 13
exclude_functional_inds = True
errors = defaultdict(list)

# Load checkpoint if exists
checkpoint_data, last_state = load_checkpoint()
if checkpoint_data:
    df_metals = checkpoint_data['metals']
    df_cn = checkpoint_data['cn']
    df_ligands = checkpoint_data['ligands']
    df_ox = checkpoint_data['ox']
    df_chg = checkpoint_data['chg']
    df_multiplicity = checkpoint_data['multiplicity']
    df_mol2 = checkpoint_data['mol2']
    df_geo = checkpoint_data['geo']
    df_energy = checkpoint_data['energy']
    errors = checkpoint_data.get('errors', defaultdict(list))
    print(f"Resuming from checkpoint with {len(df_metals)} existing complexes")
else:
    # Start fresh
    df_metals, df_cn, df_ligands, df_ox, df_chg, df_multiplicity, df_mol2, df_geo, df_energy = [], [], [], [], [], [], [], [], []
    last_state = None

# Track processed states to avoid duplicates
if checkpoint_data:
    processed_states = set(make_state_key(
        checkpoint_data['metals'][i], 
        checkpoint_data['ox'][i], 
        checkpoint_data['cn'][i], 
        checkpoint_data['ligands'][i], 
        checkpoint_data['geo'][i],
    ) for i in range(len(checkpoint_data)))
else:
    processed_states = set()

checkpoint_counter = 0
CHECKPOINT_EVERY = 5  # Save every N successful complexes

n_metal_ox = sum(len(x) for x in metal_ox.values())
n_coord_geo = sum(comb(len(ligands_list) + cn - 1, cn) for cn in coordination_numbers)
print(f"Analyzing {n_metal_ox * n_coord_geo - len(df_metals)} structures")

for metal, ox_states in metal_ox.items():
    for ox_state in ox_states:
        for cn in coordination_numbers:
            for combo in combinations_with_replacement(ligands_list, cn):
                for coreType in core_geometries[cn]:
                    ligand_labels = [x[0] for x in combo]
                    state_key = make_state_key(metal, ox_state, cn, ligand_labels, coreType)
                    
                    # Skip if already processed
                    if state_key in processed_states:
                        continue
                    
                    ligand_dicts = [x[1] for x in combo]
                    if exclude_functional_inds:
                        ligand_dicts = [{k: v for k, v in lig.items() if k != 'functional_inds'} for lig in ligand_dicts]

                    # Log to file in case kernel dies
                    with open(PROGRESS_LOG, 'a') as f:
                        f.write(f"STARTING: {state_key}\n")
                    
                    print(f"\nProcessing: {metal=}, {ox_state=}, {cn=}, {coreType=}, ligands={ligand_labels}", flush=True)
                    
                    inpDict = {
                        "core": {"metal": metal, 'coreCN': cn, "coreType": [coreType]},
                        "ligands": ligand_dicts,
                        'parameters':{
                            "metal_ox": ox_state,
                            'assemble_method':'GFN-FF', # Switch to GFN-FF for faster assembly, 
                            'n_symmetries': 100, # build 5 structures with different ligand arrangements, duplicates will be filtered
                            'n_conformers': 1,
                            'return_only_1':True # Return just one geometry with lowest energy
                        },
                    }

                    try:
                        print("  Calling build_complex...", flush=True)
                        out = next(iter(build_complex(inpDict).values()))
                        print("  build_complex completed", flush=True)
                        
                        n_el = [out['xtb_n_unpaired_electrons'], out['calc_n_unpaired_electrons']]
                        if len(set(n_el)) > 1:
                            raise ValueError(f"N_unpaired_el should agree between keys: 'xtb_n_unpaired_electrons', 'calc_n_unpaired_electrons', 'metal_spin', resulting {n_el}")
                        chg = [out['total_charge'], out['xtb_total_charge']]
                        if len(set(chg)) > 1:
                            raise ValueError(f"Charge should agree between keys: 'total_charge', 'xtb_total_charge', resulting {chg}")
                        if ox_state != out["metal_ox"]:
                            raise ValueError("Metal oxidation state is not equal to assigned.")
                    except Exception as e:
                        error_msg = f"{type(e).__name__}: {str(e)[:100]}"
                        print(f"  ERROR: {error_msg}", flush=True)
                        errors[str(e)[:30]].append([metal, ox_state, cn, ligand_labels, coreType, e])
                        with open(PROGRESS_LOG, 'a') as f:
                            f.write(f"ERROR: {state_key} - {error_msg}\n")
                    else:
                        print("  SUCCESS", flush=True)
                        df_metals.append(metal)
                        df_ox.append(ox_state)
                        df_cn.append(cn)
                        df_geo.append(coreType)
                        df_ligands.append(ligand_labels)
                        df_chg.append(out['total_charge'])
                        df_multiplicity.append(out['xtb_n_unpaired_electrons'] + 1)
                        df_mol2.append(out['mol2string'])
                        df_energy.append(out['energy'])
                        
                        checkpoint_counter += 1
                        
                        # Save checkpoint periodically
                        if checkpoint_counter >= CHECKPOINT_EVERY:
                            data_dict = {
                                'metals': df_metals, 'cn': df_cn, 'ligands': df_ligands,
                                'ox': df_ox, 'chg': df_chg, 'multiplicity': df_multiplicity,
                                'mol2': df_mol2, 'geo': df_geo, 'energy': df_energy,
                                'errors': errors
                            }
                            save_checkpoint(data_dict, state_key)
                            checkpoint_counter = 0
                        
                        with open(PROGRESS_LOG, 'a') as f:
                            f.write(f"SUCCESS: {state_key}\n")
                    
                    processed_states.add(state_key)


# Final checkpoint
data_dict = {
    'metals': df_metals, 'cn': df_cn, 'ligands': df_ligands,
    'ox': df_ox, 'chg': df_chg, 'multiplicity': df_multiplicity,
    'mol2': df_mol2, 'geo': df_geo, 'energy': df_energy,
    'errors': errors
}
save_checkpoint(data_dict, "COMPLETED")

df = pd.DataFrame({
    "metal": df_metals,
    "oxidation_state": df_ox,
    "coordination_number": df_cn,
    "geometry": df_geo,
    "ligand_names": df_ligands,
    "total_charge": df_chg,
    "multiplicity": df_multiplicity,
    "mol2string": df_mol2,
    "xtb_energy": df_energy,
}) # Create a dataframe

print(f"\nGenerated {len(df)} complexes with {len(errors)} error types.")

No checkpoint found, starting fresh
Analyzing 14520 structures

Processing: metal='Fe', ox_state=0, cn=2, coreType='bent_109', ligands=['amine', 'amine']
  Calling build_complex...
  build_complex completed
  SUCCESS

Processing: metal='Fe', ox_state=0, cn=2, coreType='bent_120', ligands=['amine', 'amine']
  Calling build_complex...
  build_complex completed
  SUCCESS

Processing: metal='Fe', ox_state=0, cn=2, coreType='linear', ligands=['amine', 'amine']
  Calling build_complex...
  build_complex completed
  SUCCESS

Processing: metal='Fe', ox_state=0, cn=2, coreType='bent_109', ligands=['amine', 'fluoride']
  Calling build_complex...
  ERROR: StopIteration: 

Processing: metal='Fe', ox_state=0, cn=2, coreType='bent_120', ligands=['amine', 'fluoride']
  Calling build_complex...
  ERROR: StopIteration: 

Processing: metal='Fe', ox_state=0, cn=2, coreType='linear', ligands=['amine', 'fluoride']
  Calling build_complex...
  ERROR: StopIteration: 

Processing: metal='Fe', ox_state=0, cn=2

KeyboardInterrupt: 

In [None]:
len(df)

310

In [None]:
keys = list(errors.keys())
keys

['']

In [None]:
df.head()

Unnamed: 0,metal,oxidation_state,coordination_number,geometry,ligand_names,total_charge,multiplicity,mol2string,xtb_energy
0,Fe,2,2,bent_109,"[amine, amine]",2,5,@<TRIPOS>MOLECULE\nbent_109_0_nunpairedes_4_ch...,-298.173926
1,Fe,2,2,bent_120,"[amine, amine]",2,5,@<TRIPOS>MOLECULE\nbent_109_0_nunpairedes_4_ch...,-298.174912
2,Fe,2,2,linear,"[amine, amine]",2,5,@<TRIPOS>MOLECULE\nbent_109_0_nunpairedes_4_ch...,-298.173683
3,Fe,2,2,bent_109,"[amine, bromide]",2,5,@<TRIPOS>MOLECULE\nbent_109_0_nunpairedes_4_ch...,-301.972133
4,Fe,2,2,bent_120,"[amine, bromide]",2,5,@<TRIPOS>MOLECULE\nbent_109_0_nunpairedes_4_ch...,-301.971854


In [None]:
# Analyze each column independently for unique values and their counts
columns_to_analyze = ['metal', 'oxidation_state', 'coordination_number', 'geometry', 'ligand_names', 'total_charge', 'multiplicity']

for col in columns_to_analyze:
    print(f"\nColumn: {col}")
    print(f"{'='*60}")
    value_counts = df[col].value_counts().sort_index()
    print(value_counts)
    print(f"Total unique values: {len(value_counts)}")


Column: metal
metal
Fe    310
Name: count, dtype: int64
Total unique values: 1

Column: oxidation_state
oxidation_state
0      1
2    309
Name: count, dtype: int64
Total unique values: 2

Column: coordination_number
coordination_number
2    310
Name: count, dtype: int64
Total unique values: 1

Column: geometry
geometry
bent_109    104
bent_120    103
linear      103
Name: count, dtype: int64
Total unique values: 3

Column: ligand_names
ligand_names
[amine, amine]                              4
[amine, bromide]                            3
[amine, choride]                            3
[amine, dimethylamine]                      3
[amine, fluoride]                           3
                                           ..
[phosphine, thiol]                          3
[phosphine, trimethylphosphine]             3
[thiol, thiol]                              3
[thiol, trimethylphosphine]                 3
[trimethylphosphine, trimethylphosphine]    3
Name: count, Length: 103, dtype: int64
T