Skip to content

Commit

Permalink
Merge pull request #207 from namisan/xiaodl/tf-upgrade
Browse files Browse the repository at this point in the history
fix adv train
  • Loading branch information
namisan committed Feb 16, 2021
2 parents 7ef6fd9 + 1f43b6e commit 471f717
Showing 1 changed file with 4 additions and 66 deletions.
70 changes: 4 additions & 66 deletions mt_dnn/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
import torch.nn as nn
from pretrained_models import MODEL_CLASSES
from transformers import BertConfig

from module.dropout_wrapper import DropoutWrapper
from module.san import SANClassifier, MaskLmHeader
from module.san_model import SanModel
Expand Down Expand Up @@ -96,86 +94,26 @@ def __init__(self, opt, bert_config=None, initial_from_local=False):


def embed_encode(self, input_ids, token_type_ids=None, attention_mask=None):
# support BERT now
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
embedding_output = self.bert.embeddings(input_ids, token_type_ids)
return embedding_output


def encode(self, input_ids, token_type_ids, attention_mask):
def encode(self, input_ids, token_type_ids, attention_mask, inputs_embeds=None):
if self.encoder_type == EncoderModelType.T5:
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
else:
outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask)
attention_mask=attention_mask, inputs_embeds=inputs_embeds)
last_hidden_state = outputs.last_hidden_state
all_hidden_states = outputs.hidden_states # num_layers + 1 (embeddings)
return last_hidden_state, all_hidden_states

def embed_forward(self, embed, attention_mask=None, output_all_encoded_layers=True):
device = embed.device
input_shape = embed.size()[:-1]
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.bert.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
encoder_extended_attention_mask = None

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = [None] * self.bert.config.num_hidden_layers

#head_mask = self.bert.get_head_mask(head_mask, self.bert.config.num_hidden_layers)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
#extended_attention_mask = self.bert.get_extended_attention_mask(
# attention_mask, input_shape, device
#)
encoder_outputs = self.bert.encoder(
embed,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
)
sequence_output = encoder_outputs[0]
pooled_output = self.bert.pooler(sequence_output)
outputs = sequence_output, pooled_output
return outputs

def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0, fwd_type=0, embed=None):
if fwd_type == 2:
assert embed is not None
sequence_output, pooled_output = self.embed_forward(embed, attention_mask)
last_hidden_state, all_hidden_states = self.encode(None, token_type_ids, attention_mask, embed)
elif fwd_type == 1:
return self.embed_encode(input_ids, token_type_ids, attention_mask)
else:
Expand Down

0 comments on commit 471f717

Please sign in to comment.