# Notebook 1: Descriptor Extraction
Extracts orbital-weight-based descriptors from `vasprun.xml` for all Rashba compounds.

**Descriptor Types:**
- **A** `WM_total`: sum(w_X * M_X) at VBM/CBM
- **B** `WM_p_only`: sum(p_X * M_X) using only p-orbital contributions
- **C** `WM_p_frac`: p_i*M_i / sum(p_j*M_j) for top 2 heaviest elements
- **D** `WM_indiv`: w_heavy1 * M_heavy1, w_heavy2 * M_heavy2 individually
- **E** `p_frac`: total p-orbital fraction
- **F** `p_heavy`: p_frac_of_heaviest * M_heaviest for top 2

**Windows:** 0.05, 0.1, 0.5 eV

**Target:** max(Rashba_parameter) per UID from rashba.csv

**Output CSVs:** one per (type, window) combo + combined CSVs

In [None]:
import numpy as np
import pandas as pd
import os
import glob
import warnings
from pathlib import Path
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.core import Spin, OrbitalType
from pymatgen.core.periodic_table import Element

warnings.filterwarnings('ignore')

In [None]:
# ============================================================
# PATHS - ADJUST THESE
# ============================================================
BASE_DIR = r"C:\Users\AbCMS_Lab\Desktop\Keshav-DDP"
RASHBA_CSV = os.path.join(BASE_DIR, "Data", "rashba.csv")
RASHBA_DIR = os.path.join(BASE_DIR, "Inverse-design", "rashba")
OUTPUT_DIR = os.path.join(BASE_DIR, "Weight-contribution", "contribution-model")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Windows
WINDOWS = [0.05, 0.1, 0.5, 1.0]

