diff --git a/core/leras/device.py b/core/leras/device.py index 31d2f8809..a2ba3712c 100644 --- a/core/leras/device.py +++ b/core/leras/device.py @@ -161,6 +161,7 @@ def initialize_main_env(): if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): os.environ.pop('CUDA_VISIBLE_DEVICES') + os.environ['TF_DIRECTML_KERNEL_CACHE_SIZE'] = '2500' os.environ['CUDA_​CACHE_​MAXSIZE'] = '2147483647' os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only