In [None]:
import sys 
sys.path.append('/Users/mbvalentin/scripts/wsbmr/dev')
sys.path.append('/Users/mbvalentin/scripts/fkeras')
sys.path.append('/Users/mbvalentin/scripts/qkeras')
sys.path.append('/Users/mbvalentin/scripts/DynamicTable')
sys.path.append('/Users/mbvalentin/scripts/tensorplot')

# Basic modules 
import os 

# Let's add our custom wsbmr code 
import wsbmr

# Hashlib
import hashlib

In [None]:
# Get the wsbmr nodus db connector 
nodus_db = wsbmr.nodus_db
# Job manager obj
jm = nodus_db.job_manager

In [None]:
# Define benchmark, prune, etc. 
benchmarks = ['autompg']
quants = ["num_bits=6 integer=0"]#, "num_bits=8 integer=0", "num_bits=16 integer=0"]
prunes = [0.0] #wsbmr.config.DEFAULT_PRUNINGS
protection_range = [0.0, 0.2, 0.4, 0.6, 0.8]
ber_range = [0.001, 0.00167, 0.00278, 0.00464, 0.00774, 0.01, 0.01292, 0.02154, 0.03594, 0.05995, 0.1]
methods = ['bitwise_msb', 'random', 'layerwise_first', 'layerwise_last', 'weight_abs_value', 'hirescam_norm', 'hiresdelta', 'hessian', 'hessiandelta']

# Define the common format
def format_cmd(benchmark, prune, protection_range, ber_range):
    return f"""python wsbmr \\
    --benchmarks_dir /Users/mbvalentin/scripts/wsbmr/benchmarks \\
    --datasets_dir /Users/mbvalentin/scripts/wsbmr/datasets \\
    --benchmark {benchmark} \\
    --bits_config num_bits=6 integer=0 \\
    --prune {prune} \\
    --protection_range {' '.join(map(str, protection_range))} \\
    --ber_range {' '.join(map(str, ber_range))}"""

def format_extra(cfg):
    return f""" \\
        --method {cfg['method']} \\
        --method_suffix {cfg['method_suffix']} \\
        --method_kws {cfg['method_kws']} \\
        --plot """

""" Main name of the processes """
# We need to train a model for each combo of benchmark, quantization and prune factor
prev_dependency = []
for b in benchmarks:
    for q in quants:
        for p in prunes:
            # Name
            name = f"user::{b}::{q}::{p}"

            # Get the common part for each job in this case
            common = format_cmd(benchmark=b, prune=p, protection_range=protection_range, ber_range=ber_range)

            # This common part is basically the job that we will run first, cause it will make sure
            # we train the model and save it. 
            # this job will be parent of all the rest of the jobs, which will be dependent on this one.
            # This means, the rest of the jobs (the ones that will get the actual results) will only start
            # after the parent job is done.
            hash = hashlib.md5(common.encode()).hexdigest()

            # Create job for parent
            job_id_parent, job_parent = jm.create_job(
                name = f"autompg_{hash}",
                parent_caller = name,
                job_type = "command",
                command = common
            )

            # Create job for each method 
            for m in methods:
                job_cmd = common + '\\ ' + format_extra(wsbmr.config.config_per_method[m])
                
                # Get hash 
                hash = hashlib.md5(job_cmd.encode()).hexdigest()


                # Create job
                job_id, job = jm.create_job(
                    name = f"model_training_{hash}",
                    parent_caller = f'{name}::{m}',
                    job_type = "command",
                    command = job_cmd,
                    dependencies = [job_id_parent]
                )
            
            # Before moving to the next benchmark, quantization and prune factor, let's wait for the parent job
            jm.wait_for_job_completion(job_id)
            