In [1]:
import os
import h5py
from datetime import timedelta
from toolz import complement, curry, compose, valmap, valfilter
from aging.organization.paths import FOLDERS

In [2]:
user = os.environ['USER']

In [3]:
FOLDERS

(PosixPath('/n/groups/datta/Dana/Ontogeny/raw_data/Ontogeny_females'),
 PosixPath('/n/groups/datta/Dana/Ontogeny/raw_data/Ontogeny_males'),
 PosixPath('/n/groups/datta/Dana/Ontogeny/raw_data/longtogeny_pre_unet/Males'),
 PosixPath('/n/groups/datta/min/longtogeny_072023'),
 PosixPath('/n/groups/datta/min/wheel_062023'))

In [4]:
recon_key = "win_size_norm_frames_v4"
max_time = timedelta(hours=8)
min_time = timedelta(minutes=5)
buffer = timedelta(minutes=5)

In [5]:
# old model trained on females
# model_path = '/n/groups/datta/win/longtogeny/size_norm/models/param_scan/bc632741-9de6-44bf-8b3c-e10e835948f8/Autoencoder-epoch=73-val_loss=5.84e-04.ckpt'

# new model trained on males
# model_path = '/n/groups/datta/win/longtogeny/size_norm/models/pre_final_model/model.pt'
model_path = '/home/wg41/groups/win/longtogeny/size_norm/models/pre_final_model-2023-08-02/model.pt'

In [6]:
# look through the contents of each folder and predict how long inference will take
def has_key(path, key=recon_key):
    try:
        with h5py.File(path, "r") as h5f:
            return key in h5f
    except Exception:
        # skip this file, because it errors
        return True


n_files = dict(
    zip(
        FOLDERS,
        map(
            compose(
                sum,
                curry(map)(complement(has_key)),
                lambda f: f.glob("**/results_00.h5"),
            ),
            FOLDERS,
        ),
    )
)

In [7]:
n_files

{PosixPath('/n/groups/datta/Dana/Ontogeny/raw_data/Ontogeny_females'): 0,
 PosixPath('/n/groups/datta/Dana/Ontogeny/raw_data/Ontogeny_males'): 0,
 PosixPath('/n/groups/datta/Dana/Ontogeny/raw_data/longtogeny_pre_unet/Males'): 0,
 PosixPath('/n/groups/datta/min/longtogeny_072023'): 54,
 PosixPath('/n/groups/datta/min/wheel_062023'): 0}

In [8]:
estimate_rate = 120  # seconds / file

In [9]:
time_est = valmap(lambda v: min(timedelta(seconds=v * estimate_rate), max_time), n_files)
time_est = valfilter(lambda v: v > min_time, time_est)
time_est = valmap(lambda v: v + buffer, time_est)

In [10]:
time_est

{PosixPath('/n/groups/datta/min/longtogeny_072023'): datetime.timedelta(seconds=6780)}

In [11]:
script = '''#!/bin/env bash
#SBATCH -c 1
#SBATCH -n 1
#SBATCH --mem=8G
#SBATCH -p gpu_quad
#SBATCH --gres=gpu:1
#SBATCH -t {runtime}
#SBATCH --output=/n/scratch3/users/{user}/tmp/win-size-norm-%j.out

source $HOME/.bashrc
conda activate aging
module load gcc/9.2.0
module load cuda/11.7
python $HOME/code/ontogeny/scripts/02-apply-dnn.py "{data_path}" "{model_path}" --key {key}
'''

In [None]:
for folder, time in time_est.items():
    new_script = script.format(
        data_path=folder, model_path=model_path, user=f"{user[0]}/{user}", key=recon_key, runtime=time
    )
    with open("tmp.sh", "w") as f:
        f.write(new_script)
    !sbatch tmp.sh
!rm tmp.sh