-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
124 lines (107 loc) · 4.14 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from tqdm import trange
from nlp_practice.case.translation.training.utils import (
create_masks,
create_padding_masks,
)
from nlp_practice.model.decoder import Decoder
from nlp_practice.model.encoder import EncoderRNN
from nlp_practice.model.transformer import Seq2SeqTransformer
class Seq2SeqTrainer:
def __init__(
self,
train_dataloader: DataLoader,
encoder: EncoderRNN,
decoder: Decoder,
num_epochs: int,
learning_rate: float,
print_log_frequency: int = 10,
):
self.train_dataloader = train_dataloader
self.encoder = encoder
self.decoder = decoder
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.print_log_frequency = print_log_frequency
self._criterion = nn.NLLLoss()
self._encoder_optimizer = optim.Adam(
self.encoder.parameters(), lr=self.learning_rate
)
self._decoder_optimizer = optim.Adam(
self.decoder.parameters(), lr=self.learning_rate
)
def _train_per_epoch(self) -> float:
total_loss = 0
num_batches = len(self.train_dataloader)
if num_batches == 0:
raise ValueError("Empty dataloader. Cannot train without any batches.")
for input_tensor, target_tensor in self.train_dataloader:
self._encoder_optimizer.zero_grad()
self._decoder_optimizer.zero_grad()
encoder_outputs, encoder_hidden = self.encoder(input_tensor)
decoder_outputs, _, _ = self.decoder(
encoder_outputs, encoder_hidden, target_tensor
)
loss = self._criterion(
decoder_outputs.view(-1, decoder_outputs.size(-1)),
target_tensor.view(-1),
)
loss.backward()
self._encoder_optimizer.step()
self._decoder_optimizer.step()
total_loss += loss.item()
return total_loss / num_batches
def train(self) -> list[float]:
return [self._train_per_epoch() for _ in trange(self.num_epochs)]
class TransformerTrainer:
def __init__(
self,
train_dataloader: DataLoader,
transformer: Seq2SeqTransformer,
num_epochs: int,
learning_rate: float,
print_log_frequency: int = 10,
):
self.train_dataloader = train_dataloader
self.transformer = transformer
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.print_log_frequency = print_log_frequency
self._criterion = nn.NLLLoss()
self._optimizer = optim.Adam(
self.transformer.parameters(), lr=self.learning_rate
)
def _train_per_epoch(self) -> float:
total_loss = 0
num_batches = len(self.train_dataloader)
if num_batches == 0:
raise ValueError("Empty dataloader. Cannot train without any batches.")
for input_tensor, target_tensor in self.train_dataloader:
self._optimizer.zero_grad()
target_input, target_output = target_tensor[:, :-1], target_tensor[:, 1:]
input_mask, output_mask = create_masks(input_tensor, target_input)
input_padding_mask, output_padding_mask = create_padding_masks(
input_tensor, target_input
)
logits = self.transformer(
input=input_tensor,
output=target_input,
input_mask=input_mask,
output_mask=output_mask,
memory_mask=None,
input_padding_mask=input_padding_mask,
output_padding_mask=output_padding_mask,
memory_key_padding_mask=None,
)
loss = self._criterion(
logits.reshape(-1, logits.size(-1)),
target_output.reshape(-1),
)
loss.backward()
self._optimizer.step()
total_loss += loss.item()
return total_loss / num_batches
def train(self) -> list[float]:
return [self._train_per_epoch() for _ in trange(self.num_epochs)]