Skip to content

Commit

Permalink
[#22] typo fixed (LogCallback)
Browse files Browse the repository at this point in the history
  • Loading branch information
eubinecto committed Jun 6, 2022
1 parent 454921f commit 185bd21
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 4 additions & 1 deletion cleanformer/logger.py → cleanformer/logcallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cleanformer.models.transformer import Transformer


class Logger(Callback):
class LogCallback(Callback):
"""
For logging loss, perplexity, accuracy, BLEU along with qualitative results.
"""
Expand Down Expand Up @@ -42,6 +42,7 @@ def on_train_batch_end(
transformer: Transformer,
out: dict,
batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
*args,
**kwargs
) -> None:
src, tgt_r, tgt_ids = batch
Expand All @@ -59,6 +60,7 @@ def on_validation_batch_end(
transformer: Transformer,
out: dict,
batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
*args,
**kwargs
) -> None:
# logging validation metrics for each batch is unnecessary
Expand All @@ -77,6 +79,7 @@ def on_test_batch_end(
transformer: Transformer,
out: dict,
batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
*args,
**kwargs
) -> None:
src, tgt_r, tgt_ids = batch
Expand Down
4 changes: 2 additions & 2 deletions main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader, TensorDataset # noqa
from cleanformer.logger import Logger
from cleanformer.logger import LogCallback
from cleanformer.fetchers import fetch_kor2eng, fetch_config, fetch_tokenizer, fetch_transformer
from cleanformer import preprocess as P # noqa

Expand Down Expand Up @@ -48,7 +48,7 @@
fast_dev_run=config["fast_dev_run"],
gpus=torch.cuda.device_count(),
logger=logger,
callbacks=[Logger(tokenizer)],
callbacks=[LogCallback(tokenizer)],
)
# start testing here
trainer.test(model=transformer, dataloaders=test_dataloader)
Expand Down
4 changes: 2 additions & 2 deletions main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import DataLoader, TensorDataset # noqa
from cleanformer import preprocess as P # noqa
from cleanformer.fetchers import fetch_tokenizer, fetch_config, fetch_kor2eng
from cleanformer.logger import Logger
from cleanformer.logger import LogCallback
from cleanformer.models.transformer import Transformer
from cleanformer.paths import WANDB_DIR

Expand Down Expand Up @@ -94,7 +94,7 @@
save_on_train_epoch_end=config["save_on_train_epoch_end"],
),
LearningRateMonitor(logging_interval="epoch"),
Logger(tokenizer)
LogCallback(tokenizer)
],
)
# --- start training --- #
Expand Down

0 comments on commit 185bd21

Please sign in to comment.