Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fixbug] Fix a bug in IRModule.update_function #124

Merged
merged 2 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 7 additions & 13 deletions gallery/getting-started/quick-start.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,15 @@
)
model = model.cuda().eval()

# we should register the hidet backend for pytorch dynamo
# only need to do this if you import hidet before torch. Otherwise, it is done automatically
hidet.torch.register_dynamo_backends()
# optimize the model with 'hidet' backend
model_opt = torch.compile(model, backend='hidet')

# currently, hidet only support inference
with torch.no_grad():
# optimize the model with 'hidet' backend
model_opt = torch.compile(model, backend='hidet')
# run the optimized model
y1 = model_opt(x)
y2 = model(x)

# run the optimized model
y1 = model_opt(x)
y2 = model(x)

# check the correctness
torch.testing.assert_close(actual=y1, expected=y2, rtol=1e-2, atol=1e-2)
# check the correctness
torch.testing.assert_close(actual=y1, expected=y2, rtol=1e-2, atol=1e-2)


# benchmark the performance
Expand Down
2 changes: 0 additions & 2 deletions gallery/tutorials/optimize-pytorch-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@
import torch.backends.cudnn
import hidet

hidet.torch.register_dynamo_backends() # register hidet backend to torch dynamo

x = torch.randn(1, 3, 224, 224).cuda()
model = torch.hub.load(
'pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False
Expand Down
1 change: 0 additions & 1 deletion python/hidet/cli/bench/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def bench_with_backend(self, backend: str, mode=None, warmup=3, number=10, repea
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

hidet.torch.register_dynamo_backends()
torch.backends.cudnn.allow_tf32 = self.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = self.allow_tf32

Expand Down
8 changes: 7 additions & 1 deletion python/hidet/ir/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,15 @@ def lookup_var(self, name):
return self.global_vars[name]

def update_function(self, func: Function):
from hidet.ir.tools import rewrite

self.functions[func.name] = func
if func.name in self.global_vars:
self.global_vars[func.name].type = func.name, FuncType.from_func(func)
old_var = self.global_vars[func.name]
new_var = Var(func.name, FuncType.from_func(func))
self.global_vars[func.name] = new_var
for name, f in self.functions.items():
self.functions[name] = rewrite(f, {old_var: new_var})

def add(self, name, func: Function):
if name in self.functions:
Expand Down
12 changes: 2 additions & 10 deletions python/hidet/ir/tools/util_functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,7 @@
class MapBasedRewriter(IRRewriter):
def __init__(self, rmap):
super().__init__()
self.rmap = rmap

def visit(self, node):
if node not in self.memo:
if node in self.rmap:
self.memo[node] = self.rmap[node]
else:
self.memo[node] = super().visit(node)
return self.memo[node]
self.memo.update(rmap)


class IRCollector(IRVisitor):
Expand Down Expand Up @@ -102,7 +94,7 @@ def visit_Let(self, e: Let):
return Let(v, self(e.value), self(e.body))


def rewrite(node: Union[Expr, Stmt, tuple], rewrite_map: Mapping[Union[Stmt, Expr], Union[Stmt, Expr]]):
def rewrite(node: Union[Function, Expr, Stmt, tuple], rewrite_map: Mapping[Union[Stmt, Expr], Union[Stmt, Expr]]):
assert isinstance(rewrite_map, dict)
rewriter = MapBasedRewriter(rewrite_map)
return rewriter.rewrite(node)
Expand Down