Skip to content

Commit

Permalink
Convert matmuls to quantizable nn.Linear modules (#1304)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1304

Pull Request resolved: pytorch/translate#657

Pull Request resolved: facebookresearch/pytext#1065

Pull Request resolved: fairinternal/fairseq-py#889

We are converting matmuls to quantizable nn.Linear modules in this diff. First let's test profile after the diff to see how low level operations are changing.

Reviewed By: jmp84, edunov, lly-zero-one, jhcross

Differential Revision: D17964796

fbshipit-source-id: 3ddd3ff81fa1ea5864dded98e993f4fe3b71fe5e
  • Loading branch information
halilakin authored and facebook-github-bot committed Oct 25, 2019
1 parent fdf4c3e commit c07362c
Showing 1 changed file with 42 additions and 53 deletions.
95 changes: 42 additions & 53 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,9 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'

self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

Expand All @@ -71,25 +66,30 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
@property
def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj_weight, self.k_proj_weight, self.v_proj_weight))
return torch.cat((self.q_proj.weight, self.k_proj.weight, self.v_proj.weight))

@property
def in_proj_bias(self):
# TODO: Remove this backward compatibility code (in_proj_bias)
return torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias))

def prepare_for_onnx_export_(self):
self.onnx_trace = True

def reset_parameters(self):
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.k_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj_weight, gain=1/math.sqrt(2))
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1/math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)

nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
Expand Down Expand Up @@ -139,9 +139,9 @@ def forward(
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight)

if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
Expand All @@ -155,23 +155,23 @@ def forward(
saved_state = None

if self.self_attention:
q = self.in_proj_q(query)
k = self.in_proj_k(query)
v = self.in_proj_v(query)
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
k = self.k_proj(key)
v = self.v_proj(key)

else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling

if self.bias_k is not None:
Expand Down Expand Up @@ -284,26 +284,6 @@ def forward(

return attn, attn_weights

def in_proj_q(self, query):
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)

def in_proj_k(self, key):
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)

def in_proj_v(self, value):
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)

def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
Expand Down Expand Up @@ -341,12 +321,21 @@ def upgrade_state_dict_named(self, state_dict, name):
if k.endswith(prefix + 'in_proj_weight'):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj_weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj_weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj_weight'] = state_dict[k][2*dim:]
items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2*dim:]

keys_to_remove.append(k)

k_bias = prefix + 'in_proj_bias'
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj.bias'] = state_dict[k_bias][:dim]
items_to_add[prefix + 'k_proj.bias'] = state_dict[k_bias][dim:2*dim]
items_to_add[prefix + 'v_proj.bias'] = state_dict[k_bias][2*dim:]

keys_to_remove.append(prefix + 'in_proj_bias')

for k in keys_to_remove:
del state_dict[k]

Expand Down

0 comments on commit c07362c

Please sign in to comment.