From bef8c4eda5bcc984871ef6ebd2813e4173c4daf4 Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Fri, 28 Jan 2022 09:12:05 -0800 Subject: [PATCH] Enable doc classification recipe to work with all text datasets Summary: ## Summary - Updated datamodule to work with any torchtext dataset - No longer checking to see whether the dataset is an intance of the SST2Dataset - Updated `DocClassificationDataModuleConf` dataclass to take in user provided `columns` and `label_column` fields since different datasets have different column orderings - Updated tests to use patching for testing with mocked datasets similar to what is done in OSS for the [AmazonReviewPolarity dataset test](https://github.com/pytorch/text/pull/1532) - Removed dataset test from torchrecipe since the torchtext repo unittests provide adequate coverage ## Followup Items - [ ] Update instantiation call for datasets to work with functional API as opposed to class API once the SST2 dataset has been migrated out of experimental ([reference GH issue](https://github.com/pytorch/text/issues/1494)) Reviewed By: abhinavarora, mthrok, parmeet Differential Revision: D33775443 fbshipit-source-id: 1e6545949808ec5bd0e13cf3f9e7aaea08d68a59 --- .../text/doc_classification/conf/common.py | 1 - .../conf/datamodule/dataset/sst2_dataset.yaml | 1 - .../doc_classification_datamodule.yaml | 4 ++ .../datamodule/doc_classification.py | 41 ++++++++++------- .../tests/test_doc_classification_config.py | 15 +++++- .../test_doc_classification_datamodule.py | 18 +++++++- .../tests/test_doc_classification_dataset.py | 46 ------------------- .../tests/test_doc_classification_module.py | 12 ++++- .../test_doc_classification_train_app.py | 14 +++++- 9 files changed, 82 insertions(+), 70 deletions(-) delete mode 100644 torchrecipes/text/doc_classification/tests/test_doc_classification_dataset.py diff --git a/torchrecipes/text/doc_classification/conf/common.py b/torchrecipes/text/doc_classification/conf/common.py index 0a749f9..31197f4 100644 --- a/torchrecipes/text/doc_classification/conf/common.py +++ b/torchrecipes/text/doc_classification/conf/common.py @@ -27,7 +27,6 @@ class DatasetConf: class SST2DatasetConf(DatasetConf): _target_: str = get_class_name_str(SST2) root: str = MISSING - validate_hash: Optional[bool] = True @dataclass diff --git a/torchrecipes/text/doc_classification/conf/datamodule/dataset/sst2_dataset.yaml b/torchrecipes/text/doc_classification/conf/datamodule/dataset/sst2_dataset.yaml index 98b85bb..1843cd6 100644 --- a/torchrecipes/text/doc_classification/conf/datamodule/dataset/sst2_dataset.yaml +++ b/torchrecipes/text/doc_classification/conf/datamodule/dataset/sst2_dataset.yaml @@ -1,3 +1,2 @@ _target_: torchtext.experimental.datasets.sst2.SST2 root: ~/.torchtext/cache -validate_hash: True diff --git a/torchrecipes/text/doc_classification/conf/datamodule/doc_classification_datamodule.yaml b/torchrecipes/text/doc_classification/conf/datamodule/doc_classification_datamodule.yaml index b800324..b2f4a4b 100644 --- a/torchrecipes/text/doc_classification/conf/datamodule/doc_classification_datamodule.yaml +++ b/torchrecipes/text/doc_classification/conf/datamodule/doc_classification_datamodule.yaml @@ -1,4 +1,8 @@ _target_: torchrecipes.text.doc_classification.datamodule.doc_classification.DocClassificationDataModule.from_config +columns: +- text +- label +label_column: label batch_size: 16 num_workers: 0 drop_last: False diff --git a/torchrecipes/text/doc_classification/datamodule/doc_classification.py b/torchrecipes/text/doc_classification/datamodule/doc_classification.py index 9025c98..65e887c 100644 --- a/torchrecipes/text/doc_classification/datamodule/doc_classification.py +++ b/torchrecipes/text/doc_classification/datamodule/doc_classification.py @@ -3,8 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Tuple +from dataclasses import dataclass, field +from typing import List, Tuple import hydra import pytorch_lightning as pl @@ -23,7 +23,6 @@ config_entry, get_class_config_method, ) -from torchtext.experimental.datasets.sst2 import SST2Dataset from torchtext.functional import to_tensor @@ -35,6 +34,8 @@ def __init__( test_dataset: IterDataPipe[Tuple[str, str]], transform: nn.Module, label_transform: nn.Module, + columns: List[str], + label_column: str, batch_size: int, num_workers: int = 0, drop_last: bool = False, @@ -48,6 +49,9 @@ def __init__( self.transform = transform self.label_transform = label_transform + self.columns = columns + self.label_column = label_column + self.batch_size = batch_size self.num_workers = num_workers self.drop_last = drop_last @@ -58,42 +62,43 @@ def __init__( def from_config( transform: DocClassificationTransformConf, dataset: DatasetConf, + columns: List[str], + label_column: str, batch_size: int, num_workers: int = 0, drop_last: bool = False, pin_memory: bool = False, ) -> "DocClassificationDataModule": train_dataset, val_dataset, test_dataset = hydra.utils.call(dataset) - - # check if all datasets belongs to a subset of supported datasets - for dataset in (train_dataset, val_dataset, test_dataset): - if not isinstance(dataset, (SST2Dataset)): - raise NotImplementedError(f"{type(dataset)} not supported") - text_transform = hydra.utils.instantiate(transform.transform, _recursive_=False) label_transform = hydra.utils.instantiate( transform.label_transform, _recursive_=False, ) return DocClassificationDataModule( - train_dataset, - val_dataset, + train_dataset=train_dataset, + val_dataset=val_dataset, # TODO: Note that the following line should be replaced by # `test_dataset` once we update the lightning module to support # test data with and without labels - val_dataset, - text_transform, - label_transform, - batch_size, + test_dataset=val_dataset, + transform=text_transform, + label_transform=label_transform, + columns=columns, + label_column=label_column, + batch_size=batch_size, + num_workers=num_workers, + drop_last=drop_last, + pin_memory=pin_memory, ) def _get_data_loader(self, dataset: IterDataPipe[Tuple[str, str]]) -> DataLoader: - dataset = dataset.batch(self.batch_size).rows2columnar(["text", "label"]) + dataset = dataset.batch(self.batch_size).rows2columnar(self.columns) dataset = dataset.map(self.transform) dataset = dataset.map( lambda x: { **x, - "label_ids": to_tensor(self.label_transform(x["label"])), + "label_ids": to_tensor(self.label_transform(x[self.label_column])), } ) dataset = dataset.add_index() @@ -123,6 +128,8 @@ class DocClassificationDataModuleConf(DataModuleConf): _target_: str = get_class_config_method(DocClassificationDataModule) transform: DocClassificationTransformConf = MISSING dataset: DatasetConf = MISSING + columns: List[str] = field(default_factory=lambda: ["text", "label"]) + label_column: str = "label" batch_size: int = 16 num_workers: int = 0 drop_last: bool = False diff --git a/torchrecipes/text/doc_classification/tests/test_doc_classification_config.py b/torchrecipes/text/doc_classification/tests/test_doc_classification_config.py index 7054d7e..8d157d9 100644 --- a/torchrecipes/text/doc_classification/tests/test_doc_classification_config.py +++ b/torchrecipes/text/doc_classification/tests/test_doc_classification_config.py @@ -7,6 +7,7 @@ # pyre-strict import os.path from typing import Tuple +from unittest.mock import patch import hydra import testslide @@ -26,6 +27,18 @@ class TestDocClassificationConfig(testslide.TestCase): + def setUp(self) -> None: + super().setUp() + # patch the _hash_check() fn output to make it work with the dummy dataset + self.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + self.patcher.start() + + def tearDown(self) -> None: + self.patcher.stop() + super().tearDown() + @tempdir def test_doc_classification_task(self, root_dir: str) -> None: # copy the asset files into their expected download locations @@ -44,7 +57,6 @@ def test_doc_classification_task(self, root_dir: str) -> None: "module.model.checkpoint=null", "module.model.freeze_encoder=True", f"datamodule.dataset.root={root_dir}", - "datamodule.dataset.validate_hash=False", f"trainer.default_root_dir={root_dir}", "trainer.logger=False", "trainer.checkpoint_callback=False", @@ -84,7 +96,6 @@ def test_doc_classification_task_torchscript(self, root_dir: str) -> None: "module.model.checkpoint=null", "module.model.freeze_encoder=True", f"datamodule.dataset.root={root_dir}", - "datamodule.dataset.validate_hash=False", f"trainer.default_root_dir={root_dir}", "trainer.logger=False", "trainer.checkpoint_callback=False", diff --git a/torchrecipes/text/doc_classification/tests/test_doc_classification_datamodule.py b/torchrecipes/text/doc_classification/tests/test_doc_classification_datamodule.py index 84b9e9b..c714a5b 100644 --- a/torchrecipes/text/doc_classification/tests/test_doc_classification_datamodule.py +++ b/torchrecipes/text/doc_classification/tests/test_doc_classification_datamodule.py @@ -8,6 +8,8 @@ # pyre-strict +from unittest.mock import patch + import hydra import testslide import torch @@ -28,6 +30,18 @@ class TestDocClassificationDataModule(testslide.TestCase): + def setUp(self) -> None: + super().setUp() + # patch the _hash_check() fn output to make it work with the dummy dataset + self.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + self.patcher.start() + + def tearDown(self) -> None: + self.patcher.stop() + super().tearDown() + def get_datamodule(self) -> DocClassificationDataModule: doc_transform_conf = DocClassificationTextTransformConf( vocab_path=get_asset_path("vocab_example.pt"), @@ -40,10 +54,12 @@ def get_datamodule(self) -> DocClassificationDataModule: label_transform=label_transform_conf, ) - dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH, validate_hash=False) + dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH) datamodule_conf = DocClassificationDataModuleConf( transform=transform_conf, dataset=dataset_conf, + columns=["text", "label"], + label_column="label", batch_size=8, ) return hydra.utils.instantiate( diff --git a/torchrecipes/text/doc_classification/tests/test_doc_classification_dataset.py b/torchrecipes/text/doc_classification/tests/test_doc_classification_dataset.py deleted file mode 100644 index 9f62eb3..0000000 --- a/torchrecipes/text/doc_classification/tests/test_doc_classification_dataset.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import hashlib -import json - -import hydra -import testslide -from torchrecipes.text.doc_classification.conf.common import SST2DatasetConf -from torchrecipes.text.doc_classification.tests.common.assets import _DATA_DIR_PATH -from torchtext.experimental.datasets import sst2 - - -class TestDocClassificationDataset(testslide.TestCase): - def test_doc_classification_sst2_dataset(self) -> None: - dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH, validate_hash=False) - train_dataset, dev_dataset, test_dataset = hydra.utils.call( - dataset_conf, _recursive_=False - ) - - # verify datasets objects are instances of SST2Dataset - for dataset in (train_dataset, dev_dataset, test_dataset): - self.assertTrue(isinstance(dataset, sst2.SST2Dataset)) - - # verify hashes of first line in dataset - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(train_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["train"], - ) - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["dev"], - ) - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["test"], - ) diff --git a/torchrecipes/text/doc_classification/tests/test_doc_classification_module.py b/torchrecipes/text/doc_classification/tests/test_doc_classification_module.py index a41773f..aeb3bb7 100644 --- a/torchrecipes/text/doc_classification/tests/test_doc_classification_module.py +++ b/torchrecipes/text/doc_classification/tests/test_doc_classification_module.py @@ -6,6 +6,7 @@ # pyre-strict import os +from unittest.mock import patch import hydra from pytorch_lightning.trainer import Trainer @@ -37,6 +38,15 @@ class TestDocClassificationModule(TaskTestCaseBase): def setUp(self) -> None: self.base_dir = os.path.join(os.path.dirname(__file__), "data") + # patch the _hash_check() fn output to make it work with the dummy dataset + self.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + self.patcher.start() + + def tearDown(self) -> None: + self.patcher.stop() + super().tearDown() def get_transform_conf(self) -> DocClassificationTransformConf: doc_transform_conf = DocClassificationTextTransformConf( @@ -72,7 +82,7 @@ def get_standard_task(self) -> DocClassificationModule: def get_datamodule(self) -> DocClassificationDataModule: transform_conf = self.get_transform_conf() - dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH, validate_hash=False) + dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH) datamodule_conf = DocClassificationDataModuleConf( transform=transform_conf, dataset=dataset_conf, diff --git a/torchrecipes/text/doc_classification/tests/test_doc_classification_train_app.py b/torchrecipes/text/doc_classification/tests/test_doc_classification_train_app.py index 2ba6b33..39da932 100644 --- a/torchrecipes/text/doc_classification/tests/test_doc_classification_train_app.py +++ b/torchrecipes/text/doc_classification/tests/test_doc_classification_train_app.py @@ -8,6 +8,7 @@ # pyre-strict import os.path +from unittest.mock import patch import torchrecipes.text.doc_classification.conf # noqa from torchrecipes.core.base_train_app import BaseTrainApp @@ -21,6 +22,18 @@ class TestDocClassificationTrainApp(BaseTrainAppTestCase): + def setUp(self) -> None: + super().setUp() + # patch the _hash_check() fn output to make it work with the dummy dataset + self.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + self.patcher.start() + + def tearDown(self) -> None: + self.patcher.stop() + super().tearDown() + def get_train_app(self, root_dir: str) -> BaseTrainApp: # copy the asset files into their expected download locations # note we need to do this anywhere we use hydra overrides @@ -38,7 +51,6 @@ def get_train_app(self, root_dir: str) -> BaseTrainApp: "module.model.checkpoint=null", "module.model.freeze_encoder=True", f"datamodule.dataset.root={root_dir}", - "datamodule.dataset.validate_hash=False", f"trainer.default_root_dir={root_dir}", "trainer.logger=False", "trainer.checkpoint_callback=False",