-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
36 lines (26 loc) · 1.15 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Author: lqxu
from torch import nn
from core.models import BaseConfig, BaseModel
from core.modules import EfficientGlobalPointer
class GPLinkerEEConfig(BaseConfig):
def __init__(self, n_argument_labels: int, head_size: int = 64, dropout: float = 0.3, **kwargs):
super(GPLinkerEEConfig, self).__init__(**kwargs)
self.n_argument_labels = n_argument_labels
self.dropout = dropout
self.head_size = head_size
class GPLinkerEEModel(BaseModel):
def __init__(self, config: GPLinkerEEConfig):
super(GPLinkerEEModel, self).__init__(config)
self.dropout = nn.Dropout(config.dropout)
self.argument_classifier = EfficientGlobalPointer(
config=config.bert_config, num_tags=config.n_argument_labels,
head_size=config.head_size, use_rope=True
)
self.head_classifier = EfficientGlobalPointer(
config=config.bert_config, num_tags=1,
head_size=config.head_size, use_rope=False
)
self.tail_classifier = EfficientGlobalPointer(
config=config.bert_config, num_tags=1,
head_size=config.head_size, use_rope=False
)