In [None]:
"""
Module Name: Heisenberg.ipynb
            

Author: Kuan H Hsu
Date: 2024-12-01
Description:
    - Calculates effective Raman scattering operator on different lattice geometry, with Heisenberg model
"""

# !/usr/bin/env python

# coding:utf-8
from __future__ import print_function
import math
import numpy as np
#import scipy.linalg
import scipy.sparse
import scipy.sparse.linalg
import argparse
import time
# import matplotlib and scipy toolkits
import matplotlib as mpl
import matplotlib.pyplot as plt
import colorsys
import re

def parse_args():
    parser = argparse.ArgumentParser(description='Calculate the ground state of S=1/2 Heisenberg square for a given Sz')
    parser.add_argument('-Lx', metavar='Lx',dest='Lx', type=int, default=4, help='set Lx (should be >=2)')
    parser.add_argument('-Ly', metavar='Ly',dest='Ly', type=int, default=4, help='set Ly (should be >=2)')
    parser.add_argument('-J1', metavar='J1',dest='J1', type=float, default=1.0, help='set J1')
    parser.add_argument('-J2', metavar='J2',dest='J2', type=float, default=0.0, help='set J2')
    parser.add_argument('-Sz', metavar='Sz',dest='Sz', type=int, default=0, help='set Sz')
    return parser.parse_args()

def snoob(x):
    next = 0
    if(x>0):
        smallest = x & -(x)
        ripple = x + smallest
        ones = x ^ ripple
        ones = (ones >> 2) // smallest
        next = ripple | ones
    return next

def binomial(n,r):
    return math.factorial(n) // (math.factorial(n - r) * math.factorial(r))

def count_bit(n):
    count = 0
    while (n): 
        count += n & 1
        n >>= 1
    return count 

