## 15-puzzle solver

First, we should decide how to characterize states: For this I will write the three rows one by one. And to depict the empty space I will use 16.

Now for further analysis, as explained in the Medium post, it would be best to solve the puzzle row wise. First we can solve the first row, then the second row and then we can solve the third and 4th rows together. Otherwise, the number of states would be too large to do anything. Breaking tasks in this way reduces the number of states we need to deal with at a time greatly.
However it is worth noting that when we do this, we lose the optimality of the solution. That is, a method involving lower number of steps might be possible.

Now, building up the further high level idea, for each of the 3 parts, we should be able to model the MDP without too much difficulty. On this we can apply value iteration method and build a near optimal policy. Then we will take as input the puzzle to be solved, and one by one apply the three policies on it and hence end up solving it.

Now the idea is mostly covered but further key challenges/steps would likely be:
1. Creating the MDP. For this part, apart from the row to be dealt with, the empty space we treat everything else as zeroes. The number of states is huge and we'll have to automate the MDP creation process. We will keep reward only when it reaches a state where the required row is set properly.
2. Value iteration should be fairly simple once MDP is created.
3. Final implementational details of applying the policies on the input puzzle seem to be a bit complicated as the policies will have 0 in the place of remaining elements. I guess, in the policy we will stor actions as replacing of index i and index j, then we will apply that action on the input puzzle. Also, this makes me think it would be cool to have a separate function made to print the puzzle in nice form from the form I'll be using.

In [107]:
import numpy as np
from itertools import permutations


def printpuzzle(a):
    for i in range(4):  # Loop through 4 rows
        print(" | ".join(f"{a[j]:>3}" for j in range(4 * i, 4 * i + 4)))
        if i < 3:  # Add a horizontal separator between rows
            print("---+-----+-----+---")
    print()

In [108]:
def findindex(a):
    for i in range(0,16):
        if (a[i]==16):
            return i
def moveup(a, part):
    b=list(a)
    index=findindex(b)
    if (index<4*(1+part)):
        return tuple(b)
    else:
        b[index]=b[index-4]
        b[index-4]=16
        return tuple(b)
def movedown(a, part):
    b=list(a)
    index=findindex(b)
    if (index>=(16-4*(1+part))):
        return tuple(b)
    else:
        b[index]=b[index+4]
        b[index+4]=16
        return tuple(b)
def moveleft(a):
    b=list(a)
    index=findindex(b)
    if ((index%4)==0):
        return tuple(b)
    else:
        b[index]=b[index-1]
        b[index-1]=16
        return tuple(b)
def moveright(a):
    b=list(a)
    index=findindex(b)
    if ((index%4)==3):
        return tuple(b)
    else:
        b[index]=b[index+1]
        b[index+1]=16
        return tuple(b)
def move(s, a, which_part):
    if (s=="up"):
        return moveup(a, which_part)
    elif (s=="down"):
        return movedown(a, which_part)
    elif (s=="left"):
        return moveleft(a)
    else:
        return moveright(a)


In [109]:
actions=["up","down","left","right"]

In [110]:
def terminality(which_part):
    if (which_part==0):
        return lambda a: (a[0]==1 and a[1]==2 and a[2]==3 and a[3]==4)
    elif (which_part==1): 
        return lambda a: (a[4]==5 and a[5]==6 and a[6]==7 and a[7]==8)
    else:
        return lambda a: (a[8]==9 and a[9]==10 and a[10]==11 and a[11]==12 and a[12]==13 and a[13]==14 and a[14]==15 and a[15]==16)

def presence(which_part):
    if (which_part==0):
        return lambda a: ((1<=a and a<=4) or a==16)
    elif (which_part==1):
        return lambda a: ((5<=a and a<=8) or a==16)
    else:
        return lambda a: ((9<=a and a<=16))

In [111]:
def createstates(part):
    states=[]
    if (part==0 or part==1):
        tuples=list(permutations(range(4*part,16), 5))
        for tups in tuples:
            tu=[0]*16
            for i in range(4):
                tu[tups[i]]=1+i+(4*part)
            tu[tups[4]]=16
            tupletu=tuple(tu)
            states.append(tupletu)
    else:
        tuples=list(permutations(range(8,16), 8))
        for tups in tuples:
            tu=[0]*16
            for i in range(8):
                tu[tups[i]]=9+i
            tupletu=tuple(tu)
            states.append(tupletu)
    return states


def createmdp(states, which_part):
    mdp={}
    isterminal=terminality(which_part)
    for state in states:
        transitions={}
        for action in actions:
            new_arr=move(action,state, which_part)
            if (isterminal(new_arr)):
                transitions[action]=[(tuple(new_arr),1)]
            else:
                transitions[action]=[(tuple(new_arr),0)]
        mdp[state]=transitions
    return mdp


In [112]:
def createpolicy(states, mdp):
    theta=0.0001
    delta=69
    values = {}
    pi = {}
    gamma=0.9
    while (delta>theta):
        delta=0
        for state in states:
            v = values.get(state, 0)
            m=0
            for action in mdp[state]:
                new_value=0
                for case in mdp[state][action]:
                    new_value+=((case[1]+(gamma*values.get(case[0],0))))
                if (new_value>m):
                    m=new_value
                    pi[state]=action
            values[state]=m
            delta=max(delta,abs(v-values.get(state,0)))
    return pi

In [113]:
def solve(pi,which_part,inp):
    isterminal=terminality(which_part)
    pres=presence(which_part)
    
    ou=[0]*16
    for i in range(0,16):
        if (pres(inp[i])):
            ou[i]=inp[i]
    while(not isterminal(ou)):
        out=tuple(ou)
        act=pi.get(out, "up")
        ou=list(move(act,out,which_part))
        inp=list(move(act,inp, which_part))
        printpuzzle(inp)
    return inp

In [114]:
def running(inp,which_state):
    states=createstates(which_state)
    mdp=createmdp(states,which_state)
    pi=createpolicy(states,mdp)
    ind=solve(pi,which_state,inp)
    return ind

In [115]:
#inp=[0]*16
#for i in range(16):
#   inp[i]=int(input())
inp=[7,8,4,5,3,12,15,1,14,13,2,9,10,11,16,6]
printpuzzle(inp)
ina=running(inp,0)
inb=running(ina,1)
inc=running(inb,2)