# Bandstructure of AB Bilayer Graphene

<h1>Table of Contents<span class="tocSkip"></span></h1>

<div class="toc">
<ul class="toc-item">

<li><a href="#1-initialization">1. Initialization</a></li>

<li><a href="#2-constants-and-important-quantities">2. Constants and Important Quantities</a>
  <ul>
    <li><a href="#a-visualize-brillouin-zone">a) Visualize Brillouin Zone</a></li>
  </ul>
</li>

<li><a href="#3-tight-binding-continuum-hamiltonian">3. Tight-Binding Continuum Hamiltonian</a></li>

<li><a href="#4-visualizers">4. Visualizers</a>
  <ul>
    <li><a href="#a-parameters">a) Parameters</a></li>
    <li><a href="#b-3d-visualization-of-energy-bands-near-k">b) 3D Visualization of Energy Bands near K</a></li>
    <li><a href="#c-contour-of-bands-at-the-fermi-level">c) Contour of Bands at the Fermi Level</a></li>
    <li><a href="#d-2d-line-cut-through-the-dirac-point-k--γ">d) 2D Line Cut Through the Dirac Point (K → Γ)</a></li>
    <li><a href="#e-density-of-states-approximation">e) Density of States Approximation</a></li>
  </ul>
</li>

</ul>
</div>


___
## 1. Initialization

In [1]:
# ======== Imports ==========
import numpy as np
import matplotlib.pyplot as plt
import json

from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 (needed for 3D)
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

from scipy.ndimage import gaussian_filter1d


___
## 2. Constants and Vectors Regarding Graphene Lattice Structure
Nearest neighbor vectors, lattice vectors, important points

In [2]:

# Lattice constant (angstrom)
a = 2.46    # lattice spacing in angstroms
a_cc = 1.42    # C-C spacing in angstrmos
# Convenience
pi = np.pi
sqrt3 = np.sqrt(3.0)
I = 1j  # imaginary unit

def vec(x, y):
    """Return a 2D numpy vector."""
    return np.array([float(x), float(y)], dtype=float)

# # --- Nearest-neighbor vectors (angstrom) ---
# d1 = (a/2.0) * vec(1.0,  sqrt3)
# d2 = (a/2.0) * vec(1.0, -sqrt3)
# d3 = (a)     * vec(-1.0, 0.0)
# --- Nearest-neighbor vectors (angstrom) ---
d1 = a * vec(1, 0)
d2 = a * vec(-1/2, sqrt3/2)
d3 = a * vec(-1/2, -sqrt3/2)

# --- Lattice vectors (angstrom) ---
a1 = a * vec( sqrt3/2, -0.5 )
a2 = a * vec( sqrt3/2,  0.5 )

# --- Reciprocal base vectors (1/angstrom) ---
# b_i · a_j = 2π δ_ij for this convention
b1 = (2*np.pi/a) * vec(1/np.sqrt(3), -1.0)
b2 = (2*np.pi/a) * vec(1/np.sqrt(3),  1.0)

# --- High-symmetry K points (1/angstrom) ---
K0  = (2*b1 +   b2) / 3
K0p  = (  b1 + 2*b2) / 3
K1  = (- b1 +   b2) / 3
K1p = -K0
K2= -K0p
K2p = -K1

# --- M points (1/angstrom) ---
M0 =  0.5 * b1
M1 =  0.5 * b2
M2 =  0.5 * (b1 + b2)
M3 = -M0
M4 = -M1
M5 = -M2

# K and M points as arrays

K_points = {
    "K0": K0, "K0'": K0p, "K1": K1, "K1'": K1p, "K2": K2, "K2'": K2p,
}
M_points = {
    "M0": M0, "M1": M1, "M2": M2,
    "M3": M3, "M4": M4, "M5": M5,
}
# Γ (g) point for convenience
g = vec(0.0, 0.0)




## 2.1 Visualize Brillouin Zone

In [3]:
hexagon = np.array([K0, K0p, K1, K1p, K2, K2p, K0])  # closed loop
fig = go.Figure()

# Brillouin zone hexagon
fig.add_trace(
    go.Scatter(
        x=hexagon[:, 0],
        y=hexagon[:, 1],
        mode="lines",
        line=dict(width=2),
        name="1st Brillouin Zone",
        hoverinfo="skip",
    )
)

