### Cluster Init

Since the cluster workers have CPU (4 cores) Dask will try to assign 4 tasks on a single worker (running in parallel). First of all, Dask does not know that a single Task (which is a Tensorflow simulation) will likely utilize 4 cores anyway, and more importandly it does not take into account the very limited RAM (~3GB) each worker has. Hence, the workers will run out of memory if we do not do something about this. 

We can use the `resources` functionality to define custom resources of our workers. We define `PROCESS` resource which we assign to be one. When we later `.submit` tasks we will inform Dask that on a worker a single task uses all of the worker's `PROCESS` resource, i.e., `{"PROCESS" : 1}` so that Dask will not assign another Task to this worker. See docs [Resources](https://distributed.dask.org/en/stable/resources.html) and relevant *stackoverflow* question [one task per worker](https://stackoverflow.com/questions/45052535/dask-distributed-how-to-run-one-task-per-worker-making-that-task-running-on-a).

Note: Dask obviously does not understand what `PROCESS` resrouce means, it is conceptual; it just knows that this arbitrary resource named `PROCESS` has one (it could be GPU resource, CPU, RAM whatever we think it is).

In [1]:
from distributed import LocalCluster
import dask

with dask.config.set({"distributed.worker.resources.PROCESS": 1}):
    cluster = LocalCluster(
        n_workers=2,
        threads_per_worker=4,
        memory_limit='9GB'
    )

In [2]:
cluster

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 2
Total threads: 8,Total memory: 16.76 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:60572,Workers: 2
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 16.76 GiB

0,1
Comm: tcp://127.0.0.1:60583,Total threads: 4
Dashboard: http://127.0.0.1:60584/status,Memory: 8.38 GiB
Nanny: tcp://127.0.0.1:60575,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-bat4_3dw,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-bat4_3dw

0,1
Comm: tcp://127.0.0.1:60586,Total threads: 4
Dashboard: http://127.0.0.1:60587/status,Memory: 8.38 GiB
Nanny: tcp://127.0.0.1:60576,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-eplupafd,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-eplupafd


### Client Init

In [3]:
from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 2
Total threads: 8,Total memory: 16.76 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:60572,Workers: 2
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 16.76 GiB

0,1
Comm: tcp://127.0.0.1:60583,Total threads: 4
Dashboard: http://127.0.0.1:60584/status,Memory: 8.38 GiB
Nanny: tcp://127.0.0.1:60575,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-bat4_3dw,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-bat4_3dw

0,1
Comm: tcp://127.0.0.1:60586,Total threads: 4
Dashboard: http://127.0.0.1:60587/status,Memory: 8.38 GiB
Nanny: tcp://127.0.0.1:60576,
Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-eplupafd,Local directory: C:\Users\miket\AppData\Local\Temp\dask-worker-space\worker-eplupafd


In [4]:
memory_configs = dask.config.get("distributed.worker.memory")
for key, value in memory_configs.items():
    print(f"{key}: {value}")

recent-to-old-time: 30s
rebalance: {'measure': 'optimistic', 'sender-min': 0.3, 'recipient-max': 0.6, 'sender-recipient-gap': 0.1}
transfer: 0.1
target: 0.6
spill: 0.7
pause: 0.8
terminate: 0.95
max-spill: False
monitor-interval: 100ms


### Load Data Lazily

In [5]:
import tensorflow as tf
from dask import delayed

@delayed
def load_data():
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
    X_train, X_test = X_train / 255.0, X_test / 255.0

    return X_train, y_train, X_test, y_test

In [6]:
data_delayed = load_data()

In [7]:
client.upload_file('TF_Simulation_FDA_CNN.py')

{'tcp://127.0.0.1:60583': {'status': 'OK'},
 'tcp://127.0.0.1:60586': {'status': 'OK'}}

In [8]:
from dask.distributed import get_worker

