In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import shutil



In [2]:
def extract_hoppings_from_hr(hr_path, dim=2):
    """
    Extract hopping terms from wannier90_hr.dat

    Parameters
    ----------
    hr_path : str or Path
        Path to wannier90_hr.dat
    dim : int
        Dimensionality of the model (2 or 3).
        For 2D, Rz will be ignored.

    Returns
    -------
    num_wann : int
        Number of Wannier orbitals
    onsite : np.ndarray (complex)
        On-site terms H_ii(R=0)
    hoppings : list of tuples
        Each element is (R, i, j, t),
        where:
            R : list[int]   (length = dim)
            i,j : int       (0-based)
            t : complex
    """
    hr_path = Path(hr_path)
    lines = hr_path.read_text().splitlines()

    # --- header ---
    num_wann = int(lines[1].strip())
    nrpts = int(lines[2].strip())

    # --- skip degeneracy lines ---
    idx = 3
    degeneracies = []
    while len(degeneracies) < nrpts:
        parts = lines[idx].split()
        for p in parts:
            if len(degeneracies) < nrpts:
                degeneracies.append(int(p))
        idx += 1

    onsite = np.zeros(num_wann, dtype=complex)
    hoppings = []

    # --- hopping data ---
    for line in lines[idx:]:
        parts = line.split()
        if len(parts) != 7:
            continue

        Rx, Ry, Rz = map(int, parts[:3])
        i, j = map(int, parts[3:5])
        re, im = map(float, parts[5:7])

        t = re + 1j * im
        i0, j0 = i - 1, j - 1  # Wannier90 is 1-based

        # on-site term
        if Rx == 0 and Ry == 0 and Rz == 0 and i0 == j0:
            onsite[i0] = t
        else:
            R = [Rx, Ry] if dim == 2 else [Rx, Ry, Rz]
            hoppings.append((R, i0, j0, t))

    return num_wann, onsite, hoppings


In [3]:
# ===== 晶格（2D 四方，单位 Å）=====
a0 = 4.1630263607523368

lat_vecs = [
    [a0, 0.0],
    [0.0, a0],
]

# ===== 轨道中心（分数坐标）=====
# 2 个磁性原子 × dz2 × (up, dn)
orb_vecs = [
    [0.5, 0.0],  # atom A, dz2 ↑
    [0.0, 0.5],  # atom B, dz2 ↑
    [0.5, 0.0],  # atom A, dz2 ↓
    [0.0, 0.5],  # atom B, dz2 ↓
]


In [4]:
workdir = Path(".").resolve()

dat_files = sorted(workdir.glob("wannier*.dat"))

if not dat_files:
    raise RuntimeError("No wannier*.dat files found in current directory.")

print(f"Found {len(dat_files)} hr.dat file(s):")
for f in dat_files:
    print("  -", f.name)


Found 2329 hr.dat file(s):
  - wannier90_iE15_iE22_iT12_iT22_iT35_iT42.dat
  - wannier90_iE15_iE22_iT12_iT22_iT35_iT43.dat
  - wannier90_iE15_iE22_iT12_iT22_iT35_iT44.dat
  - wannier90_iE15_iE22_iT12_iT22_iT35_iT45.dat
  - wannier90_iE15_iE22_iT12_iT23_iT31_iT41.dat
  - wannier90_iE15_iE22_iT12_iT23_iT31_iT42.dat
  - wannier90_iE15_iE22_iT12_iT23_iT31_iT43.dat
  - wannier90_iE15_iE22_iT12_iT23_iT31_iT44.dat
  - wannier90_iE15_iE22_iT12_iT23_iT31_iT45.dat
  - wannier90_iE15_iE22_iT12_iT23_iT32_iT41.dat
  - wannier90_iE15_iE22_iT12_iT23_iT32_iT42.dat
  - wannier90_iE15_iE22_iT12_iT23_iT32_iT43.dat
  - wannier90_iE15_iE22_iT12_iT23_iT32_iT44.dat
  - wannier90_iE15_iE22_iT12_iT23_iT32_iT45.dat
  - wannier90_iE15_iE22_iT12_iT23_iT33_iT41.dat
  - wannier90_iE15_iE22_iT12_iT23_iT33_iT42.dat
  - wannier90_iE15_iE22_iT12_iT23_iT33_iT43.dat
  - wannier90_iE15_iE22_iT12_iT23_iT33_iT44.dat
  - wannier90_iE15_iE22_iT12_iT23_iT33_iT45.dat
  - wannier90_iE15_iE22_iT12_iT23_iT34_iT41.dat
  - wannier90

In [5]:
import numpy as np
from pathlib import Path

