Skip to content

Commit

Permalink
[#22] learning rate scheduler added
Browse files Browse the repository at this point in the history
  • Loading branch information
eubinecto committed Jun 4, 2022
1 parent e1f95cc commit 39bffac
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 43 deletions.
12 changes: 6 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ kor2eng:

train:
python3 main_train.py \
--max_epochs=60 \
--save_top_k=5 \
--max_epochs=1000 \
--batch_size=128 \
--save_on_train_epoch_end=1 \
--every_n_epochs=1 \
--log_every_n_steps=2 \
--log_every_n_steps=10 \
--check_val_every_n_epoch=1

train_check:
python3 main_train.py \
--fast_dev_run \
--max_epochs=60 \
--save_top_k=5 \
--max_epochs=1000 \
--batch_size=3 \
--save_on_train_epoch_end=1 \
--every_n_epochs=1 \
--log_every_n_steps=2 \
--log_every_n_steps=10 \
--check_val_every_n_epoch=1


Expand Down
2 changes: 1 addition & 1 deletion cleanformer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def on_any_epoch_end(self, key: str):
columns=["input", "prediction", "answer", "losses"],
data=list(zip(inputs, predictions, answers, losses)),
)
self.logger.log_metrics({"Train/BLEU": float(F.bleu_score(answers, predictions))})
self.logger.log_metrics({f"{key}/BLEU": float(F.bleu_score(answers, predictions))})

def on_train_epoch_end(self, *args, **kwargs):
self.on_any_epoch_end("Train")
Expand Down
32 changes: 26 additions & 6 deletions cleanformer/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import torch # noqa
from pytorch_lightning import LightningModule
from tokenizers import Tokenizer # noqa
from torch.optim.lr_scheduler import ReduceLROnPlateau # noqa
from torchmetrics import functional as metricsF # noqa
from torch.nn import functional as torchF # noqa
from cleanformer.models.decoder import Decoder
from cleanformer.models.encoder import Encoder
from cleanformer.models import functional as cleanF # noqa
from torch.nn import functional as torchF # noqa
from torchmetrics import functional as metricsF # noqa


class Transformer(LightningModule): # lgtm [py/missing-call-to-init]
Expand All @@ -21,9 +22,10 @@ def __init__(
depth: int,
dropout: float,
lr: float, # noqa
**kwargs # noqa
):
super().__init__()
self.save_hyperparameters(ignore="tokenizer")
self.save_hyperparameters()
self.token_embeddings = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
self.encoder = Encoder(hidden_size, ffn_size, max_length, heads, depth, dropout) # the encoder stack
self.decoder = Decoder(hidden_size, ffn_size, max_length, heads, depth, dropout) # the decoder stack
Expand Down Expand Up @@ -135,8 +137,26 @@ def validation_step(
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args, **kwargs) -> dict:
return self.training_step(batch, *args, **kwargs)

def configure_optimizers(self):
# --- for optimisation --- #
def configure_optimizers(self) -> dict:
optimizer = torch.optim.Adam(
params=self.parameters(), lr=self.hparams["lr"], betas=(0.9, 0.98), eps=1e-9
params=self.parameters(),
lr=self.hparams["lr"],
betas=self.hparams['betas'],
eps=self.hparams['eps'],
weight_decay=self.hparams['weight_decay']
)
return optimizer
scheduler = ReduceLROnPlateau(
optimizer,
verbose=True,
mode=self.hparams['mode'],
patience=self.hparams['patience'],
cooldown=self.hparams['cooldown']
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": self.hparams["monitor"]
}
}
34 changes: 21 additions & 13 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
# --- config for training a transformer --- #
transformer:
best: "transformer:overfit"
hidden_size: 512
ffn_size: 512
heads: 32
depth: 3
max_length: 149
batch_size: 3
lr: 0.0001
dropout: 0.0
ffn_size: 2048
heads: 8
depth: 5
max_length: 150
lr: 0.1
dropout: 0.1
tokenizer: "tokenizer:v20"
# for dataloader
seed: 410
shuffle: true
tokenizer: "tokenizer:v20"
monitor: Validation/Loss_epoch
# for ADAM
eps: 0.000000001
betas: [0.9, 0.98]
weight_decay: 0.0005
# for ReduceOnPlateau
monitor: Train/Loss_epoch
mode: min
patience: 10
cooldown: 1

# --- config for building a tokenizer --- #
tokenizer:
Expand All @@ -29,6 +35,8 @@ tokenizer:
eos: "[EOS]"
eos_id: 3

# --- config for building a dataset --- #
kor2eng:
# TODO: do this later, when you need preprocessing the dataset (e.g. upsampling)
# --- config for recommended versions of artifacts --- #
recommended:
transformer: "transformer:overfit"
tokenizer: "tokenizer:v20"
kor2eng: "kor2eng:v0"
7 changes: 5 additions & 2 deletions main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
parser.add_argument("--fast_dev_run", action="store_true", default=False)
required = parser.add_argument_group("required arguments")
required.add_argument("--batch_size", type=int, required=True)
optional = parser.add_argument_group("optional arguments")
optional.add_argument("--fast_dev_run", action="store_true", default=False)
optional.add_argument("--num_workers", type=int, default=os.cpu_count())
args = parser.parse_args()
config = fetch_config()["transformer"]
config.update(vars(args))
Expand Down
21 changes: 6 additions & 15 deletions main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def main():
parser = argparse.ArgumentParser()
required = parser.add_argument_group("required arguments")
required.add_argument("--max_epochs", type=int, required=True)
required.add_argument("--save_top_k", type=int, required=True)
required.add_argument("--batch_size", type=int, required=True)
required.add_argument("--save_on_train_epoch_end", type=int, required=True)
required.add_argument("--every_n_epochs", type=int, required=True)
required.add_argument("--log_every_n_steps", type=int, required=True)
Expand Down Expand Up @@ -63,17 +63,11 @@ def main():
num_workers=config["num_workers"],
)
# --- instantiate the transformer to train --- #
transformer = Transformer(
config["hidden_size"],
config["ffn_size"],
tokenizer.get_vocab_size(), # vocab_size
config["max_length"],
tokenizer.pad_token_id, # noqa
config["heads"],
config["depth"],
config["dropout"],
config["lr"],
)
config.update({
"vocab_size": tokenizer.get_vocab_size(),
"pad_token_id": tokenizer.pad_token_id # noqa
})
transformer = Transformer(**config)
# --- start wandb context --- #
with wandb.init(project="cleanformer", config=config, tags=[__file__]):
# --- prepare a logger (wandb) and a trainer to use --- #
Expand All @@ -91,9 +85,6 @@ def main():
callbacks=[
ModelCheckpoint(
verbose=True,
monitor=config["monitor"],
mode=config["mode"],
save_top_k=config["save_top_k"],
every_n_epochs=config["every_n_epochs"],
save_on_train_epoch_end=config["save_on_train_epoch_end"],
),
Expand Down

0 comments on commit 39bffac

Please sign in to comment.