Skip to content

Commit

Permalink
Enable doc classification recipe to work with all text datasets
Browse files Browse the repository at this point in the history
Summary:
## Summary
- Updated datamodule to work with any torchtext dataset
   - No longer checking to see whether the dataset is an intance of the SST2Dataset
   - Updated `DocClassificationDataModuleConf` dataclass to take in user provided `columns` and `label_column` fields since different datasets have different column orderings
- Updated tests to use patching for testing with mocked datasets similar to what is done in OSS for the [AmazonReviewPolarity dataset test](pytorch/text#1532)
- Removed dataset test from torchrecipe since the torchtext repo unittests provide adequate coverage

## Followup Items
- [ ] Update instantiation call for datasets to work with functional API as opposed to class API once the SST2 dataset has been migrated out of experimental ([reference GH issue](pytorch/text#1494))

Reviewed By: abhinavarora, mthrok, parmeet

Differential Revision: D33775443

fbshipit-source-id: 1e6545949808ec5bd0e13cf3f9e7aaea08d68a59
  • Loading branch information
Nayef211 authored and facebook-github-bot committed Jan 28, 2022
1 parent dd50d86 commit bef8c4e
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 70 deletions.
1 change: 0 additions & 1 deletion torchrecipes/text/doc_classification/conf/common.py
Expand Up @@ -27,7 +27,6 @@ class DatasetConf:
class SST2DatasetConf(DatasetConf):
_target_: str = get_class_name_str(SST2)
root: str = MISSING
validate_hash: Optional[bool] = True


@dataclass
Expand Down
@@ -1,3 +1,2 @@
_target_: torchtext.experimental.datasets.sst2.SST2
root: ~/.torchtext/cache
validate_hash: True
@@ -1,4 +1,8 @@
_target_: torchrecipes.text.doc_classification.datamodule.doc_classification.DocClassificationDataModule.from_config
columns:
- text
- label
label_column: label
batch_size: 16
num_workers: 0
drop_last: False
Expand Down
Expand Up @@ -3,8 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Tuple
from dataclasses import dataclass, field
from typing import List, Tuple

import hydra
import pytorch_lightning as pl
Expand All @@ -23,7 +23,6 @@
config_entry,
get_class_config_method,
)
from torchtext.experimental.datasets.sst2 import SST2Dataset
from torchtext.functional import to_tensor


Expand All @@ -35,6 +34,8 @@ def __init__(
test_dataset: IterDataPipe[Tuple[str, str]],
transform: nn.Module,
label_transform: nn.Module,
columns: List[str],
label_column: str,
batch_size: int,
num_workers: int = 0,
drop_last: bool = False,
Expand All @@ -48,6 +49,9 @@ def __init__(
self.transform = transform
self.label_transform = label_transform

self.columns = columns
self.label_column = label_column

self.batch_size = batch_size
self.num_workers = num_workers
self.drop_last = drop_last
Expand All @@ -58,42 +62,43 @@ def __init__(
def from_config(
transform: DocClassificationTransformConf,
dataset: DatasetConf,
columns: List[str],
label_column: str,
batch_size: int,
num_workers: int = 0,
drop_last: bool = False,
pin_memory: bool = False,
) -> "DocClassificationDataModule":
train_dataset, val_dataset, test_dataset = hydra.utils.call(dataset)

# check if all datasets belongs to a subset of supported datasets
for dataset in (train_dataset, val_dataset, test_dataset):
if not isinstance(dataset, (SST2Dataset)):
raise NotImplementedError(f"{type(dataset)} not supported")

text_transform = hydra.utils.instantiate(transform.transform, _recursive_=False)
label_transform = hydra.utils.instantiate(
transform.label_transform,
_recursive_=False,
)
return DocClassificationDataModule(
train_dataset,
val_dataset,
train_dataset=train_dataset,
val_dataset=val_dataset,
# TODO: Note that the following line should be replaced by
# `test_dataset` once we update the lightning module to support
# test data with and without labels
val_dataset,
text_transform,
label_transform,
batch_size,
test_dataset=val_dataset,
transform=text_transform,
label_transform=label_transform,
columns=columns,
label_column=label_column,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory,
)

def _get_data_loader(self, dataset: IterDataPipe[Tuple[str, str]]) -> DataLoader:
dataset = dataset.batch(self.batch_size).rows2columnar(["text", "label"])
dataset = dataset.batch(self.batch_size).rows2columnar(self.columns)
dataset = dataset.map(self.transform)
dataset = dataset.map(
lambda x: {
**x,
"label_ids": to_tensor(self.label_transform(x["label"])),
"label_ids": to_tensor(self.label_transform(x[self.label_column])),
}
)
dataset = dataset.add_index()
Expand Down Expand Up @@ -123,6 +128,8 @@ class DocClassificationDataModuleConf(DataModuleConf):
_target_: str = get_class_config_method(DocClassificationDataModule)
transform: DocClassificationTransformConf = MISSING
dataset: DatasetConf = MISSING
columns: List[str] = field(default_factory=lambda: ["text", "label"])
label_column: str = "label"
batch_size: int = 16
num_workers: int = 0
drop_last: bool = False
Expand Down
Expand Up @@ -7,6 +7,7 @@
# pyre-strict
import os.path
from typing import Tuple
from unittest.mock import patch

import hydra
import testslide
Expand All @@ -26,6 +27,18 @@


class TestDocClassificationConfig(testslide.TestCase):
def setUp(self) -> None:
super().setUp()
# patch the _hash_check() fn output to make it work with the dummy dataset
self.patcher = patch(
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
)
self.patcher.start()

def tearDown(self) -> None:
self.patcher.stop()
super().tearDown()

@tempdir
def test_doc_classification_task(self, root_dir: str) -> None:
# copy the asset files into their expected download locations
Expand All @@ -44,7 +57,6 @@ def test_doc_classification_task(self, root_dir: str) -> None:
"module.model.checkpoint=null",
"module.model.freeze_encoder=True",
f"datamodule.dataset.root={root_dir}",
"datamodule.dataset.validate_hash=False",
f"trainer.default_root_dir={root_dir}",
"trainer.logger=False",
"trainer.checkpoint_callback=False",
Expand Down Expand Up @@ -84,7 +96,6 @@ def test_doc_classification_task_torchscript(self, root_dir: str) -> None:
"module.model.checkpoint=null",
"module.model.freeze_encoder=True",
f"datamodule.dataset.root={root_dir}",
"datamodule.dataset.validate_hash=False",
f"trainer.default_root_dir={root_dir}",
"trainer.logger=False",
"trainer.checkpoint_callback=False",
Expand Down
Expand Up @@ -8,6 +8,8 @@

# pyre-strict

from unittest.mock import patch

import hydra
import testslide
import torch
Expand All @@ -28,6 +30,18 @@


class TestDocClassificationDataModule(testslide.TestCase):
def setUp(self) -> None:
super().setUp()
# patch the _hash_check() fn output to make it work with the dummy dataset
self.patcher = patch(
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
)
self.patcher.start()

def tearDown(self) -> None:
self.patcher.stop()
super().tearDown()

def get_datamodule(self) -> DocClassificationDataModule:
doc_transform_conf = DocClassificationTextTransformConf(
vocab_path=get_asset_path("vocab_example.pt"),
Expand All @@ -40,10 +54,12 @@ def get_datamodule(self) -> DocClassificationDataModule:
label_transform=label_transform_conf,
)

dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH, validate_hash=False)
dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH)
datamodule_conf = DocClassificationDataModuleConf(
transform=transform_conf,
dataset=dataset_conf,
columns=["text", "label"],
label_column="label",
batch_size=8,
)
return hydra.utils.instantiate(
Expand Down

This file was deleted.

Expand Up @@ -6,6 +6,7 @@

# pyre-strict
import os
from unittest.mock import patch

import hydra
from pytorch_lightning.trainer import Trainer
Expand Down Expand Up @@ -37,6 +38,15 @@
class TestDocClassificationModule(TaskTestCaseBase):
def setUp(self) -> None:
self.base_dir = os.path.join(os.path.dirname(__file__), "data")
# patch the _hash_check() fn output to make it work with the dummy dataset
self.patcher = patch(
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
)
self.patcher.start()

def tearDown(self) -> None:
self.patcher.stop()
super().tearDown()

def get_transform_conf(self) -> DocClassificationTransformConf:
doc_transform_conf = DocClassificationTextTransformConf(
Expand Down Expand Up @@ -72,7 +82,7 @@ def get_standard_task(self) -> DocClassificationModule:

def get_datamodule(self) -> DocClassificationDataModule:
transform_conf = self.get_transform_conf()
dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH, validate_hash=False)
dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH)
datamodule_conf = DocClassificationDataModuleConf(
transform=transform_conf,
dataset=dataset_conf,
Expand Down
Expand Up @@ -8,6 +8,7 @@

# pyre-strict
import os.path
from unittest.mock import patch

import torchrecipes.text.doc_classification.conf # noqa
from torchrecipes.core.base_train_app import BaseTrainApp
Expand All @@ -21,6 +22,18 @@


class TestDocClassificationTrainApp(BaseTrainAppTestCase):
def setUp(self) -> None:
super().setUp()
# patch the _hash_check() fn output to make it work with the dummy dataset
self.patcher = patch(
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
)
self.patcher.start()

def tearDown(self) -> None:
self.patcher.stop()
super().tearDown()

def get_train_app(self, root_dir: str) -> BaseTrainApp:
# copy the asset files into their expected download locations
# note we need to do this anywhere we use hydra overrides
Expand All @@ -38,7 +51,6 @@ def get_train_app(self, root_dir: str) -> BaseTrainApp:
"module.model.checkpoint=null",
"module.model.freeze_encoder=True",
f"datamodule.dataset.root={root_dir}",
"datamodule.dataset.validate_hash=False",
f"trainer.default_root_dir={root_dir}",
"trainer.logger=False",
"trainer.checkpoint_callback=False",
Expand Down

0 comments on commit bef8c4e

Please sign in to comment.