Skip to content

Commit

Permalink
Fully quantize Fairseq transformer (#1993)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1993

F.linear -> nn.Linear so FBGEMM backend could quantize the linear projection. We observed 3x+ speedup.

Add backward compatibility code.

Reviewed By: jhcross

Differential Revision: D20967830

fbshipit-source-id: 11d2c98dd5c1965691d6df433e8428499c9c4dc0
  • Loading branch information
cndn authored and facebook-github-bot committed Apr 19, 2020
1 parent 57526c6 commit 6379573
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions fairseq/models/transformer.py
Expand Up @@ -669,11 +669,6 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(
torch.Tensor(len(dictionary), self.output_embed_dim)
)
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)

if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
Expand All @@ -686,6 +681,16 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
else:
self.layernorm_embedding = None

if self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False
)
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5)

def build_decoder_layer(self, args, no_encoder_attn=False):
return TransformerDecoderLayer(args, no_encoder_attn)

Expand Down Expand Up @@ -852,10 +857,7 @@ def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
return self.output_projection(features)
else:
return features

Expand Down Expand Up @@ -890,6 +892,18 @@ def upgrade_state_dict_named(self, state_dict, name):
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)

embed_tokens_weights_key = f"{name}.embed_tokens.weights"
embed_out_key = f"{name}.embed_out"
if embed_tokens_weights_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_tokens_weights_key
]
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
del state_dict[embed_out_key]

for i in range(self.num_layers):
# update layer norms
layer_norm_map = {
Expand Down

0 comments on commit 6379573

Please sign in to comment.