# Notes
- simple article: https://www.geeksforgeeks.org/introduction-to-beam-search-algorithm/
- many beam search descriptions are given in terms of NLP context
- I think A* might not be feasible for large graphs, but I'm not sure...
- It'd be nice to code up the Clifford problem in a nice object-oriented with that made the graph structure abstracted, so I could use any old graph traversal algorithm

```
Start 
Take the inputs 
NODE = Root_Node & Found = False
If : Node is the Goal Node,
     Then Found = True, 
Else : 
     Find SUCCs of NODE if any, with its estimated cost&
     store it in OPEN List 
While (Found == false & not able to proceed further), do
{
     Sort OPEN List
     Select top W elements from OPEN list and put it in
     W_OPEN list and empty the OPEN list.
     for each NODE from W_OPEN list
     {
         if NODE = Goal,
             then FOUND = true 
         else
             Find SUCCs of NODE. If any with its estimated
             cost & Store it in OPEN list
     }
}
If FOUND = True,
    then return Yes
else
    return No
Stop
```

In [1]:
import tqdm
import numpy as np
import torch
import rubiks.clifford as cl
import rubiks.lgf as lgf
from qiskit.quantum_info import Clifford

In [2]:
num_qubits = 3

SEED = 123
use_qiskit = False
device = torch.device('cpu')
drop_phase_bits = True
scaling = 'log-linear'
data_dir = f"data/data_n_{num_qubits}_drop_phase_bits_scaling_{scaling}/"
high = cl.max_random_sequence_length(num_qubits, scaling)

lgf_model = lgf.LGFModel(
    num_qubits=num_qubits,
    device=device,
    rng=np.random.default_rng(SEED),
    hidden_layers=[32, 16, 4, 2],
    drop_phase_bits=drop_phase_bits,
    use_qiskit=use_qiskit,
).to(device)

lgf_model.load_state_dict(torch.load(data_dir + "checkpoint"))

<All keys matched successfully>

In [65]:
def beam_search(initial_state: Clifford, beam_width=1, max_count=100):

    # initialize the starting node
    problem = cl.Problem(
        lgf_model.num_qubits,
        high=high,
        seed=2305843009213693951 + 0,
        initial_state=initial_state,
    )
    num_moves = len(problem.move_set)
    identity_array = 1 * cl.sequence_to_tableau([], problem.num_qubits).tableau

    # perform the search
    beam = [problem.to_bitstring()]
    visited_nodes = {problem.to_bitstring()}
    count = 0
    while (len(beam) > 0) and (count < max_count):

        # loop over nodes in current beam (maximum of beam_width items)
        next_beam = []
        next_beam_lgf_values = []
        for node_str in beam:

            # grab the neighbors
            node = cl.Problem(num_qubits=num_qubits, initial_state=node_str)
            neighbors = node.generate_neighbors() # (M, 2N, 2N+1) array
            
            # compute neighbors and check for solution
            for i_move in range(num_moves):
                neighbor = cl.Problem(num_qubits=num_qubits, initial_state=Clifford(neighbors[i_move]))
                neighbor_str = neighbor.to_bitstring()
                next_beam.append(neighbor_str)

                # check for solution
                if np.array_equal(neighbors[i_move], identity_array):
                    return count, visited_nodes
                
            # compute LGF values for neighbors
            neighbors_torch = torch.tensor(neighbors, dtype=torch.float32)
            if drop_phase_bits:
                neighbors_torch = neighbors_torch[:, :, :-1] # drop phase bits if necessary
            neighbors_torch = torch.flatten(neighbors_torch, start_dim=1)
            with torch.no_grad():
                lgf_of_neighbors = lgf_model.forward(neighbors_torch).numpy()
            next_beam_lgf_values.extend(list(lgf_of_neighbors))

        # discard previously visited nodes and update beam
        next_beam_lgf_values = [lgf_val for i, lgf_val in enumerate(next_beam_lgf_values) if next_beam[i] not in visited_nodes]
        next_beam = [beam for beam in next_beam if beam not in visited_nodes]
        beam = [x for _, x in sorted(zip(next_beam_lgf_values, next_beam))][:beam_width]

        # update visited nodes
        visited_nodes.update(next_beam)

        count += 1

    return count, None

- Put this in the same format as other hillclimbing function
- Compare for different values of beam_width
    - should agree with previous results for beam_width=1
    - should be better for larger beam_widths
- Can it be made faster?

In [68]:
seq_length = 10
initial_state = cl.sequence_to_tableau(cl.random_sequence(np.random.default_rng(), seq_length=seq_length, num_qubits=num_qubits), num_qubits)
count, out = beam_search(initial_state=initial_state, beam_width=3, max_count=100)

if out is None:
    print(f"No solution found after {count} steps")
else:
    print(f"Solution found after {count} steps")

Solution found after 4 steps
