In [None]:
using LinearAlgebra
using SparseArrays
using JLD2
using Printf

"""
    sakurai_sugiura_diagonalize(H, energy_min, energy_max;
                                N_quad=32, solver=:gmres, tol=1e-10,
                                max_subspace_dim=500, verbose=true)

Diagonalize a sparse Hamiltonian using the Sakurai-Sugiura method to extract
eigenvalues and eigenvectors within a specified energy window.

The Sakurai-Sugiura method uses contour integration in the complex plane to
project out eigenstates within a given energy range, avoiding computation of
unwanted states outside the window.

# Arguments
- `H`: Sparse Hamiltonian matrix (should be Hermitian)
- `energy_min`: Lower bound of energy window
- `energy_max`: Upper bound of energy window
- `N_quad`: Number of quadrature points on the contour (default: 32)
- `solver`: Linear solver to use (:gmres, :bicgstabl, :direct) (default: :gmres)
- `tol`: Convergence tolerance for linear solver (default: 1e-10)
- `max_subspace_dim`: Maximum dimension of subspace (default: 500)
- `verbose`: Print progress information (default: true)

# Returns
- `eigenvalues`: Eigenvalues within the specified window
- `eigenvectors`: Corresponding eigenvectors (as columns)
- `info`: Dictionary containing diagnostic information
"""
function sakurai_sugiura_diagonalize(H, energy_min, energy_max;
                                     N_quad=32, solver=:gmres, tol=1e-10,
                                     max_subspace_dim=500, verbose=true)
    
    n = size(H, 1)
    
    if verbose
        println("="^70)
        println("SAKURAI-SUGIURA METHOD FOR EXACT DIAGONALIZATION")
        println("="^70)
        println("Matrix size: $n × $n")
        println("Number of nonzeros: $(nnz(H))")
        println("Sparsity: $(100 * (1 - nnz(H) / n^2))%")
        println("Energy window: [$energy_min, $energy_max]")
        println("Quadrature points: $N_quad")
        println()
    end
    
    # Define contour parameters (ellipse in complex plane)
    gamma_center = (energy_max + energy_min) / 2  # Center of ellipse
    gamma_radius_a = (energy_max - energy_min) / 2  # Semi-major axis (real)
    gamma_radius_b = gamma_radius_a * 0.5  # Semi-minor axis (imaginary)
    
    if verbose
        println("Contour parameters:")
        println("  Center: $gamma_center")
        println("  Semi-major axis (real): $gamma_radius_a")
        println("  Semi-minor axis (imag): $gamma_radius_b")
        println()
    end
    
    # Generate random source vectors
    L = min(max_subspace_dim, n)
    V = randn(ComplexF64, n, L)
    
    # Orthonormalize source vectors
    V, _ = qr(V)
    V = Matrix(V)
    
    if verbose
        println("Generating subspace with $L random vectors...")
        println()
    end
    
    # Compute moment matrices using contour integration
    S = zeros(ComplexF64, n, L)  # Zeroth moment
    
    if verbose
        println("Computing moment matrix via contour integration...")
        println("Progress:")
    end
    
    # Trapezoidal rule quadrature on the contour
    for k = 1:N_quad
        # Quadrature point on elliptical contour
        theta = 2π * (k - 1) / N_quad
        z = gamma_center + gamma_radius_a * cos(theta) + 
            1im * gamma_radius_b * sin(theta)
        
        # Derivative of contour parametrization
        dz = (-gamma_radius_a * sin(theta) + 
              1im * gamma_radius_b * cos(theta)) * 2π / N_quad
        
        # Solve (z*I - H) * X = V for each quadrature point
        A = z * I - H
        
        if solver == :direct
            # Direct solve (expensive but accurate)
            X = A \ V
        elseif solver == :gmres
            # Iterative GMRES solver
            X = zeros(ComplexF64, n, L)
            for col = 1:L
                X[:, col], info_gmres = gmres(A, V[:, col], 
                                              tol=tol, 
                                              restart=min(50, n),
                                              maxiter=100,
                                              verbose=0)
            end
        elseif solver == :bicgstabl
            # BiCGStab(l) solver
            X = zeros(ComplexF64, n, L)
            for col = 1:L
                X[:, col], info_bicg = bicgstabl(A, V[:, col],
                                                 2,  # l parameter
                                                 tol=tol,
                                                 max_mv_products=1000)
            end
        else
            error("Unknown solver: $solver")
        end
        
        # Accumulate moment integral
        S .+= X .* dz
        
        if verbose && (k % max(1, N_quad ÷ 10) == 0)
            @printf("  %3d/%3d quadrature points completed (%.1f%%)\n", 
                    k, N_quad, 100*k/N_quad)
        end
    end
    
    # Normalize by contour integral factor
    S .*= (1 / (2π * 1im))
    
    if verbose
        println()
        println("Contour integration complete!")
        println()
    end
    
    # Build projected Hamiltonian
    if verbose
        println("Building projected Hamiltonian in subspace...")
    end
    
    # Rayleigh-Ritz procedure
    # First compute S† * H * S
    HS = H * S
    H_proj = S' * HS
    
    # Make Hermitian (remove numerical errors)
    H_proj = (H_proj + H_proj') / 2
    
    if verbose
        println("Subspace dimension: $(size(H_proj, 1))")
        println()
    end
    
    # Diagonalize projected Hamiltonian
    if verbose
        println("Diagonalizing projected Hamiltonian...")
    end
    
    @time vals_proj, vecs_proj = eigen(Hermitian(H_proj))
    
    # Filter eigenvalues within the window
    mask = (real.(vals_proj) .>= energy_min - 1e-10) .& 
           (real.(vals_proj) .<= energy_max + 1e-10)
    
    eigenvalues = vals_proj[mask]
    vecs_proj_filtered = vecs_proj[:, mask]
    
    if verbose
        println("Found $(length(eigenvalues)) eigenvalues in window")
        println()
    end
    
    # Transform eigenvectors back to original space
    if verbose
        println("Transforming eigenvectors to original space...")
    end
    
    eigenvectors = S * vecs_proj_filtered
    
    # Normalize eigenvectors
    for i = 1:size(eigenvectors, 2)
        eigenvectors[:, i] ./= norm(eigenvectors[:, i])
    end
    
    # Verify eigenvalue equation
    if verbose
        println("Verifying eigenvalue equation...")
        max_residual = 0.0
        for i = 1:min(10, length(eigenvalues))
            residual = norm(H * eigenvectors[:, i] - 
                          eigenvalues[i] * eigenvectors[:, i])
            max_residual = max(max_residual, residual)
        end
        println("Maximum residual (first 10 states): $max_residual")
        println()
    end
    
    # Package diagnostic information
    info = Dict(
        "contour_center" => gamma_center,
        "contour_radius_real" => gamma_radius_a,
        "contour_radius_imag" => gamma_radius_b,
        "n_quadrature_points" => N_quad,
        "subspace_dimension" => size(S, 2),
        "n_states_found" => length(eigenvalues)
    )
    
    return eigenvalues, eigenvectors, info
end


"""
    gmres(A, b; tol=1e-6, restart=20, maxiter=100, verbose=0)

Simple GMRES implementation for solving A*x = b.

# Arguments
- `A`: Matrix or linear operator
- `b`: Right-hand side vector
- `tol`: Convergence tolerance
- `restart`: Number of iterations before restart
- `maxiter`: Maximum number of restarts
- `verbose`: Verbosity level (0, 1, 2)

# Returns
- `x`: Solution vector
- `info`: 0 if converged, 1 if not converged
"""
function gmres(A, b; tol=1e-6, restart=20, maxiter=100, verbose=0)
    n = length(b)
    x = zeros(eltype(b), n)
    
    for iter = 1:maxiter
        # Arnoldi process
        r = b - A * x
        beta = norm(r)
        
        if beta < tol
            return x, 0  # Converged
        end
        
        V = zeros(eltype(b), n, restart + 1)
        H = zeros(eltype(b), restart + 1, restart)
        
        V[:, 1] = r / beta
        
        for j = 1:restart
            w = A * V[:, j]
            
            # Gram-Schmidt orthogonalization
            for i = 1:j
                H[i, j] = dot(V[:, i], w)
                w -= H[i, j] * V[:, i]
            end
            
            H[j+1, j] = norm(w)
            
            if H[j+1, j] < 1e-14
                # Happy breakdown
                y = H[1:j, 1:j] \ (beta * [1; zeros(j-1)])
                x += V[:, 1:j] * y
                return x, 0
            end
            
            V[:, j+1] = w / H[j+1, j]
        end
        
        # Solve least squares problem
        e1 = zeros(restart + 1)
        e1[1] = beta
        y = H \ e1
        
        # Update solution
        x += V[:, 1:restart] * y
    end
    
    return x, 1  # Not converged
end


"""
    bicgstabl(A, b, l; tol=1e-6, max_mv_products=1000)

BiCGStab(l) iterative solver for A*x = b.

# Arguments
- `A`: Matrix or linear operator
- `b`: Right-hand side vector
- `l`: BiCGStab(l) parameter (typically 2)
- `tol`: Convergence tolerance
- `max_mv_products`: Maximum number of matrix-vector products

# Returns
- `x`: Solution vector
- `info`: 0 if converged, 1 if not converged
"""
function bicgstabl(A, b, l; tol=1e-6, max_mv_products=1000)
    n = length(b)
    x = zeros(eltype(b), n)
    r = b - A * x
    
    r_hat = copy(r)
    rho_0 = 1.0
    alpha = 0.0
    omega = 1.0
    
    u = zeros(eltype(b), n, l+1)
    r_vec = zeros(eltype(b), n, l+1)
    
    mv_count = 0
    
    while mv_count < max_mv_products
        rho_0 = -omega * rho_0
        
        # Bi-CG part
        for j = 0:l-1
            rho_1 = dot(r_hat, r_vec[:, j+1])
            beta = alpha * rho_1 / rho_0
            rho_0 = rho_1
            
            for i = 0:j
                u[:, i+1] = r_vec[:, i+1] - beta * u[:, i+1]
            end
            
            u[:, j+2] = A * u[:, j+1]
            mv_count += 1
            
            alpha = rho_0 / dot(r_hat, u[:, j+2])
            
            for i = 0:j
                r_vec[:, i+1] -= alpha * u[:, i+2]
            end
            
            r_vec[:, j+2] = A * r_vec[:, j+1]
            mv_count += 1
            
            x += alpha * u[:, 1]
        end
        
        # MR part (simplified)
        omega = dot(r_vec[:, l+1], r_vec[:, l+1]) / 
                dot(A * r_vec[:, l+1], r_vec[:, l+1])
        mv_count += 1
        
        x += omega * r_vec[:, 1]
        r_vec[:, 1] = r_vec[:, 1] - omega * r_vec[:, l+1]
        
        if norm(r_vec[:, 1]) < tol
            return x, 0
        end
        
        u[:, 1] = u[:, l+1]
        r_vec[:, 1] = copy(r_vec[:, l+1])
    end
    
    return x, 1
end

