Skip to content

Commit

Permalink
[#22] logcallback.py simplified; decoding is done only at the end of …
Browse files Browse the repository at this point in the history
…epoch to reduce computation overhead during training
  • Loading branch information
eubinecto committed Jun 8, 2022
1 parent ff805f0 commit a9ee979
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 105 deletions.
122 changes: 35 additions & 87 deletions cleanformer/logcallback.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Tuple, List
from typing import Tuple
import torch # noqa
import wandb
from pytorch_lightning import Callback, Trainer
from tokenizers import Tokenizer # noqa
from torch.utils.data import DataLoader # noqa
from torchmetrics import functional as metricsF # noqa
from torch.nn import functional as torchF # noqa
from cleanformer.models.transformer import Transformer


Expand All @@ -14,39 +16,6 @@ class LogCallback(Callback):

def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.cache = {"train": dict(), "validation": dict(), "test": dict()}

self.cache["train"].clear()

def on_train_epoch_start(self, *args, **kwargs) -> None:
self.cache["train"].clear()

def on_validation_epoch_start(self, *args, **kwargs) -> None:
self.cache["validation"].clear()

def on_test_epoch_start(self, *args, **kwargs) -> None:
self.cache["test"].clear()

def on_any_batch_end(
self,
key: str,
transformer: Transformer,
src: torch.Tensor,
tgt_r: torch.Tensor,
tgt_ids: torch.Tensor,
losses: List[float],
) -> Tuple[List[List[List[str]]], List[List[str]]]:
inputs = self.tokenizer.decode_batch(src[:, 0].cpu().tolist())
answers = self.tokenizer.decode_batch(tgt_ids.cpu().tolist())
predictions = self.tokenizer.decode_batch(transformer.infer(src, tgt_r).cpu().tolist())
self.cache[key]["inputs"] = self.cache[key].get("inputs", list()) + inputs
self.cache[key]["answers"] = self.cache[key].get("answers", list()) + answers
self.cache[key]["predictions"] = self.cache[key].get("predictions", list()) + predictions
self.cache[key]["losses"] = self.cache[key].get("losses", list()) + losses
# to make them compatible with torchmetrics.functional.bleu_score()
answers = [[answer.split()] for answer in answers]
predictions = [prediction.split() for prediction in predictions]
return answers, predictions

@torch.no_grad()
def on_train_batch_end(
Expand All @@ -67,20 +36,6 @@ def on_train_batch_end(
on_step=True,
on_epoch=True,
)
answers, predictions = self.on_any_batch_end(
"train", transformer, src, tgt_r, tgt_ids, out["losses"].cpu().tolist()
)
transformer.log(
"train/bleu",
metricsF.bleu_score(
answers,
predictions,
n_gram=transformer.hparams["n_gram"],
smooth=transformer.hparams["smooth"],
),
on_step=True,
on_epoch=True,
)

@torch.no_grad()
def on_validation_batch_end(
Expand All @@ -101,19 +56,6 @@ def on_validation_batch_end(
metricsF.accuracy(out["logits"], tgt_ids, ignore_index=transformer.hparams["pad_token_id"]),
on_epoch=True,
)
answers, predictions = self.on_any_batch_end(
"validation", transformer, src, tgt_r, tgt_ids, out["losses"].cpu().tolist()
)
transformer.log(
"validation/bleu_epoch",
metricsF.bleu_score(
answers,
predictions,
n_gram=transformer.hparams["n_gram"],
smooth=transformer.hparams["smooth"],
),
on_epoch=True,
)

@torch.no_grad()
def on_test_batch_end(
Expand All @@ -133,44 +75,50 @@ def on_test_batch_end(
metricsF.accuracy(out["logits"], tgt_ids, ignore_index=transformer.hparams["pad_token_id"]),
on_epoch=True,
)
answers, predictions = self.on_any_batch_end(
"test", transformer, src, tgt_r, tgt_ids, out["losses"].cpu().tolist()
)
transformer.log(
"test/bleu_epoch",
metricsF.bleu_score(
answers,
predictions,
n_gram=transformer.hparams["n_gram"],
smooth=transformer.hparams["smooth"],
),
on_epoch=True,
)

# --- for logging on epoch end --- #
@torch.no_grad()
def on_any_epoch_end(self, key: str):
def on_any_epoch_end(self, key: str, dataloader: DataLoader, transformer: Transformer):
"""
log BLEU scores, along with qualitative infos
"""
inputs = self.cache[key]["inputs"]
predictions = self.cache[key]["predictions"]
answers = self.cache[key]["answers"]
losses = self.cache[key]["losses"]
inputs = list()
answers = list()
predictions = list()
losses = list()
for batch in dataloader:
src, tgt_r, tgt_ids = batch
inputs += self.tokenizer.decode_batch(src[:, 0].cpu().tolist())
answers += self.tokenizer.decode_batch(tgt_ids.cpu().tolist())
tgt_hat_ids, logits = transformer.infer(src, tgt_r)
predictions += self.tokenizer.decode_batch(tgt_hat_ids.cpu().tolist())
losses += (
torchF.cross_entropy(
logits, tgt_ids, ignore_index=transformer.hparams["pad_token_id"], reduction="none"
)
.mean(dim=-1)
.cpu()
.tolist()
) # (N, L) -> (N,) -> list
wandb.log(
{
f"{key}/examples": wandb.Table(
columns=["input", "prediction", "answer", "losses"],
data=list(zip(inputs, predictions, answers, losses)),
)
),
f"{key}/bleu_epoch": metricsF.bleu_score(
answers,
predictions,
n_gram=transformer.hparams["n_gram"],
smooth=transformer.hparams["smooth"],
),
}
)

def on_train_epoch_end(self, *args, **kwargs) -> None:
self.on_any_epoch_end("train") # noqa
def on_train_epoch_end(self, trainer: Trainer, transformer: Transformer):
self.on_any_epoch_end("train", trainer.train_dataloader, transformer)

def on_validation_epoch_end(self, *args, **kwargs):
self.on_any_epoch_end("validation") # noqa
def on_validation_epoch_end(self, trainer: Trainer, transformer: Transformer):
self.on_any_epoch_end("validation", trainer.val_dataloaders, transformer)

def on_test_epoch_end(self, *args, **kwargs):
self.on_any_epoch_end("test") # noqa
def on_test_epoch_end(self, trainer: Trainer, transformer: Transformer):
self.on_any_epoch_end("test", trainer.test_dataloaders, transformer)
24 changes: 8 additions & 16 deletions cleanformer/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,13 @@ def step(
) # ... -> (N, L, H)
cls = self.token_embeddings.weight # (|V|, H) - reuse the embeddings as the classifier
logits = torch.einsum("...lh,vh->...vl", hidden, cls) # (N, |V|, L)
losses = torchF.cross_entropy(
logits,
tgt_ids,
ignore_index=self.hparams["pad_token_id"],
reduction="none", # so that we can explore each instance by loss
) # (N, |V|, L), (N, L) -> (N, L)
return losses, logits
loss = torchF.cross_entropy(
logits, tgt_ids, ignore_index=self.hparams["pad_token_id"]
) # (N, |V|, L), (N, L) -> (1,)
return loss, logits

@torch.no_grad()
def infer(self, src: torch.Tensor, tgt_r: torch.Tensor) -> torch.Tensor:
def infer(self, src: torch.Tensor, tgt_r: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
An implementation of autoregressive inference
"""
Expand Down Expand Up @@ -115,7 +112,7 @@ def infer(self, src: torch.Tensor, tgt_r: torch.Tensor) -> torch.Tensor:
tgt_r_key_padding_mask[:, t] = torch.where( # noqa
tgt_r_ids[:, t] == self.hparams["eos_token_id"], 0, 1
)
return tgt_r_ids
return tgt_r_ids, logits # noqa

def on_train_start(self):
"""
Expand All @@ -128,13 +125,8 @@ def on_train_start(self):

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args, **kwargs) -> dict:
src, tgt_r, tgt_ids = batch
losses, logits = self.step(src, tgt_r, tgt_ids)
return {
"loss": losses.mean(dim=-1).mean(dim=-1), # (N, L) -> (N,) -> (1,)
# --- for logging purposes --- #
"losses": losses.mean(dim=-1).detach(), # (N, L) -> (N,)
"logits": logits.detach(), # (N, |V|, L)
}
loss, logits = self.step(src, tgt_r, tgt_ids)
return {"loss": loss, "logits": logits.detach()} # (N, L) -> (N,) -> (1,) # (N, |V|, L)

@torch.no_grad()
def validation_step(
Expand Down
3 changes: 2 additions & 1 deletion cleanformer/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __call__(self, sentences: List[str]) -> Tuple[List[str], List[str]]:
x2y = [(sent, "") for sent in sentences]
src = P.to_src(self.tokenizer, self.transformer.hparams["max_length"], x2y)
tgt_r = P.to_tgt_r(self.tokenizer, self.transformer.hparams["max_length"], x2y)
tgt_hat_ids = self.transformer.infer(src, tgt_r).tolist() # (N, L) -> list
tgt_hat_ids, _ = self.transformer.infer(src, tgt_r)
tgt_hat_ids = tgt_hat_ids.tolist() # (N, L) -> list
src_ids = src[:, 0].tolist() # (N, 2, L) -> (N, L) -> list
inputs = self.tokenizer.decode_batch(src_ids, skip_special_tokens=True)
predictions = self.tokenizer.decode_batch(tgt_hat_ids, skip_special_tokens=True)
Expand Down
2 changes: 1 addition & 1 deletion main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
verbose=config["verbose"],
every_n_epochs=config["every_n_epochs"],
save_on_train_epoch_end=config["save_on_train_epoch_end"],
save_last=config['save_last'],
save_last=config["save_last"],
),
LearningRateMonitor(logging_interval="epoch"),
LogCallback(tokenizer),
Expand Down

0 comments on commit a9ee979

Please sign in to comment.