In [None]:
## Solving sudoku using discrete driving + DMRG Sweep
## Major Update Jan 2025, fixed lamb*2

# require the ITensors and ITensorMPS package 
# for handling spin model, tensor train (matrix product) states, and density matrix renormalization group
using ITensors, ITensorMPS
# require LinearAlgebra for matrix manipulations 
using LinearAlgebra
# for saving and reading MPS
using HDF5
# for random numbers, computing sum/ave, file IO etc.
using Statistics
using Random 
using DelimitedFiles

# construct and return the QUBO matrix by applying 4 constraints as penalties
# here lamb is the penalty weight passed to the function
# or simplifity, we use the same lambda for all pendalties, so it drops out and can be set as 1
function compose_Q(lamb)
    # in our QUBO problem, the binary variables are x_{ijk}
    # if cell (ij) is occupied by number k, x_{ijk}=1, otherwise x_{ijk}=0
    # the total number of x_{ijk}, is 9^3=729
    
    # create a Q matrix, array initialized to zero 
    Q = zeros(Float64, 729, 729)

    # a small inline function that maps (i,j,k) to the variable index 9*(u)+k, with u = 8*i+j the cell index
    # note: indices i,j,k start from 0 
    iota(i, j, k) = 81*i + 9*j + k 
    
    # apply constraint (I), the cell sum must be 1
    for i = 0:8, j = 0:8, k = 0:8
        # within each cell (i,j), sum over k, \sum x_{ijk}=1
        u = iota(i, j, k)
        # indices u, v also start from 0
        Q[u+1, u+1] -= lamb      
        for k_ in 0:8
            if k_ != k
                v = iota(i, j, k_)
                Q[u+1, v+1] += lamb   ## no factor of 2 here
            end
        end
    end

    # apply constraint (II), column sum must be 1
    for j = 0:8, k = 0:8, i = 0:8
        # for given j and k, sum over i, \sum x_{ijk}=1
        u = iota(i, j, k)
        Q[u+1, u+1] -= lamb      
        for i_ in 0:8
            if i_ != i
                v = iota(i_, j, k)
                Q[u+1, v+1] += lamb
            end
        end
    end

    # constraint (III), row sum must be 1
    for i = 0:8, k = 0:8, j = 0:8
        # for given i and k, sum over j
        u = iota(i, j, k)
        Q[u+1, u+1] -= lamb
        for j_ in 0:8
            if j_ != j
                v = iota(i, j_, k)
                Q[u+1, v+1] += lamb
            end
        end
    end

    # constraint (IV), block sum must be 1
    for k = 0:8, ii = 0:2, jj = 0:2
        ## for given block (ii,jj) and k,
        ## loop over the cell within the block           
        for i = 3*ii:3*ii+2, j = 3*jj:3*jj+2
            u = iota(i, j, k)
            Q[u+1, u+1] -= lamb
            for i_ = 3*ii:3*ii+2, j_ = 3*jj:3*jj+2
                if i_ != i && j_ != j
                    v = iota(i_, j_, k)
                    Q[u+1, v+1] += lamb
                end
            end
        end
    end

    ## convert it into upper triangular form
    return triu(Q) + transpose(tril(Q, -1))
end

# take a set of Sudoku clues, return two clamped sets: list_0 and list_1 
# clamp the filled cells to 1 (I); and propagate the affected row (II), column (III), and block (IV), clamp them to 0. 
function clamp_sets(clues)    
    # small inline function again
    iota(i, j, k) = 81*i + 9*j + k 
    
    list_1 = Set()  # stores the indices of binary variables to be clamped to 1
    list_0 = Set()  # similarly, those clamped to 0
    
    # loop over the clues
    for (ii, jj, kk) in clues

        ## the indices of clues start from 1, but remember i,j,k start from zero
        i = ii - 1
        j = jj - 1
        k = kk - 1
        
        # (I) clamp to 1, if the cell is filled according to the clue
        push!(list_1, iota(i, j, k)+1)
        
        # (II) clamp to 0, looping over other k values in same cell, k' <> k
        for k_ in 0:8
            if k_ != k
                push!(list_0, iota(i, j, k_)+1)
            end
        end

        # (III) clamp to 0, for cells in the same column 
        for i_ in 0:8
            if i_ != i
                push!(list_0, iota(i_, j, k)+1)
            end
        end

        # (III) similarly, for cells in the same row
        for j_ in 0:8
            if j_ != j
                push!(list_0, iota(i, j_, k)+1)
            end
        end
        
        # (IV) clamp to zero, for cells in same block
        # block index given by ii = i÷3, jj = j÷3
        ii = i÷3
        jj = j÷3
        # (i_,j_) enumerates all cells within the same block
        for i_ in 3*ii:3*ii+2
            for j_ in 3*jj:3*jj+2
                if (i_ != i) && (j_ != j)
                    push!(list_0, iota(i_, j_, k)+1)
                end
            end
        end
    end

    return list_0, list_1
