# Notes on backtracking 

References:
1. https://cs.lmu.edu/~ray/notes/backtracking/ 
2. https://stackoverflow.com/questions/59121447/backtracking-to-find-n-element-vectors-whose-elements-add-up-to-less-than-k
3. Skiena, p231

In [6]:
import numpy as np 
from functools import partial



## Problem Statement

**Problem: find all n-element vectors such that the sum of their elements is less than or equal to some number K. Each element in the vector is an integer.**

For example, let's set n = 3 and K = 10. Then: 

[1, 1, 1] is a solution

[9, 0, 0] is a solution

[9, 9, 0] is NOT at solution

[5, 5, 1] is NOT at solution

## Recursive implementation of backtracking

In [7]:
def solve(values, safe_up_to, size):
    """Finds a solution to a backtracking problem.

    values     -- a sequence of values to try, in order. For a map coloring
                  problem, this may be a list of colors, such as ['red',
                  'green', 'yellow', 'purple']
    safe_up_to -- a function with two arguments, solution and position, that
                  returns whether the values assigned to slots 0..pos in
                  the solution list, satisfy the problem constraints.
    size       -- the total number of “slots” you are trying to fill

    Return the solution as a list of values.
    """
    solution = [None] * size

    def extend_solution(position):
        for value in values:
            solution[position] = value
            if safe_up_to(solution):
                if position >= size-1:
                    yield np.array(solution)
                else: 
                    yield from extend_solution(position+1)
        solution[position] = None

    return extend_solution(0)




Note that one of the args to the "engine" is a custom function specific to the problem, named "safe_up_to( )". 

Here's the custom function we define: 

In [8]:
def safe_up_to(target, partial_solution): 
    """
    Checks that a partial solution (string of numerals) sums to less than 10
    
    Partial soln is passed to the function as a list: e.g. [1, 5]
    
    """
    partial_solution = np.array(partial_solution)  # convert to np array 
    
    # replace None with NaN
    partial_solution = np.where(partial_solution == None, np.nan, partial_solution)
    
    if np.nansum(partial_solution) <= target: 
        return True
    else: 
        return False 
    
    

Finally, here is how we combine the two to get a solution. 

### Print all solutions 

In [9]:
# Find all 7-element vectors such that their elements sum to 4 or less (each element is a 1-digit integer): 
for sol in solve(values=range(10), safe_up_to=partial(safe_up_to, 4), size=7):
    print(sol, sol.sum())

[0 0 0 0 0 0 0] 0
[0 0 0 0 0 0 1] 1
[0 0 0 0 0 0 2] 2
[0 0 0 0 0 0 3] 3
[0 0 0 0 0 0 4] 4
[0 0 0 0 0 1 0] 1
[0 0 0 0 0 1 1] 2
[0 0 0 0 0 1 2] 3
[0 0 0 0 0 1 3] 4
[0 0 0 0 0 2 0] 2
[0 0 0 0 0 2 1] 3
[0 0 0 0 0 2 2] 4
[0 0 0 0 0 3 0] 3
[0 0 0 0 0 3 1] 4
[0 0 0 0 0 4 0] 4
[0 0 0 0 1 0 0] 1
[0 0 0 0 1 0 1] 2
[0 0 0 0 1 0 2] 3
[0 0 0 0 1 0 3] 4
[0 0 0 0 1 1 0] 2
[0 0 0 0 1 1 1] 3
[0 0 0 0 1 1 2] 4
[0 0 0 0 1 2 0] 3
[0 0 0 0 1 2 1] 4
[0 0 0 0 1 3 0] 4
[0 0 0 0 2 0 0] 2
[0 0 0 0 2 0 1] 3
[0 0 0 0 2 0 2] 4
[0 0 0 0 2 1 0] 3
[0 0 0 0 2 1 1] 4
[0 0 0 0 2 2 0] 4
[0 0 0 0 3 0 0] 3
[0 0 0 0 3 0 1] 4
[0 0 0 0 3 1 0] 4
[0 0 0 0 4 0 0] 4
[0 0 0 1 0 0 0] 1
[0 0 0 1 0 0 1] 2
[0 0 0 1 0 0 2] 3
[0 0 0 1 0 0 3] 4
[0 0 0 1 0 1 0] 2
[0 0 0 1 0 1 1] 3
[0 0 0 1 0 1 2] 4
[0 0 0 1 0 2 0] 3
[0 0 0 1 0 2 1] 4
[0 0 0 1 0 3 0] 4
[0 0 0 1 1 0 0] 2
[0 0 0 1 1 0 1] 3
[0 0 0 1 1 0 2] 4
[0 0 0 1 1 1 0] 3
[0 0 0 1 1 1 1] 4
[0 0 0 1 1 2 0] 4
[0 0 0 1 2 0 0] 3
[0 0 0 1 2 0 1] 4
[0 0 0 1 2 1 0] 4
[0 0 0 1 3 0 0] 4
[0 0 0 2 0

### Save all solutions in a list

In [29]:
# Find all 7-element vectors such that their elements sum to 4 or less (each element is a 1-digit integer): 
list_of_solutions = []
for sol in solve(values=range(10), safe_up_to=partial(safe_up_to, 4), size=7):
    list_of_solutions.append(sol)
    
len(list_of_solutions)

330

# Todo: 

1. Explain how `partial( )` function from `functools` module works. 
2. Why use `yield` in a function instead of `return`? 
3. What class of object does solve( ) return if we call without a for loop? 
