# Notebook 1: Descriptor Extraction
Extracts orbital-weight-based descriptors from `vasprun.xml` for all Rashba compounds.

**Descriptor Types:**
- **A** `WM_total`: sum(w_X * M_X) at VBM/CBM
- **B** `WM_p_only`: sum(p_X * M_X) using only p-orbital contributions
- **C** `WM_p_frac`: p_i*M_i / sum(p_j*M_j) for top 2 heaviest elements
- **D** `WM_indiv`: w_heavy1 * M_heavy1, w_heavy2 * M_heavy2 individually
- **E** `p_frac`: total p-orbital fraction
- **F** `p_heavy`: p_frac_of_heaviest * M_heaviest for top 2

**Windows:** 0.05, 0.1, 0.5 eV

**Target:** max(Rashba_parameter) per UID from rashba.csv

**Output CSVs:** one per (type, window) combo + combined CSVs

In [1]:
import numpy as np
import pandas as pd
import os
import glob
import warnings
from pathlib import Path
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.core import Spin, OrbitalType
from pymatgen.core.periodic_table import Element

warnings.filterwarnings('ignore')

In [2]:
# ============================================================
# PATHS - ADJUST THESE
# ============================================================
BASE_DIR = r"C:\Users\AbCMS_Lab\Desktop\Keshav-DDP"
RASHBA_CSV = os.path.join(BASE_DIR, "Data", "rashba.csv")
RASHBA_DIR = os.path.join(BASE_DIR, "Inverse-design", "rashba")
OUTPUT_DIR = os.path.join(BASE_DIR, "Weight-contribution", "contribution-model")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Windows
WINDOWS = [0.05, 0.1, 0.5]

