Skip to content

Commit

Permalink
ignore transformers warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed May 23, 2024
1 parent 86953e5 commit 889a48d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 10 additions & 3 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pickle
import subprocess
import warnings
from functools import partial
from pathlib import Path
from tempfile import gettempdir
Expand Down Expand Up @@ -87,15 +88,21 @@ def encode(x):
return tokenizer(x)

# TODO: add hash consistency tests across sessions
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
hash1 = Hasher.hash(tokenizer)
hash1_lambda = Hasher.hash(lambda x: tokenizer(x))
hash1_encode = Hasher.hash(encode)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
hash2 = Hasher.hash(tokenizer)
hash2_lambda = Hasher.hash(lambda x: tokenizer(x))
hash2_encode = Hasher.hash(encode)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
hash3 = Hasher.hash(tokenizer)
hash3_lambda = Hasher.hash(lambda x: tokenizer(x))
hash3_encode = Hasher.hash(encode)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_metric_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import os
import re
import warnings
from contextlib import contextmanager
from functools import wraps
from unittest.mock import patch
Expand Down Expand Up @@ -105,7 +106,8 @@ def test_load_metric(self, metric_name):
parameters = inspect.signature(metric._compute).parameters
self.assertTrue(all(p.kind != p.VAR_KEYWORD for p in parameters.values())) # no **kwargs
# run doctest
with self.patch_intensive_calls(metric_name, metric_module.__name__):
with self.patch_intensive_calls(metric_name, metric_module.__name__), warnings.catch_warnings():
warnings.simplefilter("ignore")
with self.use_local_metrics():
try:
results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)
Expand Down

0 comments on commit 889a48d

Please sign in to comment.