Incredibly importand `resources` definition in order not to let Dask assign multiple tasks (i.e. learning Tensorflow simulations) on one worker because it will die due to low RAM. See docs [Resources](https://distributed.dask.org/en/stable/resources.html) and relevant *stackoverflow* question [one task per worker](https://stackoverflow.com/questions/45052535/dask-distributed-how-to-run-one-task-per-worker-making-that-task-running-on-a).

In [1]:
from distributed import LocalCluster
import dask

with dask.config.set({"distributed.worker.resources.PROCESS": 1}):
    cluster = LocalCluster(
        n_workers=4,
        threads_per_worker=2,
        memory_limit='4GB'
    )

In [2]:
cluster

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 14.90 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:55983,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 14.90 GiB

0,1
Comm: tcp://127.0.0.1:56036,Total threads: 2
Dashboard: http://127.0.0.1:56038/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55986,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-wvqm0rdm,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-wvqm0rdm

0,1
Comm: tcp://127.0.0.1:56033,Total threads: 2
Dashboard: http://127.0.0.1:56034/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55987,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-m04ww1jx,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-m04ww1jx

0,1
Comm: tcp://127.0.0.1:56042,Total threads: 2
Dashboard: http://127.0.0.1:56045/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55988,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-t4fon4cz,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-t4fon4cz

0,1
Comm: tcp://127.0.0.1:56037,Total threads: 2
Dashboard: http://127.0.0.1:56043/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55989,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-tvf9dtn4,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-tvf9dtn4


In [3]:
from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 14.90 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:55983,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 14.90 GiB

0,1
Comm: tcp://127.0.0.1:56036,Total threads: 2
Dashboard: http://127.0.0.1:56038/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55986,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-wvqm0rdm,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-wvqm0rdm

0,1
Comm: tcp://127.0.0.1:56033,Total threads: 2
Dashboard: http://127.0.0.1:56034/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55987,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-m04ww1jx,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-m04ww1jx

0,1
Comm: tcp://127.0.0.1:56042,Total threads: 2
Dashboard: http://127.0.0.1:56045/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55988,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-t4fon4cz,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-t4fon4cz

0,1
Comm: tcp://127.0.0.1:56037,Total threads: 2
Dashboard: http://127.0.0.1:56043/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:55989,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-tvf9dtn4,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-tvf9dtn4


In [4]:
def worker_addr_generator():
    """ Simple generator to yield the address of each worker round robin """
    while True:
        for worker_addr in client.scheduler_info()['workers']:
            yield worker_addr

In [5]:
import tensorflow as tf
from dask import delayed

@delayed
def load_data():
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
    X_train, X_test = X_train / 255.0, X_test / 255.0

    return X_train, y_train, X_test, y_test

In [6]:
data_delayed = load_data()

In [7]:
client.upload_file('TF_Simulation_FDA_CNN.py')

{'tcp://127.0.0.1:56033': {'status': 'OK'},
 'tcp://127.0.0.1:56036': {'status': 'OK'},
 'tcp://127.0.0.1:56037': {'status': 'OK'},
 'tcp://127.0.0.1:56042': {'status': 'OK'}}

In [8]:
def worker_training_function(data_delayed, num_clients):
    import TF_Simulation_FDA_CNN as sim
    import gc

    X_train, y_train, X_test, y_test = data_delayed.compute()
    
    train_dataset, test_dataset = sim.convert_to_tf_dataset(X_train, y_train, X_test, y_test)
    
    all_epoch_metrics, all_round_metrics = sim.run_tests(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        num_clients_list=[num_clients],
        batch_size_list=[32],
        num_steps_until_rtc_check_list=[1],
        theta_list=[1.],
        num_epochs=1,
        sketch_width=500,
        sketch_depth=7
    )
    
    del X_train, y_train, X_test, y_test, train_dataset, test_dataset
    
    gc.collect()  # force garbage collection
    sim.tf.keras.backend.clear_session() # Clear TensorFlow session
    
    return all_epoch_metrics, all_round_metrics

In [9]:
live_worker_addr_gen = worker_addr_generator()

In [10]:
futures = []

In [11]:
for num_clients in [20, 5, 9, 15, 4, 3, 21, 12, 11, 16, 21]:
    future = client.submit(
        worker_training_function,
        data_delayed=data_delayed, 
        num_clients=num_clients,
        resources={'PROCESS': 1}  # Tell Dask that the resource `PROCESS` is consumed in one task!
    )

    futures.append(future)

In [None]:
results = client.gather(futures)

In [None]:
from itertools import chain

all_tests_epoch_metrics, all_tests_round_metrics = zip(*results)

all_epoch_metrics = chain.from_iterable(all_tests_epoch_metrics)  # flatten, careful, iterator
all_round_metrics = chain.from_iterable(all_tests_round_metrics)  # flatten, careful, iterator

In [None]:
import pandas as pd

epoch_metrics_df = pd.DataFrame(all_epoch_metrics)
round_metrics_df = pd.DataFrame(all_round_metrics)

In [None]:
epoch_metrics_df

In [None]:
round_metrics_df

In [None]:
client.close()
cluster.close()