In [4]:
%pip install --upgrade pip

In [5]:
%pip install gevent

In [6]:
from gevent import monkey
monkey.patch_all() # required for gevent to work properly in Jupyter notebooks

In [9]:
%pip install --upgrade celery[redis] numpy scipy numba matplotlib tqdm ipympl ffmpeg watermark circuitree
%load_ext watermark
seed = 2024

# Parallel MCTS with CircuiTree

MCTS is an iterative sampling algorithm, where the reward found in each iteration affects sampling in later iterations. While perfect parallel execution isn't possible, we can achieve quite good performance using the so-called lock-free method [[1]](https://doi.org/10.1007/978-3-642-12993-3_2), where multiple multiple search threads in the same CPU (the *main node*) are running MCTS concurrently, each one taking turns editing the search graph. We will implement this in detail later in the tutorial, but in brief, instead of computing the (usually expensive) reward function, each search thread on the main node sends a request to a group of worker CPUs (the *worker node*) somewhere else that will do the actual computation, and while that thread is waiting for the result, other search threads can use the main CPU. As long as our execution time is significantly longer than the time spent sending and receiving those signals, we should see a performance boost!

First let's watch a parallel search in action using an example case, a parallelized version of the "bistability" circuit search from Tutorial 1. Here, we will make each reward evaluation take 0.1 seconds longer by setting the flag `expensive=True`.

In [8]:
from tutorial_2_parallel_example import ParallelBistabilityTree
from time import perf_counter

# Create the tree search object
tree = ParallelBistabilityTree(root="ABC::")

# # Run the search sequentially, with an expensive reward function (17 minutes)
# tree.search_mcts(
#     n_steps=10_000, run_kwargs=dict(expensive=True)  
# )

# # Search in parallel with 50 threads (ideally, 64x faster = 16 seconds)
# start_time = perf_counter()
# tree.search_mcts_parallel(
#     n_steps=10_000, n_threads=64, run_kwargs=dict(expensive=True)
# )
# end_time = perf_counter()

# print("Done!")
# print(f"Elapsed time: {end_time - start_time:.2f} seconds")

## Parallel CircuiTree on a single machine

In order to parallelize the search on a local machine, we can nominate a group of CPUs in our own computer to be the worker node that performs reward function evaluations. We can coordinate the main and worker nodes using a *producer-consumer* queue. The main node will produce tasks (calls to the reward function) that get added to the queue, and the worker node will consume tasks from the queue and return the result to a shared database where the main node can look up the result. We'll manage this task queue with the Python utility `celery`. 

Here's a schematic of how that infrastructure looks.

![Local-Infrastructure](./local_parallel_infrastructure.png)

### The 4 steps to setting up a local parallel search
1) Set up a simple database.
2) Package the reward function into a `celery` app.
3) Define a `CircuiTree` subclass that calls the reward function in (2).
4) Launch some workers.

### 1. Database installation

We will be using a lightweight database called Redis (https://redis.io/).

If you are running this notebook on Colab or on a machine without Redis installed, you can uncomment and run the next code block to install Redis. Otherwise, please skip the next code block and follow the installation instructions [here](https://redis.io/docs/latest/operate/oss_and_stack/install/install-redis/) instead. 

If you are using a Redis server hosted somewhere else, you can skip the next code block and change the `host` and `port` arguments later in the notebook to point to your server.

In [4]:
####################################################################################
### If you are using Colab, uncomment and run this :) ##############################
####################################################################################
# # Download the latest stable release and make from source (can take a few minutes)
# !curl -o ./redis-server.tar.gz -fsSL https://download.redis.io/redis-stable.tar.gz
# !tar -xf ./redis-server.tar.gz
# !cd ./redis-stable && make # Can take a while (5+ minutes)
# !/content/redis-stable/src/redis-server --daemonize yes

**Be sure to test your installation!!**

In [11]:
## This should return "PONG"

# Colab notebook users
# !/content/redis-stable/src/redis-cli ping 

# Local installations
!redis-cli ping 

### 2. Making a `celery` app with the reward function 
The app is a Python script that tells `celery` where the database is and which tasks it will be managing. For instance, here is the script for the bistability app.

In [12]:
from pathlib import Path

print(Path("tutorial_app.py").read_text())

We use the `Celery` command to create an app that uses the `Redis` database to pass messages (the `broker` option) and store results (the `backend` argument). The URL here points to the default location for a local database (port `6379` on the `localhost` network). Any function with the `@app.task` decorator becomes a `celery` *task* that can be executed by a worker - we'll see how this looks in the next section.

### 3. Calling the reward function as a `celery` task

Unlike a normal function call, a call to a `celery` task is *asynchronous*. This means that when the main node calls the function, it dispatches a task to the workers, and the result can be requested later. This uses different syntax - instead of running `reward = get_reward(...)` directly, we run `result = get_reward_celery.delay(...)` to dispatch the task from the main node to the workers. This immediately returns an `AsyncResult` object that can be inspected to monitor progress. Then, once we need the result, we call `future.get()` and wait for the reward to arrive. While one thread is waiting for the reply, another thread can take over the main node and run a search iteration. 

All we need to do in this step is make a new subclass of `CircuiTree` that runs the reward function using the app. Here's what that looks like in our bistability example.

In [7]:
print(Path("tutorial_2_parallel_example.py").read_text())

Python's `threading` module can manage up to a few dozen threads, but we want to run a search with hundreds to thousands of threads. For this, we will use the `gevent` module, which re-defines many of the built-in Python commands in order to support its highly scalable "green threads." Re-defining built-in code is called "monkey patching," and it has to be run as the first line in the file where we define the class. (We also ran `monkey.patch_all()` at the beginning of this notebook - this is only necessary for notebooks, not for scripts.)

### 4. Launching a worker node

We can launch a worker node using `celery`'s command line interface. To do so, open a separate terminal, `cd` to the folder with the app, and run the following command, replacing the `XX` with the number of CPUs to use. If you are using a virtual environment, be sure to activate that first. (If you aren't, you should be!)

```
# Launch a worker with 'XX' CPUs, specifying the app with the `.app` suffix
celery --app tutorial_app.app worker --concurrency=XX 
```

---