Skip to content

Commit

Permalink
Revert "Don't reset the dataset type + plug for rm unused columns (hu…
Browse files Browse the repository at this point in the history
…ggingface#6683)"

This reverts commit 9d4f322.
  • Loading branch information
fabiocapsouza committed Nov 15, 2020
1 parent 957b18b commit b67e9b4
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 15 deletions.
7 changes: 1 addition & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ def __init__(
self.scaler = torch.cuda.amp.GradScaler()

def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
signature_columns = list(signature.parameters.keys())
Expand All @@ -257,10 +255,7 @@ def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[s
logger.info(
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
)
ds_type = dataset.format["type"]
if ds_type == "python":
ds_type = None
dataset.set_format(type=ds_type, columns=columns)
dataset.set_format(columns=columns)

def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,6 @@ class TrainingArguments:
at the next training step under the keyword argument ``mems``.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
forward method.
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
"""

output_dir: str = field(
Expand Down Expand Up @@ -239,10 +234,6 @@ class TrainingArguments:
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
)

remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
)

@property
def train_batch_size(self) -> int:
"""
Expand Down

0 comments on commit b67e9b4

Please sign in to comment.