Skip to content

Commit

Permalink
[dynamo] support enable_tf32 in byteir_backend (#375)
Browse files Browse the repository at this point in the history
as title
  • Loading branch information
qingyunqu committed Jun 26, 2024
1 parent e1c6a15 commit d6f21e2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion compiler/lib/Pipelines/ByreTensorOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc,
pm.addNestedPass<func::FuncOp>(
createConvertHloToByreCustomPass(getCudaByreCustomConfig()));
pm.addNestedPass<func::FuncOp>(
createConvertHloToByreTensorPass(appendArgTypes));
createConvertHloToByreTensorPass(appendArgTypes, enableTF32));
pm.addPass(createCanonicalizerPass());
}
} // namespace
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

byteir_enable_tf32 = os.environ.get("BYTEIR_ENABLE_TF32") == "1"

byteir_not_use_cache = os.environ.get("BYTEIR_NOT_USE_CACHE") == "1"

# TODO. default not save fx graph.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
log = logging.getLogger(__name__)
g_graph_counter = count(0)

BACKEND_LEGAL_OPS = ["aten.max.dim"]


#@dynamo_utils.dynamo_timed(phase_name="byteir_compile")
def inner_compile(gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -95,14 +93,15 @@ def inner_compile(gm: torch.fx.GraphModule,
module = torch_frontend.compile_dynamo_model(
gm,
output_type="stablehlo",
backend_legal_ops=BACKEND_LEGAL_OPS)
backend_legal_ops=[])
with open(stablehlo_file, "w") as f:
print(module.operation.get_asm(), file=f)
if not os.path.exists(byre_file):
byteir.compile(stablehlo_file,
byre_file,
verbose=False,
target="cuda")
target="cuda",
enable_tf32=config.byteir_enable_tf32)
#byteir.compile(stablehlo_file, byre_file, verbose=False, target="cuda_with_ait")

byre_session = brt.Session(alloc_func=caching_allocator_alloc,
Expand Down

0 comments on commit d6f21e2

Please sign in to comment.