print(f"Rashba CSV: {RASHBA_CSV}")
print(f"Rashba compounds dir: {RASHBA_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

In [None]:
# ============================================================
# LOAD RASHBA CSV & GET TARGET (max alpha_R per UID)
# ============================================================
df_rashba = pd.read_csv(RASHBA_CSV)
print(f"Total rows in rashba.csv: {len(df_rashba)}")
print(f"Unique UIDs: {df_rashba['uid'].nunique()}")
print(f"Columns: {list(df_rashba.columns)}")

# Max Rashba parameter per UID
target = df_rashba.groupby('uid')['Rashba_parameter'].max().reset_index()
target.columns = ['uid', 'alpha_R_max']

# Also keep formula for reference
uid_formula = df_rashba.groupby('uid')['Formula'].first().reset_index()
target = target.merge(uid_formula, on='uid')

print(f"\nTarget: {len(target)} compounds with max alpha_R")
print(target.head())

In [None]:
# ============================================================
# FIND VASPRUN FILES FOR EACH UID
# ============================================================
# Folder structure: rashba/AsBrTe-671e6de2497a/ss_2d%2FAsBrTe-671e6de2497a%2Fbands_ncl%2Fvasprun.xml
# UID in CSV: 671e6de2497a (last part after dash)

compound_dirs = {}
for folder in os.listdir(RASHBA_DIR):
    folder_path = os.path.join(RASHBA_DIR, folder)
    if not os.path.isdir(folder_path):
        continue
    
    # Build expected vasprun path: ss_2d%2F{folder}%2Fbands_ncl%2Fvasprun.xml
    vasprun_name = f"ss_2d%2F{folder}%2Fbands_ncl%2Fvasprun.xml"
    vasprun_path = os.path.join(folder_path, vasprun_name)
    
    # Also try glob as fallback
    if not os.path.exists(vasprun_path):
        candidates = glob.glob(os.path.join(folder_path, "*vasprun*"))
        if candidates:
            vasprun_path = candidates[0]
        else:
            # Try recursive
            candidates = glob.glob(os.path.join(folder_path, "**", "vasprun.xml"), recursive=True)
            if candidates:
                vasprun_path = candidates[0]
            else:
                continue
    
    # Extract UID: everything after the LAST dash
    # e.g., AsBrTe-671e6de2497a -> 671e6de2497a
    # e.g., Bi2P2S6-287dcf4f1a19 -> 287dcf4f1a19
    last_dash = folder.rfind('-')
    if last_dash >= 0:
        uid_from_folder = folder[last_dash+1:]
    else:
        uid_from_folder = folder
    
    compound_dirs[uid_from_folder] = {
        'folder': folder,
        'vasprun': vasprun_path
    }

print(f"Found {len(compound_dirs)} compound directories with vasprun.xml")

# Match with target UIDs from CSV
matched = 0
unmatched_uids = []
for _, row in target.iterrows():
    uid = row['uid']
    # Try direct match
    if uid in compound_dirs:
        matched += 1
        continue
    # Try: CSV uid might be longer/shorter than folder uid
    # e.g., CSV has 'da5fd2bb4' but folder has 'da5fd2bb4xxx' or vice versa
    found = False
    for folder_uid in list(compound_dirs.keys()):
        if uid.startswith(folder_uid) or folder_uid.startswith(uid):
            compound_dirs[uid] = compound_dirs[folder_uid]
            matched += 1
            found = True
            break
    if not found:
        unmatched_uids.append((uid, row['Formula']))

print(f"Matched: {matched}/{len(target)}")
if unmatched_uids:
    print(f"\nUnmatched ({len(unmatched_uids)}):")
    for uid, formula in unmatched_uids[:15]:
        print(f"  {formula} - {uid}")

# Show a few matches for sanity check
print(f"\nSample matches:")
for i, (uid, info) in enumerate(list(compound_dirs.items())[:3]):
    print(f"  UID: {uid} -> {info['folder']}")
    print(f"    vasprun: {os.path.basename(info['vasprun'])}")
    print(f"    exists: {os.path.exists(info['vasprun'])}")

In [None]:
# ============================================================
# CORE FUNCTIONS
# ============================================================

def get_dos_array(dos_obj):
    """Extract DOS values, summing spin channels for SOC."""
    if Spin.down in dos_obj.densities:
        return dos_obj.densities[Spin.up] + dos_obj.densities[Spin.down]
    return dos_obj.densities[Spin.up]

def integrate_window(energies, dos_vals, window):
    """Integrate DOS in energy window using trapezoidal rule."""
    mask = (energies >= window[0]) & (energies <= window[1])
    e_w = energies[mask]
    d_w = dos_vals[mask]
    if len(e_w) < 2:
        return 0.0
    return np.trapezoid(d_w, e_w)

def extract_contributions(vasprun_path, window_size):
    """
    Parse vasprun.xml and extract orbital/atomic contributions at VBM/CBM.
    Returns dict with per-element per-orbital contributions (normalized).
    """
    vr = Vasprun(vasprun_path, parse_dos=True, parse_eigen=False)
    cdos = vr.complete_dos
    e_fermi = vr.efermi
    
    vbm = cdos.get_cbm_vbm()[1]
    cbm = cdos.get_cbm_vbm()[0]
    
    vbm_s = vbm - e_fermi
    cbm_s = cbm - e_fermi
    energies = cdos.energies - e_fermi
    
    vbm_win = [vbm_s - window_size, vbm_s]
    cbm_win = [cbm_s, cbm_s + window_size]
    
    element_dos = cdos.get_element_dos()
    orbital_map = {'s': OrbitalType.s, 'p': OrbitalType.p, 'd': OrbitalType.d}
    
    # Raw contributions
    raw = {}
    for element in element_dos:
        el = str(element)
        mass = Element(el).atomic_mass
        Z = Element(el).Z
        raw[el] = {'mass': float(mass), 'Z': Z}
        
        spd = cdos.get_element_spd_dos(element)
        for orb_str in ['s', 'p', 'd']:
            orb_type = orbital_map[orb_str]
            if orb_type in spd:
                dos_vals = get_dos_array(spd[orb_type])
                raw[el][f'{orb_str}_VBM'] = integrate_window(energies, dos_vals, vbm_win)
                raw[el][f'{orb_str}_CBM'] = integrate_window(energies, dos_vals, cbm_win)
            else:
                raw[el][f'{orb_str}_VBM'] = 0.0
                raw[el][f'{orb_str}_CBM'] = 0.0
    
    # Total per band edge
    total_vbm = sum(raw[el][f'{o}_VBM'] for el in raw for o in ['s','p','d'])
    total_cbm = sum(raw[el][f'{o}_CBM'] for el in raw for o in ['s','p','d'])
    
    # Normalized
    for el in raw:
        for o in ['s','p','d']:
            raw[el][f'{o}_VBM_norm'] = raw[el][f'{o}_VBM'] / total_vbm if total_vbm > 0 else 0
            raw[el][f'{o}_CBM_norm'] = raw[el][f'{o}_CBM'] / total_cbm if total_cbm > 0 else 0
        # Element-level (sum of orbitals)
        raw[el]['w_VBM'] = sum(raw[el][f'{o}_VBM_norm'] for o in ['s','p','d'])
        raw[el]['w_CBM'] = sum(raw[el][f'{o}_CBM_norm'] for o in ['s','p','d'])
        # p-orbital fraction for this element
        raw[el]['p_VBM'] = raw[el]['p_VBM_norm']
        raw[el]['p_CBM'] = raw[el]['p_CBM_norm']
    
    return raw

def compute_descriptors(raw):
    """
    From raw contributions dict, compute all descriptor types.
    Returns flat dict of descriptors.
    """
    desc = {}
    elements = list(raw.keys())
    
    # Sort by mass (heaviest first)
    sorted_els = sorted(elements, key=lambda x: raw[x]['mass'], reverse=True)
    heavy1 = sorted_els[0] if len(sorted_els) >= 1 else None
    heavy2 = sorted_els[1] if len(sorted_els) >= 2 else None
    
    # --- Type A: WM_total = sum(w_X * M_X) ---
    desc['A_WM_VBM'] = sum(raw[el]['w_VBM'] * raw[el]['mass'] for el in elements)
    desc['A_WM_CBM'] = sum(raw[el]['w_CBM'] * raw[el]['mass'] for el in elements)
    
    # --- Type B: WM_p_only = sum(p_X * M_X) ---
    desc['B_WMp_VBM'] = sum(raw[el]['p_VBM'] * raw[el]['mass'] for el in elements)
    desc['B_WMp_CBM'] = sum(raw[el]['p_CBM'] * raw[el]['mass'] for el in elements)
    
    # --- Type C: WM_p_frac = p_i*M_i / sum(p_j*M_j) for top 2 heaviest ---
    denom_vbm = sum(raw[el]['p_VBM'] * raw[el]['mass'] for el in elements)
    denom_cbm = sum(raw[el]['p_CBM'] * raw[el]['mass'] for el in elements)
    
    if heavy1:
        desc['C_pfrac_h1_VBM'] = (raw[heavy1]['p_VBM'] * raw[heavy1]['mass']) / denom_vbm if denom_vbm > 0 else 0
        desc['C_pfrac_h1_CBM'] = (raw[heavy1]['p_CBM'] * raw[heavy1]['mass']) / denom_cbm if denom_cbm > 0 else 0
    else:
        desc['C_pfrac_h1_VBM'] = 0
        desc['C_pfrac_h1_CBM'] = 0
    
    if heavy2:
        desc['C_pfrac_h2_VBM'] = (raw[heavy2]['p_VBM'] * raw[heavy2]['mass']) / denom_vbm if denom_vbm > 0 else 0
        desc['C_pfrac_h2_CBM'] = (raw[heavy2]['p_CBM'] * raw[heavy2]['mass']) / denom_cbm if denom_cbm > 0 else 0
    else:
        desc['C_pfrac_h2_VBM'] = 0
        desc['C_pfrac_h2_CBM'] = 0
    
    # --- Type D: WM_indiv = w_heavyN * M_heavyN ---
    if heavy1:
        desc['D_wm_h1_VBM'] = raw[heavy1]['w_VBM'] * raw[heavy1]['mass']
        desc['D_wm_h1_CBM'] = raw[heavy1]['w_CBM'] * raw[heavy1]['mass']
    else:
        desc['D_wm_h1_VBM'] = 0
        desc['D_wm_h1_CBM'] = 0
    
    if heavy2:
        desc['D_wm_h2_VBM'] = raw[heavy2]['w_VBM'] * raw[heavy2]['mass']
        desc['D_wm_h2_CBM'] = raw[heavy2]['w_CBM'] * raw[heavy2]['mass']
    else:
        desc['D_wm_h2_VBM'] = 0
        desc['D_wm_h2_CBM'] = 0
    
    # --- Type E: p_frac = total p-orbital % ---
    desc['E_pfrac_VBM'] = sum(raw[el]['p_VBM'] for el in elements)
    desc['E_pfrac_CBM'] = sum(raw[el]['p_CBM'] for el in elements)
    
    # --- Type F: p_heavy = p_frac_of_heaviest * M_heaviest ---
    if heavy1:
        desc['F_ph1_VBM'] = raw[heavy1]['p_VBM'] * raw[heavy1]['mass']
        desc['F_ph1_CBM'] = raw[heavy1]['p_CBM'] * raw[heavy1]['mass']
    else:
        desc['F_ph1_VBM'] = 0
        desc['F_ph1_CBM'] = 0
    
    if heavy2:
        desc['F_ph2_VBM'] = raw[heavy2]['p_VBM'] * raw[heavy2]['mass']
        desc['F_ph2_CBM'] = raw[heavy2]['p_CBM'] * raw[heavy2]['mass']
    else:
        desc['F_ph2_VBM'] = 0
        desc['F_ph2_CBM'] = 0
    
    # --- Metadata ---
    desc['heavy1_el'] = heavy1 if heavy1 else 'NA'
    desc['heavy2_el'] = heavy2 if heavy2 else 'NA'
    desc['heavy1_mass'] = raw[heavy1]['mass'] if heavy1 else 0
    desc['heavy2_mass'] = raw[heavy2]['mass'] if heavy2 else 0
    desc['n_elements'] = len(elements)
    
    return desc

In [None]:
# ============================================================
# BATCH PROCESS ALL COMPOUNDS
# ============================================================

all_results = {w: [] for w in WINDOWS}
failed = []

for idx, row in target.iterrows():
    uid = row['uid']
    formula = row['Formula']
    alpha_R = row['alpha_R_max']
    
    if uid not in compound_dirs:
        failed.append({'uid': uid, 'formula': formula, 'reason': 'no_folder'})
        continue
    
    vasprun_path = compound_dirs[uid]['vasprun']
    
    print(f"[{idx+1}/{len(target)}] {formula} ({uid[:12]}...)", end=" ")
    
    try:
        for w in WINDOWS:
            raw = extract_contributions(vasprun_path, w)
            desc = compute_descriptors(raw)
            desc['uid'] = uid
            desc['Formula'] = formula
            desc['alpha_R'] = alpha_R
            desc['window'] = w
            all_results[w].append(desc)
        print("OK")
    except Exception as e:
        failed.append({'uid': uid, 'formula': formula, 'reason': str(e)[:80]})
        print(f"FAILED: {str(e)[:60]}")

print(f"\n{'='*50}")
print(f"Processed: {len(all_results[WINDOWS[0]])} compounds")
print(f"Failed: {len(failed)}")
if failed:
    print("Failed compounds:")
    for f in failed[:10]:
        print(f"  {f['formula']} ({f['uid'][:12]}): {f['reason']}")

In [None]:
# ============================================================
# BUILD DATAFRAMES & SAVE CSVs
# ============================================================

# Descriptor type -> column prefixes
type_cols = {
    'A': [c for c in ['A_WM_VBM', 'A_WM_CBM']],
    'B': [c for c in ['B_WMp_VBM', 'B_WMp_CBM']],
    'C': [c for c in ['C_pfrac_h1_VBM', 'C_pfrac_h1_CBM', 'C_pfrac_h2_VBM', 'C_pfrac_h2_CBM']],
    'D': [c for c in ['D_wm_h1_VBM', 'D_wm_h1_CBM', 'D_wm_h2_VBM', 'D_wm_h2_CBM']],
    'E': [c for c in ['E_pfrac_VBM', 'E_pfrac_CBM']],
    'F': [c for c in ['F_ph1_VBM', 'F_ph1_CBM', 'F_ph2_VBM', 'F_ph2_CBM']],
}
meta_cols = ['uid', 'Formula', 'alpha_R', 'heavy1_el', 'heavy2_el', 'heavy1_mass', 'heavy2_mass', 'n_elements']

saved_csvs = []

for w in WINDOWS:
    df = pd.DataFrame(all_results[w])
    w_str = str(w).replace('.', '')
    
    # Individual type CSVs
    for t_name, t_cols in type_cols.items():
        cols = meta_cols + t_cols
        csv_name = f"desc_{t_name}_w{w_str}.csv"
        csv_path = os.path.join(OUTPUT_DIR, csv_name)
        df[cols].to_csv(csv_path, index=False)
        saved_csvs.append(csv_name)
    
    # Combined: all types for this window
    all_feature_cols = meta_cols + [c for cols in type_cols.values() for c in cols]
    csv_name = f"desc_ALL_w{w_str}.csv"
    csv_path = os.path.join(OUTPUT_DIR, csv_name)
    df[all_feature_cols].to_csv(csv_path, index=False)
    saved_csvs.append(csv_name)

# Also: some useful combos across windows
# Combo 1: Type A across all windows
combo_rows = []
for w in WINDOWS:
    df_w = pd.DataFrame(all_results[w])
    w_str = str(w).replace('.', '')
    rename = {c: f"{c}_w{w_str}" for c in ['A_WM_VBM','A_WM_CBM','B_WMp_VBM','B_WMp_CBM',
              'E_pfrac_VBM','E_pfrac_CBM','D_wm_h1_VBM','D_wm_h1_CBM',
              'F_ph1_VBM','F_ph1_CBM']}
    for old, new in rename.items():
        if old in df_w.columns:
            df_w[new] = df_w[old]

# Multi-window combined
dfs_w = {}
for w in WINDOWS:
    w_str = str(w).replace('.', '')
    df_w = pd.DataFrame(all_results[w])
    feature_cols_w = [c for cols in type_cols.values() for c in cols]
    df_w = df_w[['uid', 'Formula', 'alpha_R'] + feature_cols_w]
    df_w = df_w.rename(columns={c: f"{c}_w{w_str}" for c in feature_cols_w})
    dfs_w[w] = df_w

# Merge all windows
df_multi = dfs_w[WINDOWS[0]]
for w in WINDOWS[1:]:
    df_multi = df_multi.merge(dfs_w[w].drop(columns=['Formula', 'alpha_R']), on='uid')

csv_name = "desc_ALL_multiwindow.csv"
csv_path = os.path.join(OUTPUT_DIR, csv_name)
df_multi.to_csv(csv_path, index=False)
saved_csvs.append(csv_name)

print(f"\nSaved {len(saved_csvs)} CSV files to {OUTPUT_DIR}:")
for name in sorted(saved_csvs):
    fpath = os.path.join(OUTPUT_DIR, name)
    size = os.path.getsize(fpath) if os.path.exists(fpath) else 0
    print(f"  {name} ({size/1024:.1f} KB)")
# --- Extra: Type E combo CSVs (best performer) ---
# E with 1.0 eV window
if 1.0 in WINDOWS:
    df_e10 = pd.DataFrame(all_results[1.0])
    e10_path = os.path.join(OUTPUT_DIR, 'desc_E_w10.csv')
    df_e10[meta_cols + ['E_pfrac_VBM', 'E_pfrac_CBM']].to_csv(e10_path, index=False)
    saved_csvs.append('desc_E_w10.csv')

# E with 0.5 + 1.0 combined (4 features)
if 0.5 in WINDOWS and 1.0 in WINDOWS:
    df_e05 = pd.DataFrame(all_results[0.5])[['uid', 'Formula', 'alpha_R', 'E_pfrac_VBM', 'E_pfrac_CBM']]
    df_e05 = df_e05.rename(columns={'E_pfrac_VBM': 'E_pfrac_VBM_w05', 'E_pfrac_CBM': 'E_pfrac_CBM_w05'})
    df_e10_slim = pd.DataFrame(all_results[1.0])[['uid', 'E_pfrac_VBM', 'E_pfrac_CBM']]
    df_e10_slim = df_e10_slim.rename(columns={'E_pfrac_VBM': 'E_pfrac_VBM_w10', 'E_pfrac_CBM': 'E_pfrac_CBM_w10'})
    df_e_combo = df_e05.merge(df_e10_slim, on='uid')
    combo_path = os.path.join(OUTPUT_DIR, 'desc_E_w05_w10.csv')
    df_e_combo.to_csv(combo_path, index=False)
    saved_csvs.append('desc_E_w05_w10.csv')
    print(f'\nE combo (0.5+1.0) shape: {df_e_combo.shape}')
    print(df_e_combo.head())

# ============================================================
# ENHANCED E VARIANTS: p_frac + elemental properties
# ============================================================
# Key elemental properties for SOC/Rashba:
# - Z^4: SOC scales as Z^4 (THE most physically relevant)
# - atomic_mass: heavier = more SOC
# - atomic_radius: larger atoms = more diffuse orbitals
# - electronegativity (X): controls charge transfer / dipole
# - ionization_energy: relates to orbital energy levels
#
# For each compound, we compute:
# - max_Z4: Z^4 of heaviest element
# - weighted_Z4: sum(w_X * Z_X^4) at VBM/CBM
# - max_mass: mass of heaviest element
# - WM_VBM/CBM: sum(w_X * M_X)
# - electronegativity_diff: max(X) - min(X) across elements
# - weighted_radius: sum(w_X * r_X)

from pymatgen.core.periodic_table import Element as Elem

def compute_elemental_features(raw_contrib):
    """Compute elemental property descriptors from raw contributions."""
    feat = {}
    elements = list(raw_contrib.keys())
    
    # Gather properties
    props = {}
    for el in elements:
        e = Elem(el)
        props[el] = {
            'Z': e.Z,
            'Z4': e.Z ** 4,
            'mass': float(e.atomic_mass),
            'radius': float(e.atomic_radius) if e.atomic_radius else 1.0,
            'X': float(e.X) if e.X else 2.0,
            'IE': float(e.ionization_energy) if e.ionization_energy else 8.0,
            'vdw': float(e.van_der_waals_radius) if e.van_der_waals_radius else 1.5,
        }
    
    # --- Raw elemental (no DOS weighting) ---
    masses = [props[el]['mass'] for el in elements]
    feat['max_mass'] = max(masses)
    feat['max_Z'] = max(props[el]['Z'] for el in elements)
    feat['max_Z4'] = max(props[el]['Z4'] for el in elements)
    Xs = [props[el]['X'] for el in elements]
    feat['X_diff'] = max(Xs) - min(Xs)  # electronegativity difference
    feat['X_mean'] = np.mean(Xs)
    radii = [props[el]['radius'] for el in elements]
    feat['radius_diff'] = max(radii) - min(radii)
    feat['radius_mean'] = np.mean(radii)
    
    # --- DOS-weighted elemental ---
    for band in ['VBM', 'CBM']:
        w_key = f'w_{band}'
        p_key = f'p_{band}'
        
        # Weighted mass (WM)
        feat[f'WM_{band}'] = sum(raw_contrib[el][w_key] * props[el]['mass'] for el in elements)
        
        # Weighted Z^4 (THE key one for SOC)
        feat[f'WZ4_{band}'] = sum(raw_contrib[el][w_key] * props[el]['Z4'] for el in elements)
        
        # p-weighted Z^4
        feat[f'pZ4_{band}'] = sum(raw_contrib[el][p_key] * props[el]['Z4'] for el in elements)
        
        # Weighted radius
        feat[f'Wr_{band}'] = sum(raw_contrib[el][w_key] * props[el]['radius'] for el in elements)
        
        # Weighted electronegativity
        feat[f'WX_{band}'] = sum(raw_contrib[el][w_key] * props[el]['X'] for el in elements)
    
    return feat

# Recompute for 0.5 and 1.0 windows with elemental features
print('\nComputing enhanced E variants with elemental properties...')

for w, w_str in [(0.5, 'w05'), (1.0, 'w10')]:
    if w not in all_results or not all_results[w]:
        continue
    enhanced_rows = []
    for entry in all_results[w]:
        uid = entry['uid']
        if uid not in compound_dirs:
            continue
        try:
            raw = extract_contributions(compound_dirs[uid]['vasprun'], w)
            elem_feat = compute_elemental_features(raw)
            row = {
                'uid': uid,
                'Formula': entry['Formula'],
                'alpha_R': entry['alpha_R'],
                'E_pfrac_VBM': entry['E_pfrac_VBM'],
                'E_pfrac_CBM': entry['E_pfrac_CBM'],
            }
            row.update(elem_feat)
            enhanced_rows.append(row)
        except:
            pass
    
    df_enh = pd.DataFrame(enhanced_rows)
    csv_name = f'desc_E_enhanced_{w_str}.csv'
    df_enh.to_csv(os.path.join(OUTPUT_DIR, csv_name), index=False)
    saved_csvs.append(csv_name)
    print(f'  {csv_name}: {df_enh.shape}')

# Combined 0.5 + 1.0 enhanced
if 0.5 in all_results and 1.0 in all_results:
    df_05 = pd.read_csv(os.path.join(OUTPUT_DIR, 'desc_E_enhanced_w05.csv'))
    df_10 = pd.read_csv(os.path.join(OUTPUT_DIR, 'desc_E_enhanced_w10.csv'))
    # Rename to avoid collision
    rename_05 = {c: f'{c}_w05' for c in df_05.columns if c not in ['uid','Formula','alpha_R','max_mass','max_Z','max_Z4','X_diff','X_mean','radius_diff','radius_mean']}
    rename_10 = {c: f'{c}_w10' for c in df_10.columns if c not in ['uid','Formula','alpha_R','max_mass','max_Z','max_Z4','X_diff','X_mean','radius_diff','radius_mean']}
    df_05r = df_05.rename(columns=rename_05)
    df_10r = df_10.rename(columns=rename_10)
    # Merge
    shared = ['uid','Formula','alpha_R','max_mass','max_Z','max_Z4','X_diff','X_mean','radius_diff','radius_mean']
    df_combo = df_05r.merge(df_10r.drop(columns=[c for c in shared if c in df_10r.columns and c != 'uid']), on='uid')
    csv_name = 'desc_E_enhanced_w05_w10.csv'
    df_combo.to_csv(os.path.join(OUTPUT_DIR, csv_name), index=False)
    saved_csvs.append(csv_name)
    print(f'  {csv_name}: {df_combo.shape}')

print(f'\nTotal CSVs saved: {len(saved_csvs)}')
for name in sorted(saved_csvs):
    print(f'  {name}')


In [None]:
# ============================================================
# QUICK SANITY CHECK
# ============================================================
df_check = pd.read_csv(os.path.join(OUTPUT_DIR, f"desc_ALL_w005.csv"))
print(f"Shape: {df_check.shape}")
print(f"\nFirst 5 rows:")
print(df_check.head())
print(f"\nDescriptor stats:")
feature_cols = [c for c in df_check.columns if c.startswith(('A_','B_','C_','D_','E_','F_'))]
print(df_check[feature_cols].describe().round(3))
print(f"\nalpha_R range: [{df_check['alpha_R'].min():.3f}, {df_check['alpha_R'].max():.3f}]")