# Run learners in job scripts

## Define the learners

We need the following variables:
* `learners` a list of learners
* `combos` a list of dicts of parameters that describe each learner
* `fnames` a list of filenames of each learner

In [None]:
%%writefile _learners.py

import adaptive
from functools import partial

import funcs

syst_pars = dict(a=4, L=40, r=20, shape="square", dim=3)

params = dict(g=50, mu=100, B_y=0, B_x=0, **funcs.constants_InAs)

Ls = [1000, 2000, 3000, 5000, 10000]
l_Rs = [200] #np.geomspace(30, 1000, 10).tolist() + [np.inf]
l_es = [20, 50, 100, 200, 300, 500]
rs = [25]

combos = adaptive.utils.named_product(l_e=l_es, L=Ls, r=rs, l_R=l_Rs)

learners = []
fnames = []
folder = "data/q_phi_scaling_square_wire_new/"
for combo in combos:
    f = partial(
        funcs.conductance_1D,
        x_name='B_z',
        value_dict=combo,
        syst_pars=syst_pars,
        params=params,
    )
    learner = adaptive.AverageLearner1D(f, bounds=(0, 0.25))
    learner.average_priority = 1
    learner.min_seeds_per_point = 20
    fnames.append(f"{folder}_{combo}")
    learners.append(learner)

learner = adaptive.BalancingLearner(learners, strategy='cycle')

In [None]:
# Execute the previous code block and plot the learners
from _learners import *
adaptive.notebook_extension()
learner.load(fnames)
learner.plot()

## Define helper functions for both the server (headnode) and the client (nodes)

In [None]:
%%writefile slurm.py

import getpass
import subprocess
import textwrap

def make_sbatch(name, cores, executable="run_learner.py", env="py37_min"):
    job_script = textwrap.dedent(
        f"""\
        #!/bin/bash
        #SBATCH --job-name {name}
        #SBATCH --ntasks {cores}
        #SBATCH --output {name}.out
        #SBATCH --no-requeue

        export MKL_NUM_THREADS=1
        export OPENBLAS_NUM_THREADS=1
        export OMP_NUM_THREADS=1

        export MPI4PY_MAX_WORKERS=$SLURM_NTASKS
        srun -n $SLURM_NTASKS --mpi=pmi2 ~/miniconda3/envs/{env}/bin/python3 -m mpi4py.futures {executable}
        """
    )
    return job_script


def check_running(me_only=True):
    cmd = [
        "/usr/bin/squeue",
        r'--Format=",jobid:100,name:100,state:100,numnodes:100,reasonlist:400,"',
        "--noheader",
        "--array",
    ]
    if me_only:
        username = getpass.getuser()
        cmd.append(f"--user={username}")
    proc = subprocess.run(cmd, text=True, capture_output=True)
    squeue = proc.stdout

    if (
        "squeue: error" in squeue
        or "slurm_load_jobs error" in squeue
        or proc.returncode != 0
    ):
        raise RuntimeError("SLURM is too busy.")

    squeue = [line.split() for line in squeue.split("\n")]
    squeue = [line for line in squeue if line]
    allowed = ("PENDING", "RUNNING")
    running = {
        job_id: dict(
            job_name=job_name,
            state=state,
            n_nodes=int(n_nodes),
            reason_list=reason_list,
        )
        for job_id, job_name, state, n_nodes, reason_list in squeue
        if state in allowed
    }
    return running

In [None]:
%%writefile server_support.py

import asyncio
import os
import subprocess
import time

from tinydb import TinyDB, Query
import zmq
import zmq.asyncio
from concurrent.futures import ProcessPoolExecutor

from slurm import make_sbatch, check_running


ctx = zmq.asyncio.Context()


def dispatch(request, db_fname):
    request_type, request_arg = request

    if request_type == "start":
        job_id = request_arg  # workers send us their slurm ID for us to fill in
        # give the worker a job and send back the fname and combo to the worker
        return choose_combo(db_fname, job_id)

    elif request_type == "stop":
        fname = request_arg  # workers send us the fname they were given
        return done_with_learner(db_fname, fname)  # reset the job_id to None

    else:
        print(f"unknown request type: {request_type}")


async def manage_database(address, db_fname):
    socket = ctx.socket(zmq.REP)
    socket.bind(address)
    try:
        while True:
            request = await socket.recv_pyobj()
            reply = dispatch(request, db_fname)
            await socket.send_pyobj(reply)
    finally:
        socket.close()


