Skip to content

Commit

Permalink
remove_directory utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jun 6, 2024
1 parent 305b41a commit 9587bc4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
14 changes: 14 additions & 0 deletions optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import importlib.util
import itertools
import os
import shutil
import subprocess
import sys
import unittest
Expand Down Expand Up @@ -184,3 +185,16 @@ def grid_parameters(
else:
returned_list = [test_name] + list(params) if add_test_name is True else list(params)
yield returned_list


def remove_directory(dirpath):
"""
Remove a directory and its content.
This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows.
Reference: https://github.com/python/cpython/issues/107408
"""
if os.path.exists(dirpath) and os.path.isdir(dirpath):
if os.name == "nt":
os.system(f"rmdir /S /Q {dirpath}")
else:
shutil.rmtree(dirpath)
37 changes: 8 additions & 29 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
import gc
import os
import shutil
import subprocess
import tempfile
import time
Expand Down Expand Up @@ -109,7 +108,7 @@
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
logging,
)
from optimum.utils.testing_utils import grid_parameters, require_hf_token, require_ort_rocm
from optimum.utils.testing_utils import grid_parameters, remove_directory, require_hf_token, require_ort_rocm


logger = logging.get_logger()
Expand Down Expand Up @@ -184,12 +183,8 @@ def test_load_model_from_cache(self):

def test_load_model_from_empty_cache(self):
dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_MODEL_ID.replace("/", "--"))
remove_directory(dirpath)

if os.path.exists(dirpath) and os.path.isdir(dirpath):
if os.name == "nt":
os.system(f"rmdir /S /Q {dirpath}")
else:
shutil.rmtree(dirpath)
with self.assertRaises(Exception):
_ = ORTModel.from_pretrained(self.TINY_ONNX_MODEL_ID, local_files_only=True)

Expand All @@ -205,12 +200,8 @@ def test_load_seq2seq_model_from_cache(self):

def test_load_seq2seq_model_from_empty_cache(self):
dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_SEQ2SEQ_MODEL_ID.replace("/", "--"))
remove_directory(dirpath)

if os.path.exists(dirpath) and os.path.isdir(dirpath):
if os.name == "nt":
os.system(f"rmdir /S /Q {dirpath}")
else:
shutil.rmtree(dirpath)
with self.assertRaises(Exception):
_ = ORTModelForSeq2SeqLM.from_pretrained(self.TINY_ONNX_SEQ2SEQ_MODEL_ID, local_files_only=True)

Expand All @@ -231,12 +222,8 @@ def test_load_stable_diffusion_model_from_empty_cache(self):
dirpath = os.path.join(
default_cache_path, "models--" + self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID.replace("/", "--")
)
remove_directory(dirpath)

if os.path.exists(dirpath) and os.path.isdir(dirpath):
if os.name == "nt":
os.system(f"rmdir /S /Q {dirpath}")
else:
shutil.rmtree(dirpath)
with self.assertRaises(Exception):
_ = ORTStableDiffusionPipeline.from_pretrained(
self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, local_files_only=True
Expand Down Expand Up @@ -1014,9 +1001,7 @@ def test_save_load_ort_model_with_external_data(self):
# verify loading from local folder works
model = ORTModelForSequenceClassification.from_pretrained(tmpdirname, export=False)
os.environ.pop("FORCE_ONNX_EXTERNAL_DATA")

if os.name == "nt":
os.system(f"rmdir /s /q {tmpdirname}")
remove_directory(tmpdirname)

@parameterized.expand([(False,), (True,)])
@pytest.mark.run_slow
Expand All @@ -1038,9 +1023,7 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool):
model = ORTModelForCausalLM.from_pretrained(
tmpdirname, use_cache=use_cache, export=False, use_io_binding=False
)

if os.name == "nt":
os.system(f"rmdir /s /q {tmpdirname}")
remove_directory(tmpdirname)

@parameterized.expand([(False,), (True,)])
def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool):
Expand All @@ -1063,9 +1046,7 @@ def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool):
# verify loading from local folder works
model = ORTModelForSeq2SeqLM.from_pretrained(tmpdirname, use_cache=use_cache, export=False)
os.environ.pop("FORCE_ONNX_EXTERNAL_DATA")

if os.name == "nt":
os.system(f"rmdir /s /q {tmpdirname}")
remove_directory(tmpdirname)

def test_save_load_stable_diffusion_model_with_external_data(self):
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -1087,9 +1068,7 @@ def test_save_load_stable_diffusion_model_with_external_data(self):
# verify loading from local folder works
model = ORTStableDiffusionPipeline.from_pretrained(tmpdirname, export=False)
os.environ.pop("FORCE_ONNX_EXTERNAL_DATA")

if os.name == "nt":
os.system(f"rmdir /s /q {tmpdirname}")
remove_directory(tmpdirname)

@parameterized.expand([(False,), (True,)])
@unittest.skip("Skipping as this test consumes too much memory")
Expand Down

0 comments on commit 9587bc4

Please sign in to comment.