end

# reduce the full Q matrix to a lean_qubo matrix by "clamping"
# size_Q to size_Qp
# by a sequence of manipulations: Q --> T0 ---> Qpp ---> Qp
# return Qp, offset (due to clamping), and kappa (the Q->Qp mapping matrix) 
function reduce_to_Qp(Q, size_Q::Int, size_Qp::Int, list_0, list_1) 
    
    # construct the transformation matrix T0
    T0 = zeros(Float64, size_Q, size_Qp + 1)
    col_of_1 = size_Qp + 1

    ## a vector that maps from the index of the original "full vector" to the index of "free variables"
    kappa = zeros(Int, size_Q)   
 
    m = 0 ## counter, as the index of "free variables"
    ## loop through rows of Q
    for i in 1:size_Q 
        if i in list_1
            # variable clamped to 1
            T0[i, col_of_1] = 1.0
            kappa[i] = -1  ## arbitrary negative number for debugging
        elseif i in list_0
            # clamped to 0
            # nothing to be done for T0, we flag kappa for debugging
            kappa[i] = -2 
        else
            # it is a free variable, the counter increases by 1
            m += 1
            # at the corresponding column (m)
            if m<= size_Qp
                T0[i, m] = 1.0
            else 
                println("out of bound in T0[i,m] ", m)
            end
            # the map i->m is recorded in kappa 
            kappa[i] = m
        end
    end
    
    ## construct Qpp and make it into upper triangle form
    Qt = T0' * Q * T0
    Qpp = triu(Qt) + transpose(tril(Qt, -1))
    
    ## the bottom right corner element gives the offset due to clamping
    offset = Qpp[end, end]
    
    ## Finally, construct the lean_qubo matrix Qp, it should be upper triangular automatically
    Qp = Qpp[1:end-1, 1:end-1] + Diagonal(Qpp[1:end-1, end]) 
    
    return Qp, offset, kappa
end

# show the board from <Sz>, by constructing the full qubo solution
# check the row, column, and block violations
function display_solution(Z, N, list_0, list_1, kappa)
    ### convert <Sz> back to binary value, i.e. solution to the lean_qubo
    bn = zeros(Int8,N)
    for i in eachindex(Z)
        if abs(Z[i]-0.5) < 0.5  ## if closer to 1/2; was < 0.1 which seems quite harsh
            bn[i]=1 ## Recall: spin up == 1, spin down == 0
        end
    end
    ## Total # of spin ups, sum(bn)
    println("# of spin ups at the final stage: ", sum(bn))

    # expand from bn to x (the full vector), using list_0,list_1,kappa
    x = zeros(Int, 729)
    for i in 1:729
        if i in list_0
            x[i] = 0
        elseif i in list_1
            x[i] = 1
        else
            if kappa[i] > 0 && kappa[i] <= length(bn) ## extra precaution
                x[i] = bn[kappa[i]]
            else
                println("error: out of range in kappa[i]")
            end
        end
    end

    println("Done. The sudoku board is ")
    sodu = zeros(Int, 9, 9)
    for idx in eachindex(x)
        if x[idx] > 0.95 ## if it is 1
            b = idx - 1
            ## b = 81i+9j+k
            i = b÷81
            a = b%81
            j = a÷9
            k = a%9
            sodu[i + 1, j + 1] = k + 1
        end
    end
    ## zero on the board meaning the number of up spins are not enough
    
    ## show the solution board
    display(sodu)

    ## check the row, column, and block
    pass = true
    solution_set = Set(1:9)
    for i in 1:9
        row_elements = unique(sodu[i, :])
        if Set(row_elements) != solution_set
            println("! violation in row ", i)
            pass = false
        end

        col_elements = unique(sodu[:, i])
        if Set(col_elements) != solution_set
            println("! violation in col ", i)
            pass = false
        end
    end

    for u in 0:2
        for v in 0:2
            block_elements = Set()
            for i in 3*u + 1:3*u + 3
                for j in 3*v + 1:3*v + 3
                    push!(block_elements, sodu[i, j])
                end
            end
            if block_elements != solution_set
                println("! violation in block (", u + 1, ",", v + 1, ")")
                pass = false
            end
        end
    end

    if pass
        println("Check complete, solution is valid.")
    else
        println("Check complete, solution invalid.")
    end
    
    return