def save_band_dat(
    out_file,
    k_dist,
    evals,
    label=None,
    k_node=None,
    energy_shift=0.0,
    meta_dict=None,
    fmt="%.10f",
):
    """
    保存 Band.dat（数值能带）
    第一列: k_dist
    后续列: E1, E2, ..., ENb
    """
    out_file = Path(out_file)
    out_file.parent.mkdir(parents=True, exist_ok=True)

    k_dist = np.asarray(k_dist, float).reshape(-1)
    E = np.asarray(evals, float)

    if E.ndim != 2:
        raise ValueError(f"evals 必须是 2D 数组 (Nb, Nk)，但当前形状是 {E.shape}")
    if E.shape[1] != k_dist.size:
        raise ValueError(f"Nk 不一致：k_dist={k_dist.size}, evals.shape={E.shape}")

    Nb, Nk = E.shape

    # 输出矩阵：Nk 行，(1+Nb) 列
    M = np.column_stack([k_dist, E.T])

    # 头部信息
    header_lines = []
    header_lines.append(f"Nk={Nk} Nb={Nb} energy_shift={energy_shift:+.10f} (E -> E - energy_shift)")
    header_lines.append("columns: k_dist " + " ".join([f"E{ib+1}" for ib in range(Nb)]))

    if (label is not None) and (k_node is not None):
        header_lines.append("label: " + " ".join([str(s) for s in label]))
        header_lines.append("k_node: " + " ".join([f"{x:.10f}" for x in np.asarray(k_node, float).ravel()]))

    if meta_dict is not None:
        for k, v in meta_dict.items():
            header_lines.append(f"{k}: {v}")

    np.savetxt(out_file, M, fmt=fmt, header="\n".join(header_lines), comments="# ")


In [None]:
from pythtb import tb_model
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import shutil  # 你后面用到了

# ==================================================
# 工作目录与输入文件
# ==================================================
workdir = Path(r"E:\马睿骁\组会汇报\Nb2OSSe\pythTB\workflow")
dat_files = sorted(workdir.glob("wannier90_*.dat"))

# ==================================================
# 相图标准
# ==================================================

def classify_phase(evals, gap_min=1e-3):
    E = np.asarray(evals, float).ravel()
    Es = np.unique(np.sort(E))

    if Es.size < 10:
        return {"has_gap": False, "gap": 0.0, "midgap": None, "phase": "metal"}

    gaps = Es[1:] - Es[:-1]
    igap = int(np.argmax(gaps))
    gap = gaps[igap]

    if gap < gap_min:
        return {"has_gap": False, "gap": float(gap), "midgap": None, "phase": "metal"}

    VBM = Es[igap]
    CBM = Es[igap + 1]
    return {"has_gap": True, "gap": float(gap), "midgap": 0.5 * (VBM + CBM), "phase": "insulator"}

# ==================================================
# Hermitian hopping 的“规范 R”判据
# ==================================================
def canonical_R(R):
    R = tuple(R)
    return R > tuple(-r for r in R)

# ==================================================
# 仅用于可视化的 mid-gap 计算
# ==================================================
def find_midgap(E_all, search_window=2.0, gap_min=1e-3):
    E = np.asarray(E_all, dtype=float)
    mask = (E >= -search_window) & (E <= search_window)
    Ew = E[mask]
    if Ew.size < 10:
        return None

    Es = np.unique(np.sort(Ew))
    if Es.size < 10:
        return None

    gaps = Es[1:] - Es[:-1]
    igap = int(np.argmax(gaps))
    gap = gaps[igap]
    if gap < gap_min:
        return None

    VBM = Es[igap]
    CBM = Es[igap + 1]
    return 0.5 * (VBM + CBM)

# ==================================================
# 高对称路径：Γ–X–M–Y–Γ
# ==================================================
path = [
    [0.0, 0.0],
    [0.5, 0.0],
    [0.5, 0.5],
    [0.0, 0.5],
    [0.0, 0.0],
]
label = (r"$\Gamma$", r"$X$", r"$M$", r"$Y$", r"$\Gamma$")
nk = 301

