diff --git a/script/get-ml-model-rgat/customize.py b/script/get-ml-model-rgat/customize.py index dbe679243..1b8bb5385 100644 --- a/script/get-ml-model-rgat/customize.py +++ b/script/get-ml-model-rgat/customize.py @@ -26,17 +26,18 @@ def postprocess(i): env = i['env'] - if env.get('RGAT_CHECKPOINT_PATH', '') == '': - env['RGAT_CHECKPOINT_PATH'] = os.path.join( - env['RGAT_DIR_PATH'], "RGAT.pt") + if env.get('MLC_DOWNLOAD_MODE', '') != "dry": + if env.get('RGAT_CHECKPOINT_PATH', '') == '': + env['RGAT_CHECKPOINT_PATH'] = os.path.join( + env['RGAT_DIR_PATH'], "RGAT.pt") - if env.get('MLC_ML_MODEL_RGAT_CHECKPOINT_PATH', '') == '': - env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] = env['RGAT_CHECKPOINT_PATH'] + if env.get('MLC_ML_MODEL_RGAT_CHECKPOINT_PATH', '') == '': + env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] = env['RGAT_CHECKPOINT_PATH'] - if env.get('MLC_ML_MODEL_PATH', '') == '': - env['MLC_ML_MODEL_PATH'] = env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] + if env.get('MLC_ML_MODEL_PATH', '') == '': + env['MLC_ML_MODEL_PATH'] = env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] - env['RGAT_CHECKPOINT_PATH'] = env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] - env['MLC_GET_DEPENDENT_CACHED_PATH'] = env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] + env['RGAT_CHECKPOINT_PATH'] = env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] + env['MLC_GET_DEPENDENT_CACHED_PATH'] = env['MLC_ML_MODEL_RGAT_CHECKPOINT_PATH'] return {'return': 0}