Skip to content

Commit

Permalink
Add language-generation and token-detection training tasks. Closes #414
Browse files Browse the repository at this point in the history
…. Closes #415.
  • Loading branch information
davidmezzetti committed Jan 31, 2023
1 parent 0bfb596 commit 592f8cf
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 8 deletions.
11 changes: 10 additions & 1 deletion docs/pipeline/train/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@ model, tokenizer = trainer("bert-base-uncased", dt,
learning_rate=3e-5, num_train_epochs=5)
```

All [TrainingArguments](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments) are supported as function arguments to the trainer call.
All [TrainingArguments](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments) are supported as function arguments to the trainer call. Supports building new models and/or fine-tuning for following training tasks.

| Task | Description |
|:-----|:------------|
| language-generation | Causal language model for text generation (e.g. GPT) |
| language-modeling | Masked language model for general tasks (e.g. BERT) |
| question-answering | Extractive question-answering model, typically with the SQuAD dataset |
| sequence-sequence | Sequence-Sequence model (e.g. T5) |
| text-classification | Classify text with a set of labels |
| token-detection | ELECTRA-style pre-training with replaced token detection |

See the links below for more detailed examples.

Expand Down
1 change: 1 addition & 0 deletions src/python/txtai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .onnx import OnnxModel
from .pooling import MeanPooling, Pooling
from .registry import Registry
from .tokendetection import TokenDetection
121 changes: 121 additions & 0 deletions src/python/txtai/models/tokendetection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Token Detection module
"""

import inspect
import os

import torch

from transformers import PreTrainedModel


class TokenDetection(PreTrainedModel):
"""
Runs the replaced token detection training objective. This method was first proposed by the ELECTRA model.
The method consists of a masked language model generator feeding data to a discriminator that determines
which of the tokens are incorrect. More on this training objective can be found in the ELECTRA paper.
"""

def __init__(self, generator, discriminator, tokenizer, weight=50.0):
"""
Creates a new TokenDetection class.
Args:
generator: Generator model, must be a masked language model
discriminator: Discriminator model, must be a model that can detect replaced tokens. Any model can
can be customized for this task. See ElectraForPretraining for more.
"""

# Initialize model with discriminator config
super().__init__(discriminator.config)

self.generator = generator
self.discriminator = discriminator

# Tokenizer to save with generator and discriminator
self.tokenizer = tokenizer

# Discriminator weight
self.weight = weight

# Share embeddings if both models are the same type
# Embeddings must be same size
if self.generator.config.model_type == self.discriminator.config.model_type:
self.discriminator.set_input_embeddings(self.generator.get_input_embeddings())

# Set attention mask present flags
self.gattention = "attention_mask" in inspect.signature(self.generator.forward).parameters
self.dattention = "attention_mask" in inspect.signature(self.discriminator.forward).parameters

# pylint: disable=E1101
def forward(self, input_ids=None, labels=None, attention_mask=None, token_type_ids=None):
"""
Runs a forward pass through the model. This method runs the masked language model then randomly samples
the generated tokens and builds a binary classification problem for the discriminator (detecting if each token is correct).
Args:
input_ids: token ids
labels: token labels
attention_mask: attention mask
token_type_ids: segment token indices
Returns:
(loss, generator outputs, discriminator outputs, discriminator labels)
"""

# Copy input ids
dinputs = input_ids.clone()

# Run inputs through masked language model
inputs = {"attention_mask": attention_mask} if self.gattention else {}
goutputs = self.generator(input_ids, labels=labels, token_type_ids=token_type_ids, **inputs)

# Get predictions
preds = torch.softmax(goutputs[1], dim=-1)
preds = preds.view(-1, self.config.vocab_size)

tokens = torch.multinomial(preds, 1).view(-1)
tokens = tokens.view(dinputs.shape[0], -1)

# Labels have a -100 value to ignore loss from unchanged tokens
mask = labels.ne(-100)

# Replace the masked out tokens of the input with the generator predictions
dinputs[mask] = tokens[mask]

# Turn mask into new target labels - 1 (True) for corrupted, 0 otherwise.
# If the prediction was correct, mark it as uncorrupted.
correct = tokens == labels
dlabels = mask.long()
dlabels[correct] = 0

# Run token classification, predict whether each token was corrupted
inputs = {"attention_mask": attention_mask} if self.dattention else {}
doutputs = self.discriminator(dinputs, labels=dlabels, token_type_ids=token_type_ids, **inputs)

# Compute combined loss
loss = goutputs[0] + self.weight * doutputs[0]
return loss, goutputs[1], doutputs[1], dlabels

def save_pretrained(self, output, state_dict=None):
"""
Saves current model to output directory.
Args:
output: output directory
state_dict: model state
"""

# Save combined model to support training from checkpoints
super().save_pretrained(output, state_dict)

# Save generator tokenizer and model
gpath = os.path.join(output, "generator")
self.tokenizer.save_pretrained(gpath)
self.generator.save_pretrained(gpath)

# Save discriminator tokenizer and model
dpath = os.path.join(output, "discriminator")
self.tokenizer.save_pretrained(dpath)
self.discriminator.save_pretrained(dpath)
26 changes: 21 additions & 5 deletions src/python/txtai/pipeline/train/hftrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForPreTraining,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
Expand All @@ -17,7 +19,7 @@
from transformers import TrainingArguments as HFTrainingArguments

from ...data import Labels, Questions, Sequences, Texts
from ...models import Models
from ...models import Models, TokenDetection
from ..tensors import Tensors


Expand Down Expand Up @@ -76,7 +78,13 @@ def __call__(
collator, labels = None, None

# Prepare datasets
if task == "language-modeling":
if task == "language-generation":
# Default tokenizer pad token if it's not set
tokenizer.pad_token = tokenizer.pad_token if tokenizer.pad_token is not None else tokenizer.eos_token

process = Texts(tokenizer, columns, maxlength)
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8 if args.fp16 else None)
elif task in ("language-modeling", "token-detection"):
process = Texts(tokenizer, columns, maxlength)
collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
elif task == "question-answering":
Expand All @@ -92,7 +100,7 @@ def __call__(
train, validation = process(train, validation, os.cpu_count() if tokenizers and isinstance(tokenizers, bool) else tokenizers)

# Create model to train
model = self.model(task, base, config, labels)
model = self.model(task, base, config, labels, tokenizer)

# Add model to collator
if collator:
Expand Down Expand Up @@ -174,7 +182,7 @@ def load(self, base, maxlength):

return (config, tokenizer, maxlength)

def model(self, task, base, config, labels):
def model(self, task, base, config, labels, tokenizer):
"""
Loads the base model to train.
Expand All @@ -183,6 +191,7 @@ def model(self, task, base, config, labels):
base: base model - supports a file path or (model, tokenizer) tuple
config: model configuration
labels: number of labels
tokenizer: model tokenizer
Returns:
model
Expand All @@ -194,15 +203,22 @@ def model(self, task, base, config, labels):

# pylint: disable=E1120
# Unpack existing model or create new model from config
if isinstance(base, (list, tuple)):
if isinstance(base, (list, tuple)) and not isinstance(base[0], str):
return base[0]
if task == "language-generation":
return AutoModelForCausalLM.from_pretrained(base, config=config)
if task == "language-modeling":
return AutoModelForMaskedLM.from_pretrained(base, config=config)
if task == "question-answering":
return AutoModelForQuestionAnswering.from_pretrained(base, config=config)
if task == "sequence-sequence":
return AutoModelForSeq2SeqLM.from_pretrained(base, config=config)
if task == "token-detection":
return TokenDetection(
AutoModelForMaskedLM.from_pretrained(base, config=config), AutoModelForPreTraining.from_pretrained(base, config=config), tokenizer
)

# Default task
return AutoModelForSequenceClassification.from_pretrained(base, config=config)


Expand Down
33 changes: 31 additions & 2 deletions test/python/testpipeline/testtrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def testBasic(self):
labels = Labels((model, tokenizer), dynamic=False)
self.assertEqual(labels("cat")[0][0], 1)

def testCLM(self):
"""
Tests training a model with causal language modeling.
"""

trainer = HFTrainer()
model, _ = trainer("hf-internal-testing/tiny-random-gpt2", self.data, maxlength=16, task="language-generation")

# Test model completed successfully
self.assertIsNotNone(model)

def testCustom(self):
"""
Test training a model with custom parameters
Expand Down Expand Up @@ -175,11 +186,11 @@ def testEmpty(self):

def testMLM(self):
"""
Tests training a masked language model.
Tests training a model with masked language modeling.
"""

trainer = HFTrainer()
model, _ = trainer("google/bert_uncased_L-2_H-128_A-2", self.data, task="language-modeling")
model, _ = trainer("hf-internal-testing/tiny-random-bert", self.data, task="language-modeling")

# Test model completed successfully
self.assertIsNotNone(model)
Expand Down Expand Up @@ -238,6 +249,24 @@ def testRegression(self):
# Regression tasks return a single entry with the regression output
self.assertGreater(labels("cat")[0][1], 0.5)

def testRTD(self):
"""
Tests training a language model with replaced token detection
"""

# Save directory
output = os.path.join(tempfile.gettempdir(), "trainer.rtd")

trainer = HFTrainer()
model, _ = trainer("hf-internal-testing/tiny-random-electra", self.data, task="token-detection", output_dir=output)

# Test model completed successfully
self.assertIsNotNone(model)

# Test output directories exist
self.assertTrue(os.path.exists(os.path.join(output, "generator")))
self.assertTrue(os.path.exists(os.path.join(output, "discriminator")))

def testSeqSeq(self):
"""
Tests training a sequence-sequence model
Expand Down

0 comments on commit 592f8cf

Please sign in to comment.