# K points
fig.add_trace(
    go.Scatter(
        x=[p[0] for p in K_points.values()],
        y=[p[1] for p in K_points.values()],
        mode="markers+text",
        name="K points",
        marker=dict(size=10, symbol="diamond"),
        text=list(K_points.keys()),
        textposition="top center",
    )
)

# M0 point
fig.add_trace(
    go.Scatter(
        x=[M0[0]],
        y=[M0[1]],
        mode="markers+text",
        name="M points",
        marker=dict(size=9, symbol="square"),
        text=["M0"],
        textposition="bottom center",
    )
)

# g (center)
fig.add_trace(
    go.Scatter(
        x=[g[0]],
        y=[g[1]],
        mode="markers+text",
        name="Γ",
        marker=dict(size=11, symbol="circle"),
        text=["Γ"],
        textposition="top left",
    )
)

# ================================
#  Layout / styling
# ================================
fig.update_layout(
    width=500,
    height=500,
    font=dict(family="DejaVu Sans", size=14, color="white"),
    template="plotly_dark",
    title="Graphene Brillouin Zone",
    xaxis_title="kₓ (Å⁻¹)",
    yaxis_title="kᵧ (Å⁻¹)",
    xaxis=dict(range=[-2, 2], autorange=False),
    yaxis=dict(range=[-2, 2], autorange=False),
    legend=dict(
        x=0.02,
        y=0.98,
        bgcolor="rgba(0,0,0,0.3)",
        borderwidth=0,
    ),
    margin=dict(l=60, r=40, t=60, b=60),
)

# Light grid so you can see symmetry nicely
fig.update_xaxes(showgrid=True, gridwidth=1, zeroline=True, range=[-2.5, 2], autorange=False)
fig.update_yaxes(showgrid=True, gridwidth=1, zeroline=True, range=[-2, 2], autorange=False)

# Equal aspect ratio to keep the figure square
fig.update_yaxes(scaleanchor="x", scaleratio=1)

fig.show()

## 3. Tight Binding Continuum Hamiltonian

Generates eigenenergies near the K point (3 bands in conduction band and 3 bands in valence band)

In [4]:
# ============ Citation for hopping parameters =================
'''
@article{PhysRevB.82.035409,
  title = {Band structure of $ABC$-stacked graphene trilayers},
  author = {Zhang, Fan and Sahu, Bhagawan and Min, Hongki and MacDonald, A. H.},
  journal = {Phys. Rev. B},
  volume = {82},
  issue = {3},
  pages = {035409},
  numpages = {10},
  year = {2010},
  month = {Jul},
  publisher = {American Physical Society},
  doi = {10.1103/PhysRevB.82.035409},
  url = {https://link.aps.org/doi/10.1103/PhysRevB.82.035409}
}
'''

# ---- hopping parameters -------

# t = {
#     "g0": 3.10, # NN intralayer hopping
#     "g1": 0.38, # NN interlayer hopping
#     "g2": -0.015, # hopping btw low E sites A_i and B_i+2
#     "g3": -0.377, # hopping btw low E sites of AB graphene: A_i and B_i+1 (i = 1, 2)
#     # "g4": 0.141, # couples A_i and A_i+1 and B_i and B_i+1
#     "g4": -0.044,
#     #"g_5": 0.0, #  neglected term
#     #"g_6": 0.0, #  neglected term
#     "delta": -0.0014, #-0.0014, # energy difference btw high-energy (B1, A2, B2, A3) and low energy points (A1, B3)
#     "Delta2": -0.0023, #-0.0023 # separation of middle layer potential to mean of L1 and L3. in macdonald paper this is "u_a"
# }

g0 = -3.16 # t0, where v0 = alpha * t0
g1 = 0.381 # t1, interlayer hopping
g2 = 0.0 # t2 is not used
g3 = 0.38 # for v3 = alpha * g3
g4 = 0.14 # for v4 = alpha * g4
#"g_5": 0.0, #  neglected term
#"g_6": 0.0, #  neglected term
delta = 0.0 # energy difference
D2 = 0.022 # Δ' (Delta prime)
alpha = (sqrt3 / 2.0) * a

# --- helper functions ---
def pi_calc(kx, ky):
    return kx + 1j * ky

