Skip to content

Commit

Permalink
[Fixbug] Fix a bug in IRModule.update_function (#124)
Browse files Browse the repository at this point in the history
* fix a bug in IRModule.update_function

* lint
  • Loading branch information
yaoyaoding committed Mar 2, 2023
1 parent d9d6ebc commit 60dabb7
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 27 deletions.
20 changes: 7 additions & 13 deletions gallery/getting-started/quick-start.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,15 @@
)
model = model.cuda().eval()

# we should register the hidet backend for pytorch dynamo
# only need to do this if you import hidet before torch. Otherwise, it is done automatically
hidet.torch.register_dynamo_backends()
# optimize the model with 'hidet' backend
model_opt = torch.compile(model, backend='hidet')

# currently, hidet only support inference
with torch.no_grad():
# optimize the model with 'hidet' backend
model_opt = torch.compile(model, backend='hidet')
# run the optimized model
y1 = model_opt(x)
y2 = model(x)

# run the optimized model
y1 = model_opt(x)
y2 = model(x)

# check the correctness
torch.testing.assert_close(actual=y1, expected=y2, rtol=1e-2, atol=1e-2)
# check the correctness
torch.testing.assert_close(actual=y1, expected=y2, rtol=1e-2, atol=1e-2)


# benchmark the performance
Expand Down
2 changes: 0 additions & 2 deletions gallery/tutorials/optimize-pytorch-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@
import torch.backends.cudnn
import hidet

hidet.torch.register_dynamo_backends() # register hidet backend to torch dynamo

x = torch.randn(1, 3, 224, 224).cuda()
model = torch.hub.load(
'pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False
Expand Down
1 change: 0 additions & 1 deletion python/hidet/cli/bench/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def bench_with_backend(self, backend: str, mode=None, warmup=3, number=10, repea
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

hidet.torch.register_dynamo_backends()
torch.backends.cudnn.allow_tf32 = self.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = self.allow_tf32

Expand Down
8 changes: 7 additions & 1 deletion python/hidet/ir/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,15 @@ def lookup_var(self, name):
return self.global_vars[name]

def update_function(self, func: Function):
from hidet.ir.tools import rewrite

self.functions[func.name] = func
if func.name in self.global_vars:
self.global_vars[func.name].type = func.name, FuncType.from_func(func)
old_var = self.global_vars[func.name]
new_var = Var(func.name, FuncType.from_func(func))
self.global_vars[func.name] = new_var
for name, f in self.functions.items():
self.functions[name] = rewrite(f, {old_var: new_var})

def add(self, name, func: Function):
if name in self.functions:
Expand Down
12 changes: 2 additions & 10 deletions python/hidet/ir/tools/util_functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,7 @@
class MapBasedRewriter(IRRewriter):
def __init__(self, rmap):
super().__init__()
self.rmap = rmap

def visit(self, node):
if node not in self.memo:
if node in self.rmap:
self.memo[node] = self.rmap[node]
else:
self.memo[node] = super().visit(node)
return self.memo[node]
self.memo.update(rmap)


class IRCollector(IRVisitor):
Expand Down Expand Up @@ -102,7 +94,7 @@ def visit_Let(self, e: Let):
return Let(v, self(e.value), self(e.body))


def rewrite(node: Union[Expr, Stmt, tuple], rewrite_map: Mapping[Union[Stmt, Expr], Union[Stmt, Expr]]):
def rewrite(node: Union[Function, Expr, Stmt, tuple], rewrite_map: Mapping[Union[Stmt, Expr], Union[Stmt, Expr]]):
assert isinstance(rewrite_map, dict)
rewriter = MapBasedRewriter(rewrite_map)
return rewriter.rewrite(node)
Expand Down

0 comments on commit 60dabb7

Please sign in to comment.