# Scenario

We have a cluster that partially consists of preemptible ressources.  That is, we'll have to deal with workers suddenly being shut down during computation.

## A cluster

In [1]:
from dask.distributed import Client, LocalCluster

In [2]:
cluster = LocalCluster(threads_per_worker=1, n_workers=5, memory_limit=200e6)

In [3]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:36509  Dashboard: /user/dask-dask-examples-n2kj24go/proxy/8787/status,Cluster  Workers: 5  Cores: 5  Memory: 1000.00 MB


## Increase resilience

Whenever a worker shuts down, the scheduler will increment the suspicousness counter of _all_ tasks that were assigned (not necessarily computing) to the worker in question.  Whenever the suspiciousness of a task exceeds a certain threshold (3 bu default), the task will be considered broken.  We want to compute many tasks on only a few workers with workers shutting down randomly.  So we expect the suspiciousness of all tasks to grow rapidly.  Let's effectively disable the threshold:

In [4]:
cluster.scheduler.allowed_failures = int(1e32)

## A simple workload

We'll multiply a range of numbers by two, add some sleep to simulate some real work, and then reduce the whole sequence of doubled numbers by summing them.

In [5]:
N = 5000

In [6]:
from time import sleep

In [7]:
def multiply_by_two(x):
    sleep(0.02)
    return 2 * x

In [8]:
from dask import bag as db

In [9]:
x = db.from_sequence(range(N), npartitions=N // 10)
x

dask.bag<from_se..., npartitions=500>

In [10]:
mults = x.map(multiply_by_two)

In [11]:
summed = mults.reduction(sum, sum)

## Suddenly shutting down workers

Let's get all worker's pids:

In [12]:
import os
import random

In [13]:
def _get_worker_pids(cluster):
    return (v.pid for k, v in cluster.scheduler.workers.items())

In [14]:
worker_pids = list(_get_worker_pids(cluster))

In [15]:
print(worker_pids)

[13257, 13259, 13261, 13255, 13263]


Let's definbe two of them as non-preemptible.

In [16]:
non_preemptible_workers = worker_pids[:2]

In [17]:
def _get_preemptible_worker_pids(non_preemptible_workers, cluster):
    return filter(lambda p: p not in non_preemptible_workers, _get_worker_pids(cluster))

In [18]:
print(list(_get_preemptible_worker_pids(non_preemptible_workers, cluster)))

[13261, 13255, 13263]


Wrap shutting down preemptible workers in a function:

In [19]:
def maybe_kill_n_perc_of_workers_and_wait(prob_killing_worker=0.1, wait_for=2,
                                          non_preemptible_workers=None, cluster=None):
    for w in _get_preemptible_worker_pids(non_preemptible_workers, cluster):
        if random.random() < prob_killing_worker:
            os.kill(w, 15)
    sleep(wait_for)            

## Start the computation and keep shutting down workers while it runs

In [20]:
summed = summed.persist()

In [21]:
sleep(1)
while cluster.scheduler.tasks[summed.key].state != 'memory':
    maybe_kill_n_perc_of_workers_and_wait(
        prob_killing_worker=0.05, wait_for=0.5,
        non_preemptible_workers=non_preemptible_workers, cluster=cluster)



## Check if results match

In [22]:
print(summed.compute())
print(N * (N-1))  # Gauss' summation trick, starting at n=0,
                  # and accounting for the x2 applied above

24995000
24995000