def pi_dagger_calc(kx, ky):
    return -kx + 1j * ky
    
# ---------- 4×4 Hamiltonian for bilayer graphene ------------

def eigenenergies_AB(kx, ky, D1=0.0, valley=+1):
    """
    Continuum 4×4 Hamiltonian for bilayer graphene near K:

        H =
        [ 0      vπ†   -v4π†  -v3π  ]
        [ vπ     Δ'    t1     -v4π† ]
        [ -v4π   t1    Δ'     vπ†   ]
        [ -v3π†  -v4π  vπ     0     ]

    where:
    - v = v0 = alpha * g0
    - v3 = alpha * g3
    - v4 = alpha * g4
    - t1 = g1
    - Δ' = D2

    Energies returned in eV.
    """

    # broadcast kx, ky to common shape
    kx = np.asarray(kx, dtype=float)
    ky = np.asarray(ky, dtype=float)
    kx, ky = np.broadcast_arrays(kx, ky)
    kshape = kx.shape

    # valley index enters as sign of k → ξ k
    kx_eff = valley * kx
    ky_eff = valley * ky

    # π and π†
    pi     = kx_eff + 1j * ky_eff
    pi_dag = -kx_eff + 1j * ky_eff

    # momentum-dependent couplings
    # v = v0 = alpha * g0
    v_pi     = alpha * g0 * pi
    v_pi_dag = alpha * g0 * pi_dag

    v3_pi     = alpha * g3 * pi
    v3_pi_dag = alpha * g3 * pi_dag

    v4_pi     = alpha * g4 * pi
    v4_pi_dag = alpha * g4 * pi_dag

    # allocate H (4×4)
    H = np.zeros(kshape + (4, 4), dtype=complex)
    
    # ----- fill H according to the 4×4 matrix -----

    # row 0
    H[..., 0, 0] = -D1/2
    H[..., 0, 1] = v_pi_dag
    H[..., 0, 2] = -v4_pi_dag
    H[..., 0, 3] = -v3_pi

    # row 1
    H[..., 1, 0] = v_pi
    H[..., 1, 1] = -D1/2  
    H[..., 1, 2] = g1  # t1
    H[..., 1, 3] = -v4_pi_dag

    # row 2
    H[..., 2, 0] = -v4_pi
    H[..., 2, 1] = g1  # t1
    H[..., 2, 2] = D1/2  # Δ'
    H[..., 2, 3] = v_pi_dag

    # row 3
    H[..., 3, 0] = -v3_pi_dag
    H[..., 3, 1] = -v4_pi
    H[..., 3, 2] = v_pi
    H[..., 3, 3] = D1/2

    # eigenvalues (Hermitian)
    eigvals = np.linalg.eigvalsh(H)   # shape: kshape + (4,)
    return eigvals

# ---- Example 1 : calculate just kpoint K0, M_0, \g ----
# k_values = np.array([K0, M0, vec(0.0,0.0)])
# kx, ky = k_values[:, 0], k_values[:, 1]
# E_bands = eigenenergies_AB(kx, ky, 0.0) # should return a 4 x 3 matrix (4 bands, 3 k-points)
# print(E_bands)


___

# 4. Visualizers

## a) parameters

In [5]:
dk  = 0.040         # half-width of the square patch around K0 in |k| units
N = 300             # grid resolution (odd numbers help center on K0)
D1 = 0.00          # interlayer potential difference in eV
fermi_level = 0.0# Fermi level in eV

# kx = np.linspace(K0[0] - dk, K0[0] + dk, N)
# ky = np.linspace(K0[1] - dk, K0[1] + dk, N)
kx = np.linspace(-dk, dk, N)
ky = np.linspace(-dk, dk, N)
KX, KY = np.meshgrid(kx, ky, indexing="xy")

e_bands = eigenenergies_AB(KX, KY, D1)
QX = KX
QY = KY
Z2 = e_bands[..., 1]
Z3 = e_bands[..., 2]


In [6]:
# ------ Generate JSON file for k-space vs E near K point ------
data = {
    "kx": kx.tolist(),
    "ky": ky.tolist(),
    "energy": e_bands.tolist(),
    "E_F": fermi_level,
    "U": D1,
}

# --- Write JSON file ---
with open("bands/AB_K_Bands.json", "w") as f:
    json.dump(data, f, indent=2)

