In [None]:
import numpy as np
from scipy.linalg import eigh
from itertools import combinations
import enum
from qiskit_nature.second_q.mappers import JordanWignerMapper, BravyiKitaevMapper
from qiskit_nature.second_q.operators import FermionicOp


# --- Helper for Hubbard Charge Gap ---
def _get_hubbard_ground_state_energy(num_sites, N_up, N_down, t_hop, U_int, suppress_warnings=False):
    """Internal helper to get just the ground state energy for Hubbard."""
    if N_up < 0 or N_down < 0 or N_up > num_sites or N_down > num_sites:
        if not suppress_warnings:
            print(f"Warning: Invalid particle numbers for GS energy calc ({N_up}, {N_down}) on L={num_sites}. Returning inf.")
        return np.inf

    basis_states, _ = generate_hubbard_basis(num_sites, N_up, N_down)
    if not basis_states:
        if not suppress_warnings:
            print(f"Warning: No basis states for GS energy calc ({N_up}, {N_down}) on L={num_sites}. Returning inf.")
        return np.inf

    if len(basis_states) > 3000 and not suppress_warnings: # Increased limit slightly for internal calls
         print(f"Warning (internal GS calc): Hilbert space dim ({len(basis_states)}) large for L={num_sites}, N_up={N_up}, N_down={N_down}.")

    H = hubbard_hamiltonian(num_sites, N_up, N_down, t_hop, U_int)
    if H.shape[0] == 0:
        return np.inf
    
    eigenvalues = eigh(H, eigvals_only=True)
    return eigenvalues[0] if len(eigenvalues) > 0 else np.inf


# --- 1D Transverse Field Ising Model ---

def get_spin_config(state_idx, num_sites):
    binary = bin(state_idx)[2:].zfill(num_sites)
    return np.array([1 if bit == '1' else -1 for bit in binary]) # +1 for up, -1 for down

def ising_hamiltonian(num_sites, J, Gamma, h):
    num_states = 2**num_sites
    hamiltonian = np.zeros((num_states, num_states), dtype=np.float64)

    for i in range(num_states):
        spins_i = get_spin_config(i, num_sites)
        h_diag = 0
        for site in range(num_sites):
            h_diag -= J * spins_i[site] * spins_i[(site + 1) % num_sites]
            h_diag -= h * spins_i[site]
        hamiltonian[i, i] = h_diag

        for site in range(num_sites):
            spins_j_config = np.copy(spins_i)
            spins_j_config[site] *= -1
            j_idx = 0
            for k_idx, k_spin in enumerate(spins_j_config):
                if k_spin == 1:
                    j_idx |= (1 << k_idx)
            hamiltonian[i, j_idx] -= Gamma # hamiltonian[j_idx, i] will also be set due to symmetry
    return hamiltonian

def calculate_ising_magnetization(num_sites, ground_state_vector):
    """Calculates average magnetization M_z = (1/L) * Σ_i <s_z_i> for Ising model.
       s_z_i is +1 (up) or -1 (down).
    """
    if ground_state_vector is None or len(ground_state_vector) != 2**num_sites:
        return np.nan
        
    total_sz_expectation = 0.0
    for i in range(2**num_sites):
        spins_i = get_spin_config(i, num_sites)
        sum_sz_for_state_i = np.sum(spins_i)
        total_sz_expectation += (ground_state_vector[i]**2) * sum_sz_for_state_i
    
    return total_sz_expectation / num_sites


def solve_ising_model(num_sites, J, Gamma, h):
    print(f"\n--- 1D Transverse Field Ising Model (L={num_sites}, J={J}, Gamma={Gamma}, h={h}) ---")
    if num_sites > 14: # Max practical for dense ED
        print("Warning: Large number of sites for Ising ED, may be slow/memory intensive.")

    H = ising_hamiltonian(num_sites, J, Gamma, h)
    eigenvalues, eigenvectors = eigh(H) # Now get eigenvectors too

    results = {"eigenvalues": eigenvalues}
    
    ground_state_energy = eigenvalues[0]
    print(f"Ground state energy: {ground_state_energy:.6f}")
    results["ground_state_energy"] = ground_state_energy

    if len(eigenvalues) > 1:
        first_excited_energy = eigenvalues[1]
        energy_gap = first_excited_energy - ground_state_energy
        print(f"First excited state energy: {first_excited_energy:.6f}")
        print(f"Energy gap: {energy_gap:.6f}")
        results["first_excited_energy"] = first_excited_energy
        results["energy_gap"] = energy_gap
    
    ground_state_vector = eigenvectors[:, 0]
    avg_magnetization = calculate_ising_magnetization(num_sites, ground_state_vector)
    print(f"Average magnetization <M_z>: {avg_magnetization:.6f}")
    results["average_magnetization"] = avg_magnetization
    results["ground_state_vector"] = ground_state_vector
    
    return results


# --- 1D Hubbard Model ---

def generate_hubbard_basis(num_sites, N_up, N_down):
    if N_up < 0 or N_down < 0 or N_up > num_sites or N_down > num_sites:
        return [], {} # Invalid particle numbers

    possible_up_configs = list(combinations(range(num_sites), N_up))
    possible_down_configs = list(combinations(range(num_sites), N_down))

    basis_states = [] # List of (mask_up, mask_down) tuples
    state_to_idx = {} # Dictionary mapping state tuple to its index
    idx_counter = 0

    for up_pos_tuple in possible_up_configs:
        mask_up = 0
        for pos in up_pos_tuple:
            mask_up |= (1 << pos)
        for down_pos_tuple in possible_down_configs:
            mask_down = 0
            for pos in down_pos_tuple:
                mask_down |= (1 << pos)
            
            state = (mask_up, mask_down)
            basis_states.append(state)
            state_to_idx[state] = idx_counter
            idx_counter += 1
    return basis_states, state_to_idx

