Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding output attention feature for better interpretation #16

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion esim/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def forward(self,
premise_batch,
premise_mask,
hypothesis_batch,
hypothesis_mask):
hypothesis_mask,
output_attentions = False):
"""
Args:
premise_batch: A batch of sequences of vectors representing the
Expand All @@ -155,12 +156,23 @@ def forward(self,
hypothesis_mask: A mask for the sequences in the hypotheses batch,
to ignore padding data in the sequences during the computation
of the attention.
output_attentions: returns the softmaxed attention value matrix for the
premise and hypothesis after cross attention. Default is 'False'.

Returns:
attended_premises: The sequences of attention vectors for the
premises in the input batch.
attended_hypotheses: The sequences of attention vectors for the
hypotheses in the input batch.

if output_attentions is True:

hyp_prem_attn: attention values for each hypothesis token softmaxed across all premise tokens,ie
masked softmax using the premise mask

prem_hyp_mask: attention values for each premise token softmaxed across all hypothesis tokens,ie
masked softmax using the hypothesis mask

"""
# Dot product between premises and hypotheses in each sequence of
# the batch.
Expand All @@ -182,4 +194,5 @@ def forward(self,
hyp_prem_attn,
hypothesis_mask)

if output_attentions: return attended_premises, attended_hypotheses, hyp_prem_attn, prem_hyp_attn
return attended_premises, attended_hypotheses
21 changes: 18 additions & 3 deletions esim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
padding_idx=0,
dropout=0.5,
num_classes=3,
output_attentions = False,
device="cpu"):
"""
Args:
Expand All @@ -40,6 +41,8 @@ def __init__(self,
Defaults to 0.5.
num_classes: The number of classes in the output of the network.
Defaults to 3.
output_attentions: returns the attentions for premise and hypothesis
after cross attention. Defaults to 'False'
device: The name of the device on which the model is being
executed. Defaults to 'cpu'.
"""
Expand All @@ -50,6 +53,7 @@ def __init__(self,
self.hidden_size = hidden_size
self.num_classes = num_classes
self.dropout = dropout
self.output_attn = output_attentions
self.device = device

self._word_embedding = nn.Embedding(self.vocab_size,
Expand Down Expand Up @@ -128,9 +132,17 @@ def forward(self,
encoded_hypotheses = self._encoding(embedded_hypotheses,
hypotheses_lengths)

attended_premises, attended_hypotheses =\
self._attention(encoded_premises, premises_mask,
encoded_hypotheses, hypotheses_mask)
if self.output_attn:
attended_premises, attended_hypotheses, hyp_attn, prem_attn =self._attention(encoded_premises,
premises_mask,
encoded_hypotheses,
hypotheses_mask,
output_attentions = self.output_attn)
else:
attended_premises, attended_hypotheses =self._attention(encoded_premises,
premises_mask,
encoded_hypotheses,
hypotheses_mask)

enhanced_premises = torch.cat([encoded_premises,
attended_premises,
Expand Down Expand Up @@ -170,9 +182,12 @@ def forward(self,
logits = self._classification(v)
probabilities = nn.functional.softmax(logits, dim=-1)

if self.output_attn:
return logits, probabilities, attn_vec
return logits, probabilities



def _init_esim_weights(module):
"""
Initialise the weights of the ESIM model.
Expand Down