Skip to content

Commit

Permalink
Enable quantization for BART
Browse files Browse the repository at this point in the history
  • Loading branch information
katalinic-gc committed Aug 3, 2023
1 parent ba0816f commit df0dd8b
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion optimum/graphcore/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,33 @@ def change_decoder_positional_embedding(self, restore: bool):
else IPUBartLearnedPositionalEmbedding.from_model(position_embedding)
)

def quantize_linear_layers(self, restore: bool, num_groups: int = 16):
if restore:
return

from ...quantization.group_quantize import GroupQuantLinear

logger.info("Group quantizing linear layers")
for module in self.encoder.layers:
module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups)
module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups)
module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups)
module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups)
module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups)
module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups)

for module in self.decoder.layers:
module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups)
module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups)
module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups)
module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups)
module.encoder_attn.q_proj = GroupQuantLinear.from_model(module.encoder_attn.q_proj, num_groups)
module.encoder_attn.k_proj = GroupQuantLinear.from_model(module.encoder_attn.k_proj, num_groups)
module.encoder_attn.v_proj = GroupQuantLinear.from_model(module.encoder_attn.v_proj, num_groups)
module.encoder_attn.out_proj = GroupQuantLinear.from_model(module.encoder_attn.out_proj, num_groups)
module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups)
module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups)

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -790,6 +817,7 @@ def parallelize(self, for_generation=False, use_cache=False, **kwargs):
self.change_lm_head_to_indexed_input_linear(restore=not (for_generation and not use_cache))
self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False)
self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0))
self.model.quantize_linear_layers(restore=not kwargs.get("use_group_quantized_linears", False), num_groups=16)

self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0)
self.model.encoder.embed_positions = poptorch.BeginBlock(
Expand Down Expand Up @@ -943,7 +971,7 @@ def forward(

@register(BartForSequenceClassification)
class PipelinedBartForSequenceClassification(BartForSequenceClassification, PipelineMixin):
def parallelize(self):
def parallelize(self, **kwargs):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
Expand All @@ -960,6 +988,7 @@ def parallelize(self):
self.model.encoder_and_decoder_embeddings_computation(use_shared_embedding=True)
self.model.change_bart_encoder_and_decoder_classes(restore=False)
self.model.change_bart_attention_class(restore=False)
self.model.quantize_linear_layers(restore=not kwargs.get("use_group_quantized_linears", False), num_groups=16)

logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
Expand Down

0 comments on commit df0dd8b

Please sign in to comment.