def hubbard_hamiltonian(num_sites, N_up, N_down, t, U):
    basis_states, state_to_idx = generate_hubbard_basis(num_sites, N_up, N_down)
    num_basis_states = len(basis_states)
    
    if num_basis_states == 0:
        return np.array([[]]) # Return empty if no basis states

    hamiltonian = np.zeros((num_basis_states, num_basis_states), dtype=np.float64)

    for current_idx, current_state in enumerate(basis_states):
        mask_up_current, mask_down_current = current_state
        
        # On-site interaction term (U Σ_i n_i↑ n_i↓)
        interaction_energy = 0
        for site_k in range(num_sites):
            is_up_occupied = (mask_up_current >> site_k) & 1
            is_down_occupied = (mask_down_current >> site_k) & 1
            if is_up_occupied and is_down_occupied:
                interaction_energy += U
        hamiltonian[current_idx, current_idx] += interaction_energy

        # Hopping term (-t)
        for site_i in range(num_sites):
            site_j = (site_i + 1) % num_sites # Neighbor for PBC

            # Hopping for up spins
            is_up_at_i = (mask_up_current >> site_i) & 1
            is_up_at_j = (mask_up_current >> site_j) & 1 # Target site must be empty FOR UP SPINS
            
            # Check if site j is occupied by an up spin. If so, cannot hop there.
            if is_up_at_i and not is_up_at_j: # Hop from i to j
                temp_mask_up = mask_up_current ^ (1 << site_i) # Remove from i
                new_mask_up = temp_mask_up | (1 << site_j)    # Add to j
                
                target_state_up = (new_mask_up, mask_down_current)
                # state_to_idx should contain this state if valid basis generation
                target_idx_up = state_to_idx[target_state_up]
                hamiltonian[current_idx, target_idx_up] -= t
                # hamiltonian[target_idx_up, current_idx] -= t # H is hermitian, often only fill upper/lower tri

            # Hopping for down spins
            is_down_at_i = (mask_down_current >> site_i) & 1
            is_down_at_j = (mask_down_current >> site_j) & 1 # Target site must be empty FOR DOWN SPINS
            
            if is_down_at_i and not is_down_at_j: # Hop from i to j
                temp_mask_down = mask_down_current ^ (1 << site_i) # Remove from i
                new_mask_down = temp_mask_down | (1 << site_j)    # Add to j

                target_state_down = (mask_up_current, new_mask_down)
                target_idx_down = state_to_idx[target_state_down]
                hamiltonian[current_idx, target_idx_down] -= t
                # hamiltonian[target_idx_down, current_idx] -= t
    
    # Ensure hermiticity if only one triangle was filled (eigh handles this, but good practice)
    # For safety, explicitly make it Hermitian if only upper/lower triangle filled
    # However, the current (-t) on both [i,j] and [j,i] (implicitly via loop structure) handles this
    # The hopping loop above does A_ij for c_j^+c_i. The reverse c_i^+c_j will be handled when current_idx becomes target_idx
    # So the H should be symmetric.
    return hamiltonian

def calculate_hubbard_local_magnetizations(num_sites, ground_state_vector, basis_states):
    """Calculates local site magnetization <m_z_k> = <(n_k↑ - n_k↓)/2> for Hubbard."""
    if ground_state_vector is None or len(ground_state_vector) != len(basis_states) or not basis_states:
        return [np.nan] * num_sites
        
    local_mz = np.zeros(num_sites)
    for site_k in range(num_sites):
        exp_val_op_at_k = 0.0 # (n_k_up - n_k_down)
        for state_idx, (mask_up, mask_down) in enumerate(basis_states):
            coeff_sq = ground_state_vector[state_idx]**2
            
            val_at_k = 0.0
            if (mask_up >> site_k) & 1: # Up spin at site k
                val_at_k += 1.0
            if (mask_down >> site_k) & 1: # Down spin at site k
                val_at_k -= 1.0
            exp_val_op_at_k += coeff_sq * val_at_k
        local_mz[site_k] = exp_val_op_at_k / 2.0
    return local_mz

def calculate_hubbard_charge_gap(num_sites, N_up_ref, N_down_ref, t_hop, U_int):
    """Calculates charge gap: E0(N+1) + E0(N-1) - 2*E0(N)."""
    print(f"Calculating charge gap around N_up={N_up_ref}, N_down={N_down_ref}:")
    E0_N = _get_hubbard_ground_state_energy(num_sites, N_up_ref, N_down_ref, t_hop, U_int, suppress_warnings=True)
    if E0_N == np.inf:
        print("  Cannot compute charge gap: Reference state E0(N) is invalid.")
        return np.nan

    # E0(N+1): Try adding an up spin or a down spin
    E0_N_plus_1_options = []
    if N_up_ref + 1 <= num_sites: # Can add up spin
        E0_N_plus_1_options.append(
            _get_hubbard_ground_state_energy(num_sites, N_up_ref + 1, N_down_ref, t_hop, U_int, suppress_warnings=True)
        )
    if N_down_ref + 1 <= num_sites: # Can add down spin
         E0_N_plus_1_options.append(
            _get_hubbard_ground_state_energy(num_sites, N_up_ref, N_down_ref + 1, t_hop, U_int, suppress_warnings=True)
        )
    E0_N_plus_1 = min(E0_N_plus_1_options) if E0_N_plus_1_options else np.inf
    
    # E0(N-1): Try removing an up spin or a down spin
    E0_N_minus_1_options = []
    if N_up_ref - 1 >= 0: # Can remove up spin
        E0_N_minus_1_options.append(
             _get_hubbard_ground_state_energy(num_sites, N_up_ref - 1, N_down_ref, t_hop, U_int, suppress_warnings=True)
        )
    if N_down_ref - 1 >= 0: # Can remove down spin
        E0_N_minus_1_options.append(
            _get_hubbard_ground_state_energy(num_sites, N_up_ref, N_down_ref - 1, t_hop, U_int, suppress_warnings=True)
        )
    E0_N_minus_1 = min(E0_N_minus_1_options) if E0_N_minus_1_options else np.inf

    if E0_N_plus_1 == np.inf or E0_N_minus_1 == np.inf:
        print(f"  Cannot compute charge gap: E0(N+1)={E0_N_plus_1:.4f}, E0(N-1)={E0_N_minus_1:.4f}, E0(N)={E0_N:.4f}")
        return np.nan
        
    charge_gap = E0_N_plus_1 + E0_N_minus_1 - 2 * E0_N
    print(f"  E0(N+1)={E0_N_plus_1:.4f}, E0(N-1)={E0_N_minus_1:.4f}, E0(N)={E0_N:.4f}")
    return charge_gap


