# Fe₂O₃ (Hematite) Magnetic Oxide Study with SHALOM

This notebook demonstrates SHALOM's capabilities for **magnetic transition metal
oxides** — a challenging class of materials requiring:
- Spin-polarized DFT (nspin=2)
- GGA+U correction for Fe 3d electrons
- High ecutwfc (90 Ry for Fe from SSSP)
- Careful treatment of starting magnetization

## Material: α-Fe₂O₃ (Hematite)
- **Crystal structure**: Corundum (R̅c, #167, trigonal)
- **Lattice**: a = 5.0356 Å, c = 13.7489 Å (hexagonal setting)
- **Atoms**: 10 in primitive cell (4 Fe + 6 O), 30 in conventional cell
- **Magnetic**: Antiferromagnetic below Tₙ = 955 K (we use ferromagnetic init)
- **Electronic**: Charge-transfer insulator, gap ~2.2 eV (exp)

## What We Will Compute

| Step | Analysis | Time Estimate |
|------|----------|---------------|
| 1 | Crystal symmetry | < 1 sec |
| 2 | XRD pattern | < 1 sec |
| 3 | SHALOM auto-detection inspection | < 1 sec |
| 4 | Full workflow (PRECISE, GGA+U) | ~30–45 min |
| 5 | Spin-resolved band structure | < 1 sec |
| 6 | Spin-resolved DOS | < 1 sec |
| 7 | Combined band+DOS | < 1 sec |
| 8 | Electronic analysis | < 1 sec |
| 9 | Magnetic analysis | < 1 sec |

**Total: ~45–60 minutes**

## Prerequisites

- Quantum ESPRESSO (pw.x, dos.x) in WSL
- SSSP pseudopotentials: `python -m shalom setup-qe --elements Fe,O --download`

In [None]:
import os
import time
import warnings
from collections import Counter, defaultdict

import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

warnings.filterwarnings('ignore', category=DeprecationWarning)

# === Configuration ===
PSEUDO_DIR = r"C:\Users\Sejong\pseudopotentials"  # path to SSSP UPF files
OUTPUT_ROOT = os.path.expanduser("~/Desktop/shalom-tutorials/fe2o3_study")
os.makedirs(OUTPUT_ROOT, exist_ok=True)

WSL = True
NPROCS = 2

print(f"Pseudo dir:      {PSEUDO_DIR}")
print(f"Output directory: {OUTPUT_ROOT}")
print(f"WSL mode: {WSL}")
print(f"MPI processes: {NPROCS}")

## 1. Structure Creation

We build Fe₂O₃ in the corundum (R̅c) structure using ASE's `crystal()` function.

- Fe at Wyckoff 12c: (0, 0, 0.3553)
- O at Wyckoff 18e: (0.3059, 0, 0.25)
- Hexagonal cell: a = 5.0356 Å, c = 13.7489 Å

This creates 30 atoms in the conventional hexagonal cell.
SHALOM's workflow will automatically convert to the 10-atom
primitive rhombohedral cell via seekpath.

In [None]:
from ase.spacegroup import crystal

fe2o3 = crystal(
    symbols=['Fe', 'O'],
    basis=[(0, 0, 0.3553), (0.3059, 0, 0.25)],
    spacegroup=167,  # R-3c
    cellpar=[5.0356, 5.0356, 13.7489, 90, 90, 120],
)

print(f"Formula:       {fe2o3.get_chemical_formula()}")
print(f"Atoms:         {len(fe2o3)}")
print(f"Composition:   {dict(Counter(fe2o3.get_chemical_symbols()))}")
print(f"Cell (\u00c5):")
for i, v in enumerate(fe2o3.cell):
    print(f"  a{i+1} = [{v[0]:10.4f}, {v[1]:10.4f}, {v[2]:10.4f}]")
print(f"Volume:        {fe2o3.get_volume():.2f} \u00c5\u00b3")
print(f"Volume/atom:   {fe2o3.get_volume() / len(fe2o3):.2f} \u00c5\u00b3")

## 2. Symmetry Analysis

In [None]:
from shalom.analysis import analyze_symmetry, is_spglib_available

if is_spglib_available():
    sym = analyze_symmetry(fe2o3)
    print("=== Crystal Symmetry (Fe\u2082O\u2083 Conventional Cell) ===")
    print(f"Space group:    {sym.space_group_symbol} (#{sym.space_group_number})")
    print(f"Crystal system: {sym.crystal_system}")
    print(f"Point group:    {sym.point_group}")
    print(f"Lattice type:   {sym.lattice_type}")
    print(f"Is primitive:   {sym.is_primitive}")
    print(f"Symmetry ops:   {sym.n_operations}")

    # Wyckoff analysis per species
    wyckoff_counts = {}
    for wl, elem in zip(sym.wyckoff_letters, fe2o3.get_chemical_symbols()):
        key = f"{elem}({wl})"
        wyckoff_counts[key] = wyckoff_counts.get(key, 0) + 1
    print(f"\nWyckoff positions:")
    for site, count in sorted(wyckoff_counts.items()):
        print(f"  {site}: {count} atoms")
else:
    print("spglib not installed. Run: pip install spglib")

## 3. X-ray Diffraction Pattern

In [None]:
from shalom.analysis import calculate_xrd, is_xrd_available
from shalom.plotting import XRDPlotter

if is_xrd_available():
    xrd_result = calculate_xrd(fe2o3, wavelength="CuKa", two_theta_range=(10, 80))

    print(f"Number of peaks: {xrd_result.n_peaks}")
    print(f"\nTop 8 peaks:")
    print(f"{'2\u03b8':>8}  {'I':>6}  {'(h k l)':>10}  {'d (\u00c5)':>8}")
    print("-" * 40)
    order = np.argsort(xrd_result.intensities)[::-1]
    for i in order[:8]:
        hkl = xrd_result.hkl_indices[i]
        print(f"{xrd_result.two_theta[i]:8.2f}  {xrd_result.intensities[i]:6.1f}  "
              f"({hkl[0]:2d} {hkl[1]:2d} {hkl[2]:2d})  {xrd_result.d_spacings[i]:8.4f}")

    xrd_path = os.path.join(OUTPUT_ROOT, "fe2o3_xrd.png")
    fig = XRDPlotter(xrd_result).plot(
        output_path=xrd_path,
        title="\u03b1-Fe\u2082O\u2083 (Hematite) \u2014 Simulated XRD (Cu K\u03b1)",
    )
    plt.show()
    print(f"\nSaved: {xrd_path}")
else:
    print("pymatgen not installed. Run: pip install pymatgen")

## 4. SHALOM Auto-Detection Inspection

Before running the workflow, let's examine what SHALOM auto-detects for Fe₂O₃.
This demonstrates the structure-aware configuration system:
- `nspin=2` (Fe is magnetic)
- `starting_magnetization` from MAGMOM / z_valence
- `lda_plus_u` with Hubbard U for Fe 3d (PRECISE accuracy)
- `ecutwfc=90 Ry` (from SSSP Fe)

In [None]:
from shalom.backends.qe_config import (
    get_qe_preset, QECalculationType, SSSP_ELEMENTS,
)
from shalom.backends._physics import (
    AccuracyLevel, DEFAULT_MAGMOM, HUBBARD_U_VALUES,
)

# SSSP metadata for Fe and O
print("=== SSSP Metadata ===")
for el in ['Fe', 'O']:
    entry = SSSP_ELEMENTS[el]
    print(f"  {el}: ecutwfc={entry['ecutwfc']} Ry, ecutrho={entry['ecutrho']} Ry, "
          f"z_valence={entry['z_valence']}, pseudo={entry['pseudo']}")

# Magnetic data
print(f"\n=== Magnetic Settings ===")
print(f"  Fe MAGMOM:   {DEFAULT_MAGMOM.get('Fe', 'N/A')} Bohr mag")
print(f"  Fe Hubbard U: {HUBBARD_U_VALUES.get('Fe', 'N/A')}")

# Get PRECISE preset (enables GGA+U)
config = get_qe_preset(QECalculationType.SCF, accuracy=AccuracyLevel.PRECISE, atoms=fe2o3)

print(f"\n=== Auto-Detected QE Config (PRECISE) ===")
print(f"  ecutwfc:      {config.system.get('ecutwfc')} Ry")
print(f"  ecutrho:      {config.system.get('ecutrho')} Ry")
print(f"  nspin:        {config.system.get('nspin')}")
print(f"  lda_plus_u:   {config.system.get('lda_plus_u')}")

for key, val in sorted(config.system.items()):
    if 'starting_magnetization' in key or 'Hubbard_U' in key:
        print(f"  {key}: {val}")

print(f"\nNote: SHALOM auto-converts lda_plus_u/Hubbard_U to the QE 7.1+")
print(f"HUBBARD (ortho-atomic) card syntax when writing pw.x input files.")

## 5. Full Workflow (PRECISE Accuracy)

We run the 5-step workflow at **PRECISE** accuracy to enable GGA+U:

1. **vc-relax** — Cell + position optimization with spin polarization
2. **scf** — Converged ground state with GGA+U
3. **bands** — Spin-resolved band structure
4. **nscf** — Dense k-mesh for DOS
5. **dos.x** — Spin-resolved density of states

**Estimated time: ~30–45 minutes**

> **Tip:** Set `SKIP_RELAX = True` below to use the experimental structure
> directly (saves ~15–20 min). This is reasonable when lattice parameters
> are well-known from experiment.

In [None]:
from shalom.workflows import StandardWorkflow

wf_dir = os.path.join(OUTPUT_ROOT, "workflow")

SKIP_RELAX = True  # Set False to run vc-relax (~15+ min extra)

print(f"Starting 5-step QE workflow for Fe\u2082O\u2083 (PRECISE, GGA+U)...")
print(f"Skip vc-relax: {SKIP_RELAX}")
print(f"Output: {wf_dir}\n")

t0 = time.time()
wf = StandardWorkflow(
    atoms=fe2o3,
    output_dir=wf_dir,
    nprocs=NPROCS,
    wsl=WSL,
    pseudo_dir=PSEUDO_DIR,
    accuracy="precise",
    skip_relax=SKIP_RELAX,
    timeout=7200,
    dos_emin=-15.0,
    dos_emax=10.0,
)
result = wf.run()
elapsed = time.time() - t0

print(f"\n{'='*60}")
print(f"Workflow completed in {elapsed:.1f} s ({elapsed/60:.1f} min)")
print(f"{'='*60}")

for step in result["step_results"]:
    status = "OK" if step.success else "FAIL"
    t_str = f"({step.elapsed_seconds:.1f}s)" if step.elapsed_seconds > 0 else ""
    summ = f" \u2014 {step.summary}" if step.summary else ""
    print(f"  [{step.step_number}/5] {step.name:12s}  {status}  {t_str}{summ}")

fermi = result["fermi_energy"]
print(f"\nFermi energy: {fermi:.4f} eV" if fermi else "\nFermi energy: not found")
print(f"Primitive cell atoms: {len(result['atoms'])}")
print(f"\nNote: With ecutwfc=90+ Ry and spin polarization, the bands step")
print(f"may take very long. If it times out, DOS analysis still works.")

## 6. Spin-Resolved Band Structure

Fe₂O₃ is a spin-polarized insulator. The band structure shows separate
spin-up (blue, solid) and spin-down (red, dashed) channels.

In [None]:
from shalom.backends import parse_xml_bands, find_xml_path
from shalom.plotting import BandStructurePlotter
from shalom.backends.qe_config import generate_band_kpath
from IPython.display import display, Image

bands_dir = result["calc_dirs"]["bands"]
fermi = result["fermi_energy"] or 0.0

# Try to parse bands (may fail if bands step timed out)
bs_data = None
xml_path = find_xml_path(bands_dir)
if xml_path is None:
    xml_path = find_xml_path(os.path.join(result["calc_dirs"]["scf"], "tmp"))

if xml_path:
    try:
        bs_data = parse_xml_bands(xml_path, fermi_energy=fermi)
        print(f"Spin polarized: {bs_data.is_spin_polarized}")
        print(f"K-points: {bs_data.nkpts}")
        print(f"Bands: {bs_data.nbands}")

        # Set high-symmetry labels from kpath
        calc_atoms = result["atoms"]
        kpath_cfg = generate_band_kpath(calc_atoms, npoints=40, is_2d=False)
        if kpath_cfg.kpath_labels:
            cumulative_idx = 0
            label_by_idx = {}
            if kpath_cfg.kpath_points:
                for seg_idx, (_, npts) in enumerate(kpath_cfg.kpath_points):
                    label = kpath_cfg.kpath_labels.get(seg_idx)
                    if label:
                        label_by_idx[cumulative_idx] = label
                    cumulative_idx += npts
            bs_data.high_sym_labels = label_by_idx

            # Collapse discontinuity gaps
            if len(bs_data.kpath_distances) > 1:
                dist = bs_data.kpath_distances.copy()
                for k_idx in sorted(label_by_idx):
                    if "|" in label_by_idx[k_idx] and k_idx + 1 < len(dist):
                        gap = dist[k_idx + 1] - dist[k_idx]
                        if gap > 0.0:
                            dist[k_idx + 1:] -= gap
                bs_data.kpath_distances = dist

        bands_path = os.path.join(OUTPUT_ROOT, "fe2o3_bands_spin.png")
        fig = BandStructurePlotter(bs_data).plot(
            output_path=bands_path,
            title="\u03b1-Fe\u2082O\u2083 \u2014 Spin-Resolved Band Structure",
            energy_window=(-8, 6),
            color_up="royalblue",
            color_down="crimson",
        )
        plt.show()
        print(f"Saved: {bands_path}")
    except Exception as e:
        print(f"Band parsing failed (incomplete data from timeout): {e}")
        bs_data = None
else:
    print("Bands XML not found \u2014 bands step may have timed out.")
    print("Continuing with DOS-only analysis.")

## 7. Spin-Resolved DOS

Mirror-convention DOS plot: spin-up on positive y-axis,
spin-down mirrored on negative y-axis.

In [None]:
from shalom.backends import parse_dos_file
from shalom.plotting import DOSPlotter

nscf_dir = result["calc_dirs"]["nscf"]
dos_file = os.path.join(nscf_dir, "pwscf.dos")

if os.path.exists(dos_file):
    dos_data = parse_dos_file(dos_file)
    dos_data.fermi_energy = fermi

    print(f"Spin polarized: {dos_data.is_spin_polarized}")
    print(f"Energy range: [{dos_data.energies.min():.1f}, {dos_data.energies.max():.1f}] eV")
    print(f"Data points: {len(dos_data.energies)}")

    dos_fig_path = os.path.join(OUTPUT_ROOT, "fe2o3_dos_spin.png")
    fig = DOSPlotter(dos_data).plot(
        output_path=dos_fig_path,
        title="\u03b1-Fe\u2082O\u2083 \u2014 Spin-Resolved DOS",
        energy_window=(-10, 6),
    )
    plt.show()
    print(f"Saved: {dos_fig_path}")
else:
    print(f"pwscf.dos not found at {dos_file}")

## 8. Combined Band Structure + DOS

Publication-quality side-by-side figure with spin-resolved bands and DOS.

In [None]:
from shalom.plotting import CombinedPlotter

if bs_data is not None and os.path.exists(dos_file):
    combined_path = os.path.join(OUTPUT_ROOT, "fe2o3_combined.png")
    fig = CombinedPlotter(bs_data, dos_data).plot(
        output_path=combined_path,
        title="\u03b1-Fe\u2082O\u2083 \u2014 Electronic Structure (GGA+U)",
        energy_window=(-8, 6),
    )
    plt.show()
    print(f"Saved: {combined_path}")
else:
    print("Combined plot skipped \u2014 band structure data not available.")
    print("See DOS plot above for electronic structure.")

## 9. Electronic Structure Analysis

Quantitative analysis of the spin-polarized band structure:
- Band gap from each spin channel (narrower gap is reported)
- VBM/CBM positions
- Metal vs insulator classification

In [None]:
from shalom.analysis import analyze_band_structure

if bs_data is not None:
    dos_for_analysis = dos_data if os.path.exists(dos_file) else None
    elec = analyze_band_structure(bs_data, dos_data=dos_for_analysis)

    print("=== Electronic Structure Analysis (Fe\u2082O\u2083) ===")
    if elec.bandgap_eV is not None:
        gap_type = "direct" if elec.is_direct else "indirect"
        print(f"Band gap:       {elec.bandgap_eV:.3f} eV ({gap_type})")
    else:
        print(f"Band gap:       metallic (no gap)")
    print(f"Is metal:       {elec.is_metal}")
    if elec.vbm_energy is not None:
        print(f"VBM energy:     {elec.vbm_energy:.4f} eV")
    if elec.cbm_energy is not None:
        print(f"CBM energy:     {elec.cbm_energy:.4f} eV")
    if elec.dos_at_fermi is not None:
        print(f"DOS at E_F:     {elec.dos_at_fermi:.4f} states/eV")
else:
    # DOS-only gap estimation
    print("=== Electronic Structure Analysis (DOS-only) ===")
    print("Band structure unavailable \u2014 estimating gap from DOS")
    total = np.array(dos_data.dos)
    energies = np.array(dos_data.energies)
    window = 5.0
    mask = (energies > fermi - window) & (energies < fermi + window)
    e_w = energies[mask]
    d_w = total[mask]
    thr = max(d_w) * 0.01

    below = e_w[e_w <= fermi]
    d_below = d_w[e_w <= fermi]
    above = e_w[e_w > fermi]
    d_above = d_w[e_w > fermi]

    vbm_idx = np.where(d_below > thr)[0]
    vbm = below[vbm_idx[-1]] if len(vbm_idx) > 0 else fermi
    cbm_idx = np.where(d_above > thr)[0]
    cbm = above[cbm_idx[0]] if len(cbm_idx) > 0 else fermi
    dos_gap = cbm - vbm

    print(f"DOS-estimated band gap: {dos_gap:.2f} eV")
    print(f"  VBM (approx): {vbm:.3f} eV")
    print(f"  CBM (approx): {cbm:.3f} eV")

    if dos_data.is_spin_polarized:
        for label, d_arr in [("Spin-up", dos_data.dos_up), ("Spin-down", dos_data.dos_down)]:
            d_s = np.array(d_arr)[mask]
            thr_s = max(d_s) * 0.01
            d_b = d_s[e_w <= fermi]
            d_a = d_s[e_w > fermi]
            vi = np.where(d_b > thr_s)[0]
            ci = np.where(d_a > thr_s)[0]
            v = below[vi[-1]] if len(vi) > 0 else fermi
            c = above[ci[0]] if len(ci) > 0 else fermi
            print(f"  {label} gap: {c - v:.2f} eV")

print(f"\nNote: GGA+U typically gives ~2 eV gap for Fe\u2082O\u2083 (exp: ~2.0\u20132.2 eV)")

## 10. Magnetic Analysis

SHALOM parses the QE `pw.out` file for:
- **Total magnetization** per cell (Bohr magneton)
- **Site-resolved magnetization** for each atom
- **Lowdin charges** by s/p/d/f channels
- **Magnetic element identification**

For hematite, we expect Fe sites to carry moments of ~4 Bohr magneton each,
with O sites having negligible moments.

In [None]:
from shalom.analysis import analyze_magnetism
from shalom.backends import QEBackend

# Parse SCF pw.out for magnetization data
scf_dir = result["calc_dirs"]["scf"]
scf_pw_out = os.path.join(scf_dir, "pw.out")

backend = QEBackend()
dft_result = backend.parse_output(scf_dir)

# Use primitive cell atoms (what the workflow actually computed on)
prim_atoms = result["atoms"]
prim_symbols = prim_atoms.get_chemical_symbols()

mag = analyze_magnetism(
    dft_result,
    prim_atoms,
    pw_out_path=scf_pw_out,
)

print("=== Magnetic Analysis (Fe\u2082O\u2083) ===")
if mag.total_magnetization is not None:
    print(f"Total magnetization: {mag.total_magnetization:.4f} Bohr mag/cell")
print(f"Is magnetic:         {mag.is_magnetic}")
print(f"Is spin polarized:   {mag.is_spin_polarized}")
print(f"Magnetic elements:   {mag.magnetic_elements if mag.magnetic_elements else 'N/A (need verbosity=high)'}")
print(f"Dominant element:    {mag.dominant_moment_element or 'N/A'}")

# Per-site moments (only available with verbosity='high')
if mag.site_magnetizations:
    print(f"\nPer-site magnetization (primitive cell, {len(prim_symbols)} atoms):")
    print(f"{'Atom':>6}  {'Element':>8}  {'Moment (\u03bcB)':>12}")
    print("-" * 35)
    for i, (elem, moment) in enumerate(zip(prim_symbols, mag.site_magnetizations)):
        print(f"{i+1:6d}  {elem:>8}  {moment:12.4f}")

    elem_moments = defaultdict(list)
    for elem, moment in zip(prim_symbols, mag.site_magnetizations):
        elem_moments[elem].append(moment)
    print(f"\nAverage moments per element:")
    for elem, moments in sorted(elem_moments.items()):
        avg = np.mean(np.abs(moments))
        print(f"  {elem}: |avg| = {avg:.4f} Bohr mag ({len(moments)} sites)")
else:
    print(f"\nSite-resolved magnetizations not available.")
    print(f"(Requires verbosity='high' in QE &CONTROL namelist)")
    n_fe = sum(1 for s in prim_symbols if s == 'Fe')
    if mag.total_magnetization and n_fe > 0:
        avg_fe = mag.total_magnetization / n_fe
        print(f"\nEstimated average Fe moment: {avg_fe:.1f} Bohr mag")
        print(f"  ({n_fe} Fe atoms in primitive cell, total = {mag.total_magnetization:.1f} \u03bcB)")
        print(f"  Note: With ferromagnetic initialization, this is an upper bound.")
        print(f"  Real Fe\u2082O\u2083 is antiferromagnetic (net M \u2248 0).")

## 11. Summary

In [None]:
print("=" * 60)
print("Fe\u2082O\u2083 Magnetic Oxide Study \u2014 Summary")
print("=" * 60)

if 'sym' in dir():
    print(f"\nStructure:")
    print(f"  Space group:     {sym.space_group_symbol} (#{sym.space_group_number})")
    print(f"  Crystal system:  {sym.crystal_system}")
    print(f"  Conventional:    {len(fe2o3)} atoms")
    print(f"  Primitive:       {len(result['atoms'])} atoms")

print(f"\nDFT Settings (auto-detected by SHALOM):")
print(f"  Accuracy:        PRECISE (GGA+U)")
print(f"  ecutwfc:         {config.system.get('ecutwfc')} Ry")
print(f"  nspin:           2 (spin-polarized)")
print(f"  Hubbard U (Fe):  {HUBBARD_U_VALUES.get('Fe', 'N/A')}")

print(f"\nElectronic:")
if fermi:
    print(f"  Fermi energy:    {fermi:.4f} eV")
if 'elec' in dir() and hasattr(elec, 'bandgap_eV') and elec.bandgap_eV:
    gap_type = "direct" if elec.is_direct else "indirect"
    print(f"  Band gap:        {elec.bandgap_eV:.3f} eV ({gap_type})")
elif 'dos_gap' in dir():
    print(f"  DOS-est. gap:    {dos_gap:.2f} eV")
print(f"  Bands available: {'Yes' if bs_data is not None else 'No (timed out)'}")

if 'mag' in dir():
    print(f"\nMagnetic:")
    if mag.total_magnetization is not None:
        print(f"  Total moment:    {mag.total_magnetization:.1f} Bohr mag/cell")
    print(f"  Is magnetic:     {mag.is_magnetic}")
    n_fe = sum(1 for s in prim_symbols if s == 'Fe')
    if mag.total_magnetization and n_fe > 0:
        print(f"  Avg Fe moment:   ~{mag.total_magnetization / n_fe:.1f} Bohr mag (ferro init)")

print(f"\nAll outputs saved to: {OUTPUT_ROOT}")