Skip to content

Commit

Permalink
Refactor tests to use accelerate launch (#373)
Browse files Browse the repository at this point in the history
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
muellerzr and sgugger committed May 19, 2022
1 parent 6163e20 commit 23c0341
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 99 deletions.
7 changes: 7 additions & 0 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def collate_fn(examples):
return train_dataloader, eval_dataloader


# For testing only
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
from accelerate.test_utils.training import mocked_dataloaders

get_dataloaders = mocked_dataloaders # noqa: F811


def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
Expand Down
7 changes: 7 additions & 0 deletions examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def __exit__(self, *exc):
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")


# For testing only
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
from accelerate.test_utils.training import mocked_dataloaders

get_dataloaders = mocked_dataloaders # noqa: F811


def training_function(config, args):
# Initialize accelerator
if args.with_tracking:
Expand Down
8 changes: 8 additions & 0 deletions examples/by_feature/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -103,6 +104,13 @@ def collate_fn(examples):
return train_dataloader, eval_dataloader


# For testing only
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
from accelerate.test_utils.training import mocked_dataloaders

get_dataloaders = mocked_dataloaders # noqa: F811


def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
Expand Down
8 changes: 8 additions & 0 deletions examples/by_feature/multi_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -104,6 +105,13 @@ def collate_fn(examples):
return train_dataloader, eval_dataloader


# For testing only
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
from accelerate.test_utils.training import mocked_dataloaders

get_dataloaders = mocked_dataloaders # noqa: F811


def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
Expand Down
7 changes: 7 additions & 0 deletions examples/by_feature/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def collate_fn(examples):
return train_dataloader, eval_dataloader


# For testing only
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
from accelerate.test_utils.training import mocked_dataloaders

get_dataloaders = mocked_dataloaders # noqa: F811


def training_function(config, args):
# Initialize Accelerator

Expand Down
42 changes: 42 additions & 0 deletions src/accelerate/test_utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

import numpy as np
import torch
from torch.utils.data import DataLoader

from accelerate.utils.dataclasses import DistributedType
from datasets import load_dataset
from transformers import AutoTokenizer


class RegressionDataset:
Expand Down Expand Up @@ -43,3 +48,40 @@ def forward(self, x=None):
print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}")
self.first_batch = False
return x * self.a + self.b


def mocked_dataloaders(accelerator, batch_size: int = 16):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
datasets = load_dataset("csv", data_files=data_files)
label_list = datasets["train"].unique("label")

label_to_id = {v: i for i, v in enumerate(label_list)}

def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(
examples["sentence1"], examples["sentence2"], truncation=True, max_length=None, padding="max_length"
)
if "label" in examples:
outputs["labels"] = [label_to_id[l] for l in examples["label"]]
return outputs

# Apply the method we just defined to all the examples in all the splits of the dataset
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["sentence1", "sentence2", "label"],
)

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
if accelerator.distributed_type == DistributedType.TPU:
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
return tokenizer.pad(examples, padding="longest", return_tensors="pt")

# Instantiate dataloaders.
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=2)
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)

return train_dataloader, eval_dataloader
152 changes: 53 additions & 99 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import os
import sys
import re
import shutil
import subprocess
import tempfile
import unittest
from unittest import mock

from torch.utils.data import DataLoader

from accelerate import DistributedType
from accelerate.test_utils.examples import compare_against_test
from accelerate.test_utils.testing import TempDirTestCase, slow
from datasets import load_dataset
from transformers import AutoTokenizer

from accelerate.utils import write_basic_config

SRC_DIRS = [os.path.abspath(os.path.join("examples", "by_feature"))]
sys.path.extend(SRC_DIRS)

if SRC_DIRS is not None:
import checkpointing
import cross_validation
import multi_process_metrics
import tracking

# DataLoaders built from `test_samples/MRPC` for quick testing
# Should mock `{script_name}.get_dataloaders` via:
Expand All @@ -43,43 +33,6 @@
EXCLUDE_EXAMPLES = ["cross_validation.py", "multi_process_metrics.py", "memory.py", "fsdp_with_peak_mem_tracking.py"]


def mocked_dataloaders(accelerator, batch_size: int = 16):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"}
datasets = load_dataset("csv", data_files=data_files)
label_list = datasets["train"].unique("label")

label_to_id = {v: i for i, v in enumerate(label_list)}

def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(
examples["sentence1"], examples["sentence2"], truncation=True, max_length=None, padding="max_length"
)
if "label" in examples:
outputs["labels"] = [label_to_id[l] for l in examples["label"]]
return outputs

# Apply the method we just defined to all the examples in all the splits of the dataset
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["sentence1", "sentence2", "label"],
)

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
if accelerator.distributed_type == DistributedType.TPU:
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
return tokenizer.pad(examples, padding="longest", return_tensors="pt")

# Instantiate dataloaders.
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=2)
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1)

return train_dataloader, eval_dataloader


