Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about processing entity output in model #4

Open
liuyaduo opened this issue Aug 5, 2020 · 1 comment
Open

Question about processing entity output in model #4

liuyaduo opened this issue Aug 5, 2020 · 1 comment

Comments

@liuyaduo
Copy link

liuyaduo commented Aug 5, 2020

Hi!
In original paper,I found that they apply an activation operation and add a fully connected layer after the average operation to get a vector representation for each of the two target entities.

def extract_entity(sequence_output, e_mask):
       extended_e_mask = e_mask.unsqueeze(1)
       extended_e_mask = torch.bmm(
                extended_e_mask.float(), sequence_output).squeeze(1)
       return extended_e_mask.float()

e1_h = self.ent_dropout(extract_entity(sequence_output, e1_mask))
e2_h = self.ent_dropout(extract_entity(sequence_output, e2_mask))
context = self.cls_dropout(pooled_output)
pooled_output = torch.cat([context, e1_h, e2_h], dim=-1)

why don't I find activation and fully connected layer in model.py?

@mickeysjm
Copy link
Owner

Hi @liuyaduo,

Thanks for the detailed comparison. Indeed, this code does not have this additional fully connected layer + activation function. You can easily add this function as follows:

def __init__(self, config):
        super(BertForSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.cls_dropout = nn.Dropout(0.1)  # dropout on CLS transformed token embedding
        self.ent_dropout = nn.Dropout(0.5)  # dropout on average entity embedding
        self.ffn = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.classifier = nn.Linear(config.hidden_size*3, self.config.num_labels)
        self.init_weights()

def forward(self, ......):
        ...
        e1_h = self.ent_dropout(extract_entity(sequence_output, e1_mask))
        e2_h = self.ent_dropout(extract_entity(sequence_output, e2_mask))
        e1_h = self.ffn(self.activation(e1_h))
        e2_h = self.ffn(self.activation(e2_h))
        context = self.cls_dropout(pooled_output)
        pooled_output = torch.cat([context, e1_h, e2_h], dim=-1)
        ...

I am not sure whether this will improve the performance but you can easily try it.

Hope this answers your question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants