Skip to content

Commit

Permalink
[WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, Mul…
Browse files Browse the repository at this point in the history
…tipleC… (#5614)

* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests

* AutoModels


Tiny tweaks

* Style

* Final changes before merge

* Re-order for simpler review

* Final fixes

* Addressing @sgugger's comments

* Test MultipleChoice
  • Loading branch information
LysandreJik committed Jul 29, 2020
1 parent 8a8ae27 commit 3f94170
Show file tree
Hide file tree
Showing 12 changed files with 652 additions and 21 deletions.
3 changes: 3 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
XLMForTokenClassification,
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLMForMultipleChoice,
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_bart import (
Expand Down Expand Up @@ -356,6 +357,8 @@
FlaubertForTokenClassification,
FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple,
FlaubertForTokenClassification,
FlaubertForMultipleChoice,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
)
from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_flaubert import (
FlaubertForMultipleChoice,
FlaubertForQuestionAnsweringSimple,
FlaubertForSequenceClassification,
FlaubertForTokenClassification,
Expand Down Expand Up @@ -142,6 +143,7 @@
from .modeling_t5 import T5ForConditionalGeneration, T5Model
from .modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel
from .modeling_xlm import (
XLMForMultipleChoice,
XLMForQuestionAnsweringSimple,
XLMForSequenceClassification,
XLMForTokenClassification,
Expand Down Expand Up @@ -338,6 +340,7 @@
(XLNetConfig, XLNetForTokenClassification),
(AlbertConfig, AlbertForTokenClassification),
(ElectraConfig, ElectraForTokenClassification),
(FlaubertConfig, FlaubertForTokenClassification),
]
)

Expand All @@ -353,6 +356,8 @@
(MobileBertConfig, MobileBertForMultipleChoice),
(XLNetConfig, XLNetForMultipleChoice),
(AlbertConfig, AlbertForMultipleChoice),
(XLMConfig, XLMForMultipleChoice),
(FlaubertConfig, FlaubertForMultipleChoice),
]
)

Expand Down
20 changes: 20 additions & 0 deletions src/transformers/modeling_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutput
from .modeling_xlm import (
XLMForMultipleChoice,
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLMForSequenceClassification,
Expand Down Expand Up @@ -382,3 +383,22 @@ def __init__(self, config):
super().__init__(config)
self.transformer = FlaubertModel(config)
self.init_weights()


@add_start_docstrings(
"""Flaubert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
FLAUBERT_START_DOCSTRING,
)
class FlaubertForMultipleChoice(XLMForMultipleChoice):
"""
This class overrides :class:`~transformers.XLMForMultipleChoice`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""

config_class = FlaubertConfig

def __init__(self, config):
super().__init__(config)
self.transformer = FlaubertModel(config)
self.init_weights()
36 changes: 25 additions & 11 deletions src/transformers/modeling_tf_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@

from .configuration_flaubert import FlaubertConfig
from .file_utils import add_start_docstrings
from .modeling_tf_utils import keras_serializable, shape_list
from .modeling_tf_utils import cast_bool_to_primitive, keras_serializable, shape_list
from .modeling_tf_xlm import (
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMMainLayer,
TFXLMModel,
TFXLMPredLayer,
TFXLMWithLMHeadModel,
get_masks,
)
Expand Down Expand Up @@ -123,6 +124,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states

def call(
self,
Expand All @@ -135,9 +138,9 @@ def call(
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
output_attentions=False,
output_hidden_states=False,
):
# removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
Expand All @@ -150,7 +153,9 @@ def call(
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
Expand All @@ -161,10 +166,15 @@ def call(
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs

output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
Expand Down Expand Up @@ -257,9 +267,12 @@ def call(

# self attention
if not self.pre_norm:
attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training)
attn_outputs = self.attentions[i](
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
)
attn = attn_outputs[0]
attentions = attentions + (attn_outputs[1],)
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
Expand All @@ -269,7 +282,7 @@ def call(
[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training
)
attn = attn_outputs[0]
if output_attentions:
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
tensor = tensor + attn
Expand All @@ -292,7 +305,7 @@ def call(
tensor = tensor * mask[..., tf.newaxis]

# Add last hidden state
if output_hidden_states:
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
hidden_states = hidden_states + (tensor,)

# update cache length
Expand All @@ -303,9 +316,9 @@ def call(
# tensor = tensor.transpose(0, 1)

outputs = (tensor,)
if output_hidden_states:
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
outputs = outputs + (hidden_states,)
if output_attentions:
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)

Expand All @@ -321,6 +334,7 @@ class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")


@add_start_docstrings(
Expand Down
30 changes: 26 additions & 4 deletions src/transformers/modeling_tf_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import itertools
import logging
import math
import warnings

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -827,6 +828,9 @@ def __init__(self, config, *inputs, **kwargs):

self.transformer = TFXLMMainLayer(config, name="transformer")
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
self.logits_proj = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)

@property
def dummy_inputs(self):
Expand All @@ -835,7 +839,10 @@ def dummy_inputs(self):
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}

@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048")
Expand Down Expand Up @@ -892,7 +899,7 @@ def call(
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
labels = inputs[11] if len(inputs) > 11 else labels
assert len(inputs) <= 11, "Too many inputs."
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
Expand Down Expand Up @@ -921,24 +928,39 @@ def call(
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]))
if inputs_embeds is not None
else None
)

if lengths is not None:
warnings.warn(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.",
FutureWarning,
)
lengths = None

flat_inputs = [
flat_input_ids,
flat_attention_mask,
langs,
flat_langs,
flat_token_type_ids,
flat_position_ids,
lengths,
cache,
head_mask,
inputs_embeds,
flat_inputs_embeds,
output_attentions,
output_hidden_states,
]

transformer_outputs = self.transformer(flat_inputs, training=training)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))

outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
Expand Down
Loading

0 comments on commit 3f94170

Please sign in to comment.