def init_parameters(N,Sz):
    Nup = N//2 + Sz
    Nhilbert = binomial(N,Nup)
    ihfbit = 1 << (N//2)
    irght = ihfbit-1
    ilft = ((1<<N)-1) ^ irght
    iup = (1<<(N-Nup))-1
    return Nup, Nhilbert, ihfbit, irght, ilft, iup

def make_list(N,Nup,Nhilbert,ihfbit,irght,ilft,iup):
    list_1 = np.zeros(Nhilbert,dtype=int)
    list_ja = np.zeros(ihfbit,dtype=int)
    list_jb = np.zeros(ihfbit,dtype=int)
    ii = iup
    ja = 0
    jb = 0
    ia_old = ii & irght
    ib_old = (ii & ilft) // ihfbit
    list_1[0] = ii
    list_ja[ia_old] = ja
    list_jb[ib_old] = jb
    ii = snoob(ii)
    for i in range(1,Nhilbert):
        ia = ii & irght
        ib = (ii & ilft) // ihfbit
        if (ib == ib_old):
            ja += 1
        else:
            jb += ja+1
            ja = 0
        list_1[i] = ii
        list_ja[ia] = ja
        list_jb[ib] = jb
        ia_old = ia
        ib_old = ib
        ii = snoob(ii)
    return list_1, list_ja, list_jb

def get_ja_plus_jb(ii,irght,ilft,ihfbit,list_ja,list_jb):
    ia = ii & irght
    ib = (ii & ilft) // ihfbit
    ja = list_ja[ia]
    jb = list_jb[ib]
    return ja+jb

def make_hamiltonian(Jxx,Jzz,list_isite1,list_isite2,N,Nint,Nhilbert,irght,ilft,ihfbit,list_1,list_ja,list_jb):
    listki = np.zeros((Nint+1)*Nhilbert,dtype=int)
    loc = np.zeros((Nint+1)*Nhilbert,dtype=int)
    elemnt = np.zeros((Nint+1)*Nhilbert,dtype=float)
    listki = [i for k in range(Nint+1) for i in range(Nhilbert)]
    for k in range(Nint): # loop for all interactions
        isite1 = list_isite1[k]
        isite2 = list_isite2[k]
        is1 = 1<<isite1
        is2 = 1<<isite2
        is12 = is1 + is2
        wght = 2.0*Jxx[k]
        diag = Jzz[k]
## calculate elements of
## H_loc = Jzz sgmz.sgmz + Jxx (sgmx.sgmx + sgmy.sgmy)
##       = Jzz sgmz.sgmz + 2*Jxx (S+.S- + S-.S+)
        for i in range(Nhilbert): # loop for all spin configurations with fixed Sz
            ii = list_1[i]
            ibit = ii & is12 # find sgmz.sgmz|uu> = |uu> or sgmz.sgmz|dd> = |dd>
            loc[Nint*Nhilbert+i] = i # store diag index
            if (ibit==0 or ibit==is12): # if (spin1,spin2) = (00) or (11): sgmz.sgmz only
                elemnt[Nint*Nhilbert+i] += diag # store +Jzz
#                print("# diag k(interactions) i(Hilbert)",k,i)
#                print("# diag ii  ",np.binary_repr(ii,width=N))
#                print("# diag is12",np.binary_repr(is12,width=N))
#                print("# diag ibit",np.binary_repr(ibit,width=N))
            else: # if (spin1,spin2) = (01) or (10): sgmz.sgmz and (S+.S- or S-.S+)
                elemnt[Nint*Nhilbert+i] -= diag # store -Jzz
                iexchg = ii ^ is12 # find S+.S-|du> = |ud> or S-.S+|ud> = |du>
                newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                elemnt[k*Nhilbert+i] = wght # store 2*Jxx
                loc[k*Nhilbert+i] = newcfg # store offdiag index
#                print("# offdiag k(interactions) i(Hilbert)",k,i)
#                print("# offdiag ii  ",np.binary_repr(ii,width=N))
#                print("# offdiag is12",np.binary_repr(is12,width=N))
#                print("# offdiag iexc",np.binary_repr(iexchg,width=N))
    HamCSR = scipy.sparse.csr_matrix((elemnt,(listki,loc)),shape=(Nhilbert,Nhilbert))
    return HamCSR

def calc_zcorr(Nhilbert,Ncorr,list_corr_isite1,list_corr_isite2,psi,list_1,q=[],rspt=[]):
    szz = np.zeros(Ncorr,dtype=float)
    vec0 = np.zeros(Nhilbert,dtype=complex)
    for k in range(Ncorr): # loop for all bonds for correlations
        isite1 = list_corr_isite1[k]
        isite2 = list_corr_isite2[k]
        is1 = 1<<isite1
        is2 = 1<<isite2
        is12 = is1 + is2
        corr = 0.0
        for i in range(Nhilbert): # loop for all spin configurations with fixed Sz
            ii = list_1[i]
            ibit = ii & is12 # find sgmz.sgmz|uu> = |uu> or sgmz.sgmz|dd> = |dd>
            if (ibit==0 or ibit==is12): # if (spin1,spin2) = (00) or (11): factor = +1
                factor = +1.0
            else: # if (spin1,spin2) = (01) or (10): factor = -1
                factor = -1.0
            corr += factor*psi[i]**2 # psi[i]: real
            
            #### Make vector
            ibit = ii & is2
            if (ibit==is2): # if spin2 is spin up
                factor = +0.5
            else: # if spin2 is spin down
                factor = -0.5
            if (len(q) != 0):
                rk = rspt[k]
                vec0[i] += factor*psi[i]*np.exp(1j*np.dot(q,rk))
            else:
                vec0[i] += factor*psi[i]
                
        szz[k] = 0.25 * corr
#         print("i,j: ",list_corr_isite1[k],list_corr_isite2[k],0.25*corr)
        if (isite1==isite2):
            szz[k] = 0.25
    return szz, vec0

def calc_xcorr(Nhilbert,Ncorr,list_corr_isite1,list_corr_isite2,psi,irght,ilft,ihfbit,list_1,list_ja,list_jb):
    sxx = np.zeros(Ncorr,dtype=float)
    for k in range(Ncorr): # loop for all bonds for correlations
        isite1 = list_corr_isite1[k]
        isite2 = list_corr_isite2[k]
        is1 = 1<<isite1
        is2 = 1<<isite2
        is12 = is1 + is2
        corr = 0.0
        for i in range(Nhilbert): # loop for all spin configurations with fixed Sz
            ii = list_1[i]
            ibit = ii & is12 # find sgmz.sgmz|ud> = -|ud> or sgmz.sgmz|du> = -|du>
            if (ibit==is1 or ibit==is2): # if (spin1,spin2) = (10) or (01)
                iexchg = ii ^ is12 # find S+.S-|du> = |ud> or S-.S+|ud> = |du>
                newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                corr += psi[i]*psi[newcfg] # psi[i]: real
        sxx[k] = 0.25 * corr
        if (isite1==isite2):
            sxx[k] = 0.25
    return sxx

def make_lattice(Lx,Ly,J1,J2,shape="",lattice="sq"):
    # Square Lattice
    Jxx = []
    Jzz = []
    list_isite1 = []
    list_isite2 = []
    Nint = 0
    if (shape == ""): cell_size = Lx * Ly
    else: cell_size = int(re.findall(r'\d+',shape)[0])
#     for site1,i in enumerate(nbd_list):
    for site1 in range(cell_size):
        if (lattice == "sq"):
            firstNN = [[1,0],[0,1]]
            for NN in firstNN:
                site2,j = find_neighbor(NN,site1,shape,Lx,Ly,lattice)
                list_isite1.append(site1)
                list_isite2.append(site2)
                Jxx.append(J1)
                Jzz.append(J1)
                Nint += 1
            secondNN = [[1,1],[1,-1]]
            for NN in secondNN:
                site2,j = find_neighbor(NN,site1,shape,Lx,Ly,lattice)
                list_isite1.append(site1)
                list_isite2.append(site2)
                Jxx.append(J2)
                Jzz.append(J2)
                Nint += 1
        elif (lattice == "tr"):
            firstNN = [[1,0],[0,1],[1,1]]
            for NN in firstNN:
                site2,j = find_neighbor(NN,site1,shape,Lx,Ly,lattice)
                list_isite1.append(site1)
                list_isite2.append(site2)
                Jxx.append(J1)
                Jzz.append(J1)
                Nint += 1
            secondNN = [[2,1],[1,2],[1,-1]]
            for NN in secondNN:
                site2,j = find_neighbor(NN,site1,shape,Lx,Ly,lattice)
                list_isite1.append(site1)
                list_isite2.append(site2)
                Jxx.append(J2)
                Jzz.append(J2)
                Nint += 1
        elif (lattice == "kgm"):
            firstNN = [[1,0],[0,1],[1,1]]
            for NN in firstNN:
                site2,j = find_neighbor(NN,site1,shape,Lx,Ly,lattice)
#                 if (site1 == 0): print(site2,j)
                if (site2 != -1):
                    list_isite1.append(site1)
                    list_isite2.append(site2)
                    Jxx.append(J1)
                    Jzz.append(J1)
                    Nint += 1
        elif (lattice == "hny"):
            firstNN = [[1,2],[-2,-1],[1,-1],[-1,-2],[-1,1],[2,1]]
            for NN in firstNN:
                site2,j = find_neighbor(NN,site1,shape,Lx,Ly,lattice)
#                 if (site1 == 0): print(site2,j)
                if (site2 != -1):
                    list_isite1.append(site1)
                    list_isite2.append(site2)
                    Jxx.append(J1/2) # Divide by 2 to avoid double count bonds
                    Jzz.append(J1/2)
                    Nint += 1
    return Jxx, Jzz, list_isite1, list_isite2, Nint

def cluster_shape(shape,lattice="sq",real_space=False):
    if (lattice == "sq" or lattice == "tr"):        
        if (shape == "8A"):
            nbd_list = np.asarray([[-1,1],[-1,2],[0,0],[0,1],
                                   [ 0,2],[ 0,3],[1,1],[1,2]])
            ucx = np.asarray([2,2])
            ucy = np.asarray([-2,2])
        elif (shape == "10B"):
            nbd_list = np.asarray([[-2,1],[-2,2],[-2,3],[-1,1],[-1,2],
                                   [-1,3],[ 0,0],[ 0,1],[ 0,2],[ 0,3]])
            ucx = np.asarray([1,3])
            ucy = np.asarray([-3,1])
        elif (shape == "12D"):
            nbd_list = np.asarray([[0,0],[1,0],[2,0],[3,0],
                         [1,1],[2,1],[3,1],[4,1],
                         [1,2],[2,2],[3,2],[4,2]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([1,3])
        elif (shape == "12A"):
            nbd_list = np.asarray([[0,0],[1,1],[0,1],[1,2],
                         [-1,1],[0,2],[-2,1],[-1,2],
                         [-2,2],[-1,3],[-3,2],[-2,3]])
            ucx = np.asarray([2,2])
            ucy = np.asarray([-4,2])
        elif (shape == "12B"):
            nbd_list = np.asarray([[0,0],[0,1],[-1,2],[0,2],
                         [1,2],[-1,3],[0,3],[1,3],
                         [-1,4],[0,4],[1,4],[0,5]])
            ucx = np.asarray([2,3])
            ucy = np.asarray([-2,3])
        elif (shape == "12C"):
            nbd_list = np.asarray([[0,0],[1,1],[2,1],[1,2],
                                   [2,2],[3,2],[2,3],[3,3],
                                   [4,3],[3,4],[4,4],[5,5]])
            ucx = np.asarray([4,2])
            ucy = np.asarray([2,4])
        elif (shape == "12E"):
            nbd_list = np.asarray([[0,0],[1,0],[2,0],[3,0],
                         [0,1],[1,1],[2,1],[3,1],
                         [0,2],[1,2],[2,2],[3,2]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([0,3])
        elif (shape == "16A"):
            nbd_list = np.asarray([[0,0],[1,0],[2,0],[3,0],
                         [1,1],[2,1],[3,1],[4,1],
                         [1,2],[2,2],[3,2],[4,2],
                         [2,3],[3,3],[4,3],[5,3]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([2,4])
        elif (shape == "16B"):
            nbd_list = np.asarray([[0,0],[1,0],[2,0],[3,0],
                         [0,1],[1,1],[2,1],[3,1],
                         [0,2],[1,2],[2,2],[3,2],
                         [0,3],[1,3],[2,3],[3,3]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([0,4])
        elif (shape == "16C"):
            nbd_list = np.asarray([[0,0],[1,1],[2,2],[0,1],
                         [1,2],[2,3],[0,2],[1,3],
                         [-1,2],[0,3],[1,4],[-1,3],
                         [0,4],[1,5],[-1,4],[0,5]])
            ucx = np.asarray([3,2])
            ucy = np.asarray([-2,4])
        elif (shape == "16D"):
            nbd_list = np.asarray([[0,0],[1,0],[2,0],[3,0],
                         [1,1],[2,1],[3,1],[4,1],
                         [1,2],[2,2],[3,2],[4,2],
                         [1,3],[2,3],[3,3],[4,3]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([1,4])
        elif (shape == "18B"):
            nbd_list = np.asarray([[0,0],[1,1],[2,2],[0,1],[1,2],[2,3],
                                   [-1,1],[0,2],[1,3],[-1,2],[0,3],[1,4],
                                 [-2,2],[-1,3],[0,4],[-2,3],[-1,4],[0,5]])
            ucx = np.asarray([3,3])
            ucy = np.asarray([-3,3])
        elif (shape == "18A"):
            nbd_list = np.asarray([[0,0],[1,1],[2,2],[0,1],[1,2],[2,3],
                                   [0,2],[1,3],[2,4],[-1,2],[0,3],[1,4],
                                 [-1,3],[0,4],[1,5],[-1,4],[0,5],[1,6]])
            ucx = np.asarray([3,3])
            ucy = np.asarray([-2,4])
        elif (shape == "18C"):
            nbd_list = np.asarray([[0,0],[2,1],[1,1],[3,2],[0,1],[2,2],
                                   [1,2],[3,3],[0,2],[2,3],[1,3],[3,4],
                                 [0,3],[2,4],[1,4],[3,5],[0,4],[2,5]])
            ucx = np.asarray([4,2])
            ucy = np.asarray([-1,4])
        elif (shape == "18D"):
            nbd_list = np.asarray([[0,0],[0,1],[0,2],[0,3],[1,1],[1,2],
                                   [1,3],[1,4],[2,2],[2,3],[2,4],[3,2],
                                 [3,3],[3,4],[3,5],[4,3],[4,4],[4,5]])
            ucx = np.asarray([5,3])
            ucy = np.asarray([-1,3])
        elif (shape == "18T"):
            nbd_list = np.asarray([[ 0,0],[ 0,1],[ 0,2],[ 0,3],[ 0,4],[ 0,5],
                                   [-2,2],[-2,3],[-2,4],[-2,5],[-2,6],[-2,7],
                                   [-1,1],[-1,2],[-1,3],[-1,4],[-1,5],[-1,6]])
            ucx = np.asarray([1,5])
            ucy = np.asarray([-3,3])
        elif (shape == "20B"):
            nbd_list = np.asarray([[0, 0],[0, 1],[0, 2],[0, 3],
                                   [1, 0],[1, 1],[1, 2],[1, 3],
                                   [2, 0],[2, 1],[2, 2],[2, 3],
                                   [3, 0],[3, 1],[3, 2],[3, 3],
                                   [4, 0],[4, 1],[4, 2],[4, 3]])
            ucx = np.asarray([5,0])
            ucy = np.asarray([0,4])
    elif (lattice == "kgm"):
        if (shape == "6x"):
            # 1x2x3 Kagome
            nbd_list = np.asarray([[0,0],[0,1],[1,0],
                                   [2,0],[2,1],[3,0]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([0,2])
        elif (shape == "6y"):
            # 1x2x3 Kagome
            nbd_list = np.asarray([[0,0],[1,0],[0,1],
                                 [0,2],[1,2],[0,3]])
            ucx = np.asarray([2,0])
            ucy = np.asarray([0,4])
        elif (shape == "12A"):
            # 2x2x3 Kagome
            nbd_list = np.asarray([[0,0],[1,0],[2,0],[3,0],[2,1],[4,1],
                         [1,2],[2,2],[3,2],[4,2],[2,3],[4,3]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([2,4])
        elif (shape == "12B"):
            # 2x2x3 Kagome
            nbd_list = np.asarray([[0,0],[0,1],[1,0],[2,0],[2,1],[3,0],
                         [0,2],[0,3],[1,2],[2,2],[2,3],[3,2]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([0,4])
        elif (shape == "18A"):
            nbd_list = np.asarray([[0,0],[0,1],[0,2],[0,3],[1,2],[1,4],
                                   [2,2],[2,3],[2,4],[2,5],[3,4],[3,6],
                                   [4,4],[4,5],[4,6],[4,7],[5,6],[-1,2]])
            ucx = np.asarray([6,6])
            ucy = np.asarray([-2,2])
        elif (shape == "18B"):
            nbd_list = np.asarray([[0,0],[0,1],[0,2],[0,3],[1,2],[1,4],
                                   [2,1],[2,2],[2,3],[2,4],[3,2],[3,4],
                                   [4,2],[4,3],[4,4],[4,5],[5,2],[5,4]])
            ucx = np.asarray([6,2])
            ucy = np.asarray([0,4])
        elif (shape == "18C"):
            nbd_list = np.asarray([[0,0],[1,0],[2,0],[3,0],[0,1],[2,1],
                                   [0,2],[1,2],[2,2],[3,2],[0,3],[2,3],
                                   [0,4],[1,4],[2,4],[3,4],[0,5],[2,5]])
            ucx = np.asarray([4,0])
            ucy = np.asarray([0,6])
    elif (lattice == "hny"):
        # This is a 3x time basis unit cell
        if (shape == "6"):
            nbd_list = np.asarray([[0,0],[1,2],[3,0],
                                   [4,2],[6,0],[7,2]])
            ucx = np.asarray([9,0])
            ucy = np.asarray([0,3])
        elif (shape == "8"):
            nbd_list = np.asarray([[0, 0],[0, 3],[1, 2],[1, 5],
                                   [3, 0],[3, 3],[4, 2],[4, 5]])
            ucx = np.asarray([6,0])
            ucy = np.asarray([0,6])
        elif (shape == "12A"):
            # 12 site Honeycomb 
            nbd_list = np.asarray([[0, 0],[0, 3],[1, 2],[1, 5],
                                   [3, 3],[3, 6],[4, 2],[4, 5],
                                   [6, 3],[6, 6],[7, 5],[7, 8]])
            ucx = np.asarray([9,3])
            ucy = np.asarray([0,6])
        elif (shape == "12B"):
            # 12 site Honeycomb (3,0) (0,2)
            nbd_list = np.asarray([[0, 0],[0, 3],[1, 2],[1, 5],
                                   [3, 0],[3, 3],[4, 2],[4, 5],
                                   [6, 0],[6, 3],[7, 2],[7, 5]])
            ucx = np.asarray([9,0])
            ucy = np.asarray([0,6])
        elif (shape == "18A"):
            # 3x3x2 Honeycomb (3,0) (0,3) # Derived from 9A
            nbd_list = np.asarray([[0,0],[3,0],[6,0],[1,2],[4,2],[7,2],
                                   [0,3],[3,3],[6,3],[1,5],[4,5],[7,5],
                                   [0,6],[3,6],[6,6],[1,8],[4,8],[7,8]])
            ucx = np.asarray([9,0])
            ucy = np.asarray([0,9])
        elif (shape == "18B"):
            # 3x3x2 Honeycomb (3,0) (1,3) # Derived from 9B
            nbd_list = np.asarray([[ 0,  0],[ 1,  2],[ 3,  0],[ 3,  3],
                                   [ 3,  6],[ 4,  2],[ 4,  5],[ 4,  8],
                                   [ 6,  0],[ 6,  3],[ 6,  6],[ 7,  2],
                                   [ 7,  5],[ 7,  8],[ 9,  3],[ 9,  6],
                                   [10,  5],[10,  8]])
            ucx = np.asarray([9,0])
            ucy = np.asarray([3,9])
#     print("NBD_LIST: ",nbd_list,ucx,ucy)
    if (real_space and lattice != "sq"):
        ucx = np.asarray(ucx,dtype=float)
        ucy = np.asarray(ucy,dtype=float)
        nbd_list = np.asarray(nbd_list,dtype=float)
        ucx[0] = ucx[0] - 0.5*ucx[1]
        ucx[1] *= (np.sqrt(3)/2)
        ucy[0] = ucy[0] - 0.5*ucy[1]
        ucy[1] = np.sqrt(3)/2*ucy[1]
        for n in nbd_list:
            n[0] = n[0]-0.5*n[1]
            n[1] = n[1]*np.sqrt(3)/2
    return nbd_list, ucx, ucy

def find_neighbor(vec,site,shape="",Lx=4,Ly=4,lattice="sq"):
    # Square Lattice
#     if (lattice == "sq" or lattice == "tr"):
    if (shape == ""):
        ix = site%Lx
        iy = int(site/Lx)
        jx = (ix+vec[0]) % Lx
        jy = (iy+vec[1]) % Ly
        return jx + jy * Lx, [jx,jy]
    else:
        nbd_list, ucx, ucy = cluster_shape(shape,lattice)
#         print("NBD_LIST: ",nbd_list,ucx,ucy)
        i = nbd_list[site]
        for site2,j in enumerate(nbd_list):
            for du in[[0,0],ucx,-ucx,ucy,-ucy,ucx+ucy,ucx-ucy,ucy-ucx,-ucx-ucy]:
                if ((j-i+du == vec).all()):
                    return site2, nbd_list[site2]
        return -1,[0,0]

def calc_SScorr(Nhilbert,Lx,Ly,Ncorr,list_corr_isite1,list_corr_isite2,psi,irght,ilft,ihfbit,list_1,
                list_ja,list_jb,nbh,prefac,shape="",lattice="sq"):
    scol = np.zeros(Ncorr,dtype=float)
    vec0 = np.zeros(Nhilbert,dtype=float)
    for k in range(Ncorr): # loop for all bonds for correlations
        vecr = np.zeros(Nhilbert)
        vecl = np.zeros(Nhilbert)
        isite1 = list_corr_isite1[k]
        isite2 = list_corr_isite2[k]
        # Calculate S1S1+x|G>
#         nbh = [[1,0],[-1,0],[0,1],[0,-1]]
#         prefac = [1,1,-1,-1]
#         nbh = [[1,0],[0,1]]
#         prefac = [1,1]
        for n,nvec in enumerate(nbh):
            is1 = 1 << isite1
            j,_ = find_neighbor(nvec,k,shape,Lx,Ly,lattice)
            if (j == -1): 
                continue
            is2 = 1 << j
            is12 = is1 + is2
#             print(bin(is12),prefac[n],j)
            for i in range(Nhilbert):
                ii = list_1[i]
                ibit = ii & is12 # find sgmz.sgmz|uu> = |uu> or sgmz.sgmz|dd> = |dd>
                if (ibit==0 or ibit==is12): # if (spin1,spin2) = (00) or (11): sgmz.sgmz only
                    vecr[i] += 0.25*psi[i]*prefac[n] # store +Jzz
#                 elif (ibit==is1 or ibit==is2): # if (spin1,spin2) = (01) or (10): sgmz.sgmz and (S+.S- or S-.S+)
                else:
                    vecr[i] -= 0.25*psi[i]*prefac[n] # store -Jzz
                    iexchg = ii ^ is12 # find S+.S-|du> = |ud> or S-.S+|ud> = |du>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    vecr[newcfg] += 0.5*psi[i]*prefac[n] # store 2*Jxx
        vecr *= 0.5
        # Calculate <G|S1S1+
#         nbh = [[1,0],[-1,0],[0,1],[0,-1]]
#         prefac = [1,1,-1,-1]
#         nbh = [[1,0],[0,1]]
#         prefac = [1,1]
        for n,nvec in enumerate(nbh):
            is1 = 1 << isite2
            j,_ = find_neighbor(nvec,k,shape,Lx,Ly,lattice)
#             print("site,nvec,j",isite2,nvec,j)
            if (j == -1): 
#                 print("site,nvec,j",isite2,nvec,j)
                continue
            is2 = 1 << j
            is12 = is1 + is2
#             print(bin(is12),prefac[n],j)
            for i in range(Nhilbert):
                ii = list_1[i]
                ibit = ii & is12 # find sgmz.sgmz|uu> = |uu> or sgmz.sgmz|dd> = |dd>
                if (ibit==0 or ibit==is12): # if (spin1,spin2) = (00) or (11): sgmz.sgmz only
                    vecl[i] += 0.25*psi[i]*prefac[n] # store +Jzz
#                 elif (ibit==is1 or ibit==is2): # if (spin1,spin2) = (01) or (10): sgmz.sgmz and (S+.S- or S-.S+)
                else:
                    vecl[i] -= 0.25*psi[i]*prefac[n] # store -Jzz
                    iexchg = ii ^ is12 # find S+.S-|du> = |ud> or S-.S+|ud> = |du>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    vecl[newcfg] += 0.5*psi[i]*prefac[n] # store 2*Jxx
        vecl *= 0.5
        scol[k] = np.dot(vecr,vecl)
#         if (k == 0): print("vec*vec: ",np.dot(vecr,vecl))
#         else:
# #             print("point: ",ix,iy)
#             print("Cr, Cr/C0: ",scol[k],scol[k]/scol[0])
        vec0 += vecl
    return scol, vec0/Ncorr
            
def calc_chiralcorr(Nhilbert,Lx,Ly,Ncorr,list_corr_isite1,list_corr_isite2,psi,irght,
                    ilft,ihfbit,list_1,list_ja,list_jb,shape="",lattice="sq",symm="A2g"):
    schi = np.zeros(Ncorr,dtype=float)
    vec0 = np.zeros(Nhilbert,dtype=float)
    rot90 = [[0,-1],[1,0]]
    rotm90 = [[0,1],[-1,0]]
    nbh1 = [[1,0],[-1,0],[0,-1],[0,1]]
    nbh2 = [[0,1],[0,-1],[1,0],[-1,0]]
    sign = [-1,-1,1,1]
    if (symm == "A2g"): sign = [1,1,1,1]
    # Rotate points 
    for k in range(Ncorr): # loop for all bonds for correlations, kjl
        vecr = np.zeros(Nhilbert)
        vecl = np.zeros(Nhilbert)
        isite1 = list_corr_isite1[k]
        isite2 = list_corr_isite2[k]
        _,loc1r = find_neighbor([0,0],isite1,shape,Lx,Ly,lattice)
        _,loc1l = find_neighbor([0,0],isite2,shape,Lx,Ly,lattice)
        # Calculate S*SxS|G>
        loc = np.asarray([[0,0],[1,1],[0,1],[1,0]])
    #     nbh = [[1,0],[0,1]]
        for r in range(4): # Rotate 4 times
    #         nbh = np.matmul(nbh,rot90)
            s1,loc1 = find_neighbor(loc[r],isite1,shape,Lx,Ly,lattice)
            is1 = 1 << s1
            j,loc2 = find_neighbor(nbh1[r],s1,shape,Lx,Ly,lattice)
            l,loc3 = find_neighbor(nbh2[r],j,shape,Lx,Ly,lattice)
            if (s1 == -1 or j == -1 or l == -1): continue
            is2 = 1 << j
            is3 = 1 << l
#             print(nbh1[r],nbh2[r],s1,j,l,loc1,loc2,loc3)
#             print(pbin(is1),pbin(is2),pbin(is3))
            is23 = is2 + is3
            is31 = is3 + is1
            is12 = is1 + is2
            for i in range(Nhilbert):
                # First Combination 1,2,3
                ii = list_1[i] 
                ibit = ii & is23 
                if (ibit==is2 or ibit==is3): # Figure out if its |du> or |ud>
                    prefac = 1
                    iexchg = ii ^ is23 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    if (ibit == is3): prefac *= -1 # Site 3 is flipped to down spin
                    # Check Sz component
                    if (ii & is1 == 0): prefac *= -0.5
                    else: prefac *= 0.5 # Spin up
                    vecr[newcfg] += psi[i]*prefac*sign[r]
                # Second Combination 2,3,1
                ibit = ii & is31 
                if (ibit==is3 or ibit==is1): # Figure out if its |du> or |ud>
                    prefac = 1
                    iexchg = ii ^ is31 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    if (ibit == is1): prefac *= -1 # Site 3 is flipped to down spin
                    # Check Sz component
                    if (ii & is2 == 0): prefac *= -0.5
                    else: prefac *= 0.5 # Spin up
                    vecr[newcfg] += psi[i]*prefac*sign[r]
                # Second Combination 3,1,2
                ibit = ii & is12
                if (ibit==is1 or ibit==is2): # Figure out if its |du> or |ud>
                    prefac = 1
                    iexchg = ii ^ is12 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    if (ibit == is2): prefac *= -1 # Site 3 is flipped to down spin
                    # Check Sz component
                    if (ii & is3 == 0): prefac *= -0.5
                    else: prefac *= 0.5 # Spin up
                    vecr[newcfg] += psi[i]*prefac*sign[r]
                    
            #### Chi_r
            s1,loc1 = find_neighbor(loc[r],isite2,shape,Lx,Ly,lattice)
            is1 = 1 << s1
            j,loc2 = find_neighbor(nbh1[r],s1,shape,Lx,Ly,lattice)
            l,loc3 = find_neighbor(nbh2[r],j,shape,Lx,Ly,lattice)
            if (s1 == -1 or j == -1 or l == -1): continue
            is2 = 1 << j
            is3 = 1 << l
#             print(nbh1[r],nbh2[r],s1,j,l,loc1,loc2,loc3)
#             print(pbin(is1),pbin(is2),pbin(is3))
            is23 = is2 + is3
            is31 = is3 + is1
            is12 = is1 + is2
            for i in range(Nhilbert):
                # First Combination 1,2,3
                ii = list_1[i] 
                ibit = ii & is23 
                if (ibit==is2 or ibit==is3): # Figure out if its |du> or |ud>
                    prefac = 1
                    iexchg = ii ^ is23 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    if (ibit == is3): prefac *= -1 # Site 3 is flipped to down spin
                    # Check Sz component
                    if (ii & is1 == 0): prefac *= -0.5
                    else: prefac *= 0.5 # Spin up
                    vecl[newcfg] += psi[i]*prefac*sign[r]
                # Second Combination 2,3,1
                ibit = ii & is31 
                if (ibit==is3 or ibit==is1): # Figure out if its |du> or |ud>
                    prefac = 1
                    iexchg = ii ^ is31 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    if (ibit == is1): prefac *= -1 # Site 3 is flipped to down spin
                    # Check Sz component
                    if (ii & is2 == 0): prefac *= -0.5
                    else: prefac *= 0.5 # Spin up
                    vecl[newcfg] += psi[i]*prefac*sign[r]
                # Second Combination 3,1,2
                ibit = ii & is12
                if (ibit==is1 or ibit==is2): # Figure out if its |du> or |ud>
                    prefac = 1
                    iexchg = ii ^ is12 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                    newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                    if (ibit == is2): prefac *= -1 # Site 3 is flipped to down spin
                    # Check Sz component
                    if (ii & is3 == 0): prefac *= -0.5
                    else: prefac *= 0.5 # Spin up
                    vecl[newcfg] += psi[i]*prefac*sign[r]
        vecr /= 2
        vecl /= 2
        schi[k] = np.dot(vecr,vecl)
        if (k == 0): print("vec*vec: ",np.dot(vecr,vecl))
        else:
            print("point: ",loc1r,loc1l)
            print("Cr, Cr/C0: ",schi[k],schi[k]/schi[0])
        vec0 += vecl
    return schi, vec0/Ncorr
    
def calc_a2gcorr(Nhilbert,Lx,Ly,Ncorr,list_corr_isite1,list_corr_isite2,psi,irght,
                 ilft,ihfbit,list_1,list_ja,list_jb,shape="",lattice="sq",inp=[[1,0,0,1,0,0,1]]):
#     schi = np.zeros(Ncorr,dtype=float)
    vec0 = np.zeros(Nhilbert,dtype=float)
    inp = np.asarray(inp)
    for param in inp:
        coeff = param[0]
        nbhd = param[1:].reshape(3,2)
        nbhd_list = generate_rotation(nbhd,lattice)
        print(coeff,nbhd,nbhd_list)
        if ((nbhd[0] == nbhd[1]).all() or (nbhd[0] == nbhd[2]).all() or (nbhd[1] == nbhd[2]).all()):
            print("WARNING, chiral operator invalid, skipping this parameter: ", param)
            continue
        for k in range(Ncorr): # loop for all bonds for correlations, kjl
            isite1 = list_corr_isite2[k]
            # Calculate S*SxS|G>
            for nbh in nbhd_list:
                s1,loc1 = find_neighbor(nbh[0],isite1,shape,Lx,Ly,lattice)
                j,loc2 = find_neighbor(nbh[1],isite1,shape,Lx,Ly,lattice)
                l,loc3 = find_neighbor(nbh[2],isite1,shape,Lx,Ly,lattice)
                if (s1 == -1 or j == -1 or l == -1): continue
                else: print(s1,j,l,loc1,loc2,loc3)
                is1 = 1 << s1
                is2 = 1 << j
                is3 = 1 << l
                is23 = is2 + is3
                is31 = is3 + is1
                is12 = is1 + is2
                for i in range(Nhilbert):
                    # First Combination 1,2,3
                    ii = list_1[i] 
                    ibit = ii & is23 
                    if (ibit==is2 or ibit==is3): # Figure out if its |du> or |ud>
                        prefac = 1
                        iexchg = ii ^ is23 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                        newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                        if (ibit == is3): prefac *= -1 # Site 3 is flipped to down spin
                        # Check Sz component
                        if (ii & is1 == 0): prefac *= -0.5
                        else: prefac *= 0.5 # Spin up
                        vec0[newcfg] += psi[i]*prefac*coeff
                    # Second mCombination 2,3,1
                    ibit = ii & is31 
                    if (ibit==is3 or ibit==is1): # Figure out if its |du> or |ud>
                        prefac = 1
                        iexchg = ii ^ is31 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                        newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                        if (ibit == is1): prefac *= -1 # Site 3 is flipped to down spin
                        # Check Sz component
                        if (ii & is2 == 0): prefac *= -0.5
                        else: prefac *= 0.5 # Spin up
                        vec0[newcfg] += psi[i]*prefac*coeff
                    # Second Combination 3,1,2
                    ibit = ii & is12
                    if (ibit==is1 or ibit==is2): # Figure out if its |du> or |ud>
                        prefac = 1
                        iexchg = ii ^ is12 # find S2-S3+|ud> = |du> or - S2+S3-|du> = |ud>
                        newcfg = get_ja_plus_jb(iexchg,irght,ilft,ihfbit,list_ja,list_jb)
                        if (ibit == is2): prefac *= -1 # Site 3 is flipped to down spin
                        # Check Sz component
                        if (ii & is3 == 0): prefac *= -0.5
                        else: prefac *= 0.5 # Spin up
                        vec0[newcfg] += psi[i]*prefac*coeff
    return vec0/2/Ncorr


def ContFracExpan(H_0,groundH,E_0,Ham,Ediv=0.1,maxE=15,eps=0.1,niter_CFE=150):
    """
    Parameters:
        Hsize (int): Size of the Hamiltonian.
        H_0 (ndarray): vector.
        E_0 (float): Energy
        Ham = Sparse Matrix
        Ediv (ndarray): energy spacing for spectra
        maxE (ndarray): Maximum Energy
        eps: broadening of spectra
        
    Returns:
        None (output stored in specX and specY).
    """
    gscs = 0
    if (len(groundH) != 0):
        for gs in groundH:
            gscs += np.dot(H_0,gs)**2
    Hsize = H_0.size
    nedos = int(maxE/Ediv)+1
    alpha = np.zeros(niter_CFE)
    betha = np.zeros(niter_CFE)
    phip = np.zeros(Hsize)
    phipp = np.zeros(Hsize)
    phil = np.zeros(Hsize)
    specX = np.zeros(nedos)
    specY = np.zeros(nedos)
    Intensity = np.zeros(nedos,dtype=np.complex_)

#     # Normalize the matrix
    phi = H_0
    factor = np.sum(phi**2)
    phi /= np.sqrt(factor)
    print("factor, gscs",factor,gscs)
    
    for ii in range(niter_CFE):
        phip.fill(0.0)
        phip = Ham.dot(phi) # H|fn>
        alpha[ii] = np.dot(phip, phi) # an = <fn|H|fn>/<fn|fn>
        if ii != 0:
            phip -= betha[ii] * phil # H|fn> - bn|fn-1>
        phipp = phip - alpha[ii] * phi # H|fn> - an|fn> - bn|fn-1>

        if ii != niter_CFE - 1:
            betha[ii+1] = np.linalg.norm(phipp) #betha[ii+1] = np.sqrt(np.sum(phipp ** 2))
            phil = phi
            phi = phipp / betha[ii+1]
        print("iter,a,b", ii, alpha[ii],betha[ii])
        
    for ii in range(nedos):
        z = complex(float(ii) * Ediv + E_0, eps)
        Intensity[ii] = z - alpha[niter_CFE-1]
        for jj in range(1,niter_CFE):
            Intensity[ii] = z - alpha[niter_CFE-jj-1] -  (betha[niter_CFE-jj] ** 2/ Intensity[ii])
#             print("ii: ", ii, Intensity[ii])
        specX[ii] = float(ii) * Ediv
        specY[ii] = -1/np.pi * np.imag(factor/Intensity[ii] - gscs/(z-E_0))
        if (ii==0): print(factor/Intensity[ii],gscs/(z-E_0))
    return specX, specY

def generate_rotation(vec,lattice="sq"):
    rot_all = []
#     if (lattice != "sq" or lattice != "tr"):
#         return [vec]
    for v in vec:
        rot_arr = []
        if (lattice == "sq"):
            for i in range(4):
                rot90 = [[0,-1],[1,0]]
                rot_arr.append(v)
                v = np.matmul(v,rot90)
#             return np.asarray(rot_arr)
            rot_all.append(rot_arr)
        else:
            rot_d0 = [[0,0],[0,0],[0,0],[0,0],[0,0],[0,0]]
            rot_d1 = [[1,0],[1,1],[0,1],[-1,0],[-1,-1],[0,-1]]
            rot_d2 = [[2,1],[1,2],[-1,1],[-2,-1],[-1,-2],[1,-1]]
            rot_d3 = [[2,0],[2,2],[0,2],[-2,0],[-2,-2],[0,-2]]
            if (isinstance(v,np.ndarray)): v = v.tolist()
            if (v in rot_d0):
                rot_all.append(rot_d0)
            elif (v in rot_d1):
                ind = [i for i,x in enumerate(rot_d1) if x == v][0]
                rot_all.append(np.roll(rot_d1,-ind,axis=0))
            elif (v in rot_d2):
                ind = [i for i,x in enumerate(rot_d2) if x == v][0]
                rot_all.append(np.roll(rot_d2,-ind,axis=0))
            elif (v in rot_d3):
                ind = [i for i,x in enumerate(rot_d3) if x == v][0]
                rot_all.append(np.roll(rot_d3,-ind,axis=0))
            else:
                print("WARNING!!! Cannot rotate this vector")
    return np.asarray(rot_all).transpose((1,0,2))

def pbin(bit,L=18):
    return bin(bit)[2:].zfill(L)

def get_recip(shape,lattice):
    nbd_list, ax, ay = cluster_shape(shape,lattice)
    rot90 = np.asarray([[0,-1],[1,0]],dtype=float)
    bx = np.matmul(rot90,ay)/np.dot(ax,np.matmul(rot90,ay))*2*np.pi
    by = np.matmul(rot90,ax)/np.dot(ay,np.matmul(rot90,ax))*2*np.pi

    px = np.matmul(rot90,[0,1])/np.dot([1,0],np.matmul(rot90,[0,1]))*2*np.pi
    py = np.matmul(rot90,[1,0])/np.dot([0,1],np.matmul(rot90,[1,0]))*2*np.pi
    kunit = len(nbd_list)#int(re.findall(r'\d+',shape)[0])
    kpt_arr = []
    for x in range(-5,15):
        for y in range(-5,15):
            if (x*bx[0]+y*by[0] >= 0 and x*bx[0]+y*by[0] < 2*np.pi-0.001):
                if (x*bx[1]+y*by[1] >= 0 and x*bx[1]+y*by[1] < 2*np.pi-0.001):
                    kpt_arr.append([round((x*bx[0]+y*by[0])/(2*np.pi)*kunit),round((x*bx[1]+y*by[1])/(2*np.pi)*kunit)])

    kpt_arr = np.asarray(sorted(kpt_arr,key=lambda x: x[1]))
    high_symm = []
    if (lattice == "tr"):
        if (shape == "12C"):
            high_symm = [0,2,4,1]
        elif (shape == "16A"):
            high_symm = [0,4,8,6,2]
        elif (shape == "16B"):
            high_symm = [0,1,2,5]
        elif (shape == "18A"):
            high_symm = [0,15,13,10,14]
        elif (shape == "18B"):
            high_symm = [0,15,12,9]
        elif (shape == "18C"):
            high_symm = [0,15,6,9]
        elif (shape == "18D"):
            high_symm = [0,6,9]
    return kpt_arr, kunit, high_symm

def make_SSlist(Lx,Ly,shape,lattice,centered=True):
    # Centered: centered at point 0,0, with corresponding vector point at correct position
    # Not centered, just combinations of lists
    if (shape == ""): N = Lx*Ly
    else: N = int(re.findall(r'\d+',shape)[0])
    list1 = []
    list2 = []
    if (lattice=="kgm"): Nsite=3
    elif (lattice=="sq"): Nsite=2
    else: Nsite=1
    if (centered):
        for i in range(0,N,Nsite):
            site1, veci= find_neighbor([0,0],i,shape,Lx,Ly,lattice)
    #         print(ii,site1)
            for j in range(N):
                site2, vecj = find_neighbor(veci,j,shape,Lx,Ly,lattice)
                list1.append(site1)
                list2.append(site2)
    else:
        for i in range(N):
            for j in range(N):
                list1.append(i)
                list2.append(j)
    return np.asarray(list1),np.asarray(list2)

def main(Lx,Ly,Sz,J1,J2,shape="",lattice="sq",symm_in=["A1g"],save_spec=False,maxE=16,Ediv=0.01):
    if (shape == ""): N = Lx*Ly
    else: N = int(re.findall(r'\d+',shape)[0])
    Nup, Nhilbert, ihfbit, irght, ilft, iup = init_parameters(N,Sz)
    binirght = np.binary_repr(irght,width=N)
    binilft = np.binary_repr(ilft,width=N)
    biniup = np.binary_repr(iup,width=N)
    print("Lx=",Lx)
    print("Ly=",Ly)
    print("J1=",J1)
    print("J2=",J2)
    print("N=",N)
    print("Sz=",Sz)
    print("Nup=",Nup)
    print("Nhilbert=",Nhilbert)
    print("ihfbit=",ihfbit)
    print("irght,binirght=",irght,binirght)
    print("ilft,binilft=",ilft,binilft)
    print("iup,biniup=",iup,biniup)
    start = time.time()
    list_1, list_ja, list_jb = make_list(N,Nup,Nhilbert,ihfbit,irght,ilft,iup)
    end = time.time()
    print (end - start)
#     print("list_1=",list_1)
#    print("list_ja=",list_ja)
#    print("list_jb=",list_jb)
#     print("")
#    print("i ii binii ja+jb")
#    for i in range(Nhilbert):
#        ii = list_1[i]
#        binii = np.binary_repr(ii,width=N)
#        ind = get_ja_plus_jbs(ii,irght,ilft,ihfbit,list_ja,list_jb)
#        print(i,ii,binii,ind)
    
    Jxx, Jzz, list_isite1, list_isite2, Nint = make_lattice(Lx,Ly,J1,J2,shape,lattice)
    print ("Jxx",Jxx)
    print ("Jzz",Jzz)
    print ("list_isite1",list_isite1)
    print ("list_isite2",list_isite2)
    print("Nint=",Nint)

    if (lattice=="kgm"): Nsite=3
    elif (lattice=="sq"): Nsite=2
    else: Nsite=1
        
    start = time.time()
    HamCSR = make_hamiltonian(Jxx,Jzz,list_isite1,list_isite2,N,Nint,Nhilbert,irght,ilft,ihfbit,list_1,list_ja,list_jb)
    end = time.time()
    HamCSR /= 4
    print (end - start)
#     print (HamCSR)
    start = time.time()
#    ene,vec = scipy.sparse.linalg.eigsh(HamCSR,k=5)
    neig = 10
    ene,vec = scipy.sparse.linalg.eigsh(HamCSR,which='SA',k=neig)
#     ene = ene/N/4
    end = time.time()
    
    ### PLOT DOS ####
    dos = ene
    width = 0.005
    bins = int((dos[-1]-dos[0]+width)/width)
    y, xtemp = np.histogram(dos, bins=bins, range=(dos[0]-width*10,dos[-1]+width*10))    
    x = np.zeros(len(xtemp)-1)
    for i in range(len(x)):
        x[i] = 0.5*(xtemp[i]+xtemp[i+1])
    # print(x,y)
#     plt.plot(x-dos[0],y/np.max(y),c="r")
    ################
    print (end - start)
    print ("# GS energy:",ene[0])
    print ("# energy:",ene[:neig])
    print ("# energy shifted:",ene[:neig]-ene[0])
    print ("# energy per site:",ene[:neig]/N)
    #### Dealing with degenerate ground state
    psi = vec[:,0] # choose the ground state
    psi_all = [psi]
    ndegen = 1
    for n in range(1,neig):
        if (abs(ene[n]-ene[0]) < 1e-6):
            ndegen += 1
            psi_all = np.vstack((vec[:,n],psi_all))
    print("Number of degeneracy: ",ndegen)
    ####
#    vec_sgn = np.sign(np.amax(vec[:,0]))
#    print ("# GS wave function:")
#    for i in range (Nhilbert):
#        ii = list_1[i]
#        binii = np.binary_repr(ii,width=N)
#        print (i,vec[i,0]*vec_sgn,binii)
#
    Ncorr = N # number of total correlations
    cluster_sites,_,_ = cluster_shape(shape,lattice)
    for ii,site1 in enumerate(cluster_sites):
        if ((site1 == [0,0]).all()):
            i = ii
            break
    list_corr_isite1 = [i for k in range(Ncorr)] # site 1
    list_corr_isite2 = [k for k in range(Ncorr)] # site 2
    print ("corr1",list_corr_isite1)
    print ("corr2",list_corr_isite2)
    centered = False
    srlist1, srlist2 = make_SSlist(Lx,Ly,shape,lattice,centered)
    # TEMP
    srlist1 = []
    srlist2 = []
    for i in range(N):
        for j in range(N):
            srlist1.append(i)
            srlist2.append(j)
    SrNcorr = len(srlist1)
    print ("Sr corr1",srlist1)
    print ("Sr corr2",srlist2)
    
#     print("Ground state eigenvector: ",psi)
#     start = time.time()

#     maxE = 16.001
#     Ediv = 0.004001000250062516
    print("MAXE, EDIV: ",maxE,Ediv)
    
    ### Extract Spin Excitation Gap ####
    rspt,_,_ = cluster_shape(shape,lattice,real_space=True)
    nedos = int(maxE/Ediv)+1
#     specY = np.zeros(nedos)
#     fig = plt.figure(figsize=(6,4))
#     q = np.asarray([np.pi,np.pi/np.sqrt(3)]) #Gamma (0,0), K(4π/3,0), and M(π,π/√3)
#     qpt = np.asarray([[0,0],[np.pi/2,np.pi/np.sqrt(3)/2],[np.pi,0]])
#     for q in qpt:a
#         print("Qpoint: ",q)l
    for psi in psi_all:
        szz, sz_vec = calc_zcorr(Nhilbert,SrNcorr,srlist1,srlist2,psi,list_1)
        sxx = calc_xcorr(Nhilbert,SrNcorr,srlist1,srlist2,psi,irght,ilft,ihfbit,list_1,list_ja,list_jb)
        ss = szz+sxx+sxx
        NDIM = int(ss.shape[0]/N)
        ss = np.sum(np.reshape(ss,(NDIM,N)),axis=0)/NDIM
        stot2 = N*np.sum(ss)
        np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
        print ("#### Degenerate Ground State #####")
        print ("# szz:",repr(np.reshape(szz,(NDIM,N))*3))
        print ("# sxx:",repr(np.reshape(sxx,(NDIM,N))))
        if (not centered): print("!!WARNING!! Lattice not centered, Not printing spin-spin correlation")
        else: print ("# ss:",repr(ss))
        print ("# stot(stot+1):",stot2)
        print ("# Cross Section, aeig, ediff, cs (Sz)")
        for cs in range(neig):
#             print(cs)
#             plt.axvline(x=ene[cs]-ene[0],ls="--",c="r")
            eigvec = np.asarray(vec[:,cs],dtype=complex)
            if (np.linalg.norm(np.dot(sz_vec,eigvec) > 1e-5)):
                print("SPIN GAP: ",ene[cs]-ene[0])
                print("Neig, cs: ",cs,np.linalg.norm(np.dot(sz_vec,eigvec)))
                break


    if isinstance(symm_in, str): symm_in = [symm_in]
    for symm in symm_in:
        if (symm in ["","A1g","B1g","B2g","Eg1","Eg2","xx","yy"]):
            print("SYMMETRY: ", symm)
            if (lattice == "sq"):
                if (symm == "A1g"):
                    # Square A1g
                    nbh = [[1,0],[-1,0],[0,1],[0,-1],
                           [2,0],[-2,0],[0,2],[0,-2]]
                    prefac = [1*J1,1*J1,1*J1,1*J1,
                              2*J2,2*J2,2*J2,2*J2]
                elif (symm == "B1g" or symm == "Eg1"):
                    # Square B1g
                    nbh = [[1,0],[-1,0],[0,1],[0,-1],
                           [2,0],[-2,0],[0,2],[0,-2]]
                    prefac = [1*J1,1*J1,-1*J1,-1*J1,
                              2*J2,2*J2,-2*J2,-2*J2]
                elif (symm == "B2g" or symm == "Eg2"):
                    # Square B2g
                    nbh = [[1,1],[-1,-1],[-1,1],[1,-1]]
                    prefac = [2*J1,2*J1,-2*J1,-2*J1]
            elif (lattice == "tr"):
                if (symm == "A1g"):
                    # Triangle A1g
                    nbh = [[1,0],[0,1],[1,1],[-1,0],[0,-1],[-1,-1]]
                    prefac = [1*J1,1*J1,1*J1,1*J1,1*J1,1*J1]
                elif (symm == "Eg1"):
                    # Triangle Eg1
                    nbh = [[1,0],[0,1],[1,1],[2,1],[1,-1],[1,2]]
                    prefac = [1*J1,-0.5*J1,-0.5*J1,1.5*J2,1.5*J2,-3*J2]
                elif (symm == "Eg2"):
                    # Triangle Eg2
                    nbh = [[0,1],[1,1],[1,-1],[2,1]]
                    prefac = [-1*J1,1*J1,-3*J2,3*J2]
                elif (symm == "xx"):
                    # Triangle x in x out
                    nbh = [[1,0],[0,1],[1,1]]
                    prefac = [2*J1,0.5*J1,0.5*J1]
                elif (symm == "yy"):
                    # Triangle y in y out
                    nbh = [[0,1],[1,1]]
                    prefac = [3/2*J1,3/2*J1]
            elif (lattice == "kgm"):
                if (symm == "A1g"):
                    # Kagome A1g
                    nbh = [[1,0],[0,1],[1,1],[-1,0],[0,-1],[-1,-1]]
                    prefac = [1*J1,1*J1,1*J1,1*J1,1*J1,1*J1]
                elif (symm == "Eg1"):
                    # Kagome Eg1
                    nbh = [[1,0],[0,1],[1,1],[-1,0],[0,-1],[-1,-1]]
                    prefac = [1*J1,-0.5*J1,-0.5*J1,1*J1,-0.5*J1,-0.5*J1]
                elif (symm == "Eg2"):
                    # Kagome Eg2
                    C1 = np.sqrt(3)/2
                    nbh = [[1,1],[-1,-1],[0,1],[0,-1]]
                    prefac = [C1*J1,C1*J1,-C1*J1,-C1*J1]
                elif (symm == "xx"):
                    # Kagome x in x out
                    nbh = [[1,0],[-1,0],[1,1],[-1,-1],[0,1],[0,-1]]
                    prefac = [2*J1,2*J1,0.5*J1,0.5*J1,0.5*J1,0.5*J1]
                elif (symm == "yy"):
                    # Kagome y in y out
                    C1 = 3/4 * 2
                    nbh = [[1,1],[-1,-1],[0,1],[0,-1]]
                    prefac = [C1*J1,C1*J1,C1*J1,C1*J1]
            elif (lattice == "hny"):
                if (symm == "Eg1"):
                    # Kagome Eg1
                    C1 = 2/3
                    nbh = [[2,1],[-1,1],[1,-1],[-2,-1],[-1,-2],[1,2]]
                    prefac = [C1*J1,C1*J1,C1*J1,C1*J1,-2*C1*J1,-2*C1*J1]
                elif (symm == "Eg2"):
                    # Kagome Eg2
                    C1 = 2/np.sqrt(3)
                    nbh = [[2,1],[-1,1],[1,-1],[-2,-1]]
                    prefac = [C1*J1,-C1*J1,-C1*J1,C1*J1]
            nedos = int(maxE/Ediv)+1
            specY = np.zeros(nedos)
            for psi in psi_all:
#             for psi in [psi_all[1]]:
                scol, colvec = calc_SScorr(Nhilbert,Lx,Ly,Ncorr,list_corr_isite1,list_corr_isite2,psi,
                                           irght,ilft,ihfbit,list_1,list_ja,list_jb,nbh,prefac,shape,lattice)

#                 print ("# Cross Section, eig, ediff, cs")
#                 for cs in range(neig):
#                     if (np.linalg.norm(np.dot(colvec,vec[:,cs])) > 1e-5):
#                         print(cs,ene[cs]-ene[0],np.linalg.norm(np.dot(colvec,vec[:,cs])))
                specX, specY_temp = ContFracExpan(colvec,psi_all,ene[0],HamCSR,maxE=maxE,Ediv=Ediv,niter_CFE=100,eps=0.1)
                specY = specY + specY_temp
            if (save_spec):
                np.savetxt("./HSB-final/"+shape+"-"+lattice+"-"+symm+"-J1_1-J2_"+str(J2)+"-dense.dat", 
                           np.vstack((specX,specY)).T)
            else:
                if (lattice != "sq" or lattice != "hny"): plt.plot(specX,specY*200,label=symm+", x200")
                else: plt.plot(specX,specY*0.01,label=symm)
                plt.title(lattice+", "+symm)
                plt.xlim([0,10])
                plt.legend()
        if (symm == "A2g"):
            if (lattice == "sq"):
                # Square A2g
                inp = [[-34,0,0,1,0,0,1],[9,0,0,1,0,2,1],[9,0,0,1,2,0,1],[30,0,1,2,1,1,0],[9,0,0,1,0,0,2],[9,0,0,2,0,0,1]]
                inp = [[9,0,0,1,0,2,1],[9,0,0,1,2,0,1]]
#                 inp = [[1,0,0,1,0,0,1]]
            elif (lattice == "tr"):
                # Rotating in triangular lattice, [1,0],[1,1],[0,1],[-1,0],[-1,-1],[0,-1]
                # Rotating in triangular lattice, [2,1],[1,2],[-1,1],[-2,-1],[-1,-2],[1,-1]
                # Rotating in triangular lattice, [2,0],[2,2],[0,2],[-2,0],[-2,-2],[0,-2]
                print("Triangular A2g")
                inp = [[1,1,0,2,1,2,0],[-1,1,0,1,-1,2,0],[1,0,0,2,1,1,0],[-1,0,0,1,-1,1,0]
                       ,[1,0,0,2,1,2,0],[-1,0,0,1,-1,2,0]]
            elif (lattice == "kgm"):
                # Kagome A2g
                print("Kagome A2g")
                # May have double counted this
                inp = [[3,0,0,1,0,1,1],[1,0,0,1,0,0,1]]
#                 inp = [[3,0,0,1,0,1,1],[3,0,0,-1,0,-1,-1],[1,0,0,0,1,-1,-1],[1,0,0,0,-1,1,1],
#                        [1,0,0,1,0,0,1],[1,0,0,-1,0,0,-1],[1,0,0,-1,-1,1,0],[1,0,0,1,1,-1,0]]
#                 inp = [[3,0,0,0,1,-1,0],[3,0,0,-1,0,1,0],[1,0,0,-1,-2,0,-1],[1,0,0,1,2,0,1],
#                        [1,0,0,1,0,0,1],[1,0,0,-1,0,0,-1],[1,0,0,-1,0,-2,-1],[1,0,0,1,0,2,1]]
            elif (lattice == "hny"):
                # Kagome A2g
                print("Honeycomb A2g")
                C1 = 2*np.sqrt(3)
                # May have double counted this
                inp = [[C1,0,0,-1,1,-1,-2],[C1,0,0,-1,-2,2,1],[C1,0,0,2,1,-1,1],
                       [C1,0,0,1,2,-2,-1],[C1,0,0,-2,-1,1,-1],[C1,0,0,1,-1,1,2]]
            nedos = int(maxE/Ediv)+1
            specY = np.zeros(nedos)
            for psi in psi_all:
                chiralvec = calc_a2gcorr(Nhilbert,Lx,Ly,Ncorr,list_corr_isite1,list_corr_isite2,
                                            psi,irght,ilft,ihfbit,list_1,list_ja,list_jb,shape,lattice,inp)
#                 print ("# Cross Section, eig, ediff, cs (A2g)")
#                 for cs in range(neig):
#     #                 print(cs)
#                     if (np.linalg.norm(np.dot(chiralvec,vec[:,cs])) > 1e-5):
#                         print(cs,ene[cs]-ene[0],np.linalg.norm(np.dot(chiralvec,vec[:,cs])))

                specX, specY_temp = ContFracExpan(chiralvec,psi_all,ene[0],HamCSR,maxE=maxE,Ediv=Ediv,niter_CFE=100,eps=0.1)
                specY = specY + specY_temp
            if (save_spec):
                np.savetxt("./HSB-final/"+shape+"-"+lattice+"-A2g-J1_1-J2_"+str(J2)+"-dense.dat", np.vstack((specX,specY)).T)
            else:
                plt.plot(specX,specY,label=symm)
                plt.title(lattice+", "+symm)
                plt.xlim([0,10])
                plt.legend()

In [None]:
main(6,6,0,1,0,"12C","tr",["Eg1","Eg2","A2g"],maxE=15,Ediv=0.01,save_spec=True)

factor, gscs 0.014651575152875386 1.6341585859830906e-34
iter,a,b 0 -6.131688083104132 0.0
iter,a,b 1 -2.67977196562108 0.892185355267761
iter,a,b 2 -2.551773166381021 1.149989018204877
iter,a,b 3 -1.7056421367338943 0.41183976772522113
iter,a,b 4 1.0688753518401262 1.428601530501468
iter,a,b 5 7.647684576832401 2.9975180030078387e-12
iter,a,b 6 -1.7621778471544514 2.6896485022553125
iter,a,b 7 2.351022468172619 5.7860029131728865
iter,a,b 8 3.3043244331930195 1.3896978827096627
iter,a,b 9 -2.742472077881876 4.086398261417943
iter,a,b 10 1.835157347719706 2.3535666067040313
iter,a,b 11 -1.4292999494736007 2.3902996188228376
iter,a,b 12 -2.0614046464944646 3.6146033719975037
iter,a,b 13 -0.5698175802660437 2.4882982250677004
iter,a,b 14 -2.2778594269828356 4.021785281908443
iter,a,b 15 -1.2286529777766595 2.463105212505408
iter,a,b 16 -0.5483385969266954 3.167437928490622
iter,a,b 17 -0.7724728420644869 2.6370632980121314
iter,a,b 18 -0.5504280727235417 2.5416507882662014
iter,a,b 19 -0

10 9 7 [4 4] [3 4] [3 3]
3 8 1 [1 2] [4 3] [1 1]
1 [[0 0]
 [2 1]
 [1 0]] [[[ 0  0]
  [ 2  1]
  [ 1  0]]

 [[ 0  0]
  [ 1  2]
  [ 1  1]]

 [[ 0  0]
  [-1  1]
  [ 0  1]]

 [[ 0  0]
  [-2 -1]
  [-1  0]]

 [[ 0  0]
  [-1 -2]
  [-1 -1]]

 [[ 0  0]
  [ 1 -1]
  [ 0 -1]]]
0 2 9 [0 0] [2 1] [3 4]
0 3 1 [0 0] [1 2] [1 1]
0 7 8 [0 0] [3 3] [4 3]
0 2 5 [0 0] [2 1] [3 2]
0 3 11 [0 0] [1 2] [5 5]
0 7 6 [0 0] [3 3] [2 3]
1 5 2 [1 1] [3 2] [2 1]
1 6 4 [1 1] [2 3] [2 2]
1 10 3 [1 1] [4 4] [1 2]
1 5 8 [1 1] [3 2] [4 3]
1 6 0 [1 1] [2 3] [0 0]
1 10 9 [1 1] [4 4] [3 4]
2 0 11 [2 1] [0 0] [5 5]
2 7 5 [2 1] [3 3] [3 2]
2 3 4 [2 1] [1 2] [2 2]
2 0 1 [2 1] [0 0] [1 1]
2 7 9 [2 1] [3 3] [3 4]
2 3 10 [2 1] [1 2] [4 4]
3 7 4 [1 2] [3 3] [2 2]
3 0 6 [1 2] [0 0] [2 3]
3 2 11 [1 2] [2 1] [5 5]
3 7 10 [1 2] [3 3] [4 4]
3 0 8 [1 2] [0 0] [4 3]
3 2 1 [1 2] [2 1] [1 1]
4 8 5 [2 2] [4 3] [3 2]
4 9 7 [2 2] [3 4] [3 3]
4 11 6 [2 2] [5 5] [2 3]
4 8 3 [2 2] [4 3] [1 2]
4 9 1 [2 2] [3 4] [1 1]
4 11 2 [2 2] [5 5] [2 1]
5 1 0 