forked from flairNLP/flair
/
nn.py
188 lines (149 loc) · 6.25 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import warnings
from pathlib import Path
import torch.nn
from abc import abstractmethod
from typing import Union, List
import flair
from flair.data import Sentence
from flair.training_utils import Result
class Model(torch.nn.Module):
"""Abstract base class for all downstream task models in Flair, such as SequenceTagger and TextClassifier.
Every new type of model must implement these methods."""
@abstractmethod
def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.tensor:
"""Performs a forward pass and returns a loss tensor for backpropagation. Implement this to enable training."""
pass
@abstractmethod
def predict(
self, sentences: Union[List[Sentence], Sentence], mini_batch_size=32
) -> List[Sentence]:
"""Predicts the labels/tags for the given list of sentences. The labels/tags are added directly to the
sentences. Implement this to enable prediction."""
pass
@abstractmethod
def evaluate(
self,
sentences: List[Sentence],
eval_mini_batch_size: int = 32,
embeddings_in_memory: bool = False,
out_path: Path = None,
num_workers: int = 8,
) -> (Result, float):
"""Evaluates the model on a list of gold-labeled Sentences. Returns a Result object containing evaluation
results and a loss value. Implement this to enable evaluation."""
pass
@abstractmethod
def _get_state_dict(self):
"""Returns the state dictionary for this model. Implementing this enables the save() and save_checkpoint()
functionality."""
pass
@abstractmethod
def _init_model_with_state_dict(state):
"""Initialize the model from a state dictionary. Implementing this enables the load() and load_checkpoint()
functionality."""
pass
@abstractmethod
def _fetch_model(model_name) -> str:
return model_name
def save(self, model_file: Union[str, Path]):
"""
Saves the current model to the provided file.
:param model_file: the model file
"""
model_state = self._get_state_dict()
torch.save(model_state, str(model_file), pickle_protocol=4)
def save_checkpoint(
self,
model_file: Union[str, Path],
optimizer_state: dict,
scheduler_state: dict,
epoch: int,
loss: float,
):
model_state = self._get_state_dict()
# additional fields for model checkpointing
model_state["optimizer_state_dict"] = optimizer_state
model_state["scheduler_state_dict"] = scheduler_state
model_state["epoch"] = epoch
model_state["loss"] = loss
torch.save(model_state, str(model_file), pickle_protocol=4)
@classmethod
def load(cls, model: Union[str, Path]):
"""
Loads the model from the given file.
:param model_file: the model file
:return: the loaded text classifier model
"""
model_file = cls._fetch_model(str(model))
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups
# see https://github.com/zalandoresearch/flair/issues/351
f = flair.file_utils.load_big_file(str(model_file))
state = torch.load(f, map_location=flair.device)
model = cls._init_model_with_state_dict(state)
model.eval()
model.to(flair.device)
return model
@classmethod
def load_checkpoint(cls, checkpoint_file: Union[str, Path]):
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups
# see https://github.com/zalandoresearch/flair/issues/351
f = flair.file_utils.load_big_file(str(checkpoint_file))
state = torch.load(f, map_location=flair.device)
model = cls._init_model_with_state_dict(state)
model.eval()
model.to(flair.device)
epoch = state["epoch"] if "epoch" in state else None
loss = state["loss"] if "loss" in state else None
optimizer_state_dict = (
state["optimizer_state_dict"] if "optimizer_state_dict" in state else None
)
scheduler_state_dict = (
state["scheduler_state_dict"] if "scheduler_state_dict" in state else None
)
return {
"model": model,
"epoch": epoch,
"loss": loss,
"optimizer_state_dict": optimizer_state_dict,
"scheduler_state_dict": scheduler_state_dict,
}
class LockedDropout(torch.nn.Module):
"""
Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space.
"""
def __init__(self, dropout_rate=0.5, inplace=False):
super(LockedDropout, self).__init__()
self.dropout_rate = dropout_rate
self.inplace = inplace
def forward(self, x):
if not self.training or not self.dropout_rate:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
mask = mask.expand_as(x)
return mask * x
def extra_repr(self):
inplace_str = ", inplace" if self.inplace else ""
return "p={}{}".format(self.dropout_rate, inplace_str)
class WordDropout(torch.nn.Module):
"""
Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space.
"""
def __init__(self, dropout_rate=0.05, inplace=False):
super(WordDropout, self).__init__()
self.dropout_rate = dropout_rate
self.inplace = inplace
def forward(self, x):
if not self.training or not self.dropout_rate:
return x
m = x.data.new(x.size(0), 1, 1).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False)
mask = mask.expand_as(x)
return mask * x
def extra_repr(self):
inplace_str = ", inplace" if self.inplace else ""
return "p={}{}".format(self.dropout_rate, inplace_str)