async def manage_jobs(job_names, db_fname, ioloop, cores=8, interval=30):
    with ProcessPoolExecutor() as ex:
        while True:
            running = check_running()
            update_db(db_fname, running)  # in case some jobs died
            running_job_names = {job["job_name"] for job in running.values()}
            for job_name in job_names:
                if job_name not in running_job_names:
                    await ioloop.run_in_executor(ex, start_job, job_name, cores)
            await asyncio.sleep(interval)


def create_empty_db(db_fname, fnames, combos):
    entries = [
        dict(fname=fname, combo=combo, job_id=None, is_done=False)
        for fname, combo in zip(fnames, combos)
    ]
    if os.path.exists(db_fname):
        os.remove(db_fname)
    with TinyDB(db_fname) as db:
        db.insert_multiple(entries)


def update_db(db_fname, running):
    """If the job_id isn't running anymore, replace it with None."""
    with TinyDB(db_fname) as db:
        doc_ids = [entry.doc_id for entry in db.all() if entry["job_id"] not in running]
        db.update({"job_id": None}, doc_ids=doc_ids)


def choose_combo(db_fname, job_id):
    Entry = Query()
    with TinyDB(db_fname) as db:
        entry = db.get(Entry.job_id == None)
        db.update({"job_id": job_id}, doc_ids=[entry.doc_id])
    return entry["fname"], entry["combo"]


def done_with_learner(db_fname, fname):
    Entry = Query()
    with TinyDB(db_fname) as db:
        db.update({"job_id": None, "is_done": True}, Entry.fname == fname)



def start_job(name, cores=8, *, job_script_function=make_sbatch):
    with open(name + ".sbatch", "w") as f:
        job_script = job_script_function(name, cores)
        f.write(job_script)

    returncode = None
    while returncode != 0:
        returncode = subprocess.run(
            f"sbatch {name}.sbatch".split(), stderr=subprocess.PIPE
        ).returncode
        time.sleep(0.5)

In [None]:
%%writefile client_support.py

import os
import zmq

from _learners import learners, combos

ctx = zmq.Context()

def get_learner(url):
    with ctx.socket(zmq.REQ) as socket:
        socket.connect(url)
        job_id = os.environ.get("SLURM_JOB_ID", "UNKNOWN")
        socket.send_pyobj(("start", job_id))
        fname, combo = socket.recv_pyobj()
    learner = next(lrn for lrn, c in zip(learners, combos) if c == combo)
    return learner, fname


def tell_done(url, fname):
    with ctx.socket(zmq.REQ) as socket:
        socket.connect(url)
        socket.send_pyobj(("stop", fname))
        socket.recv_pyobj()  # Needed because of socket type

## The Python script that is being run in the job

In [None]:
# Make sure to use the headnode's IP below.
import socket
import zmq.ssh
ip = socket.gethostbyname(socket.gethostname())
port = zmq.ssh.tunnel.select_random_ports(1)[0]
print(f'tcp://{ip}:{port}')

In [None]:
%%writefile run_learner.py

import adaptive
from mpi4py.futures import MPIPoolExecutor

import client_support

url = "tcp://10.76.0.5:57681"

if __name__ == "__main__":
    learner, fname = client_support.get_learner(url)
    learner.load(fname)
    ex = MPIPoolExecutor()
    runner = adaptive.Runner(
        learner,
        executor=ex,
        goal=None,
        shutdown_executor=True,
        ioloop=None,
        retries=10,
        raise_if_retries_exceeded=False,
    )
    runner.start_periodic_saving(dict(fname=fname), interval=600)
    runner.ioloop.run_until_complete(runner.task)  # wait until runner goal reached
    client_support.is_done(url, fname)

# Import the files that were created

In [None]:
import asyncio
from importlib import reload

from pprint import pprint
from tinydb import TinyDB

import server_support, _learners, run_learner

reload(server_support)
reload(_learners)
reload(run_learner)

db_fname = 'running.tinydb'

In [None]:
# Create a new database
server_support.create_empty_db(db_fname, _learners.fnames, _learners.combos)

## Check the running learners
All the onces that are `None` are still `PENDING` or are not scheduled.

In [None]:
with TinyDB(db_fname) as db:
    pprint(db.all())

## Start the job scripts

In [None]:
# Get some unique names for the jobs
job_names = [f"WAL-{i}" for i in range(len(_learners.learners))]

ioloop = asyncio.get_event_loop()

database_task = ioloop.create_task(
    server_support.manage_database("tcp://*:57681", db_fname)
)

job_task = ioloop.create_task(
    server_support.manage_jobs(job_names, db_fname, ioloop, cores=50*8, interval=60)
)

In [None]:
job_task.cancel(), database_task.cancel()

In [None]:
job_task.print_stack()

In [None]:
database_task.print_stack()