def solve_hubbard_model(num_sites, N_up, N_down, t_hop, U_int, calculate_charge_gap_flag=False):
    print(f"\n--- 1D Hubbard Model (L={num_sites}, N_up={N_up}, N_down={N_down}, t={t_hop}, U={U_int}) ---")
    
    results = {}
    if N_up < 0 or N_down < 0 or N_up > num_sites or N_down > num_sites:
        print("Error: Invalid particle numbers.")
        return results # Return empty dict

    basis_states, state_to_idx = generate_hubbard_basis(num_sites, N_up, N_down)
    if not basis_states:
        print("Error: No basis states found for this configuration.")
        return results
    
    results["basis_states"] = basis_states
    results["state_to_idx"] = state_to_idx
    
    hilbert_dim = len(basis_states)
    print(f"Hilbert space dimension for this sector: {hilbert_dim}")
    if hilbert_dim > 2500: # Practical ED limit
         print(f"Warning: Hilbert space dimension ({hilbert_dim}) is large, ED may be slow.")

    H = hubbard_hamiltonian(num_sites, N_up, N_down, t_hop, U_int)
    if H.shape[0] == 0: # Check if hamiltonian construction failed
        return results
        
    eigenvalues, eigenvectors = eigh(H)
    results["eigenvalues"] = eigenvalues
    results["eigenvectors"] = eigenvectors # Store all eigenvectors

    ground_state_energy = eigenvalues[0]
    print(f"Ground state energy: {ground_state_energy:.6f}")
    results["ground_state_energy"] = ground_state_energy

    if len(eigenvalues) > 1:
        first_excited_energy = eigenvalues[1]
        energy_gap = first_excited_energy - ground_state_energy # This is the spin gap or excitation gap *within the sector*
        print(f"First excited state energy (within sector): {first_excited_energy:.6f}")
        print(f"Energy gap (within sector): {energy_gap:.6f}")
        results["first_excited_energy"] = first_excited_energy
        results["energy_gap_in_sector"] = energy_gap # Clarify this is not "the" charge gap
    
    # Total magnetization (trivial for fixed N_up, N_down sector)
    total_mz_component = (N_up - N_down) / 2.0
    print(f"Total magnetization S_z^tot: {total_mz_component:.1f} ( (N_up-N_down)/2 )")
    results["total_magnetization_Sz"] = total_mz_component
    
    # Local magnetizations
    ground_state_vector = eigenvectors[:, 0]
    results["ground_state_vector"] = ground_state_vector
    local_mzs = calculate_hubbard_local_magnetizations(num_sites, ground_state_vector, basis_states)
    print(f"Local magnetizations <m_z_k>: {np.round(local_mzs, 4)}")
    results["local_magnetizations"] = local_mzs
    
    if calculate_charge_gap_flag:
        charge_gap = calculate_hubbard_charge_gap(num_sites, N_up, N_down, t_hop, U_int)
        print(f"Charge gap Δ_c: {charge_gap:.6f}")
        results["charge_gap"] = charge_gap
        
    return results


if __name__ == '__main__':
    # --- Ising Model Examples ---
    ising_results_1 = solve_ising_model(num_sites=4, J=1.0, Gamma=1.0, h=0.0)
    ising_results_2 = solve_ising_model(num_sites=3, J=1.0, Gamma=0.5, h=0.1)

    # --- Hubbard Model Examples ---
    hubbard_results_1 = solve_hubbard_model(num_sites=4, N_up=2, N_down=2, t_hop=1.0, U_int=4.0, calculate_charge_gap_flag=True)
    # Expected E0 for L=4, N_up=2, N_down=2, U=4, t=1 is approx -2.86 for PBC (from Bethe Ansatz or DMRG)

    hubbard_results_non_int = solve_hubbard_model(num_sites=4, N_up=1, N_down=0, t_hop=1.0, U_int=0.0, calculate_charge_gap_flag=True)
    # Expected E0 for L=4,N_up=1,N_down=0,U=0,t=1: -2t*cos(0) = -2.0
    # Eigenvalues: -2t, 0, 0, 2t (for k=0, pi/2, -pi/2, pi)
    if hubbard_results_non_int and "eigenvalues" in hubbard_results_non_int:
         print(f"Calculated eigenvalues for L=4,N_up=1,N_down=0,U=0,t=1: {np.round(hubbard_results_non_int['eigenvalues'], 5)}")

    hubbard_L2_N11 = solve_hubbard_model(num_sites=2, N_up=1, N_down=1, t_hop=1.0, U_int=4.0, calculate_charge_gap_flag=True)
    # Expected E0 for L=2, (1,1), U=4, t=1: U/2 - sqrt((U/2)^2 + 4t^2) = 2 - sqrt(4+4) = 2 - 2*sqrt(2) approx -0.828
    # Eigenvalues: -0.828, 0, U, U/2 + sqrt((U/2)^2+4t^2) => -0.828, 0, 4.0, 4.828
    if hubbard_L2_N11 and "eigenvalues" in hubbard_L2_N11:
        print(f"Calculated eigenvalues for L=2,N_up=1,N_down=1,U=4,t=1: {np.round(hubbard_L2_N11['eigenvalues'], 5)}")

    # Example of a case where charge gap might be interesting: half-filled insulator
    # L=2, N_up=1, N_down=1, U=8, t=1
    # E0(1,1) = 4 - sqrt(16+4) = 4 - sqrt(20) = 4 - 4.472 = -0.472
    # E0(N+1): E0(2,1) or E0(1,2). E.g., (2,1) on L=2. basis: |↑↓,↑0>, |↑0,↑↓>. (mask_up=(11)base2=3, mask_down=(01)base2=1), (mask_up=(11)base2=3, mask_down=(10)base2=2) -> not possible, N_up=2.
    # L=2, N_up=2, N_down=1. Basis states:
    # up_mask=0b11 (sites 0,1). down_mask=0b01 (site 0) -> |↑↓,↑> Energy U
    # up_mask=0b11 (sites 0,1). down_mask=0b10 (site 1) -> |↑,↑↓> Energy U
    # Hamiltonian for (2,1) on L=2. H = [[U, -t],[ -t, U]]. E = U ± t.
    # So E0(2,1) = U-t = 8-1 = 7.  Same for E0(1,2). So E0(N+1)=7.
    # E0(N-1): E0(1,0) or E0(0,1). E.g. (1,0) on L=2. H = [[0, -t],[-t,0]]. E = ±t. So E0(1,0)=-t = -1. Same for E0(0,1). So E0(N-1)=-1.
    # Charge Gap = E0(N+1) + E0(N-1) - 2*E0(N) = (U-t) + (-t) - 2*(U/2 - sqrt((U/2)^2+4t^2))
    # = U - 2t - U + 2*sqrt((U/2)^2+4t^2) = -2t + 2*sqrt((U/2)^2+4t^2)
    # For U=8,t=1: -2 + 2*sqrt(16+4) = -2 + 2*sqrt(20) = -2 + 2*4.472 = -2 + 8.944 = 6.944
    # Let's check with code:
    print("\n--- Test charge gap for L=2, (1,1), U=8, t=1 ---")
    solve_hubbard_model(num_sites=2, N_up=1, N_down=1, t_hop=1.0, U_int=8.0, calculate_charge_gap_flag=True)

import numpy as np
import ml_collections
from ml_collections import config_dict
from functools import partial

# First, ensure you have openfermion installed:
# pip install openfermion

import openfermion as of

def jw_sigma_z_fermionic(site_index):
    """
    Returns the Jordan-Wigner transformed sigma_z operator for a given site
    as an OpenFermion FermionOperator.
    σ_z^(i) = 2c†_i c_i - 1
    """
    op_string = f"{site_index}^ {site_index}" # c†_i c_i
    term1 = of.FermionOperator(op_string, 2.0)
    term2 = of.FermionOperator("", -1.0) # Identity term
    return term1 + term2

