Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Refactor model rewrite in modular way #1660

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions pytext/task/accelerator_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,20 +243,6 @@ def forward(
return rep, (new_hidden, new_cell)


# Swap a transformer for only RoBERTaEncoder encoders
def swap_modules_for_accelerator(model):
if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder):
old_transformer = model.encoder.encoder.transformer
model.encoder.encoder.transformer = AcceleratorTransformer(old_transformer)
return model
elif hasattr(model, "representation") and isinstance(model.representation, BiLSTM):
old_biLSTM = model.representation
model.representation = AcceleratorBiLSTM(old_biLSTM)
return model
else:
return model


def lower_modules_to_accelerator(
model: nn.Module, trace, export_options: ExportConfig, throughput_optimize=False
):
Expand Down Expand Up @@ -321,3 +307,13 @@ def lower_modules_to_accelerator(
return trace
else:
return trace


def nnpi_rewrite_roberta_transformer(model):
model.encoder.encoder.transformer = AcceleratorTransformer(
model.encoder.encoder.transformer
)


def nnpi_rewrite_bilstm(model):
model.representation = AcceleratorBiLSTM(model.representation)
12 changes: 4 additions & 8 deletions pytext/task/cuda_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,7 @@ def forward(self, tokens: Tensor) -> List[Tensor]:
return states


# Swap a transformer for only RoBERTaEncoder encoders
def swap_modules_for_faster_transformer(model):
if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder):
old_transformer = model.encoder.encoder.transformer
model.encoder.encoder.transformer = NVFasterTransformerEncoder(old_transformer)
return model
else:
return model
def cuda_rewrite_roberta_transformer(model):
model.encoder.encoder.transformer = NVFasterTransformerEncoder(
model.encoder.encoder.transformer
)
62 changes: 58 additions & 4 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from pytext.data.tensorizers import Tensorizer
from pytext.metric_reporters import MetricReporter
from pytext.models.model import BaseModel
from pytext.models.representations.bilstm import BiLSTM
from pytext.models.roberta import RoBERTaEncoder
from pytext.trainers import TaskTrainer, TrainingState
from pytext.utils import cuda, onnx, precision
from pytext.utils.file_io import PathManager
Expand All @@ -25,10 +27,11 @@

from .accelerator_lowering import (
lower_modules_to_accelerator,
swap_modules_for_accelerator,
nnpi_rewrite_roberta_transformer,
nnpi_rewrite_bilstm,
)
from .cuda_lowering import (
swap_modules_for_faster_transformer,
cuda_rewrite_roberta_transformer,
)
from .quantize import (
quantize_statically,
Expand All @@ -44,6 +47,57 @@
)


MODULE_TO_REWRITER = {
"nnpi": {
RoBERTaEncoder: nnpi_rewrite_roberta_transformer,
BiLSTM: nnpi_rewrite_bilstm,
},
"cuda": {
RoBERTaEncoder: cuda_rewrite_roberta_transformer,
},
}


def find_module_instances(model, module_type, cur_path):
"""
Finds all module instances of the specified type and returns the paths to get to each of
those instances
"""
if isinstance(model, module_type):
yield list(cur_path) # copy the list since cur_path is a shared list
for attr in dir(model):
if (
attr[0] == "_" or len(cur_path) > 4
): # avoids infinite recursion and exploring unnecessary paths
continue
cur_path.append(attr)
# recursively yield
yield from find_module_instances(getattr(model, attr), module_type, cur_path)
cur_path.pop()


def rewrite_transformer(model, module_path, rewriter):
"""
Descends model hierarchy according to module_path and calls the rewriter at the end
"""
for prefix in module_path[:-1]:
model = getattr(model, prefix)
rewriter(model)


def swap_modules(model, module_to_rewriter):
"""
Finds modules within a model that can be rewritten and rewrites them with predefined
rewrite functions
"""
for module in module_to_rewriter:
instance_paths = find_module_instances(model, module, [])
rewriter = module_to_rewriter[module]
for path in instance_paths:
rewrite_transformer(model, path, rewriter)
return model


def create_schema(
tensorizers: Dict[str, Tensorizer], extra_schema: Optional[Dict[str, Type]] = None
) -> Schema:
Expand Down Expand Up @@ -346,7 +400,7 @@ def torchscript_export(
optimizer.pre_export(model)

if use_nnpi or use_fx_quantize:
model = swap_modules_for_accelerator(model)
model = swap_modules(model, MODULE_TO_REWRITER["nnpi"])

# Trace needs eval mode, to disable dropout etc
model.eval()
Expand Down Expand Up @@ -392,7 +446,7 @@ def torchscript_export(
precision.FP16_ENABLED = True
cuda.CUDA_ENABLED = True

model = swap_modules_for_faster_transformer(model)
model = swap_modules(model, MODULE_TO_REWRITER["cuda"])
model.eval()
model.half().cuda()
# obtain new inputs with cuda/fp16 enabled.
Expand Down