print("Saved AB_K_Bands.json")

Saved AB_K_Bands.json


## b) 3d visualizer of k-space vs energy near K point

In [7]:
pio.renderers.default = "jupyterlab"

# ====== only zoom plotting in dk x dk region not 2dk x 2dk ==========
ix_start = N // 4
ix_end   = 3 * N // 4
iy_start = N // 4
iy_end   = 3 * N // 4
QX_cr = QX[iy_start:iy_end, ix_start:ix_end]
QY_cr = QY[iy_start:iy_end, ix_start:ix_end]
Z2_cr = Z2[iy_start:iy_end, ix_start:ix_end]
Z3_cr = Z3[iy_start:iy_end, ix_start:ix_end]

# ========= energy bounds ============
zmin = float(np.nanmin([Z2_cr, Z3_cr]))
zmax = float(np.nanmax([Z2_cr, Z3_cr]))
# Apply z-axis cutoff at -0.06 and 0.06
z_pad = 0.002
zmin = max(zmin - z_pad, -0.6)
zmax = min(zmax + z_pad, 0.06)

Z2_plot = Z2_cr.copy()
Z3_plot = Z3_cr.copy()

Z2_plot[(Z2_plot < zmin) | (Z2_plot > zmax)] = np.nan
Z3_plot[(Z3_plot < zmin) | (Z3_plot > zmax)] = np.nan


# --- build a DataFrame for px (for K-axis only) ---
df_axis = pd.DataFrame({
    "qx": [0, 0],
    "qy": [0, 0],
    "E": [zmin, zmax],
})

# --- start figure with px (dark theme, fonts) ---
fig = px.line_3d(
    color_discrete_sequence=["red"],
    template="plotly_dark",
    width=700, height=700
)
fig.update_traces(line=dict(width=7), name="K-axis")

# --- add band surfaces (go.Surface) ---
fig.add_trace(go.Surface(
    x=QX_cr, y=QY_cr, z=Z2_plot,
    colorscale="viridis", opacity=0.85, showscale=False, name="Band 2"
))
fig.add_trace(go.Surface(
    x=QX_cr, y=QY_cr, z=Z3_plot,
    colorscale="viridis_r", opacity=0.85, showscale=False, name="Band 3"
))
fig.add_trace(go.Scatter3d(
    x=[0, 0], y=[0, 0], z=[zmin, 0.9 * zmax],
    mode="lines",
    line=dict(color="red", width=7),
    name="K-axis"
))

# --- add Fermi level plane (go.Surface) ---
Zplane = np.full_like(QX_cr, fermi_level)
fig.add_trace(go.Surface(
    x=QX_cr, y=QY_cr, z=Zplane,
    colorscale="Reds_r", opacity=0.3, showscale=False, name=f"E={fermi_level} plane"
))

# ========== style / annotations ==========

#generate ticks
xyticks = [-0.04, -0.03, -0.02, -0.01, 0, 0.01, 0.02, 0.03, 0.04]
spacing = 0.02
start = spacing * np.floor(zmin / spacing)
stop = spacing * np.ceil(zmax / spacing)
zticks = np.arange(start / spacing, stop / spacing + 0.5) * spacing
dk_cr = dk//2

#plot
fig.update_layout(
    font=dict(family="DejaVu Sans", size=14, color="white"),
    margin=dict(l=0, r=0, t=0, b=0),
    scene=dict(
        xaxis=dict(
            title=dict(text=""),  # Hide default "x"
            tickvals=xyticks,
            tickfont=dict(family="DejaVu Sans", size=12)
        ),
        yaxis=dict(
            title=dict(text=""),  # Hide default "y"
            tickvals=xyticks,
            tickfont=dict(family="DejaVu Sans", size=12)
        ),
        zaxis=dict(
            title=dict(text=""),  # Hide default "z"
            range=[-0.06, 0.06],  # z-axis cutoff at -0.06 and 0.06
            tickvals=zticks,
            tickfont=dict(family="DejaVu Sans", size=12)
        ),
        aspectmode="manual",
        aspectratio=dict(x=0.9, y=0.9, z=1.1),
        annotations=[
            dict(x=0, y=-dk_cr, z=zmin - z_pad, text="qx (in 1/å)", showarrow=False,
                font=dict(family="DejaVu Sans", size=18, color="white"),
                xanchor="center", standoff=10),
            dict(x=dk_cr, y=0, z=zmin, text="qy (in 1/å)", showarrow=False,
                font=dict(family="DejaVu Sans", size=18, color="white"),
                yanchor="middle", standoff=10),
            dict(x=dk_cr, y=-dk_cr, z=dk_cr, text="E(q) (in eV)", showarrow=False,
                font=dict(family="DejaVu Sans", size=18, color="white"),
                textangle=0, standoff=10),
            dict(x=0, y=0, z=zmax, text="K point", showarrow=False,
                font=dict(family="DejaVu Sans", size=18, color="red"),
                textangle=0, standoff=10),
        ]
    )

)