def jw_sigma_x_fermionic(site_index, num_sites):
    """
    Returns the Jordan-Wigner transformed sigma_x operator for a given site
    as an OpenFermion FermionOperator.
    σ_x^(i) = (Π_{j<i} (2c†_j c_j - 1)) * (c†_i + c_i)
    """
    # 1. Construct the Jordan-Wigner string: Π_{j<i} (2c†_j c_j - 1)
    jw_string_op = of.FermionOperator("", 1.0) # Start with identity
    for j in range(site_index): # j from 0 to site_index - 1
        z_j = jw_sigma_z_fermionic(j) # (2c†_j c_j - 1)
        jw_string_op *= z_j

    # 2. Construct (c†_i + c_i)
    c_dag_i = of.FermionOperator(f"{site_index}^", 1.0)
    c_i = of.FermionOperator(f"{site_index}", 1.0)
    sum_creation_annihilation = c_dag_i + c_i

    # 3. Combine them
    sigma_x_op = jw_string_op * sum_creation_annihilation
    return sigma_x_op

def ising_hamiltonian_jw(num_sites, J_coupling, h_field, pbc=False):
    """
    Constructs the Jordan-Wigner transformed Ising Hamiltonian.

    H_Ising = -J Σ_{i} σ_z^(i) σ_z^(i+1) - h Σ_{i} σ_x^(i)

    Args:
        num_sites (int): Number of sites (spins/fermions).
        J_coupling (float): Interaction strength J.
        h_field (float): Transverse field strength h.
        pbc (bool): If True, use periodic boundary conditions for interaction term.
                    Otherwise, open boundary conditions.

    Returns:
        openfermion.FermionOperator: The fermionic Hamiltonian.
    """
    if num_sites <= 0:
        return of.FermionOperator()
    if num_sites == 1 and pbc:
        print("Warning: PBC for N=1 interaction term is ill-defined, treating as OBC.")
        pbc = False


    fermionic_H = of.FermionOperator()

    # 1. Interaction term: -J Σ_{i} σ_z^(i) σ_z^(i+1)
    num_interaction_terms = num_sites if pbc and num_sites > 1 else num_sites - 1
    if num_sites == 1: # No interaction term for a single site
        num_interaction_terms = 0
        
    for i in range(num_interaction_terms):
        site1_idx = i
        site2_idx = (i + 1) % num_sites # Handles PBC for the last term

        sigma_z_i = jw_sigma_z_fermionic(site1_idx)
        sigma_z_j = jw_sigma_z_fermionic(site2_idx)
        
        interaction_term = sigma_z_i * sigma_z_j
        fermionic_H += -J_coupling * interaction_term
        
        # For illustration: print symbolic form of this Pauli term's JW transformation
        print(f"Original Pauli term: (-{J_coupling}) * σ_z^{site1_idx} * σ_z^{site2_idx}")
        # The actual term added to fermionic_H is already simplified by OpenFermion
        # To show unsimplified:
        # print(f"  JW unsimplified: (-{J_coupling}) * ({sigma_z_i}) * ({sigma_z_j})")
        # print(f"  JW simplified contribution: {-J_coupling * interaction_term}\n")


    # 2. Transverse field term: -h Σ_{i} σ_x^(i)
    for i in range(num_sites):
        sigma_x_i = jw_sigma_x_fermionic(i, num_sites)
        fermionic_H += -h_field * sigma_x_i
        
        # For illustration:
        print(f"Original Pauli term: (-{h_field}) * σ_x^{i}")
        # print(f"  JW unsimplified (symbolic string part for σ_x^{i}):")
        # string_part_symb = ""
        # if i > 0:
        #     string_part_symb = " * ".join([f"(2c†_{k}c_{k} - 1)" for k in range(i)]) + " * "
        # print(f"  (-{h_field}) * {string_part_symb}(c†_{i} + c_{i})")
        # print(f"  JW simplified contribution: {-h_field * sigma_x_i}\n")


    return fermionic_H

# --- Example Usage ---
if __name__ == "__main__":
    N_sites = 3
    J = 1.0
    h = 0.5

    print(f"Constructing JW transformed Ising Hamiltonian for N={N_sites}, J={J}, h={h} (OBC):\n")
    H_fermionic_obc = ising_hamiltonian_jw(N_sites, J, h, pbc=False)
    print("\n--- Total Fermionic Hamiltonian (OBC, terms simplified and combined) ---")
    print(H_fermionic_obc)
    # OpenFermion might compress terms, e.g. (2.0+0j) [0^ 0] + (1.0+0j) [0^ 0] = (3.0+0j) [0^ 0]
    # and perform normal ordering.

    print("-" * 50)
    # Example with Periodic Boundary Conditions (PBC)
    # The σ_z σ_z interaction term (N-1, 0) is local in fermions.
    # If the model had σ_x σ_x interactions, PBC would make the (N-1,0) term highly non-local.
    if N_sites > 1:
        print(f"\nConstructing JW transformed Ising Hamiltonian for N={N_sites}, J={J}, h={h} (PBC):\n")
        H_fermionic_pbc = ising_hamiltonian_jw(N_sites, J, h, pbc=True)
        print("\n--- Total Fermionic Hamiltonian (PBC, terms simplified and combined) ---")
        print(H_fermionic_pbc)
        
    # Example for N=1 (no interaction term)
    print("-" * 50)
    print(f"\nConstructing JW transformed Ising Hamiltonian for N=1, J={J}, h={h}:\n")
    H_fermionic_N1 = ising_hamiltonian_jw(1, J, h, pbc=False) # pbc=True will be ignored
    print("\n--- Total Fermionic Hamiltonian (N=1, terms simplified and combined) ---")
    print(H_fermionic_N1)


    # To demonstrate the symbolic representation of individual Pauli operators:
    print("\n--- Symbolic JW representation of individual Pauli operators ---")
    site = 1
    N = 3 # context for sigma_x
    
    sz1_symbolic = f"(2c†_{site}c_{site} - 1)"
    print(f"σ_z^{site}  ->  {sz1_symbolic}")
    
    # For σ_x^1 (site=1), string is Π_{j<1} Z_j = Z_0
    # Z_0 = (2c†_0c_0 - 1)
    # σ_x^1 -> (2c†_0c_0 - 1) * (c†_1 + c_1)
    string_part_symb = ""
    if site > 0:
        string_part_symb = " * ".join([f"(2c†_{k}c_{k} - 1)" for k in range(site)]) + " * "
    sx1_symbolic = f"{string_part_symb}(c†_{site} + c_{site})"
    print(f"σ_x^{site} (for N={N}) -> {sx1_symbolic}")

    print("\n--- Corresponding OpenFermion objects (simplified) ---")
    sz1_of = jw_sigma_z_fermionic(site)
    print(f"σ_z^{site} (OpenFermion): {sz1_of}")
    sx1_of = jw_sigma_x_fermionic(site, N)
    print(f"σ_x^{site} (OpenFermion, for N={N}): {sx1_of}")
