Skip to content

Commit

Permalink
[#22] logging qualitative data for training set as well
Browse files Browse the repository at this point in the history
  • Loading branch information
eubinecto committed Jun 4, 2022
1 parent 90ce15b commit e1f95cc
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 45 deletions.
17 changes: 9 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ train:
--log_every_n_steps=2 \
--check_val_every_n_epoch=1

eval:
python3 main_eval.py \
--max_epochs=1

# pseudo-tests
test_train:
train_check:
python3 main_train.py \
--fast_dev_run \
--max_epochs=60 \
Expand All @@ -29,5 +24,11 @@ test_train:
--log_every_n_steps=2 \
--check_val_every_n_epoch=1

test_eval:
python3 main_eval.py --fast_dev_run

test:
python3 main_test.py \
--max_epochs=1


test_check:
python3 main_test.py --fast_dev_run
88 changes: 59 additions & 29 deletions cleanformer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def on_validation_batch_end(
unused: int = 0,
) -> None:
_, _, tgt = batch
# logging validation metrics for each batch is unnecessary
pl_module.log("Validation/Loss_epoch", out["loss"], on_epoch=True)
pl_module.log("Validation/Perplexity_epoch", torch.exp(out["loss"]), on_epoch=True)
pl_module.log("Validation/Accuracy_epoch", F.accuracy(out["logits"], tgt), on_epoch=True)
Expand All @@ -54,12 +55,43 @@ def __init__(self, logger: WandbLogger, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.cache = dict()

def on_validation_epoch_start(self, trainer: Trainer, pl_module: Transformer) -> None:
self.cache.clear()
def on_train_epoch_start(self, *args, **kwargs):
self.cache.pop("Train", None)
self.cache["Train"] = dict()

def on_test_epoch_start(self, trainer: Trainer, pl_module: Transformer) -> None:
self.cache.clear()
def on_validation_epoch_start(self, *args, **kwargs):
self.cache.pop("Validation", None)
self.cache["Validation"] = dict()

def on_test_epoch_start(self, *args, **kwargs):
self.cache.pop("Test", None)
self.cache["Test"] = dict()

def on_any_batch_end(self, key: str, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
losses: torch.Tensor,
transformer: Transformer):
"""
cache any data needed for logging
"""
src, tgt_r, tgt_ids = batch
tgt_hat_ids = transformer.infer(src, tgt_r)
self.cache[key]["src_ids"] = self.cache[key].get("src_ids", list()) + src[:, 0].cpu().tolist()
self.cache[key]["tgt_ids"] = self.cache[key].get("tgt_ids", list()) + tgt_ids.cpu().tolist()
self.cache[key]["tgt_hat_ids"] = self.cache[key].get("tgt_hat_ids", list()) + tgt_hat_ids.cpu().tolist()
self.cache[key]["losses"] = self.cache[key].get("losses", list()) + losses.cpu().tolist()

@torch.no_grad()
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: Transformer,
out: dict,
batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
batch_idx: int,
unused: int = 0,
) -> None:
self.on_any_batch_end("Train", batch, out['losses'].sum(dim=1), pl_module)

@torch.no_grad()
def on_validation_batch_end(
self,
Expand All @@ -70,13 +102,7 @@ def on_validation_batch_end(
batch_idx: int,
unused: int = 0,
) -> None:
src, tgt_r, tgt = batch # noqa
tgt_hat = pl_module.infer(src, tgt_r)
self.cache["tgt"] = self.cache.get("tgt", list()) + tgt.cpu().tolist()
self.cache["tgt_hat"] = self.cache.get("tgt_hat", list()) + tgt_hat.cpu().tolist()
self.cache["losses"] = (
self.cache.get("losses", list()) + out["losses"].sum(dim=1).cpu().tolist()
) # (N, L) -> (N,)
self.on_any_batch_end("Validation", batch, out['losses'].sum(dim=1), pl_module)

@torch.no_grad()
def on_test_batch_end(
Expand All @@ -88,24 +114,28 @@ def on_test_batch_end(
batch_idx: int,
unused: int = 0,
) -> None:
src, tgt_r, tgt = batch # noqa
self.cache["tgt"] = self.cache.get("tgt", list()) + tgt.cpu().tolist()
self.cache["tgt_hat"] = self.cache.get("tgt_hat", list()) + out["tgt_hat"].cpu().tolist()

def on_validation_epoch_end(self, trainer: Trainer, pl_module: Transformer) -> None:
predictions = self.tokenizer.decode_batch(self.cache["tgt_hat"]) # (N, L) -> (N,) = list
answers = self.tokenizer.decode_batch(self.cache["tgt"]) # (N, L) -> (N,) = list
self.on_any_batch_end("Test", batch, out['losses'].sum(dim=1), pl_module)

def on_any_epoch_end(self, key: str):
"""
log BLEU scores, along with qualitative infos
"""
inputs = self.tokenizer.decode_batch(self.cache[key]['src_ids'])
predictions = self.tokenizer.decode_batch(self.cache[key]['tgt_hat_ids'])
answers = self.tokenizer.decode_batch(self.cache[key]['tgt_ids'])
losses = self.cache[key]['losses']
self.logger.log_text(
f"Validation/samples",
columns=["prediction", "answer", "losses"], # noqa
data=list(zip(predictions, answers, self.cache["losses"])),
f"{key}/Samples",
columns=["input", "prediction", "answer", "losses"],
data=list(zip(inputs, predictions, answers, losses)),
)
self.logger.log_metrics({"Validation/BLEU": float(F.bleu_score(answers, predictions))})
self.logger.log_metrics({"Train/BLEU": float(F.bleu_score(answers, predictions))})

def on_test_epoch_end(self, trainer: Trainer, pl_module: Transformer) -> None:
predictions = self.tokenizer.decode_batch(self.cache["tgt_hat"]) # (N, L) -> (N,) = list
answers = self.tokenizer.decode_batch(self.cache["tgt"]) # (N, L) -> (N,) = list
self.logger.log_text(
f"Test/samples", columns=["prediction", "answer"], data=list(zip(predictions, answers)) # noqa
)
self.logger.log_metrics({"Test/BLEU": float(F.bleu_score(answers, predictions))})
def on_train_epoch_end(self, *args, **kwargs):
self.on_any_epoch_end("Train")

def on_validation_epoch_end(self, *args, **kwargs):
self.on_any_epoch_end("Validation")

def on_test_epoch_end(self, *args, **kwargs):
self.on_any_epoch_end("Test")
10 changes: 2 additions & 8 deletions cleanformer/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
src, tgt_r, tgt = batch
losses, logits = self.step(src, tgt_r, tgt)
return {
"loss": losses.sum(), # (N, L) -> (1,)
# for logging purposes
"loss": losses.sum(), # (N, L) -> (1)
"losses": losses.detach(), # (N, L)
"logits": logits.detach(), # (N, L)
}
Expand All @@ -134,12 +133,7 @@ def validation_step(

@torch.no_grad()
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args, **kwargs) -> dict:
src, tgt_r, _ = batch
tgt_hat = self.infer(src, tgt_r) # ... -> (N, L)
return {
# for logging purposes
"tgt_hat": tgt_hat
}
return self.training_step(batch, *args, **kwargs)

def configure_optimizers(self):
optimizer = torch.optim.Adam(
Expand Down
File renamed without changes.

0 comments on commit e1f95cc

Please sign in to comment.