From a9ee9798e6c34f86627bef79253eedfe680f5aa0 Mon Sep 17 00:00:00 2001 From: Eu-Bin KIM Date: Wed, 8 Jun 2022 13:57:33 +0100 Subject: [PATCH] [#22] logcallback.py simplified; decoding is done only at the end of epoch to reduce computation overhead during training --- cleanformer/logcallback.py | 122 +++++------------- cleanformer/models/transformer.py | 24 ++-- cleanformer/translator.py | 3 +- ...xplore_torch_cross_entropy_reduce_none.py} | 0 main_train.py | 2 +- 5 files changed, 46 insertions(+), 105 deletions(-) rename explore/{explore_torch._cross_entropy_reduce_none.py => explore_torch_cross_entropy_reduce_none.py} (100%) diff --git a/cleanformer/logcallback.py b/cleanformer/logcallback.py index 36cc896..182a610 100644 --- a/cleanformer/logcallback.py +++ b/cleanformer/logcallback.py @@ -1,9 +1,11 @@ -from typing import Tuple, List +from typing import Tuple import torch # noqa import wandb from pytorch_lightning import Callback, Trainer from tokenizers import Tokenizer # noqa +from torch.utils.data import DataLoader # noqa from torchmetrics import functional as metricsF # noqa +from torch.nn import functional as torchF # noqa from cleanformer.models.transformer import Transformer @@ -14,39 +16,6 @@ class LogCallback(Callback): def __init__(self, tokenizer: Tokenizer): self.tokenizer = tokenizer - self.cache = {"train": dict(), "validation": dict(), "test": dict()} - - self.cache["train"].clear() - - def on_train_epoch_start(self, *args, **kwargs) -> None: - self.cache["train"].clear() - - def on_validation_epoch_start(self, *args, **kwargs) -> None: - self.cache["validation"].clear() - - def on_test_epoch_start(self, *args, **kwargs) -> None: - self.cache["test"].clear() - - def on_any_batch_end( - self, - key: str, - transformer: Transformer, - src: torch.Tensor, - tgt_r: torch.Tensor, - tgt_ids: torch.Tensor, - losses: List[float], - ) -> Tuple[List[List[List[str]]], List[List[str]]]: - inputs = self.tokenizer.decode_batch(src[:, 0].cpu().tolist()) - answers = self.tokenizer.decode_batch(tgt_ids.cpu().tolist()) - predictions = self.tokenizer.decode_batch(transformer.infer(src, tgt_r).cpu().tolist()) - self.cache[key]["inputs"] = self.cache[key].get("inputs", list()) + inputs - self.cache[key]["answers"] = self.cache[key].get("answers", list()) + answers - self.cache[key]["predictions"] = self.cache[key].get("predictions", list()) + predictions - self.cache[key]["losses"] = self.cache[key].get("losses", list()) + losses - # to make them compatible with torchmetrics.functional.bleu_score() - answers = [[answer.split()] for answer in answers] - predictions = [prediction.split() for prediction in predictions] - return answers, predictions @torch.no_grad() def on_train_batch_end( @@ -67,20 +36,6 @@ def on_train_batch_end( on_step=True, on_epoch=True, ) - answers, predictions = self.on_any_batch_end( - "train", transformer, src, tgt_r, tgt_ids, out["losses"].cpu().tolist() - ) - transformer.log( - "train/bleu", - metricsF.bleu_score( - answers, - predictions, - n_gram=transformer.hparams["n_gram"], - smooth=transformer.hparams["smooth"], - ), - on_step=True, - on_epoch=True, - ) @torch.no_grad() def on_validation_batch_end( @@ -101,19 +56,6 @@ def on_validation_batch_end( metricsF.accuracy(out["logits"], tgt_ids, ignore_index=transformer.hparams["pad_token_id"]), on_epoch=True, ) - answers, predictions = self.on_any_batch_end( - "validation", transformer, src, tgt_r, tgt_ids, out["losses"].cpu().tolist() - ) - transformer.log( - "validation/bleu_epoch", - metricsF.bleu_score( - answers, - predictions, - n_gram=transformer.hparams["n_gram"], - smooth=transformer.hparams["smooth"], - ), - on_epoch=True, - ) @torch.no_grad() def on_test_batch_end( @@ -133,44 +75,50 @@ def on_test_batch_end( metricsF.accuracy(out["logits"], tgt_ids, ignore_index=transformer.hparams["pad_token_id"]), on_epoch=True, ) - answers, predictions = self.on_any_batch_end( - "test", transformer, src, tgt_r, tgt_ids, out["losses"].cpu().tolist() - ) - transformer.log( - "test/bleu_epoch", - metricsF.bleu_score( - answers, - predictions, - n_gram=transformer.hparams["n_gram"], - smooth=transformer.hparams["smooth"], - ), - on_epoch=True, - ) # --- for logging on epoch end --- # - @torch.no_grad() - def on_any_epoch_end(self, key: str): + def on_any_epoch_end(self, key: str, dataloader: DataLoader, transformer: Transformer): """ log BLEU scores, along with qualitative infos """ - inputs = self.cache[key]["inputs"] - predictions = self.cache[key]["predictions"] - answers = self.cache[key]["answers"] - losses = self.cache[key]["losses"] + inputs = list() + answers = list() + predictions = list() + losses = list() + for batch in dataloader: + src, tgt_r, tgt_ids = batch + inputs += self.tokenizer.decode_batch(src[:, 0].cpu().tolist()) + answers += self.tokenizer.decode_batch(tgt_ids.cpu().tolist()) + tgt_hat_ids, logits = transformer.infer(src, tgt_r) + predictions += self.tokenizer.decode_batch(tgt_hat_ids.cpu().tolist()) + losses += ( + torchF.cross_entropy( + logits, tgt_ids, ignore_index=transformer.hparams["pad_token_id"], reduction="none" + ) + .mean(dim=-1) + .cpu() + .tolist() + ) # (N, L) -> (N,) -> list wandb.log( { f"{key}/examples": wandb.Table( columns=["input", "prediction", "answer", "losses"], data=list(zip(inputs, predictions, answers, losses)), - ) + ), + f"{key}/bleu_epoch": metricsF.bleu_score( + answers, + predictions, + n_gram=transformer.hparams["n_gram"], + smooth=transformer.hparams["smooth"], + ), } ) - def on_train_epoch_end(self, *args, **kwargs) -> None: - self.on_any_epoch_end("train") # noqa + def on_train_epoch_end(self, trainer: Trainer, transformer: Transformer): + self.on_any_epoch_end("train", trainer.train_dataloader, transformer) - def on_validation_epoch_end(self, *args, **kwargs): - self.on_any_epoch_end("validation") # noqa + def on_validation_epoch_end(self, trainer: Trainer, transformer: Transformer): + self.on_any_epoch_end("validation", trainer.val_dataloaders, transformer) - def on_test_epoch_end(self, *args, **kwargs): - self.on_any_epoch_end("test") # noqa + def on_test_epoch_end(self, trainer: Trainer, transformer: Transformer): + self.on_any_epoch_end("test", trainer.test_dataloaders, transformer) diff --git a/cleanformer/models/transformer.py b/cleanformer/models/transformer.py index a120684..0b7d093 100755 --- a/cleanformer/models/transformer.py +++ b/cleanformer/models/transformer.py @@ -77,16 +77,13 @@ def step( ) # ... -> (N, L, H) cls = self.token_embeddings.weight # (|V|, H) - reuse the embeddings as the classifier logits = torch.einsum("...lh,vh->...vl", hidden, cls) # (N, |V|, L) - losses = torchF.cross_entropy( - logits, - tgt_ids, - ignore_index=self.hparams["pad_token_id"], - reduction="none", # so that we can explore each instance by loss - ) # (N, |V|, L), (N, L) -> (N, L) - return losses, logits + loss = torchF.cross_entropy( + logits, tgt_ids, ignore_index=self.hparams["pad_token_id"] + ) # (N, |V|, L), (N, L) -> (1,) + return loss, logits @torch.no_grad() - def infer(self, src: torch.Tensor, tgt_r: torch.Tensor) -> torch.Tensor: + def infer(self, src: torch.Tensor, tgt_r: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ An implementation of autoregressive inference """ @@ -115,7 +112,7 @@ def infer(self, src: torch.Tensor, tgt_r: torch.Tensor) -> torch.Tensor: tgt_r_key_padding_mask[:, t] = torch.where( # noqa tgt_r_ids[:, t] == self.hparams["eos_token_id"], 0, 1 ) - return tgt_r_ids + return tgt_r_ids, logits # noqa def on_train_start(self): """ @@ -128,13 +125,8 @@ def on_train_start(self): def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args, **kwargs) -> dict: src, tgt_r, tgt_ids = batch - losses, logits = self.step(src, tgt_r, tgt_ids) - return { - "loss": losses.mean(dim=-1).mean(dim=-1), # (N, L) -> (N,) -> (1,) - # --- for logging purposes --- # - "losses": losses.mean(dim=-1).detach(), # (N, L) -> (N,) - "logits": logits.detach(), # (N, |V|, L) - } + loss, logits = self.step(src, tgt_r, tgt_ids) + return {"loss": loss, "logits": logits.detach()} # (N, L) -> (N,) -> (1,) # (N, |V|, L) @torch.no_grad() def validation_step( diff --git a/cleanformer/translator.py b/cleanformer/translator.py index f3ad0ba..a60873d 100644 --- a/cleanformer/translator.py +++ b/cleanformer/translator.py @@ -17,7 +17,8 @@ def __call__(self, sentences: List[str]) -> Tuple[List[str], List[str]]: x2y = [(sent, "") for sent in sentences] src = P.to_src(self.tokenizer, self.transformer.hparams["max_length"], x2y) tgt_r = P.to_tgt_r(self.tokenizer, self.transformer.hparams["max_length"], x2y) - tgt_hat_ids = self.transformer.infer(src, tgt_r).tolist() # (N, L) -> list + tgt_hat_ids, _ = self.transformer.infer(src, tgt_r) + tgt_hat_ids = tgt_hat_ids.tolist() # (N, L) -> list src_ids = src[:, 0].tolist() # (N, 2, L) -> (N, L) -> list inputs = self.tokenizer.decode_batch(src_ids, skip_special_tokens=True) predictions = self.tokenizer.decode_batch(tgt_hat_ids, skip_special_tokens=True) diff --git a/explore/explore_torch._cross_entropy_reduce_none.py b/explore/explore_torch_cross_entropy_reduce_none.py similarity index 100% rename from explore/explore_torch._cross_entropy_reduce_none.py rename to explore/explore_torch_cross_entropy_reduce_none.py diff --git a/main_train.py b/main_train.py index 642fc2b..ae420b2 100755 --- a/main_train.py +++ b/main_train.py @@ -95,7 +95,7 @@ verbose=config["verbose"], every_n_epochs=config["every_n_epochs"], save_on_train_epoch_end=config["save_on_train_epoch_end"], - save_last=config['save_last'], + save_last=config["save_last"], ), LearningRateMonitor(logging_interval="epoch"), LogCallback(tokenizer),