# ==================================================
# 主循环：无筛选，生成所有体能带
# ==================================================
for dat in dat_files:
    name = dat.stem
    calc_dir = workdir / name
    calc_dir.mkdir(exist_ok=True)

    print(f"\nProcessing {dat.name}")

    # --------------------------------------------------
    # 1) 读取 Wannier hr
    # --------------------------------------------------
    num_wann, onsite, hoppings = extract_hoppings_from_hr(dat, dim=2)

    if num_wann != len(orb_vecs):
        print(f"  [SKIP] Orbital mismatch: {num_wann} vs {len(orb_vecs)}")
        continue

    # --------------------------------------------------
    # 2) 构造 TB 模型
    # --------------------------------------------------
    model = tb_model(dim_k=2, dim_r=2, lat=lat_vecs, orb=orb_vecs)

    model.set_onsite(onsite.real.tolist())

    nhop_used = 0
    for R, i, j, t in hoppings:
        R = tuple(R)
        if (i > j) or (i == j and not canonical_R(R)):
            continue
        model.set_hop(t, i, j, list(R))
        nhop_used += 1

    print(f"  [INFO] Used hoppings: {nhop_used}")

    # --------------------------------------------------
    # 3) 体能带计算（原始本征值）
    # --------------------------------------------------
    k_vec, k_dist, k_node = model.k_path(path, nk, report=False)
    evals = model.solve_ham(k_pts=k_vec).T  # (nband, nk)

    # --------------------------------------------------
    # 3.5) 仅用于可视化：mid-gap 对齐
    # --------------------------------------------------
    midgap = find_midgap(evals.ravel(), search_window=2.0)

    if midgap is not None:
        evals_plot = evals - midgap
        title_suffix = " (mid-gap aligned)"
        print(f"  [INFO] mid-gap = {midgap:+.6f} eV → shifted to 0")
    else:
        evals_plot = evals
        title_suffix = " (no clear gap)"
        print("  [INFO] No clear gap found (possible metal or phase boundary)")

    # --------------------------------------------------
    # 4) 绘图
    # --------------------------------------------------
    fig, ax = plt.subplots(figsize=(6, 4.5))
    ax.set_xlim(k_node[0], k_node[-1])
    ax.set_xticks(k_node)
    ax.set_xticklabels(label)

    for x in k_node:
        ax.axvline(x=x, lw=0.5, color="k")

    for band in evals_plot:
        ax.plot(k_dist, band, color="black", lw=1)

    ax.axhline(0.0, color="red", ls="--", lw=0.8)
    ax.set_xlabel("Path in k-space")
    ax.set_ylabel(r"$E$ (eV)")
    ax.set_title(f"{name}{title_suffix}")

    plt.tight_layout()

    # --------------------------------------------------
    # 5) 保存图
    # --------------------------------------------------
    fig.savefig(calc_dir / "bandstructure.png", dpi=300)
    fig.savefig(calc_dir / "bandstructure.pdf")
    plt.close(fig)
    print(f"  [OK] Saved figure to {calc_dir.name}/")

    # --------------------------------------------------
    # 5.5) 导出 Band.dat（注意：必须在 for 循环内部）
    # --------------------------------------------------
    save_band_dat(
        out_file=calc_dir / "Band_raw.dat",
        k_dist=k_dist,
        evals=evals,
        label=label,
        k_node=k_node,
        energy_shift=0.0,
        meta_dict={"name": name},
    )

    shift = float(midgap) if (midgap is not None) else 0.0
    save_band_dat(
        out_file=calc_dir / "Band.dat",
        k_dist=k_dist,
        evals=evals_plot,
        label=label,
        k_node=k_node,
        energy_shift=shift,
        meta_dict={"name": name, "midgap_eV": (midgap if midgap is not None else "NaN")},
    )
    print(f"  [OK] Saved Band.dat to {calc_dir.name}/")

    # --------------------------------------------------
    # 6) 相信息记录
    # --------------------------------------------------
    phase_info = classify_phase(evals)

    info_file = calc_dir / "phase_info.txt"
    with open(info_file, "w", encoding="utf-8") as f:
        f.write(f"name        = {name}\n")
        f.write(f"nhop_used  = {nhop_used}\n")
        f.write(f"nband      = {evals.shape[0]}\n")
        f.write(f"nk         = {evals.shape[1]}\n")
        f.write(f"has_gap    = {phase_info['has_gap']}\n")
        f.write(f"gap_eV     = {phase_info['gap']:.8f}\n")
        f.write(f"midgap_eV  = {phase_info['midgap'] if phase_info['midgap'] is not None else 'NaN'}\n")
        f.write(f"phase      = {phase_info['phase']}\n")

    # --------------------------------------------------
    # 7) 移动 hr.dat 到该目录
    # --------------------------------------------------
    hr_target = calc_dir / dat.name
    if not hr_target.exists():
        shutil.move(str(dat), hr_target)
        print(f"  [INFO] Moved {dat.name} → {calc_dir.name}/")
    else:
        print(f"  [WARN] {dat.name} already exists in {calc_dir.name}/, skip move")

    print(f"  [OK] Finished {calc_dir.name}/")
