Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Custom metrics from compute_loss in TGA (2nd try) #4913

Merged
merged 2 commits into from Jan 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions parlai/core/torch_generator_agent.py
Expand Up @@ -717,6 +717,12 @@ def _encoder_input(self, batch):
"""
return self._model_input(batch)

def record_per_token_metrics(self, batch, loss_per_token):
"""
Override this method for custom loss values that require loss_per_token.
"""
pass

def compute_loss(self, batch, return_output=False):
"""
Compute and return the loss for the given batch.
Expand Down Expand Up @@ -752,6 +758,7 @@ def compute_loss(self, batch, return_output=False):
self.record_local_metric(
'token_em', AverageMetric.many(num_tokens_correct == num_target_tokens)
)
self.record_per_token_metrics(batch, loss_per_token)

# actually do backwards loss
loss = loss_per_token.sum(dim=1)
Expand Down