# Hubbard Model Configuration for QMC Simulation
class Hubbard1D:
    def __init__(self, L=10, U=4.0, t=1.0):
        self.L = L  # Lattice length
        self.U = U  # On-site interaction strength
        self.t = t  # Hopping strength (t term in Hamiltonian)
        self.hamiltonian = self.build_hamiltonian()

    def build_hamiltonian(self):
        """Construct the Hamiltonian for the 1D Hubbard model"""
        H = np.zeros((self.L, self.L))  # Hamiltonian matrix

        for i in range(self.L):
            # Hopping terms (t) for nearest neighbors
            H[i, (i+1) % self.L] = -self.t
            H[(i+1) % self.L, i] = -self.t

            # On-site interaction (U)
            H[i, i] = self.U

        return H

    def get_energy(self, state):
        """Calculate the energy for a given state using the Hamiltonian"""
        return np.dot(state.T, np.dot(self.hamiltonian, state))


def default_hubbard_config() -> ml_collections.ConfigDict:
    """Create set of default parameters for running the Hubbard QMC."""
    cfg = ml_collections.ConfigDict({
        'batch_size': 4096,  # Batch size for QMC sampling
        'optim': {
            'objective': 'vmc',  # Objective is VMC energy minimization
            'iterations': 1000000,
            'optimizer': 'adam',  # Optimizer to use
            'lr': {
                'rate': 0.01,  # Learning rate
                'decay': 0.95,  # Learning rate decay
                'delay': 10000.0,
            },
            'clip_local_energy': 5.0,
            'clip_median': False,
            'reset_if_nan': True,
            'spin_energy': 0.0,
        },
        'system': {
            'type': 'hubbard',  # System type
            'lattice_size': 10,  # Number of sites in the 1D Hubbard model
            'on_site_interaction': 4.0,  # U parameter (on-site interaction)
            'hopping_strength': 1.0,  # t parameter (hopping strength)
            'electrons': (5, 5),  # Number of electrons in the system (spin-up, spin-down)
            'ndim': 1,  # Dimensionality (1D for Hubbard model)
            'states': 0,
            'units': 'angstrom',
            'use_pp': False,
        },
        'network': {
            'network_type': 'psiformer',  # Use PsiFormer architecture
            'psiformer': {
                'num_layers': 4,
                'num_heads': 4,
                'heads_dim': 64,
                'mlp_hidden_dims': (256,),
                'use_layer_norm': True,
            },
        },
    })
    return cfg

# Initialize the Hubbard model and configuration
hubbard_model = Hubbard1D(L=10, U=4.0, t=1.0)
hubbard_config = default_hubbard_config()

# Example function to calculate the energy for a random state
def sample_energy(model, state):
    """Calculate energy for a given state."""
    energy = model.get_energy(state)
    return energy

# Generate a random state for testing
random_state = np.random.rand(hubbard_model.L)
random_state = random_state / np.linalg.norm(random_state)  # Normalize

# Calculate energy for the random state
energy = sample_energy(hubbard_model, random_state)
print(f"Calculated energy: {energy}")

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

import numpy as np
import ml_collections
from ml_collections import config_dict
from functools import partial

# Ising Model Configuration for QMC Simulation
class Ising1D:
    def __init__(self, L=10, J=1.0, h=0.0):
        self.L = L  # Number of spins
        self.J = J  # Interaction strength
        self.h = h  # External magnetic field
        self.spins = np.ones(self.L)  # Initialize spins to +1

    def hamiltonian(self):
        """Calculate the Ising Hamiltonian"""
        H = 0
        for i in range(self.L):
            H -= self.J * self.spins[i] * self.spins[(i+1) % self.L]  # Interactions with neighbors
            H -= self.h * self.spins[i]  # Interaction with external field
        return H

    def flip_spin(self, index):
        """Flip the spin at the specified index"""
        self.spins[index] *= -1

    def get_energy(self):
        """Calculate the energy of the current state"""
        return self.hamiltonian()


def default_ising_config() -> ml_collections.ConfigDict:
    """Create set of default parameters for running the Ising QMC."""
    cfg = ml_collections.ConfigDict({
        'batch_size': 4096,
        'optim': {
            'objective': 'vmc',
            'iterations': 1000000,
            'optimizer': 'adam',
            'lr': {
                'rate': 0.01,
                'decay': 0.95,
                'delay': 10000.0,
            },
            'clip_local_energy': 5.0,
            'clip_median': False,
            'reset_if_nan': True,
            'spin_energy': 0.0,
        },
        'system': {
            'type': 'ising',  # System type
            'lattice_size': 10,  # Number of spins in 1D Ising model
            'interaction_strength': 1.0,  # J parameter (interaction strength)
            'external_field': 0.0,  # External magnetic field (h)
            'electrons': (5, 5),
            'ndim': 1,  # Dimensionality (1D for Ising model)
            'states': 0,
            'units': 'angstrom',
            'use_pp': False,
        },
        'network': {
            'network_type': 'psiformer',  # Use PsiFormer architecture
            'psiformer': {
                'num_layers': 4,
                'num_heads': 4,
                'heads_dim': 64,
                'mlp_hidden_dims': (256,),
                'use_layer_norm': True,
            },
        },
    })
    return cfg

# Initialize the Ising model and configuration
ising_model = Ising1D(L=10, J=1.0, h=0.0)
ising_config = default_ising_config()

# Example function to calculate the energy for a random state
def sample_ising_energy(model):
    """Calculate energy for the current state of the Ising model."""
    energy = model.get_energy()
    return energy

# Calculate energy for the current state
energy_ising = sample_ising_energy(ising_model)
print(f"Calculated energy for Ising model: {energy_ising}")

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from collections import namedtuple

Site = namedtuple('Site', ['id', 'coords', 'sublattice', 'unit_cell_idx'])
Bond = namedtuple('Bond', ['site1_id', 'site2_id', 'type', 'J_coupling', 'vector_coords'])

