Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

fix dialogpt dual usage of END_IDX #3256

Merged
merged 8 commits into from Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions parlai/agents/hugging_face/dialogpt.py
Expand Up @@ -26,6 +26,12 @@ class DialoGPTDecoder(GPT2Decoder):
This decoder is initialized with the pretrained model from Hugging Face.
"""

def __init__(self, opt, dict):
super().__init__(opt, dict)
if opt.get('batchsize', 1) == 1 and self.END_IDX == self.NULL_IDX:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is the latter condition not going to be true? if you are inheriting from this model but changing things?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when -bs 1 --add_special_token True ? basically I only want to override the NULL_IDX if it's the same as END_IDX.

# get around the dual usage of end_idx that would otherwise mask endtoken during forward pass.
self.NULL_IDX = -1

def _init_from_pretrained(self, opt):
# load model
model_sz = opt['gpt2_size']
Expand Down
1 change: 0 additions & 1 deletion parlai/agents/hugging_face/dict.py
Expand Up @@ -20,7 +20,6 @@
)

SPECIAL_TOKENS = {"bos_token": "<bos>", "eos_token": "<eos>", "pad_token": "<pad>"}

NO_OP = "x"


Expand Down
1 change: 1 addition & 0 deletions parlai/agents/hugging_face/gpt2.py
Expand Up @@ -108,6 +108,7 @@ def forward(self, input, encoder_state, incr_state=None):
model_input = input[:, -1:]
attention_mask = torch.cat([encoder_state, input], dim=-1) != self.NULL_IDX

model_input = model_input.clamp_(min=0)
transformer_outputs = self.transformer(
model_input,
past=incr_state,
Expand Down
64 changes: 64 additions & 0 deletions tests/nightly/gpu/test_dialogpt.py
Expand Up @@ -6,6 +6,12 @@

import unittest
import parlai.utils.testing as testing_utils
from parlai.core.agents import create_agent
import sys
import warnings

if not sys.warnoptions:
warnings.simplefilter("ignore")
stephenroller marked this conversation as resolved.
Show resolved Hide resolved


@testing_utils.skipUnlessGPU
Expand All @@ -16,6 +22,64 @@ class TestDialogptModel(unittest.TestCase):
Checks that DialoGPT gets a certain performance on the integration test task.
"""

def _test_batchsize(self, batchsize, add_special_tokens):
utterances = [
'How is your day so far?',
'I hope you you have a good day.',
"Nice to meet you. My name is John. ",
"I've got a feeling we're not in Kansas anymore.",
]
opt = {
'model': 'hugging_face/dialogpt',
'gpt2_size': 'small',
'text_truncate': 100,
'label_truncate': 20,
'beam_min_length': 1,
'inference': 'beam',
'beam_size': 1,
'add_special_tokens': add_special_tokens,
'batchsize': batchsize,
'add_start_token': False,
}
dialogpt = create_agent(opt)

results_single = []
agents = [dialogpt.clone() for _ in utterances]
for u, a in zip(utterances, agents):
a.observe({'text': u, 'episode_done': True})
generation = a.act()['text']
results_single.append(generation)

results_batched = []
for idx in range(len(utterances) // batchsize):
agents = [dialogpt.clone() for _ in range(batchsize)]
batch = utterances[idx * batchsize : (idx + 1) * batchsize]
obs = []
for i, a in enumerate(agents):
obs.append(a.observe({'text': batch[i], 'episode_done': True}))
generations = [x['text'] for x in dialogpt.batch_act(obs)]
results_batched += generations

print(f'results_single = {results_single}')
print(f'results_batched = {results_batched}')
assert results_single == results_batched

def test_batchsize(self):
"""
Ensures dialogpt provides the same generation results regardless of batchsize.
"""
for batchsize in [2, 2, 4, 2]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I do not understand, why the batch size 2 is repeated many times?
  • Since you have 4 utterances, I think it is not a bad idea to test with a batch size that results in the last batch being less than batch size (for example 3).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah This is me testing generation consistensy on randomized initialization. The pr is work in progress.

for add_special_tokens in [True]:
if batchsize > 1 and not add_special_tokens:
continue
with self.subTest(
f'test_batchsize with bs={batchsize} and add_special_token={add_special_tokens}'
):
print(
f'_____________test_batchsize with bs={batchsize} and add_special_token={add_special_tokens}'
)
self._test_batchsize(batchsize, add_special_tokens)

@testing_utils.retry(ntries=3, log_retry=True)
def test_dialogpt(self):
valid, test = testing_utils.train_model(
Expand Down