# --- enforce z-range (final correction) ---
fig.update_layout(scene=dict(zaxis=dict(range=[-0.06, 0.06])))  # z-axis cutoff at -0.06 and 0.06

fig.show(config={"displayModeBar": False})



## c) Contour of bands at Fermi level

In [8]:


def contour_points_2d(QX, QY, Z, level, tol=0.001):
    """
    Return (N,2) points of the contour Z(X,Y)=level
    If tol is set (e.g. tol=1e-3), also include grid points with |Z-level|<=tol.
    Function: idneitfy points and edges (both vertical and horizontal) where the energy crosses fermi level
    """

    # Early exit if level is outside the surface range
    zmin, zmax = float(np.nanmin(Z)), float(np.nanmax(Z))
    if zmin > level or zmax < level:
        if tol is not None:
            mask = np.isfinite(Z) & (np.abs(Z - level) <= tol)
            return np.column_stack([QX[mask], QY[mask]])
        return np.empty((0, 2))

    # change so leveled to 0
    Z0 = Z - level
    pts_chunks = []

    # Vertical edges: (j,i) -> (j+1,i)
    z00 = Z0[:-1, :] # takes lower n-1 rows 
    z10 = Z0[1:, :] # takes upper n-1 rows
    denom_v = z00 - z10
    cross_v = (z00 == 0) | (z10 == 0) | (z00 * z10 < 0) # either z = 0 at a point or btw the points at distance t btw the points

    with np.errstate(divide='ignore', invalid='ignore'): # if point is equal to 0
        t_v = np.divide(z00, denom_v, out=np.zeros_like(z00, dtype=float), where=(denom_v != 0))

    X00 = QX[:-1, :]
    X10 = QX[1:, :]
    Y00 = QY[:-1, :]
    Y10 = QY[1:, :]

    xv = X00 + t_v * (X10 - X00)
    yv = Y00 + t_v * (Y10 - Y00)

    mask_v_pts = cross_v # & (denom_v != 0)
    if np.any(mask_v_pts):
        pts_chunks.append(np.column_stack([xv[mask_v_pts], yv[mask_v_pts]]))

    # Horizontal edges: (j,i) -> (j,i+1)
    z00h = Z0[:, :-1] # takes the n-1 rightmost columns
    z01h = Z0[:, 1:] # takes the n-1 leftmost columns
    denom_h = z00h - z01h
    cross_h = ((z00h == 0) | (z01h == 0) | (z00h * z01h < 0))

    with np.errstate(divide='ignore', invalid='ignore'):
        t_h = np.divide(z00h, denom_h, out=np.zeros_like(z00h, dtype=float), where=(denom_h != 0))

    #print(t_h[0, :])
    X00h = QX[:, :-1]
    X01h = QX[:, 1:]
    Y00h = QY[:, :-1]
    Y01h = QY[:, 1:]

    xh = X00h + t_h * (X01h - X00h)
    yh = Y00h + t_h * (Y01h - Y00h)

    mask_h_pts = cross_h #& (denom_h != 0)

    if np.any(mask_h_pts):
        pts_chunks.append(np.column_stack([xh[mask_h_pts], yh[mask_h_pts]]))

    #Optionally, also include grid points within tolerance
    if tol is not None:
        mask_tol = (np.abs(Z0) <= tol)
        if np.any(mask_tol):
            pts_chunks.append(np.column_stack([QX[mask_tol], QY[mask_tol]]))

    if not pts_chunks:
        return np.empty((0, 2))

    pts = np.vstack(pts_chunks)

    # Deduplicate points
    if len(pts):
        r = np.round(pts, 10)
        _, idx = np.unique(r, axis=0, return_index=True)
        pts = pts[np.sort(idx)]

    return pts

