Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Decoder-Only Transformer #4329

Merged
merged 29 commits into from May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ee1dd65
quick and dirty decoder-only implementation
spencerp Jan 19, 2022
5388bfa
Merge branch 'main' into decoder-only
spencerp Jan 19, 2022
8ac8529
fix decoder_only incremental decoding
spencerp Jan 25, 2022
814b99c
remove unused code, add some comments, propogate func signature change
spencerp Jan 27, 2022
279bb51
consolidate code in decoder.py
spencerp Jan 27, 2022
f36c4be
unify encoder_state
spencerp Jan 27, 2022
5116656
export PassThroughEncoder
spencerp Jan 27, 2022
834bd2a
add missing build_ functions
spencerp Feb 1, 2022
0cfc9c5
defaults in TransformerDecoderLayer __init__
spencerp Feb 1, 2022
11cfc7e
Merge branch 'main' into decoder-only
spencerp Feb 8, 2022
01a46eb
comments, consolidating more logic, simplified forward_layers args
spencerp Feb 8, 2022
7be772e
resize token embeddings and unit test
spencerp Feb 8, 2022
0adf6d8
attempt to suppress some unused import warnings
spencerp Feb 8, 2022
911a513
Merge branch 'main' into decoder-only
spencerp Mar 1, 2022
fbdccd4
Merge branch 'main' into decoder-only
spencerp Mar 3, 2022
b4c2a62
padded_tensor fp16 friendly
spencerp Mar 3, 2022
9251c67
autoformat
spencerp Mar 3, 2022
58e2289
decoder_only -> decoder
spencerp Mar 3, 2022
661f1d3
more documentation
spencerp Mar 3, 2022
e29219e
update name in test
spencerp Mar 3, 2022
1afc309
Merge branch 'main' into decoder-only
spencerp Mar 5, 2022
652123e
add missing dict args
spencerp Mar 5, 2022
c19cb44
more argument massaging
spencerp Mar 5, 2022
d4f5660
Merge branch 'main' into decoder-only
klshuster Mar 25, 2022
cadd468
Merge branch 'main' into decoder-only
spencerp Apr 19, 2022
7685ff2
Merge branch 'main' into decoder-only
spencerp May 3, 2022
17b179e
update TestBartDistillation::test_narrow_distillation_losses numbers
spencerp May 3, 2022
996b9bb
update TestTransformerDistillation::test_narrow_distillation_losses n…
spencerp May 3, 2022
d399a62
fix _pad_tensor in seeker
spencerp May 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion parlai/agents/hugging_face/gpt2.py
Expand Up @@ -318,7 +318,7 @@ def build_model(self, states=None):
def _encoder_input(self, batch):
return (batch.text_vec,)

def _pad_tensor(self, items):
def _pad_tensor(self, items, is_label=False):
"""
Override to always set fp16friendly to False and left_pad to True.
"""
Expand Down
9 changes: 7 additions & 2 deletions parlai/agents/rag/modules.py
Expand Up @@ -89,7 +89,10 @@ def __init__(self, opt, dictionary, retriever_shared=None):
padding_idx=self.pad_idx,
)
self.seq2seq_decoder = self.build_decoder(
opt, embedding=self.embeddings, padding_idx=self.pad_idx
opt,
embedding=self.embeddings,
dictionary=dictionary,
padding_idx=self.pad_idx,
)

@classmethod
Expand Down Expand Up @@ -121,7 +124,9 @@ def build_decoder(
**kwargs,
):
if decoder_class is None:
return RagDecoder(opt=opt, embedding=embedding, n_positions=n_positions)
return RagDecoder(
opt=opt, embedding=embedding, n_positions=n_positions, **kwargs
)
else:
return decoder_class(opt, *args, **kwargs)

Expand Down
95 changes: 95 additions & 0 deletions parlai/agents/transformer/decoder.py
@@ -0,0 +1,95 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

from parlai.agents.transformer.transformer import add_common_cmdline_args
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_generator_agent import TorchGeneratorAgent
from parlai.utils.logging import logging
from parlai.utils.misc import recursive_getattr
from parlai.utils.torch import padded_tensor

from .modules import (
PassThroughEncoder,
TransformerDecoderOnly,
TransformerGeneratorModel,
)


class DecoderAgent(TorchGeneratorAgent):
"""
DecoderOnlyAgent.

Implementation of TorchGeneratorAgent, where the model is a Decoder-Only
Transformer.
"""