def generate_honeycomb_lattice(Lx, Ly, Jx, Jy, Jz, bond_length=1.0, periodic=False):
    print(f"\n--- Generating Honeycomb Lattice ({Lx}x{Ly} unit cells) ---")
    print(f"Bond length: {bond_length}, Periodic: {periodic}")
    print(f"Couplings: Jx={Jx}, Jy={Jy}, Jz={Jz}")

    sites_list = []
    site_map = {}
    site_id_counter = 0

    a1_uc = np.array([1.5 * bond_length,  0.5 * np.sqrt(3) * bond_length])
    a2_uc = np.array([1.5 * bond_length, -0.5 * np.sqrt(3) * bond_length])
    b_offset = np.array([bond_length, 0.0])

    for n1 in range(Lx):
        for n2 in range(Ly):
            unit_cell_idx = (n1, n2)
            pos_A = n1 * a1_uc + n2 * a2_uc
            site_A_key = (n1, n2, 'A')
            sites_list.append(Site(id=site_id_counter, coords=pos_A, sublattice='A', unit_cell_idx=unit_cell_idx))
            site_map[site_A_key] = site_id_counter
            site_id_counter += 1

            pos_B = pos_A + b_offset
            site_B_key = (n1, n2, 'B')
            sites_list.append(Site(id=site_id_counter, coords=pos_B, sublattice='B', unit_cell_idx=unit_cell_idx))
            site_map[site_B_key] = site_id_counter
            site_id_counter += 1
            
    print(f"Generated {len(sites_list)} sites.")

    bonds_list = []
    bond_vectors = {
        'x': b_offset,
        'y': b_offset - a2_uc,
        'z': b_offset - a1_uc
    }

    bond_type_map = {
        (0,0): ('x', Jx),
        (0,-1): ('y', Jy),
        (-1,0): ('z', Jz)
    }

    for n1_A in range(Lx):
        for n2_A in range(Ly):
            site_A_id = site_map[(n1_A, n2_A, 'A')]
            
            for (dn1, dn2), (bond_type, J_val) in bond_type_map.items():
                n1_B_target, n2_B_target = n1_A + dn1, n2_A + dn2
                site_B_key_target = (n1_B_target, n2_B_target, 'B')

                if periodic:
                    n1_B_target_pbc = n1_B_target % Lx
                    n2_B_target_pbc = n2_B_target % Ly
                    site_B_key_target_pbc = (n1_B_target_pbc, n2_B_target_pbc, 'B')
                    if site_B_key_target_pbc in site_map:
                        site_B_id = site_map[site_B_key_target_pbc]
                        vec_coords_start = sites_list[site_A_id].coords
                        pos_B_abs_target = (n1_A + dn1) * a1_uc + (n2_A + dn2) * a2_uc + b_offset
                        vec_coords_end = pos_B_abs_target
                        bonds_list.append(Bond(site_A_id, site_B_id, bond_type, J_val, (vec_coords_start, vec_coords_end)))
                else:
                    if 0 <= n1_B_target < Lx and 0 <= n2_B_target < Ly:
                        site_B_id = site_map[site_B_key_target]
                        vec_coords_start = sites_list[site_A_id].coords
                        vec_coords_end = sites_list[site_B_id].coords
                        bonds_list.append(Bond(site_A_id, site_B_id, bond_type, J_val, (vec_coords_start, vec_coords_end)))

    print(f"Generated {len(bonds_list)} bonds.")
    return sites_list, bonds_list

def plot_lattice(sites_list, bonds_list, title="Honeycomb Lattice"):
    if not sites_list:
        print("No sites to plot.")
        return

    fig, ax = plt.subplots(figsize=(Lx*1.5+2, Ly*1.5+2) if Lx > 0 and Ly > 0 else (6,6))
    
    coords_A = np.array([s.coords for s in sites_list if s.sublattice == 'A'])
    coords_B = np.array([s.coords for s in sites_list if s.sublattice == 'B'])
    
    if coords_A.size > 0:
        ax.scatter(coords_A[:, 0], coords_A[:, 1], c='blue', label='Sublattice A', s=50, zorder=2)
        for s in sites_list:
             if s.sublattice == 'A': ax.text(s.coords[0]+0.1, s.coords[1]+0.1, str(s.id), fontsize=8)
    if coords_B.size > 0:
        ax.scatter(coords_B[:, 0], coords_B[:, 1], c='red', label='Sublattice B', s=50, zorder=2)
        for s in sites_list:
             if s.sublattice == 'B': ax.text(s.coords[0]+0.1, s.coords[1]+0.1, str(s.id), fontsize=8)

    bond_colors = {'x': 'magenta', 'y': 'green', 'z': 'orange'}
    line_segments = []
    colors = []
    drawn_bonds_for_legend = set()

    for bond in bonds_list:
        start_coords, end_coords = bond.vector_coords
        line_segments.append([start_coords, end_coords])
        colors.append(bond_colors[bond.type])
        
        if bond.type not in drawn_bonds_for_legend:
            ax.plot([], [], color=bond_colors[bond.type], label=f'{bond.type}-bond (J{bond.type})')
            drawn_bonds_for_legend.add(bond.type)

    if line_segments:
        lc = LineCollection(line_segments, colors=colors, linewidths=1.5, zorder=1)
        ax.add_collection(lc)

    ax.set_aspect('equal')
    ax.set_title(title)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.legend(loc='best')
    plt.grid(True, linestyle=':', alpha=0.5)
    plt.tight_layout()
    plt.show()

def print_spin_hamiltonian_symbolic(bonds_list):
    print("\n--- Original Spin Hamiltonian (H_spin) ---")
    print("H_spin = sum over bonds <i,j>alpha of: -J_alpha * sigma_i^alpha * sigma_j^alpha")
    print("Where alpha is the bond type (x, y, or z) and direction.")
    
    if not bonds_list:
        print("No bonds to define Hamiltonian terms.")
        return

    ham_terms = []
    for bond in bonds_list:
        i, j = bond.site1_id, bond.site2_id
        site_i, site_j = i,j
        
        term = f"-{bond.J_coupling} * sigma_{site_i}^{bond.type} * sigma_{site_j}^{bond.type}"
        ham_terms.append(term)
    
    print("Symbolic terms (first few if many):")
    for i, term_str in enumerate(ham_terms[:min(5, len(ham_terms))]):
        print(f"  Term {i+1}: {term_str}")
    if len(ham_terms) > 5:
        print(f"  ... and {len(ham_terms)-5} more terms.")
    print("-" * 40)

def apply_kitaev_transformation(bonds_list, num_sites):
    

    transformed_ham_terms = []
    if not bonds_list:
        print("No bonds to transform.")
        return transformed_ham_terms

    print("\nSymbolic Transformed Hamiltonian Terms (first few if many):")
    for k, bond in enumerate(bonds_list):
        i, j = bond.site1_id, bond.site2_id
        site_i, site_j = i, j
        
        u_op_str = f"u_{site_i}_{site_j}^{bond.type}"
        c_ops_str = f"c_{site_i} * c_{site_j}"
        
        term_coeff_val = bond.J_coupling
        term_str = f"(i * {term_coeff_val} * {u_op_str}) * ({c_ops_str})"
        
        transformed_ham_terms.append({
            "coeff_numeric": 1j * term_coeff_val,
            "u_operator": u_op_str,
            "c_operators": (f"c_{site_i}", f"c_{site_j}"),
            "site_indices": (site_i, site_j),
            "bond_type": bond.type,
            "J_coupling": bond.J_coupling,
            "full_term_str": term_str
        })
        if k < 5:
            print(f"  Term {k+1}: {term_str}")
    
    if len(bonds_list) > 5:
        print(f"  ... and {len(bonds_list)-5} more terms.")
    print("-" * 40)
    return transformed_ham_terms

