Skip to content

Commit

Permalink
[Torch] Steal Pytorch weights (#310)
Browse files Browse the repository at this point in the history
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
  • Loading branch information
hjjq and yaoyaoding committed Jul 13, 2023
1 parent 9d51c74 commit 692192c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
10 changes: 10 additions & 0 deletions python/hidet/graph/frontend/torch/dynamo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self):
self._print_input_graph: bool = False
self._dump_graph_ir: Optional[str] = None
self._correctness_report: bool = False
self._steal_weights: bool = False

def __getitem__(self, item: str):
assert isinstance(item, str)
Expand All @@ -43,6 +44,7 @@ def reset(self):
self._print_input_graph: bool = False
self._dump_graph_ir: Optional[str] = None
self._correctness_report: bool = False
self._steal_weights: bool = False

def search_space(self, level: int = 2):
"""
Expand Down Expand Up @@ -145,5 +147,13 @@ def correctness_report(self, flag=True):
self._correctness_report = flag
return self

def steal_weights(self, flag=True):
"""
Whether to clear pytorch weights in certain layers after
converting them to Hidet tensors. This will save some GPU memory usage.
"""
self._steal_weights = flag
return self


dynamo_config = DynamoConfig()
7 changes: 6 additions & 1 deletion python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _compute_weight_norm(self, name: str) -> Tensor:
hook = self._get_weight_norm_hook(name)
return hook.compute_weight(self.mod)

def param(self, name: str, optional=False) -> Optional[Tensor]:
def param(self, name: str, optional=False, steal=False) -> Optional[Tensor]:
if name not in self.torch_params:
# see https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
# to learn more about weight norm.
Expand All @@ -208,7 +208,12 @@ def param(self, name: str, optional=False) -> Optional[Tensor]:
self.hidet_params[name] = None
else:
torch_param: torch.Tensor = self.torch_params[name]
if steal:
del self.torch_params[name]
setattr(self.mod, name, None)
self.hidet_params[name] = tensor_from_torch(torch_param)
del torch_param
torch.cuda.empty_cache()
return self.hidet_params[name]


Expand Down
5 changes: 4 additions & 1 deletion python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hidet.graph.tensor import Tensor
from .interpreter import HidetModule, register_module
from . import register_functions as regs
from .dynamo_config import dynamo_config


@register_module(torch.nn.Conv1d)
Expand Down Expand Up @@ -159,7 +160,9 @@ def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
from hidet import ops

self.transposed_weight = ops.transpose(self.param('weight'), [1, 0])
steal = dynamo_config['steal_weights']

self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0])

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
Expand Down

0 comments on commit 692192c

Please sign in to comment.