pts2 = contour_points_2d(QX, QY, Z2, level=fermi_level)
pts3 = contour_points_2d(QX, QY, Z3, level=fermi_level)
# print(pts2.shape)
# print(pts3.shape)

fig = go.Figure()

# Band 2 points @ fermi level
if pts2.size:
    fig.add_trace(go.Scatter(
        x=pts2[:, 0], y=pts2[:, 1],
        mode="markers",
        marker=dict(size=6, opacity=0.95),
        name="Band 2 @ E_F"
    ))

# Band 3 points @ fermi level
if pts3.size:
    fig.add_trace(go.Scatter(
        x=pts3[:, 0], y=pts3[:, 1],
        mode="markers",
        marker=dict(size=6, opacity=0.95),
        name="Band 3 @ E_F"
    ))


fig.update_layout(
    template="plotly_dark",
    width=700, height=700,
    margin=dict(l=0, r=0, t=40, b=0),
    font=dict(family="DejaVu Sans", size=14, color="white"),
    title=dict(
        text=fr"Constant-energy cross section at E = {(fermi_level * 1000):g} meV",
        x=0.5, xanchor="center",
    ),
    legend=dict(x=0.02, y=0.98, xanchor="left", yanchor="top", bgcolor="rgba(0,0,0,0)")
)

fig.update_xaxes(
    title_text="kx",
    tickvals=xyticks,
    dtick=0.005,
    range=[np.min(xyticks), np.max(xyticks)],  # ← FIXED range
    tickfont=dict(family="DejaVu Sans", size=12),
    showgrid=True, gridcolor="rgba(255,255,255,0.15)",
    zeroline=False,
    constrain="domain",
)

fig.update_yaxes(
    title_text="ky",
    tickvals=xyticks,
    dtick=0.005,
    range=[np.min(xyticks), np.max(xyticks)],  # ← FIXED range
    tickfont=dict(family="DejaVu Sans", size=12),
    showgrid=True, gridcolor="rgba(255,255,255,0.15)",
    zeroline=False,
    scaleanchor="x",
)


fig.show(config={"displayModeBar": False})



## d) 2d line cut

In [9]:
import json
import numpy as np
import plotly.graph_objects as go

# --- Load zoomed k-path ---
with open("kpaths/kpath_zoomed.json", "r") as file:
    data = json.load(file)

kx_z_line = np.array(data.get("qx", []))
ky_z_line = np.array(data.get("qy", []))
breaks = data.get("breaks", [])
k_labels = data.get("k_labels", ["start", "end"])

# --- Compute eigenenergies ---
energies = eigenenergies_AB(kx_z_line, ky_z_line, D1=D1)  # shape (N, nbands)
# adjust slicing based on your band ordering
Ev = energies[:, :3]   # conduction bands
Ec = energies[:, -3:]  # valence bands 
nbands = energies.shape[1]

# --- Compute cumulative path distance s (x-axis) ---
dx = np.diff(kx_z_line, prepend=kx_z_line[0])
dy = np.diff(ky_z_line, prepend=ky_z_line[0])
s = np.cumsum(np.hypot(dx, dy))

# --- Determine K location and recenter x so that K -> 0 ---
# Symmetry tick positions remain for reference
tick_positions = [s[b] for b in breaks] + [s[-1]]
if isinstance(k_labels, list) and "K" in k_labels and len(tick_positions) >= len(k_labels):
    x_k = tick_positions[k_labels.index("K")]
else:
    x_k = 0.5 * (s[0] + s[-1])  # fallback: center of the path

s0 = s - x_k  # recentered x: K at 0

# --- Build symmetric tick grid with positive labels (magnitudes) ---
L = float(max(abs(s0.min()), abs(s0.max())))
tickvals = np.linspace(-L, L, 11)
ticktext = ["0" if abs(v) < 1e-12 else f"{abs(v):.3f}" for v in tickvals]

# --- Plotly figure ---
fig = go.Figure()


fig.add_trace(go.Scatter(
        x=s0, y=energies[:, 1],
        mode="lines",
        line=dict(width=1.2),
        name=f"Highest Energy Valence Band"
    ))