if __name__ == "__main__":
    print("="*50)
    print("Elaborate Kitaev Transformation Demonstration")
    print("="*50)

    Lx, Ly = 2, 1
    Jx, Jy, Jz = 1.0, 1.0, 1.0
    bond_len = 1.0
    use_pbc = False

    sites, bonds = generate_honeycomb_lattice(Lx, Ly, Jx, Jy, Jz, bond_length=bond_len, periodic=use_pbc)
    num_total_sites = len(sites)

    if num_total_sites > 0 :
        plot_lattice(sites, bonds, title=f"Honeycomb Lattice ({Lx}x{Ly} cells, PBC: {use_pbc})")
    else:
        print("Lattice generation resulted in 0 sites. Check parameters (Lx, Ly).")

    print_spin_hamiltonian_symbolic(bonds)
    transformed_H = apply_kitaev_transformation(bonds, num_total_sites)


    if transformed_H and num_total_sites > 0:
        print("\n--- Example: Building the Majorana Matrix M for a specific flux sector ---")
        print("Assuming all u_ij^alpha = +1.")
        
        M = np.zeros((num_total_sites, num_total_sites), dtype=float)
        
        u_values_fixed = {term['u_operator']: 1.0 for term in transformed_H} 
        print(f"Assuming all {len(u_values_fixed)} u_ij^alpha = +1.")

        for term_info in transformed_H:
            idx_i, idx_j = term_info['site_indices']
            J_val = term_info['J_coupling']
            u_op_key = term_info['u_operator']
            u_val = u_values_fixed.get(u_op_key, 1.0)
            
            coeff_K = J_val * u_val
            
            M[idx_i, idx_j] += 2 * coeff_K
            M[idx_j, idx_i] -= 2 * coeff_K

        print(f"\nMajorana matrix M (shape {M.shape}):")
        if num_total_sites <= 6:
             np.set_printoptions(precision=2, suppress=True)
             print(M)
        else:
             print(f"(Matrix is {M.shape}, too large to print fully)")
             print("Slice of M (e.g., M[:4,:4]):")
             print(M[:min(4,num_total_sites), :min(4,num_total_sites)])
        
        is_antisymmetric = np.allclose(M, -M.T)
        print(f"Is M antisymmetric? {is_antisymmetric}")
        if not is_antisymmetric and num_total_sites > 0:
            print("M should be antisymmetric! Debug needed if False.")

        if num_total_sites > 0 and is_antisymmetric:
            if num_total_sites % 2 == 0 :
                eigvals_iM = np.linalg.eigvalsh(1j * M)
                positive_energies = np.sort(eigvals_iM[eigvals_iM > 1e-9])
                print(f"\nSingle Majorana fermion mode energies (positive values of eig(iM)):")
                if positive_energies.size > 0:
                    print(positive_energies)
                else:
                    print("No positive energies found (e.g. all zero, or N=0).")
                gs_energy = -0.5 * np.sum(positive_energies)
                print(f"Ground state energy (for this flux sector): {gs_energy:.4f}")
            else:
                print("Number of sites is odd. Pfaffian/GS energy calculation typically for even N.")
                eigvals_iM = np.linalg.eigvalsh(1j * M)
                print(f"Eigenvalues of iM: {np.sort(eigvals_iM)}")

import numpy as np
from scipy.sparse import kron, identity, lil_matrix
from scipy.sparse.linalg import eigsh
import time

s_plus = lil_matrix(np.array([[0, 1], [0, 0]], dtype=float))
s_minus = lil_matrix(np.array([[0, 0], [1, 0]], dtype=float))
s_x = 0.5 * (s_plus + s_minus)
s_y = -0.5j * (s_plus - s_minus) 
s_z = lil_matrix(np.array([[0.5, 0], [0, -0.5]], dtype=float))
ident = identity(2, dtype=float, format='lil')

L_total = 10
m_states = 20 
num_sweeps = 5
J = 1.0

def build_superblock_hamiltonian(H_L, Sz_L, Sp_L, Sm_L,
                                H_R, Sz_R, Sp_R, Sm_R):
    dim_L = H_L.shape[0]
    dim_R = H_R.shape[0]
    dim_site = 2 

    id_L = identity(dim_L, format='lil')
    id_R = identity(dim_R, format='lil')
    id_s = ident

    H_superblock = kron(H_L, kron(id_s, kron(id_s, id_R)))
    H_superblock += kron(id_L, kron(id_s, kron(id_s, H_R)))

    term_L_s1 = J * kron(Sz_L, kron(s_z, kron(id_s, id_R)))
    term_L_s1 += J * 0.5 * kron(Sp_L, kron(s_minus, kron(id_s, id_R)))
    term_L_s1 += J * 0.5 * kron(Sm_L, kron(s_plus, kron(id_s, id_R)))
    H_superblock += term_L_s1

    H_int_s1_s2 = J * (0.5 * (kron(s_plus, s_minus) + kron(s_minus, s_plus)) + kron(s_z, s_z))
    term_s1_s2 = kron(id_L, kron(H_int_s1_s2, id_R))
    H_superblock += term_s1_s2
    
    term_s2_R = J * kron(id_L, kron(id_s, kron(s_z, Sz_R))) 
    term_s2_R += J * 0.5 * kron(id_L, kron(id_s, kron(s_plus, Sm_R)))
    term_s2_R += J * 0.5 * kron(id_L, kron(id_s, kron(s_minus, Sp_R)))
    H_superblock += term_s2_R

    return H_superblock.tocsr()

