Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

fix dialogpt dual usage of END_IDX #3256

Merged
merged 8 commits into from Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions parlai/agents/hugging_face/dialogpt.py
Expand Up @@ -26,6 +26,25 @@ class DialoGPTDecoder(GPT2Decoder):
This decoder is initialized with the pretrained model from Hugging Face.
"""

def __init__(self, opt, dict):
super().__init__(opt, dict)
self.NULL_IDX, self.START_IDX, self.END_IDX = self._get_special_tokens(
opt, dict
)

@staticmethod
def _get_special_tokens(opt, dict):
null_idx = dict.null_idx
if (
opt.get('batchsize', 1) == 1
and not opt['add_special_tokens']
and null_idx == dict.end_idx
):
# get around the dual usage of end_idx that would otherwise mask endtoken during forward pass.
null_idx = -1
warn_once("WARNING: null_idx is set to -1 otherwise null_idx = end_idx")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the warning? IDTS, no?

return null_idx, dict.start_idx, dict.end_idx

def _init_from_pretrained(self, opt):
# load model
model_sz = opt['gpt2_size']
Expand All @@ -38,6 +57,10 @@ class DialoGPTModel(HFGPT2Model):
Hugging Face DialoGPT Model.
"""

def _get_special_tokens(self, opt, dict):
# keep it consistent between DialoGPTModel and DialoGPTDecoder on start_idx, end_idx, null_idx
return DialoGPTDecoder._get_special_tokens(opt, dict)

def _get_decoder(self, opt, dict):
return DialoGPTDecoder(opt, dict)

Expand Down
4 changes: 3 additions & 1 deletion parlai/agents/hugging_face/gpt2.py
Expand Up @@ -85,7 +85,8 @@ def forward(self, input, encoder_state, incr_state=None):
and int(input[0][0]) == self.START_IDX
):
# generating: ignore the start token
model_input = encoder_state
# without deep copy, the padding_idx (-1) in encoder_state can be reset to 0 with clamp_ inplace operation
model_input = encoder_state.clone()
else:
# forced decoding: concatenate the context
# with the labels
Expand All @@ -108,6 +109,7 @@ def forward(self, input, encoder_state, incr_state=None):
model_input = input[:, -1:]
attention_mask = torch.cat([encoder_state, input], dim=-1) != self.NULL_IDX

model_input = model_input.clamp_(min=0)
transformer_outputs = self.transformer(
model_input,
past=incr_state,
Expand Down
105 changes: 103 additions & 2 deletions tests/nightly/gpu/test_dialogpt.py
Expand Up @@ -6,18 +6,119 @@

import unittest
import parlai.utils.testing as testing_utils
from parlai.core.agents import create_agent


@testing_utils.skipUnlessGPU
class TestDialogptModel(unittest.TestCase):
"""
Test of DialoGPT model.

Checks that DialoGPT gets a certain performance on the integration test task.
"""

def _test_batchsize(self, batchsize, add_start_token):
utterances = [
'How is your day so far?',
'I hope you you have a good day.',
"Nice to meet you. My name is John. ",
"I've got a feeling we're not in Kansas anymore.",
]
opt = {
'model': 'hugging_face/dialogpt',
'gpt2_size': 'small',
'text_truncate': 100,
'label_truncate': 20,
'beam_min_length': 1,
'inference': 'beam',
'beam_size': 1,
'add_special_tokens': True,
'batchsize': batchsize,
'add_start_token': add_start_token,
}
dialogpt = create_agent(opt)

results_single = []
agents = [dialogpt.clone() for _ in utterances]
for u, a in zip(utterances, agents):
a.observe({'text': u, 'episode_done': True})
generation = a.act()['text']
results_single.append(generation)

results_batched = []
for idx in range(len(utterances) // batchsize):
agents = [dialogpt.clone() for _ in range(batchsize)]
batch = utterances[idx * batchsize : (idx + 1) * batchsize]
obs = []
for i, a in enumerate(agents):
obs.append(a.observe({'text': batch[i], 'episode_done': True}))
generations = [x['text'] for x in dialogpt.batch_act(obs)]
results_batched += generations

assert results_single == results_batched

def test_batchsize(self):
"""
Ensures dialogpt provides the same generation results regardless of batchsize.
"""
# Test throwing the RuntimeError with add_special_tokens = False and batchsize > 1
with self.assertRaises(RuntimeError):
create_agent(
{
'model': 'hugging_face/dialogpt',
'add_special_tokens': False,
'batchsize': 2,
}
)

for batchsize in [1, 2, 4]:
for add_start_token in [True, False]:
with self.subTest(
f'test_batchsize with bs={batchsize} and add_start_token={add_start_token}'
):
self._test_batchsize(batchsize, add_start_token)

def test_start_token(self):
"""
Test RuntimeError is thrown when add_start_token = True and yet add_special_tokens = False
"""
with self.assertRaises(RuntimeError):
create_agent(
{
'model': 'hugging_face/dialogpt',
'add_special_tokens': False,
'add_start_token': True,
}
)

def test_nospecialtok(self):
"""
Test generation consistency for off-the-shelf dialogpt models.
"""
test_cases = [
("What a nice weather!", "I'm in the UK and it's raining here."),
("Nice to meet you!", "Hello! I'm from the future!"),
]
opt = {
'model': 'hugging_face/dialogpt',
'gpt2_size': 'small',
'text_truncate': 100,
'label_truncate': 20,
'beam_min_length': 1,
'inference': 'beam',
'beam_size': 1,
'add_special_tokens': False,
'batchsize': 1,
}
dialogpt = create_agent(opt)
for text, label in test_cases:
dialogpt.observe({'text': text, 'episode_done': True})
response = dialogpt.act()
assert response['text'] == label

@testing_utils.retry(ntries=3, log_retry=True)
def test_dialogpt(self):
"""
Checks that DialoGPT gets a certain performance on the integration test task.
"""
valid, test = testing_utils.train_model(
dict(
task='integration_tests:overfit',
Expand Down
5 changes: 5 additions & 0 deletions tests/nightly/gpu/test_gpt2.py
Expand Up @@ -12,9 +12,14 @@
import parlai.scripts.build_dict as build_dict
import os
import copy
from parlai.utils.misc import warn_once


class TestGpt2(unittest.TestCase):
warn_once(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a comment, not a warning plz

'WARNING: DID YOU IMPLEMENT A TEST FOR DIALOGPT TOO? YOU KNOW YOU NEED TO TEST IT TOO!'
)

def _test_batchsize(self, batchsize, add_start_token):
utterances = [
'Just keep swimming -',
Expand Down