@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
"""
Add command-line arguments specifically for this agent.
"""
agent = parser.add_argument_group('Decoder-Only Transformer Arguments')
add_common_cmdline_args(agent)
cls.dictionary_class().add_cmdline_args(parser, partial_opt=partial_opt)

super().add_cmdline_args(parser, partial_opt=partial_opt)
return agent

def build_model(self, states=None):
"""
Override of ``TorchAgent.build_model``.
"""
assert (
self.opt['n_encoder_layers'] == -1
), "Decoder-only model cannot have encoder layers."
wrapped_class = TransformerGeneratorModel.with_components(
encoder=PassThroughEncoder, decoder=TransformerDecoderOnly
)
return wrapped_class(self.opt, self.dict)

def _pad_tensor(self, items, is_label=False):
"""
Override of ``TorchAgent._pad_tensor``.

Pads context tensor on the left and label tensor on the right, such that when
they are concatenated the example meets in the middle to form a continuous
sequence.
"""
return padded_tensor(
items,
pad_idx=self.NULL_IDX,
left_padded=(not is_label),
fp16friendly=self.fp16,
)

def _resize_token_embeddings(self, state_dict, msg=None):
"""
Resize the token embeddings when adding extra special tokens.
"""
# map extra special tokens carefully
new_size = self.model.embeddings.weight.size()[0]
orig_size = state_dict['embeddings.weight'].size()[0]
logging.info(f'Resizing token embeddings from {orig_size} to {new_size}')
if new_size <= orig_size:
# new size should be greater than original size,
# as we are adding special tokens
raise RuntimeError(msg)

for emb_weights in ['embeddings.weight', 'decoder.embeddings.weight']:
# get new_embs
old_embs = state_dict[emb_weights]
new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device)
# copy over old weights
new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
# reset in state dict
state_dict[emb_weights] = new_embs

return state_dict
13 changes: 11 additions & 2 deletions parlai/agents/transformer/modules/__init__.py
Expand Up @@ -11,8 +11,17 @@
)
from .attention import BasicAttention, MultiHeadAttention # noqa: F401
from .ffn import TransformerFFN # noqa: F401
from .encoder import TransformerEncoder, TransformerEncoderLayer # noqa: F401
from .decoder import TransformerDecoder, TransformerDecoderLayer # noqa: F401
from .encoder import ( # noqa: F401
PassThroughEncoder,
TransformerEncoder,
TransformerEncoderLayer,
)
from .decoder import ( # noqa: F401
TransformerDecoder,
TransformerDecoderLayer,
TransformerDecoderOnly,
TransformerDecoderOnlyLayer,
)
from .generator import TransformerGeneratorModel # noqa: F401
from .wrappers import TransformerLinearWrapper, TransformerResponseWrapper # noqa: F401
from .mem_net import TransformerMemNetModel # noqa: F401
21 changes: 10 additions & 11 deletions parlai/agents/transformer/modules/attention.py
Expand Up @@ -15,6 +15,7 @@
import torch.nn.functional as F

from parlai.core.opt import Opt
from parlai.core.params import default
from parlai.utils.torch import neginf


Expand Down Expand Up @@ -98,14 +99,8 @@ def __init__(
):
super(MultiHeadAttention, self).__init__()

def _default(val, default):
"""
shorthand for explicit None check for optional arguments.
"""
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
dim = _default(dim, opt['embedding_size'])
n_heads = default(n_heads, opt['n_heads'])
dim = default(dim, opt['embedding_size'])

self.n_heads = n_heads
self.dim = dim
Expand Down Expand Up @@ -224,9 +219,13 @@ def prepare_head(tensor):
if static_kv:
mask = incr_state['prev_mask']
else:
mask = torch.cat([incr_state['prev_mask'], mask], dim=2)
# Prepend along the key_len dimension (analogous to
# incr_state['prev_key'])
# Mask will be of size (B x query_len x key_len)
# During incremental decoding the query will only represent the next token,
# whereas the key/value will represent the entire sequence thus far.
# In such a case, we only want to look at the last element of the mask in the query dimension.
prev_mask = incr_state['prev_mask'][:, -query_len:, :]
spencerp marked this conversation as resolved.
Show resolved Hide resolved
mask = torch.cat([prev_mask, mask], dim=2)
# Prepend along the key_len dimension (analogous to incr_state['prev_key'])

# Save new incremental states. We reshape to allow for reordering along batch
# dimension.
Expand Down