print(f"Rashba CSV: {RASHBA_CSV}")
print(f"Rashba compounds dir: {RASHBA_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

Rashba CSV: C:\Users\AbCMS_Lab\Desktop\Keshav-DDP\Data\rashba.csv
Rashba compounds dir: C:\Users\AbCMS_Lab\Desktop\Keshav-DDP\Inverse-design\rashba
Output dir: C:\Users\AbCMS_Lab\Desktop\Keshav-DDP\Weight-contribution\contribution-model


In [3]:
# ============================================================
# LOAD RASHBA CSV & GET TARGET (max alpha_R per UID)
# ============================================================
df_rashba = pd.read_csv(RASHBA_CSV)
print(f"Total rows in rashba.csv: {len(df_rashba)}")
print(f"Unique UIDs: {df_rashba['uid'].nunique()}")
print(f"Columns: {list(df_rashba.columns)}")

# Max Rashba parameter per UID
target = df_rashba.groupby('uid')['Rashba_parameter'].max().reset_index()
target.columns = ['uid', 'alpha_R_max']

# Also keep formula for reference
uid_formula = df_rashba.groupby('uid')['Formula'].first().reset_index()
target = target.merge(uid_formula, on='uid')

print(f"\nTarget: {len(target)} compounds with max alpha_R")
print(target.head())

Total rows in rashba.csv: 205
Unique UIDs: 99
Columns: ['Formula', 'uid', 'spacegroup', 'ehull', 'bandgap', 'band', 'kpath', 'Rashba_parameter', 'SS', 'dE', 'anticrossing']

Target: 99 compounds with max alpha_R
            uid  alpha_R_max  Formula
0  001e03f2c095        3.288     SSeW
1  03bcf7dcdaf2        4.804   Sn2Te2
2  04fdd7d1ec5c        1.018   ClSbTe
3  05a06afa3b20        1.643  WMo3Se8
4  0b7696e1f4c9        1.756  CrW3Se8


In [4]:
# ============================================================
# FIND VASPRUN FILES FOR EACH UID
# ============================================================
compound_dirs = {}
for folder in os.listdir(RASHBA_DIR):
    folder_path = os.path.join(RASHBA_DIR, folder)
    if not os.path.isdir(folder_path):
        continue
    # Extract UID from folder name (e.g., ISbSe-343d2125478e -> 343d2125478e)
    parts = folder.split('-')
    if len(parts) >= 2:
        uid_candidate = parts[-1]
        # Find vasprun.xml
        vasprun_pattern = os.path.join(folder_path, "**", "vasprun.xml")
        vasprun_files = glob.glob(vasprun_pattern, recursive=True)
        if vasprun_files:
            compound_dirs[uid_candidate] = {
                'folder': folder,
                'vasprun': vasprun_files[0]
            }

print(f"Found {len(compound_dirs)} compound directories with vasprun.xml")

# Match with target UIDs
# UIDs in CSV might be partial matches with folder names
matched = 0
unmatched_uids = []
for _, row in target.iterrows():
    uid = row['uid']
    # Try direct match first
    if uid in compound_dirs:
        matched += 1
    else:
        # Try partial match (uid might be substring of folder uid)
        found = False
        for folder_uid in compound_dirs:
            if uid.startswith(folder_uid) or folder_uid.startswith(uid):
                compound_dirs[uid] = compound_dirs[folder_uid]
                matched += 1
                found = True
                break
        if not found:
            unmatched_uids.append(uid)

print(f"Matched: {matched}/{len(target)}")
if unmatched_uids:
    print(f"Unmatched UIDs (first 10): {unmatched_uids[:10]}")

Found 0 compound directories with vasprun.xml
Matched: 0/99
Unmatched UIDs (first 10): ['001e03f2c095', '03bcf7dcdaf2', '04fdd7d1ec5c', '05a06afa3b20', '0b7696e1f4c9', '0c0fbdaf8f4a', '0f02957b17cf', '114b3382699c', '11db0908d9ef', '159f028a85d0']


In [5]:
# ============================================================
# CORE FUNCTIONS
# ============================================================

def get_dos_array(dos_obj):
    """Extract DOS values, summing spin channels for SOC."""
    if Spin.down in dos_obj.densities:
        return dos_obj.densities[Spin.up] + dos_obj.densities[Spin.down]
    return dos_obj.densities[Spin.up]

def integrate_window(energies, dos_vals, window):
    """Integrate DOS in energy window using trapezoidal rule."""
    mask = (energies >= window[0]) & (energies <= window[1])
    e_w = energies[mask]
    d_w = dos_vals[mask]
    if len(e_w) < 2:
        return 0.0
    return np.trapezoid(d_w, e_w)

def extract_contributions(vasprun_path, window_size):
    """
    Parse vasprun.xml and extract orbital/atomic contributions at VBM/CBM.
    Returns dict with per-element per-orbital contributions (normalized).
    """
    vr = Vasprun(vasprun_path, parse_dos=True, parse_eigen=False)
    cdos = vr.complete_dos
    e_fermi = vr.efermi
    
    vbm = cdos.get_cbm_vbm()[1]
    cbm = cdos.get_cbm_vbm()[0]
    
    vbm_s = vbm - e_fermi
    cbm_s = cbm - e_fermi
    energies = cdos.energies - e_fermi
    
    vbm_win = [vbm_s - window_size, vbm_s]
    cbm_win = [cbm_s, cbm_s + window_size]
    
    element_dos = cdos.get_element_dos()
    orbital_map = {'s': OrbitalType.s, 'p': OrbitalType.p, 'd': OrbitalType.d}
    
    # Raw contributions
    raw = {}
    for element in element_dos:
        el = str(element)
        mass = Element(el).atomic_mass
        Z = Element(el).Z
        raw[el] = {'mass': float(mass), 'Z': Z}
        
        spd = cdos.get_element_spd_dos(element)
        for orb_str in ['s', 'p', 'd']:
            orb_type = orbital_map[orb_str]
            if orb_type in spd:
                dos_vals = get_dos_array(spd[orb_type])
                raw[el][f'{orb_str}_VBM'] = integrate_window(energies, dos_vals, vbm_win)
                raw[el][f'{orb_str}_CBM'] = integrate_window(energies, dos_vals, cbm_win)
            else:
                raw[el][f'{orb_str}_VBM'] = 0.0
                raw[el][f'{orb_str}_CBM'] = 0.0
    
    # Total per band edge
    total_vbm = sum(raw[el][f'{o}_VBM'] for el in raw for o in ['s','p','d'])
    total_cbm = sum(raw[el][f'{o}_CBM'] for el in raw for o in ['s','p','d'])
    
    # Normalized
    for el in raw:
        for o in ['s','p','d']:
            raw[el][f'{o}_VBM_norm'] = raw[el][f'{o}_VBM'] / total_vbm if total_vbm > 0 else 0
            raw[el][f'{o}_CBM_norm'] = raw[el][f'{o}_CBM'] / total_cbm if total_cbm > 0 else 0
        # Element-level (sum of orbitals)
        raw[el]['w_VBM'] = sum(raw[el][f'{o}_VBM_norm'] for o in ['s','p','d'])
        raw[el]['w_CBM'] = sum(raw[el][f'{o}_CBM_norm'] for o in ['s','p','d'])
        # p-orbital fraction for this element
        raw[el]['p_VBM'] = raw[el]['p_VBM_norm']
        raw[el]['p_CBM'] = raw[el]['p_CBM_norm']
    
    return raw

def compute_descriptors(raw):
    """
    From raw contributions dict, compute all descriptor types.
    Returns flat dict of descriptors.
    """
    desc = {}
    elements = list(raw.keys())
    
    # Sort by mass (heaviest first)
    sorted_els = sorted(elements, key=lambda x: raw[x]['mass'], reverse=True)
    heavy1 = sorted_els[0] if len(sorted_els) >= 1 else None
    heavy2 = sorted_els[1] if len(sorted_els) >= 2 else None
    
    # --- Type A: WM_total = sum(w_X * M_X) ---
    desc['A_WM_VBM'] = sum(raw[el]['w_VBM'] * raw[el]['mass'] for el in elements)
    desc['A_WM_CBM'] = sum(raw[el]['w_CBM'] * raw[el]['mass'] for el in elements)
    
    # --- Type B: WM_p_only = sum(p_X * M_X) ---
    desc['B_WMp_VBM'] = sum(raw[el]['p_VBM'] * raw[el]['mass'] for el in elements)
    desc['B_WMp_CBM'] = sum(raw[el]['p_CBM'] * raw[el]['mass'] for el in elements)
    
    # --- Type C: WM_p_frac = p_i*M_i / sum(p_j*M_j) for top 2 heaviest ---
    denom_vbm = sum(raw[el]['p_VBM'] * raw[el]['mass'] for el in elements)
    denom_cbm = sum(raw[el]['p_CBM'] * raw[el]['mass'] for el in elements)
    
    if heavy1:
        desc['C_pfrac_h1_VBM'] = (raw[heavy1]['p_VBM'] * raw[heavy1]['mass']) / denom_vbm if denom_vbm > 0 else 0
        desc['C_pfrac_h1_CBM'] = (raw[heavy1]['p_CBM'] * raw[heavy1]['mass']) / denom_cbm if denom_cbm > 0 else 0
    else:
        desc['C_pfrac_h1_VBM'] = 0
        desc['C_pfrac_h1_CBM'] = 0
    
    if heavy2:
        desc['C_pfrac_h2_VBM'] = (raw[heavy2]['p_VBM'] * raw[heavy2]['mass']) / denom_vbm if denom_vbm > 0 else 0
        desc['C_pfrac_h2_CBM'] = (raw[heavy2]['p_CBM'] * raw[heavy2]['mass']) / denom_cbm if denom_cbm > 0 else 0
    else:
        desc['C_pfrac_h2_VBM'] = 0
        desc['C_pfrac_h2_CBM'] = 0
    
    # --- Type D: WM_indiv = w_heavyN * M_heavyN ---
    if heavy1:
        desc['D_wm_h1_VBM'] = raw[heavy1]['w_VBM'] * raw[heavy1]['mass']
        desc['D_wm_h1_CBM'] = raw[heavy1]['w_CBM'] * raw[heavy1]['mass']
    else:
        desc['D_wm_h1_VBM'] = 0
        desc['D_wm_h1_CBM'] = 0
    
    if heavy2:
        desc['D_wm_h2_VBM'] = raw[heavy2]['w_VBM'] * raw[heavy2]['mass']
        desc['D_wm_h2_CBM'] = raw[heavy2]['w_CBM'] * raw[heavy2]['mass']
    else:
        desc['D_wm_h2_VBM'] = 0
        desc['D_wm_h2_CBM'] = 0
    
    # --- Type E: p_frac = total p-orbital % ---
    desc['E_pfrac_VBM'] = sum(raw[el]['p_VBM'] for el in elements)
    desc['E_pfrac_CBM'] = sum(raw[el]['p_CBM'] for el in elements)
    
    # --- Type F: p_heavy = p_frac_of_heaviest * M_heaviest ---
    if heavy1:
        desc['F_ph1_VBM'] = raw[heavy1]['p_VBM'] * raw[heavy1]['mass']
        desc['F_ph1_CBM'] = raw[heavy1]['p_CBM'] * raw[heavy1]['mass']
    else:
        desc['F_ph1_VBM'] = 0
        desc['F_ph1_CBM'] = 0
    
    if heavy2:
        desc['F_ph2_VBM'] = raw[heavy2]['p_VBM'] * raw[heavy2]['mass']
        desc['F_ph2_CBM'] = raw[heavy2]['p_CBM'] * raw[heavy2]['mass']
    else:
        desc['F_ph2_VBM'] = 0
        desc['F_ph2_CBM'] = 0
    
    # --- Metadata ---
    desc['heavy1_el'] = heavy1 if heavy1 else 'NA'
    desc['heavy2_el'] = heavy2 if heavy2 else 'NA'
    desc['heavy1_mass'] = raw[heavy1]['mass'] if heavy1 else 0
    desc['heavy2_mass'] = raw[heavy2]['mass'] if heavy2 else 0
    desc['n_elements'] = len(elements)
    
    return desc

In [6]:
# ============================================================
# BATCH PROCESS ALL COMPOUNDS
# ============================================================

all_results = {w: [] for w in WINDOWS}
failed = []

for idx, row in target.iterrows():
    uid = row['uid']
    formula = row['Formula']
    alpha_R = row['alpha_R_max']
    
    if uid not in compound_dirs:
        failed.append({'uid': uid, 'formula': formula, 'reason': 'no_folder'})
        continue
    
    vasprun_path = compound_dirs[uid]['vasprun']
    
    print(f"[{idx+1}/{len(target)}] {formula} ({uid[:12]}...)", end=" ")
    
    try:
        for w in WINDOWS:
            raw = extract_contributions(vasprun_path, w)
            desc = compute_descriptors(raw)
            desc['uid'] = uid
            desc['Formula'] = formula
            desc['alpha_R'] = alpha_R
            desc['window'] = w
            all_results[w].append(desc)
        print("OK")
    except Exception as e:
        failed.append({'uid': uid, 'formula': formula, 'reason': str(e)[:80]})
        print(f"FAILED: {str(e)[:60]}")

print(f"\n{'='*50}")
print(f"Processed: {len(all_results[WINDOWS[0]])} compounds")
print(f"Failed: {len(failed)}")
if failed:
    print("Failed compounds:")
    for f in failed[:10]:
        print(f"  {f['formula']} ({f['uid'][:12]}): {f['reason']}")


Processed: 0 compounds
Failed: 99
Failed compounds:
  SSeW (001e03f2c095): no_folder
  Sn2Te2 (03bcf7dcdaf2): no_folder
  ClSbTe (04fdd7d1ec5c): no_folder
  WMo3Se8 (05a06afa3b20): no_folder
  CrW3Se8 (0b7696e1f4c9): no_folder
  ClSbSe (0c0fbdaf8f4a): no_folder
  ISbTe (0f02957b17cf): no_folder
  AsITe (114b3382699c): no_folder
  BiBrSe (11db0908d9ef): no_folder
  CrMo3Te8 (159f028a85d0): no_folder


In [7]:
# ============================================================
# BUILD DATAFRAMES & SAVE CSVs
# ============================================================

# Descriptor type -> column prefixes
type_cols = {
    'A': [c for c in ['A_WM_VBM', 'A_WM_CBM']],
    'B': [c for c in ['B_WMp_VBM', 'B_WMp_CBM']],
    'C': [c for c in ['C_pfrac_h1_VBM', 'C_pfrac_h1_CBM', 'C_pfrac_h2_VBM', 'C_pfrac_h2_CBM']],
    'D': [c for c in ['D_wm_h1_VBM', 'D_wm_h1_CBM', 'D_wm_h2_VBM', 'D_wm_h2_CBM']],
    'E': [c for c in ['E_pfrac_VBM', 'E_pfrac_CBM']],
    'F': [c for c in ['F_ph1_VBM', 'F_ph1_CBM', 'F_ph2_VBM', 'F_ph2_CBM']],
}
meta_cols = ['uid', 'Formula', 'alpha_R', 'heavy1_el', 'heavy2_el', 'heavy1_mass', 'heavy2_mass', 'n_elements']

saved_csvs = []

for w in WINDOWS:
    df = pd.DataFrame(all_results[w])
    w_str = str(w).replace('.', '')
    
    # Individual type CSVs
    for t_name, t_cols in type_cols.items():
        cols = meta_cols + t_cols
        csv_name = f"desc_{t_name}_w{w_str}.csv"
        csv_path = os.path.join(OUTPUT_DIR, csv_name)
        df[cols].to_csv(csv_path, index=False)
        saved_csvs.append(csv_name)
    
    # Combined: all types for this window
    all_feature_cols = meta_cols + [c for cols in type_cols.values() for c in cols]
    csv_name = f"desc_ALL_w{w_str}.csv"
    csv_path = os.path.join(OUTPUT_DIR, csv_name)
    df[all_feature_cols].to_csv(csv_path, index=False)
    saved_csvs.append(csv_name)

# Also: some useful combos across windows
# Combo 1: Type A across all windows
combo_rows = []
for w in WINDOWS:
    df_w = pd.DataFrame(all_results[w])
    w_str = str(w).replace('.', '')
    rename = {c: f"{c}_w{w_str}" for c in ['A_WM_VBM','A_WM_CBM','B_WMp_VBM','B_WMp_CBM',
              'E_pfrac_VBM','E_pfrac_CBM','D_wm_h1_VBM','D_wm_h1_CBM',
              'F_ph1_VBM','F_ph1_CBM']}
    for old, new in rename.items():
        if old in df_w.columns:
            df_w[new] = df_w[old]

# Multi-window combined
dfs_w = {}
for w in WINDOWS:
    w_str = str(w).replace('.', '')
    df_w = pd.DataFrame(all_results[w])
    feature_cols_w = [c for cols in type_cols.values() for c in cols]
    df_w = df_w[['uid', 'Formula', 'alpha_R'] + feature_cols_w]
    df_w = df_w.rename(columns={c: f"{c}_w{w_str}" for c in feature_cols_w})
    dfs_w[w] = df_w

# Merge all windows
df_multi = dfs_w[WINDOWS[0]]
for w in WINDOWS[1:]:
    df_multi = df_multi.merge(dfs_w[w].drop(columns=['Formula', 'alpha_R']), on='uid')

csv_name = "desc_ALL_multiwindow.csv"
csv_path = os.path.join(OUTPUT_DIR, csv_name)
df_multi.to_csv(csv_path, index=False)
saved_csvs.append(csv_name)

print(f"\nSaved {len(saved_csvs)} CSV files to {OUTPUT_DIR}:")
for name in sorted(saved_csvs):
    fpath = os.path.join(OUTPUT_DIR, name)
    size = os.path.getsize(fpath) if os.path.exists(fpath) else 0
    print(f"  {name} ({size/1024:.1f} KB)")

KeyError: "None of [Index(['uid', 'Formula', 'alpha_R', 'heavy1_el', 'heavy2_el', 'heavy1_mass',\n       'heavy2_mass', 'n_elements', 'A_WM_VBM', 'A_WM_CBM'],\n      dtype='str')] are in the [columns]"

In [None]:
# ============================================================
# QUICK SANITY CHECK
# ============================================================
df_check = pd.read_csv(os.path.join(OUTPUT_DIR, f"desc_ALL_w005.csv"))
print(f"Shape: {df_check.shape}")
print(f"\nFirst 5 rows:")
print(df_check.head())
print(f"\nDescriptor stats:")
feature_cols = [c for c in df_check.columns if c.startswith(('A_','B_','C_','D_','E_','F_'))]
print(df_check[feature_cols].describe().round(3))
print(f"\nalpha_R range: [{df_check['alpha_R'].min():.3f}, {df_check['alpha_R'].max():.3f}]")