Skip to content

Commit

Permalink
TF GPU CI (#3085)
Browse files Browse the repository at this point in the history
* debug env

* Restrict TF GPU memory

* Fixup

* One more test

* rm debug logs

* Fixup
  • Loading branch information
julien-c committed Mar 2, 2020
1 parent d3eb7d2 commit f169957
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/self-push.yml
Expand Up @@ -40,7 +40,8 @@ jobs:
- name: Run all non-slow tests on GPU
env:
TF_FORCE_GPU_ALLOW_GROWTH: yes
TF_FORCE_GPU_ALLOW_GROWTH: "true"
# TF_GPU_MEMORY_LIMIT: 4096
OMP_NUM_THREADS: 1
USE_CUDA: yes
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/self-scheduled.yml
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Run all tests on GPU
env:
TF_FORCE_GPU_ALLOW_GROWTH: yes
TF_FORCE_GPU_ALLOW_GROWTH: "true"
OMP_NUM_THREADS: 1
RUN_SLOW: yes
USE_CUDA: yes
Expand Down
16 changes: 14 additions & 2 deletions tests/test_modeling_tf_common.py
Expand Up @@ -21,14 +21,26 @@

from transformers import is_tf_available, is_torch_available

from .utils import require_tf
from .utils import _tf_gpu_memory_limit, require_tf


if is_tf_available():
import tensorflow as tf
import numpy as np

# from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
try:
tf.config.experimental.set_virtual_device_configuration(
gpu, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print("Logical GPUs", logical_gpus)
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)


def _config_zero_init(config):
Expand Down
14 changes: 14 additions & 0 deletions tests/utils.py
Expand Up @@ -29,8 +29,22 @@ def parse_flag_from_env(key, default=False):
return _value


def parse_int_from_env(key, default=None):
try:
value = os.environ[key]
except KeyError:
_value = default
else:
try:
_value = int(value)
except ValueError:
raise ValueError("If set, {} must be a int.".format(key))
return _value


_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)


def slow(test_case):
Expand Down
5 changes: 4 additions & 1 deletion utils/link_tester.py
Expand Up @@ -14,6 +14,9 @@
REGEXP_FIND_S3_LINKS = r"""([\"'])(https:\/\/s3)(.*)?\1"""


S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"


def list_python_files_in_repository():
""" List all python files in the repository.
Expand All @@ -36,7 +39,7 @@ def find_all_links(file_paths):
for path in file_paths:
links += scan_code_for_links(path)

return links
return [link for link in links if link != S3_BUCKET_PREFIX]


def scan_code_for_links(source):
Expand Down

0 comments on commit f169957

Please sign in to comment.