/
test_doc_classification_datamodule.py
81 lines (66 loc) · 2.72 KB
/
test_doc_classification_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#!/usr/bin/env python3
# pyre-strict
from unittest.mock import patch
import hydra
import testslide
import torch
from torchrecipes.text.doc_classification.conf.common import (
SST2DatasetConf,
LabelTransformConf,
DocClassificationTransformConf,
)
from torchrecipes.text.doc_classification.datamodule.doc_classification import (
DocClassificationDataModuleConf,
DocClassificationDataModule,
)
from torchrecipes.text.doc_classification.tests.common.assets import _DATA_DIR_PATH
from torchrecipes.text.doc_classification.tests.common.assets import get_asset_path
from torchrecipes.text.doc_classification.transform.doc_classification_text_transform import (
DocClassificationTextTransformConf,
)
class TestDocClassificationDataModule(testslide.TestCase):
def setUp(self) -> None:
super().setUp()
# patch the _hash_check() fn output to make it work with the dummy dataset
self.patcher = patch(
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
)
self.patcher.start()
def tearDown(self) -> None:
self.patcher.stop()
super().tearDown()
def get_datamodule(self) -> DocClassificationDataModule:
doc_transform_conf = DocClassificationTextTransformConf(
vocab_path=get_asset_path("vocab_example.pt"),
spm_model_path=get_asset_path("spm_example.model"),
)
label_transform_conf = LabelTransformConf(label_names=["0", "1"])
transform_conf = DocClassificationTransformConf(
transform=doc_transform_conf,
label_transform=label_transform_conf,
)
dataset_conf = SST2DatasetConf(root=_DATA_DIR_PATH)
datamodule_conf = DocClassificationDataModuleConf(
transform=transform_conf,
dataset=dataset_conf,
columns=["text", "label"],
label_column="label",
batch_size=8,
)
return hydra.utils.instantiate(
datamodule_conf,
_recursive_=False,
)
def test_doc_classification_datamodule(self) -> None:
datamodule = self.get_datamodule()
self.assertIsInstance(datamodule, DocClassificationDataModule)
dataloader = datamodule.train_dataloader()
batch = next(iter(dataloader))
self.assertTrue(torch.is_tensor(batch["label_ids"]))
self.assertTrue(torch.is_tensor(batch["token_ids"]))
self.assertEqual(batch["label_ids"].size(), torch.Size([8]))
self.assertEqual(batch["token_ids"].size(), torch.Size([8, 35]))