# Backtracking

Lets first start with a toy problem to motivate the need for backtracking


### Example - Generate Parentheses Pairs

Given n pairs of parentheses, write a function to generate all combinations of well-formed parentheses.
```
input: 
n=3

output:
[
  "((()))",
  "(()())",
  "(())()",
  "()(())",
  "()()()"
]
```

In [77]:
ans = []

def generate_parens_bf(n, current = []):
    if len(current) == 2*n:
        if valid(current):
            ans.append("".join(current))
    else:
        current.append('(')
        generate_parens_bf(n, current)
        current.pop()
        current.append(')')
        generate_parens_bf(n, current)
        current.pop()

def valid(current):
    bal = 0
    for c in current:
        if c == '(': 
            bal += 1
        else: 
            bal -= 1
        if bal < 0: 
            return False
    return bal == 0

generate_parens_bf(4)
print("Sample run n=3", ans)

Sample run n=3 ['(((())))', '((()()))', '((())())', '((()))()', '(()(()))', '(()()())', '(()())()', '(())(())', '(())()()', '()((()))', '()(()())', '()(())()', '()()(())', '()()()()']


But this soluiton is really inefficient and doesn't scale with large n.  Even an n of 10 is problematic here. 

Run time complexity: ```O(n * 2^(2n))```

### Backtracking Overview - 

Pseudo code

```
def find_solutions(n, other_params) :
    if (found a solution):
        # Save your solution
        solutions_found = solutions_found + 1

    for (val = first to last):
        if (is_valid(val, n)):
            apply_value(val, n)
            find_solutions(n+1, other_params)
            remove_value(val, n)
```            


In [64]:
def generate_parens_bt(n):
    ans = []
    def backtrack(current = '', left = 0, right = 0):
        if len(current) == 2 * n:
            ans.append(current)
            return
        
        # For loop part
        if left < n: # Is Valid
            backtrack(current+'(', left+1, right)
        if right < left: # Is valid
            backtrack(current+')', left, right+1)

    backtrack()
    return ans
    
print("Sample run n=3", generate_parens_bt(3))


Sample run n=3 ['((()))', '(()())', '(())()', '()(())', '()()()']


Much better!

Runtime complexity: ```O((4^n) / sqrt(n))```	

In [78]:
print("Brute Force")
%timeit -n1000 generate_parens_bf(5)
print("Backtracking")
%timeit -n1000 generate_parens_bt(5)

Brute Force
949 µs ± 12.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Backtracking
64.7 µs ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [258]:
board = [['.', '3', '.', '.', '.', '.', '.', '.', '.'], 
         ['.', '.', '.', '1', '.', '.', '.', '.', '.'], 
         ['.', '.', '.', '3', '.', '.', '.', '6', '.'], 
         ['.', '5', '9', '7', '.', '1', '4', '2', '.'], 
         ['4', '2', '6', '8', '5', '3', '7', '9', '1'], 
         ['.', '.', '3', '9', '.', '4', '8', '5', '6'], 
         ['.', '6', '1', '5', '.', '.', '.', '.', '4'], 
         ['.', '.', '7', '4', '1', '9', '6', '.', '5'], 
         ['.', '.', '.', '.', '8', '.', '.', '.', '9']]

# board = [['.', '.', '.', '.', '.', '.', '.', '.', '.'], 
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.'],
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.'],
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.'],
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.'],
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.'],
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.'],
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.'],
#          ['.', '.', '.', '.', '.', '.', '.', '.', '.']]

In [259]:
class Sudoku_Solver:
    
    def is_valid_square(self, board, row, col):
        row_square = (row // 3) * 3
        col_square = (col // 3) * 3
        
        contents = {}
        for i in range(row_square, row_square + 3):
            for j in range(col_square, col_square + 3):
                c = board[i][j]
                if c.isdigit() and c in contents:
                    return False
                contents[c] = 1
        return True
    
    def is_valid_row(self, board, row):
        contents = {}
        for c in board[row]:
            if c.isdigit() and c in contents:
                return False
            contents[c] = 1
        return True
            
    def is_valid_col(self, board, col):
        contents = {}
        for i in range(9):
            c = board[i][col]
            if c.isdigit() and c in contents:
                return False
            contents[c] = 1
        return True
    
    def backtrack(self, board, i=0, j=0):
        self.counter += 1
        # i = row
        # j = col

        # Check for valid soluiton
        valid = True
        for row in range(9):
            for col in range(9):
                if not board[row][col].isdigit():
                    valid = False
        
        if valid:
            self.solutions.append(board)
            return True
        
        if board[i][j].isdigit():
            if j == 8:
                return self.backtrack(board, i + 1, 0)
            else:
                return self.backtrack(board, i, j+1)
        
        # On a "." position
        for num in range(1, 10):
            board[i][j] = str(num)
            # check validity
            if self.is_valid_row(board, i) and self.is_valid_col(board, j) and self.is_valid_square(board, i, j):
                self.backtrack(board, i, j)
            
            board[i][j] = "."
        #print("End of function")
        return False
    
    def is_board_valid(self, board):
        for i in range(9):
            for j in range(9):
                if not (self.is_valid_row(board, i) and self.is_valid_col(board, j) and self.is_valid_square(board, i, j)):
                    return False

        return True
    
    def solve_sudoku(self, board):
        self.solutions = []
        self.counter = 0

        if self.is_board_valid(board):
            #self.board = board
            self.backtrack(board, 0, 0)
        else:
            print("board in invalid state initial board")
        return len(self.solutions)


In [261]:
solver = Sudoku_Solver()
print("Number solutions: ", solver.solve_sudoku(board))
print("Number of recursive calls:", solver.counter)

Number solutions:  23
Number of recursive calls: 193589