def dmrg_step_left_to_right(H_L, Sz_L, Sp_L, Sm_L, H_R, Sz_R, Sp_R, Sm_R, m_keep):
    dim_L_old = H_L.shape[0]
    dim_R = H_R.shape[0]
    dim_site = 2

    H_superblock = build_superblock_hamiltonian(H_L, Sz_L, Sp_L, Sm_L, H_R, Sz_R, Sp_R, Sm_R)
    
    gs_energy = 0.0
    psi_gs = None

    if H_superblock.shape[0] > 0:
        num_eigvals_to_find = 0
        if H_superblock.shape[0] > 1:
            num_eigvals_to_find = min(1, H_superblock.shape[0]-1)

        if H_superblock.shape[0] <= 6 or num_eigvals_to_find == 0: 
            eigvals, eigvecs = np.linalg.eigh(H_superblock.toarray())
            gs_energy = eigvals[0]
            psi_gs = eigvecs[:, 0]
        else:
            try:
                eigvals, eigvecs = eigsh(H_superblock, k=num_eigvals_to_find, which='SA', tol=1e-7, maxiter=H_superblock.shape[0]*3)
                gs_energy = eigvals[0]
                psi_gs = eigvecs[:, 0]
            except Exception: # Fallback if eigsh fails
                eigvals, eigvecs = np.linalg.eigh(H_superblock.toarray())
                gs_energy = eigvals[0]
                psi_gs = eigvecs[:, 0]
    else: # Should not happen with proper block dimensions
        return H_L, Sz_L, Sp_L, Sm_L, 0.0, 1.0


    psi_reshaped = psi_gs.reshape((dim_L_old * dim_site, dim_site * dim_R), order='C') 
    rho_L_s1 = np.dot(psi_reshaped, psi_reshaped.T.conjugate())

    eigvals_rho, eigvecs_rho = np.linalg.eigh(rho_L_s1)
    sorted_indices = np.argsort(eigvals_rho)[::-1]
    
    m_eff = min(m_keep, len(sorted_indices))
    truncation_matrix_T = eigvecs_rho[:, sorted_indices[:m_eff]]
    truncation_error = 1.0 - np.sum(eigvals_rho[sorted_indices[:m_eff]])

    id_L_old = identity(dim_L_old, format='lil')
    H_L_s1_untruncated = kron(H_L, ident) + kron(id_L_old, lil_matrix(np.zeros((2,2)))) 
    H_int_L_s1 = J * (kron(Sz_L, s_z) + \
                      0.5 * kron(Sp_L, s_minus) + \
                      0.5 * kron(Sm_L, s_plus))
    H_L_s1_untruncated += H_int_L_s1
                      
    Sz_L_s1_untruncated_edge = kron(id_L_old, s_z)
    Sp_L_s1_untruncated_edge = kron(id_L_old, s_plus)
    Sm_L_s1_untruncated_edge = kron(id_L_old, s_minus)

    H_L_new = lil_matrix(truncation_matrix_T.T.conjugate() @ H_L_s1_untruncated.toarray() @ truncation_matrix_T)
    Sz_L_new = lil_matrix(truncation_matrix_T.T.conjugate() @ Sz_L_s1_untruncated_edge.toarray() @ truncation_matrix_T)
    Sp_L_new = lil_matrix(truncation_matrix_T.T.conjugate() @ Sp_L_s1_untruncated_edge.toarray() @ truncation_matrix_T)
    Sm_L_new = lil_matrix(truncation_matrix_T.T.conjugate() @ Sm_L_s1_untruncated_edge.toarray() @ truncation_matrix_T)
    
    return H_L_new, Sz_L_new, Sp_L_new, Sm_L_new, gs_energy, truncation_error

print(f"Starting DMRG for L={L_total} sites, m={m_states} states, {num_sweeps} sweeps.")

H_site_0 = lil_matrix((2,2), dtype=float)

# Initialize all block operators with single-site versions (crude but makes it runnable)
H_L_blocks = [H_site_0.copy() for _ in range(L_total)]
Sz_L_blocks = [s_z.copy() for _ in range(L_total)]
Sp_L_blocks = [s_plus.copy() for _ in range(L_total)]
Sm_L_blocks = [s_minus.copy() for _ in range(L_total)]

H_R_blocks = [H_site_0.copy() for _ in range(L_total)]
Sz_R_blocks = [s_z.copy() for _ in range(L_total)]
Sp_R_blocks = [s_plus.copy() for _ in range(L_total)]
Sm_R_blocks = [s_minus.copy() for _ in range(L_total)]

gs_energy_dmrg = 0.0

for sweep in range(num_sweeps):
    start_time = time.time()
    print(f"\n--- Sweep {sweep + 1}/{num_sweeps} ---")

    for i in range(L_total - 3): 
        num_sites_L = i + 1
        num_sites_R = L_total - num_sites_L - 2
        
        H_L_curr  = H_L_blocks[num_sites_L - 1]
        Sz_L_curr = Sz_L_blocks[num_sites_L - 1]
        Sp_L_curr = Sp_L_blocks[num_sites_L - 1]
        Sm_L_curr = Sm_L_blocks[num_sites_L - 1]

        H_R_curr  = H_R_blocks[num_sites_R - 1]
        Sz_R_curr = Sz_R_blocks[num_sites_R - 1]
        Sp_R_curr = Sp_R_blocks[num_sites_R - 1]
        Sm_R_curr = Sm_R_blocks[num_sites_R - 1]
        
        H_L_new, Sz_L_new, Sp_L_new, Sm_L_new, gs_energy_dmrg, trunc_err = \
            dmrg_step_left_to_right(H_L_curr, Sz_L_curr, Sp_L_curr, Sm_L_curr,
                                   H_R_curr, Sz_R_curr, Sp_R_curr, Sm_R_curr, m_states)
        
        H_L_blocks[num_sites_L] = H_L_new
        Sz_L_blocks[num_sites_L] = Sz_L_new
        Sp_L_blocks[num_sites_L] = Sp_L_new
        Sm_L_blocks[num_sites_L] = Sm_L_new
        
        print(f"    L->R, L-size: {num_sites_L+1:2d}, R-size: {num_sites_R:2d}, E: {gs_energy_dmrg:+.6f}, TrErr: {trunc_err:.2e}")

    for i in range(L_total - 3): 
        num_sites_R = i + 1
        num_sites_L = L_total - num_sites_R - 2

        H_R_curr  = H_R_blocks[num_sites_R - 1]
        Sz_R_curr = Sz_R_blocks[num_sites_R - 1]
        Sp_R_curr = Sp_R_blocks[num_sites_R - 1]
        Sm_R_curr = Sm_R_blocks[num_sites_R - 1]

        H_L_curr  = H_L_blocks[num_sites_L - 1]
        Sz_L_curr = Sz_L_blocks[num_sites_L - 1]
        Sp_L_curr = Sp_L_blocks[num_sites_L - 1]
        Sm_L_curr = Sm_L_blocks[num_sites_L - 1]
        
        H_R_new, Sz_R_new, Sp_R_new, Sm_R_new, gs_energy_dmrg, trunc_err = \
            dmrg_step_left_to_right(H_R_curr, Sz_R_curr, Sp_R_curr, Sm_R_curr,
                                   H_L_curr, Sz_L_curr, Sp_L_curr, Sm_L_curr,
                                   m_states)
        
        H_R_blocks[num_sites_R] = H_R_new
        Sz_R_blocks[num_sites_R] = Sz_R_new
        Sp_R_blocks[num_sites_R] = Sp_R_new
        Sm_R_blocks[num_sites_R] = Sm_R_new

        print(f"    R->L, L-size: {num_sites_L:2d}, R-size: {num_sites_R+1:2d}, E: {gs_energy_dmrg:+.6f}, TrErr: {trunc_err:.2e}")
    
    end_time = time.time()
    print(f"Sweep {sweep+1} took {end_time - start_time:.2f} seconds. Final Energy: {gs_energy_dmrg:.8f}")

print(f"\nDMRG finished after {num_sweeps} sweeps.")
print(f"Final Ground State Energy: {gs_energy_dmrg:.8f}")
print(f"Ground State Energy per site: {gs_energy_dmrg/L_total:.8f}")