def worker_single_fda_simulation(data_delayed, fda_name, num_clients, batch_size, num_steps_until_rtc_check, 
                                 theta, num_epochs, sketch_width=-1, sketch_depth=-1, bench_test=False):
    
    import psutil
    
    process = psutil.Process()
    
    logs = {"worker" : get_worker().address, "bench_test" : bench_test, "fda_name" : fda_name, "num_clients" : num_clients, "num_epochs" : num_epochs, "batch_size" : batch_size}

    logs["mem_usage_start"] = process.memory_info().rss / (1024 * 1024 * 1024)  # Memory usage in GiB
    
    import TF_Simulation_FDA_CNN as sim
    import gc
    
    X_train, y_train, X_test, y_test = data_delayed.compute()
    
    train_dataset, test_dataset = sim.convert_to_tf_dataset(X_train, y_train, X_test, y_test)
    
    del X_train, y_train, X_test, y_test
    
    logs["mem_usage_before_simulation"] = process.memory_info().rss / (1024 * 1024 * 1024)  # Memory usage in GiB
    
    epoch_metrics, round_metrics = sim.single_simulation(
        fda_name, num_clients, train_dataset, test_dataset, batch_size, num_steps_until_rtc_check,
        theta, num_epochs, sketch_width=sketch_width, sketch_depth=sketch_depth, bench_test=bench_test
    )
    
    logs["mem_usage_after_simulation"] = process.memory_info().rss / (1024 * 1024 * 1024)  # Memory usage in GiB
    
    del train_dataset, test_dataset
    
    gc.collect()  # force garbage collection
    sim.tf.keras.backend.clear_session()  # Clear TensorFlow session
    
    logs["mem_usage_after_gc"] = process.memory_info().rss / (1024 * 1024 * 1024)  # Memory usage in GiB
    
    return epoch_metrics, round_metrics, logs

In [9]:
num_clients_list = [20, 50, 55, 60]
batch_size_list = [32]
num_steps_until_rtc_check_list = [1]
theta_list = [1.]
num_epochs = 1

sketch_width = 500
sketch_depth = 7

In [10]:
futures = []

for num_clients in num_clients_list:
    for batch_size in batch_size_list:
        for num_steps_until_rtc_check in num_steps_until_rtc_check_list:
            for theta in theta_list:
                
                for fda_name in ["naive", "linear", "sketch"]:
                
                    future = client.submit(
                        worker_single_fda_simulation,
                        data_delayed=data_delayed, 
                        fda_name=fda_name,
                        num_clients=num_clients, 
                        batch_size=batch_size, 
                        num_steps_until_rtc_check=num_steps_until_rtc_check,
                        theta=theta, 
                        num_epochs=num_epochs,
                        sketch_width=sketch_width if fda_name == "sketch" else -1,
                        sketch_depth=sketch_depth if fda_name == "sketch" else -1,
                        bench_test=False,
                        resources={'PROCESS': 1}  # Tell Dask that the resource `PROCESS` is consumed in one task!
                    ) 

                    futures.append(future)

In [11]:
from dask.distributed import as_completed

logs = []
num_completed = 0
total_futures = len(futures)

for future, result in as_completed(futures, with_results=True):
    epoch_metrics, round_metrics, worker_logs = result
    
    num_completed += 1
    worker_logs["task_num"] = total_futures
    logs.append(worker_logs)
    
    future.release()
    
    print(f"\rProgress on Gathered-Saved Results: {num_completed} / {total_futures}", end="", flush=True)  # Print progress

Progress on Gathered-Saved Results: 12 / 12

In [22]:
import json

sorted_logs = sorted(logs, key=lambda x: (x['worker'], x['task_num']))

with open("logs.txt", "a") as f:
    f.write("\n\n")
    for d in sorted_logs:
        line = json.dumps(d)
        f.write(f"{line}\n")

### Terminate `Client` and `Cluster`

In [26]:
client.close()
cluster.close()