end


## Main algorithm: For given clues, solve the sudoku by quantum annealing based on DMRG
## tuning parameters:
## hx_const: constant term in hx, 1
## hx_random_scale: overall scale for random fluctuations in hx, 0.333
## bond_dimension: the bond dimension of the annealing process, 40-60
## numsteps: number of annealing steps, 10 or 5
## numsweeps: number of sweeps in each DMRG, 4-10

function tensor_train_express(clues; hx_const = 1, hx_random_scale=0.333, numsteps=5, bond_dimension=40, numsweeps=5)
    ## create the clamp sets from the clues
    list_0, list_1 = clamp_sets(clues)
    num_clamp = length(list_0) + length(list_1)

    ## create the original bloated Q matrix, in upper triangular form
    Q0 = compose_Q(1)
    
    ## reduce the bloated QUBO to a lean QUBO by applying the clues
    size_Q0 = size(Q0, 1)
    size_Qp = size_Q0 - num_clamp
    Q, c_off_set, kappa = reduce_to_Qp(Q0, size_Q0, size_Qp, list_0, list_1)
    
    ## list_0, list_1, and kappa will be used later to expand_to_full
    
    ## the size of the spin problem
    (N, M) = size(Q)
    ## double check
    if N != M
        println("Q matrix is not square!")
    else
        println("The total number of spins is ", N)
    end
   
    ## the shift in energy, due to the transformation from binary variable x --> (s-1/2) 
    energy_shift = 0.25*(tr(Q) + sum(Q))
    ## sum up all the constant shifts, its negative is the energy target
    en_target = energy_shift+c_off_set+81*4
    
    ## compute the onsite magnetic fields for reduced qubo, to be used for h1, the problem hamiltonian
    ha = zeros(N)
    for i in eachindex(ha)
        slice1 = Q[i,i+1:N]
        slice2 = Q[1:i-1,i]
        ha[i] = Q[i,i] + 0.5*( sum(slice1) + sum(slice2) ) 
    end    
    
    ## important note: with no conserved QN, otherwise one cannot use h2, the transverse hamiltonian
    sites = siteinds("S=1/2",N)   
    
    ## set up h1, the problem Hamiltonian, as OpSum 
    h1 = OpSum()
    for i=1:N
        h1 += (ha[i], "Sz", i)
        for j in i+1:N 
            h1 += (Q[i,j], "Sz", i, "Sz", j)  
        end
    end

    ## print Nup, Magnetization, and energy target, mostly for debugging
    println("# of clues: ", length(clues))
    Nup = 81 - length(clues)
    println("Expected # of spin ups: ", Nup)
    println("The energy target is ", -en_target)
    totalSZ = (Nup - 0.5*N)
    println("Expected total SZ: ", totalSZ)
    
    ## start from the product state, equal weight superposition of all basis states
    psi = MPS(sites, ["-" for n in 1:N] )
    println("Starting from |->^N state:")
    
    ## start from some random mps, here the bond dimension can be tuned separately
    #psi = random_mps(sites; linkdims=3) ## 3
    #println("Starting from random mps state:")    
    
    # saving the initial MPS, if needed, modify the filename
    #f = h5open("init-mps-save-j12.h5","w")
    #write(f,"psi",psi)
    #close(f)

    # random hx, if needed
    rd = hx_random_scale*randn(numsteps,N)

    ### saving the random hx, if needed
    #open("rand-save-jl2.txt", "w") do io
    #       writedlm(io, rd)
    #end
    
    del = 1/numsteps
    en_list = zeros(Float64, numsteps)
    sz_list = zeros(Float64, numsteps)
    sx_list = zeros(Float64, numsteps)
    idx = 1
    for step in 1:numsteps
        sp = 1 - step/numsteps
        println("driving sp  = ", sp)

        ## set up the transverse Hamiltonian with added randomness
        h2 = OpSum()
        for i=1:N
            h2 += (hx_const+rd[idx,i], "Sx", i) ## overall scale, plust hx_random_scale*random_number ## 
        end
    
        H = MPO( (1.0-sp)*h1 + sp*h2, sites) ### used in most

        # eigsolve_krylovdim=7, noise=1E-5 used for many 
        energy,psi = dmrg(H,psi;nsweeps=numsweeps,maxdim=bond_dimension,noise=0.003) 
        en_list[idx] = energy + en_target

        # computer the Sz and Sx, at each site and also the total
        #sigz = expect(psi, "Sz")
        #sz_list[idx] = sum(sigz)
        # println("sum Z = ", sum(sigz))
        # println("Z = ", sigz)
        #sigx = expect(psi, "Sx")
        #sx_list[idx] = sum(sigx)
        # println("sum X = ", sum(sigx))
        # println("sum X = ", sigx)

        idx = idx + 1
    end

    println("EN: ", en_list)
    println("SZ: ", sz_list)
    println("SX: ", sx_list)

    ### compute the expectation value of Sz
    Z = expect(psi, "Sz")
    #println("sum Z = ", sum(Z))
    
    ### pass Z, construct the full qubo binary vector, and display the sudoku board
    display_solution(Z, N, list_0, list_1, kappa)
    
    return
