Skip to content

Commit

Permalink
[s2s] Delete useless method, log tokens_per_batch (#6081)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Jul 28, 2020
1 parent dc4755c commit dafa296
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
23 changes: 14 additions & 9 deletions examples/seq2seq/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,16 @@ def _step(self, batch: dict) -> Tuple:
)
return (loss,)

@property
def pad(self) -> int:
return self.tokenizer.pad_token_id

def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)

logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum()
return {"loss": loss_tensors[0], "log": logs}

def validation_step(self, batch, batch_idx) -> Dict:
Expand All @@ -172,7 +179,7 @@ def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges)
Expand All @@ -190,23 +197,21 @@ def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_rouge(preds, target)

def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(
input_ids=source_ids,
attention_mask=source_mask,
batch["input_ids"],
attention_mask=batch["attention_mask"],
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
gen_time = (time.time() - t0) / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics

def test_step(self, batch, batch_idx):
Expand Down
6 changes: 0 additions & 6 deletions examples/seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,6 @@ def __getitem__(self, index) -> Dict[str, torch.Tensor]:
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]

@staticmethod
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
return source_ids, source_mask, y

def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
Expand Down

0 comments on commit dafa296

Please sign in to comment.