class ExampleDifferenceTests(unittest.TestCase):
"""
This TestCase checks that all of the `complete_*` scripts contain all of the
Expand Down Expand Up @@ -159,88 +112,89 @@ def test_cv_examples(self):
self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings)


@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})
class FeatureExamplesTests(TempDirTestCase):
clear_on_setup = False

@mock.patch("checkpointing.get_dataloaders", mocked_dataloaders)
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._tmpdir = tempfile.mkdtemp()
cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")

write_basic_config(save_location=cls.configPath)
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]

@classmethod
def tearDownClass(cls):
super().tearDownClass()
shutil.rmtree(cls._tmpdir)

def test_checkpointing_by_epoch(self):
testargs = f"""
checkpointing.py
examples/by_feature/checkpointing.py
--checkpointing_steps epoch
--output_dir {self.tmpdir}
""".split()
with mock.patch.object(sys, "argv", testargs):
checkpointing.main()
self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "epoch_1")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "epoch_1")))

@mock.patch("checkpointing.get_dataloaders", mocked_dataloaders)
def test_checkpointing_by_steps(self):
testargs = f"""
checkpointing.py
examples/by_feature/checkpointing.py
--checkpointing_steps 2
--output_dir {self.tmpdir}
""".split()
with mock.patch.object(sys, "argv", testargs):
checkpointing.main()
self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "step_4")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE, env=os.environ)
self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "step_4")))

@mock.patch("checkpointing.get_dataloaders", mocked_dataloaders)
def test_load_states_by_epoch(self):
testargs = f"""
checkpointing.py
examples/by_feature/checkpointing.py
--resume_from_checkpoint {os.path.join(self.tmpdir, "epoch_1")}
""".split()
dummy_results = {"accuracy": mock.ANY, "f1": mock.ANY}
with mock.patch("accelerate.Accelerator.print") as mocked_print:
with mock.patch.object(sys, "argv", testargs):
checkpointing.main()
with self.assertRaises(AssertionError):
mocked_print.assert_any_call("epoch 0:", dummy_results)
with self.assertRaises(AssertionError):
mocked_print.assert_any_call("epoch 1:", dummy_results)
mocked_print.assert_any_call("epoch 2:", dummy_results)

@mock.patch("checkpointing.get_dataloaders", mocked_dataloaders)
output = subprocess.run(
self._launch_args + testargs, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
).stdout
self.assertNotIn("epoch 0:", output)
self.assertNotIn("epoch 1:", output)
self.assertIn("epoch 2:", output)

def test_load_states_by_steps(self):
testargs = f"""
checkpointing.py
examples/by_feature/checkpointing.py
--resume_from_checkpoint {os.path.join(self.tmpdir, "step_4")}
""".split()
dummy_results = {"accuracy": mock.ANY, "f1": mock.ANY}
with mock.patch("accelerate.Accelerator.print") as mocked_print:
with mock.patch.object(sys, "argv", testargs):
checkpointing.main()
with self.assertRaises(AssertionError):
mocked_print.assert_any_call("epoch 0:", dummy_results)
mocked_print.assert_any_call("epoch 1:", dummy_results)
mocked_print.assert_any_call("epoch 2:", dummy_results)
output = subprocess.run(
self._launch_args + testargs, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
).stdout
self.assertNotIn("epoch 0:", output)
self.assertIn("epoch 1:", output)
self.assertIn("epoch 2:", output)

@slow
def test_cross_validation(self):
testargs = """
cross_validation.py
examples/by_feature/cross_validation.py
--num_folds 2
""".split()
with mock.patch.object(sys, "argv", testargs):
with mock.patch("accelerate.Accelerator.print") as mocked_print:
cross_validation.main()
call = mocked_print.mock_calls[-1]
self.assertGreaterEqual(call.args[1]["accuracy"], 0.75)
with mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "0"}):
output = subprocess.run(
self._launch_args + testargs, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
).stdout
results = ast.literal_eval(re.findall("({.+})", output)[-1])
self.assertGreaterEqual(results["accuracy"], 0.75)

@mock.patch("multi_process_metrics.get_dataloaders", mocked_dataloaders)
def test_multi_process_metrics(self):
testargs = ["multi_process_metrics.py"]
with mock.patch.object(sys, "argv", testargs):
multi_process_metrics.main()
testargs = ["examples/by_feature/multi_process_metrics.py"]
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)

@mock.patch("tracking.get_dataloaders", mocked_dataloaders)
def test_tracking(self):
with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
tracking.py
examples/by_feature/tracking.py
--with_tracking
--logging_dir {tmpdir}
""".split()
with mock.patch.object(sys, "argv", testargs):
tracking.main()
self.assertTrue(os.path.exists(os.path.join(tmpdir, "tracking")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
self.assertTrue(os.path.exists(os.path.join(tmpdir, "tracking")))

0 comments on commit 23c0341

Please sign in to comment.