Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

fix dialogpt dual usage of END_IDX #3256

Merged
merged 8 commits into from Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions parlai/agents/hugging_face/dialogpt.py
Expand Up @@ -26,6 +26,24 @@ class DialoGPTDecoder(GPT2Decoder):
This decoder is initialized with the pretrained model from Hugging Face.
"""

def __init__(self, opt, dict):
super().__init__(opt, dict)
self.NULL_IDX, self.START_IDX, self.END_IDX = self._get_special_tokens(
opt, dict
)

@staticmethod
def _get_special_tokens(opt, dict):
null_idx = dict.null_idx
if (
opt.get('batchsize', 1) == 1
and not opt['add_special_tokens']
and null_idx == dict.end_idx
):
# get around the dual usage of end_idx that would otherwise mask endtoken during forward pass.
null_idx = -1
return null_idx, dict.start_idx, dict.end_idx

def _init_from_pretrained(self, opt):
# load model
model_sz = opt['gpt2_size']
Expand All @@ -38,6 +56,10 @@ class DialoGPTModel(HFGPT2Model):
Hugging Face DialoGPT Model.
"""

def _get_special_tokens(self, opt, dict):
# keep it consistent between DialoGPTModel and DialoGPTDecoder on start_idx, end_idx, null_idx
return DialoGPTDecoder._get_special_tokens(opt, dict)

def _get_decoder(self, opt, dict):
return DialoGPTDecoder(opt, dict)

Expand Down
4 changes: 3 additions & 1 deletion parlai/agents/hugging_face/gpt2.py
Expand Up @@ -85,7 +85,8 @@ def forward(self, input, encoder_state, incr_state=None):
and int(input[0][0]) == self.START_IDX
):
# generating: ignore the start token
model_input = encoder_state
# without deep copy, the padding_idx (-1) in encoder_state can be reset to 0 with clamp_ inplace operation
model_input = encoder_state.clone()
else:
# forced decoding: concatenate the context
# with the labels
Expand All @@ -108,6 +109,7 @@ def forward(self, input, encoder_state, incr_state=None):
model_input = input[:, -1:]
attention_mask = torch.cat([encoder_state, input], dim=-1) != self.NULL_IDX

model_input = model_input.clamp_(min=0)
transformer_outputs = self.transformer(
model_input,
past=incr_state,
Expand Down
105 changes: 103 additions & 2 deletions tests/nightly/gpu/test_dialogpt.py
Expand Up @@ -6,18 +6,119 @@

import unittest
import parlai.utils.testing as testing_utils
from parlai.core.agents import create_agent


@testing_utils.skipUnlessGPU
class TestDialogptModel(unittest.TestCase):
"""
Test of DialoGPT model.

Checks that DialoGPT gets a certain performance on the integration test task.
"""

def _test_batchsize(self, batchsize, add_start_token):
utterances = [
'How is your day so far?',
'I hope you you have a good day.',
"Nice to meet you. My name is John. ",
"I've got a feeling we're not in Kansas anymore.",
]
opt = {
'model': 'hugging_face/dialogpt',
'gpt2_size': 'small',
'text_truncate': 100,
'label_truncate': 20,
'beam_min_length': 1,
'inference': 'beam',
'beam_size': 1,
'add_special_tokens': True,
'batchsize': batchsize,
'add_start_token': add_start_token,
}
dialogpt = create_agent(opt)

results_single = []
agents = [dialogpt.clone() for _ in utterances]
for u, a in zip(utterances, agents):
a.observe({'text': u, 'episode_done': True})
generation = a.act()['text']
results_single.append(generation)

results_batched = []
for idx in range(len(utterances) // batchsize):
agents = [dialogpt.clone() for _ in range(batchsize)]
batch = utterances[idx * batchsize : (idx + 1) * batchsize]
obs = []
for i, a in enumerate(agents):
obs.append(a.observe({'text': batch[i], 'episode_done': True}))
generations = [x['text'] for x in dialogpt.batch_act(obs)]
results_batched += generations

assert results_single == results_batched

def test_batchsize(self):
"""
Ensures dialogpt provides the same generation results regardless of batchsize.
"""
# Test throwing the RuntimeError with add_special_tokens = False and batchsize > 1
with self.assertRaises(RuntimeError):
create_agent(
{
'model': 'hugging_face/dialogpt',
'add_special_tokens': False,
'batchsize': 2,
}
)

for batchsize in [1, 2, 4]:
for add_start_token in [True, False]:
with self.subTest(
f'test_batchsize with bs={batchsize} and add_start_token={add_start_token}'
):
self._test_batchsize(batchsize, add_start_token)

def test_start_token(self):
"""
Test RuntimeError is thrown when add_start_token = True and yet add_special_tokens = False
"""
with self.assertRaises(RuntimeError):
create_agent(
{
'model': 'hugging_face/dialogpt',
'add_special_tokens': False,
'add_start_token': True,
}
)

def test_nospecialtok(self):
"""
Test generation consistency for off-the-shelf dialogpt models.
"""
test_cases = [
("What a nice weather!", "I'm in the UK and it's raining here."),
("Nice to meet you!", "Hello! I'm from the future!"),
]
opt = {
'model': 'hugging_face/dialogpt',
'gpt2_size': 'small',
'text_truncate': 100,
'label_truncate': 20,
'beam_min_length': 1,
'inference': 'beam',
'beam_size': 1,
'add_special_tokens': False,
'batchsize': 1,
}
dialogpt = create_agent(opt)
for text, label in test_cases:
dialogpt.observe({'text': text, 'episode_done': True})
response = dialogpt.act()
assert response['text'] == label

@testing_utils.retry(ntries=3, log_retry=True)
def test_dialogpt(self):
"""
Checks that DialoGPT gets a certain performance on the integration test task.
"""
valid, test = testing_utils.train_model(
dict(
task='integration_tests:overfit',
Expand Down
2 changes: 2 additions & 0 deletions tests/nightly/gpu/test_gpt2.py
Expand Up @@ -36,6 +36,8 @@ def test_custom_special_tokens(self):


class TestGpt2(unittest.TestCase):
# Did you implement a test for DialoGPT too if your changes affect it?

def _test_batchsize(self, batchsize, add_start_token):
utterances = [
'Just keep swimming -',
Expand Down