fig.add_trace(go.Scatter(
    x=s0, y=energies[:, 2],
    mode="lines",
    line=dict(width=1.2),
    name=f"Lowest Energy Conduction Band"
))

# --- Axes: center at K=0, positive tick labels ---
fig.update_xaxes(
    title_text=r"$k$ path",
    range=[-L, L],
    tickmode="array",
    tickvals=tickvals,
    ticktext=ticktext,
    zeroline=True,
    zerolinewidth=1,
    zerolinecolor="white"
)

# --- Vertical separators at symmetry points (recentered) ---
centers = [xp - x_k for xp in tick_positions]
for x in centers[1:-1]:
    fig.add_vline(x=x, line_dash="dash", line_color="gray", opacity=0.5)

# --- Fermi level (horizontal) and K-point (vertical) guides ---
fig.add_hline(y=fermi_level, line_dash="dot", line_color="gray", opacity=0.4)
fig.add_vline(x=0.0, line_dash="dot", line_color="gray", opacity=0.6)

# --- Layout ---
fig.update_layout(
    title="Zoomed-in Bandstructure near K-point",
    xaxis_title=r"$k$ path",
    yaxis_title="Energy (eV)",
    template="plotly_dark",
    font=dict(family="DejaVu Sans"),
    width=800,
    height=800,
    margin=dict(l=60, r=20, t=60, b=60),
    legend=dict(x=0.02, y=0.98, bgcolor="rgba(0,0,0,0.3)"),
)

fig.show()


## e) Density of States approximation

In [None]:

#zmin = float(np.nanmin([Z2, Z3]))
zmax = float(np.nanmax([Z2, Z3]))

# ========================================================================


def compute_dos(Z2, Z3, *, Emin=None, Emax=None, nE=2000, sigma=None):
    """
    DOS via cumulative state counting + finite-difference derivative.
    """

    # Flatten all energies from both bands
    energies = np.concatenate([Z2.ravel(), Z3.ravel()])
    energies = energies[np.isfinite(energies)]
    energies.sort()

    if Emin is None:
        Emin = energies.min()
    if Emax is None:
        Emax = energies.max()

    # Energy grid
    E_grid = np.linspace(Emin, Emax, nE)
    dE = E_grid[1] - E_grid[0]

    # Cumulative counts: N(E) = # of states with energy <= E
    # searchsorted is O(nE log N_states) and vectorized
    cum_counts = np.searchsorted(energies, E_grid, side="right").astype(float)

    # DOS = dN/dE via finite differences
    DOS = np.empty_like(E_grid)

    # central differences for interior points
    DOS[1:-1] = (cum_counts[2:] - cum_counts[:-2]) / (2 * dE)

    # one-sided differences at the ends
    DOS[0] = (cum_counts[1] - cum_counts[0]) / dE
    DOS[-1] = (cum_counts[-1] - cum_counts[-2]) / dE

    # Convert to physical units, same as before
    # area of each dk pixel in (1/cm^2) units, then DOS in cm^-2 eV^-1
    area = (2 * dk / N / (2 * pi))**2  # area per k-square
    prefactor = area * 1e16            # to cm⁻²
    DOS *= prefactor                   # counts/energy * area → cm⁻²·eV⁻¹

    return E_grid, DOS

E_grid, DOS = compute_dos(Z2, Z3, Emin=zmin*0.5, Emax=zmax*0.5, nE = Z2.size) # nE = number of energy samples to avoid too fine or too spaced

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=E_grid,
    y=DOS,
    mode="lines",
    line=dict(width=2),
    name="DOS",
    hovertemplate="E = %{x:.3f} eV<br>DOS = %{y:.3e} cm⁻²·eV⁻¹<extra></extra>"
))

fig.update_layout(
    template="plotly_dark",
    title="AB Graphene Density of States",
    font=dict(family="DejaVu Sans", size=14),
    xaxis_title="Energy (eV)",
    yaxis_title="DOS (cm⁻²eV⁻¹)",
    xaxis=dict(
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor="white"
    ),
    yaxis=dict(
        range=[0, 5e13],
        tickformat=".1e",
        zeroline=False
    ),
    width=700,
    height=450,
    margin=dict(l=60, r=20, t=50, b=50),
)


fig.show(config=dict(displayModeBar=False))