end

# ## NY times, Jan 8, 2025, medium level, with 24 clues! 

# jan8m
cjan8m = [(1, 3, 4), (1, 6, 1), (1, 7, 7), (2, 2, 2), (2, 7, 9), (3, 1, 6),
        (3, 3, 8), (4, 4, 9), (5, 1, 7), (5, 5, 4), (5, 7, 5), (5, 8, 3),
        (6, 1, 2), (6, 2, 9), (6, 3, 3), (6, 8, 7), (7, 3, 6), (7, 8, 1),         
        (8, 4, 8), (8, 6, 3), (8, 8, 5), (9, 1, 1), (9, 2, 5)]


tensor_train_express(cjan8m,hx_const=1.5,hx_random_scale=0.3,bond_dimension=20,numsteps=5,numsweeps=5)

The total number of spins is 232
# of clues: 23
Expected # of spin ups: 58
The energy target is -461.5
Expected total SZ: -58.0
Starting from |->^N state:
driving sp  = 0.9
After sweep 1 energy=-171.32742649640318  maxlinkdim=15 maxerr=4.15E-04 time=1.178
After sweep 2 energy=-171.33586218383203  maxlinkdim=15 maxerr=4.15E-04 time=1.938
After sweep 3 energy=-171.34021966487563  maxlinkdim=15 maxerr=4.16E-04 time=2.110
After sweep 4 energy=-171.3433415445993  maxlinkdim=15 maxerr=4.16E-04 time=2.128
After sweep 5 energy=-171.3454514952473  maxlinkdim=15 maxerr=4.16E-04 time=2.857
driving sp  = 0.8
After sweep 1 energy=-184.31612800919635  maxlinkdim=15 maxerr=3.57E-04 time=1.851
After sweep 2 energy=-184.3226518111536  maxlinkdim=15 maxerr=3.53E-04 time=2.180
After sweep 3 energy=-184.32749671195984  maxlinkdim=15 maxerr=3.52E-04 time=2.650
After sweep 4 energy=-184.33384042168964  maxlinkdim=15 maxerr=3.52E-04 time=2.130
After sweep 5 energy=-184.3367632119078  maxlinkdim=15 maxerr=3.5

9×9 Matrix{Int64}:
 9  3  4  6  2  1  7  8  5
 5  2  7  3  8  4  9  6  1
 6  1  8  7  7  9  3  4  4
 8  4  5  9  3  7  1  2  6
 7  6  1  2  4  8  5  3  9
 2  9  3  1  5  6  4  7  8
 3  8  6  5  9  2  4  1  7
 4  7  9  8  1  3  6  5  2
 1  5  2  4  6  7  8  9  3