Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Steal Pytorch weights #310

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/hidet/graph/frontend/torch/dynamo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self):
self._print_input_graph: bool = False
self._dump_graph_ir: Optional[str] = None
self._correctness_report: bool = False
self._steal_weights: bool = False

def __getitem__(self, item: str):
assert isinstance(item, str)
Expand All @@ -43,6 +44,7 @@ def reset(self):
self._print_input_graph: bool = False
self._dump_graph_ir: Optional[str] = None
self._correctness_report: bool = False
self._steal_weights: bool = False

def search_space(self, level: int = 2):
"""
Expand Down Expand Up @@ -145,5 +147,13 @@ def correctness_report(self, flag=True):
self._correctness_report = flag
return self

def steal_weights(self, flag=True):
"""
Whether to clear pytorch weights in certain layers after
converting them to Hidet tensors. This will save some GPU memory usage.
"""
self._steal_weights = flag
return self


dynamo_config = DynamoConfig()
7 changes: 6 additions & 1 deletion python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _compute_weight_norm(self, name: str) -> Tensor:
hook = self._get_weight_norm_hook(name)
return hook.compute_weight(self.mod)

def param(self, name: str, optional=False) -> Optional[Tensor]:
def param(self, name: str, optional=False, steal=False) -> Optional[Tensor]:
if name not in self.torch_params:
# see https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
# to learn more about weight norm.
Expand All @@ -208,7 +208,12 @@ def param(self, name: str, optional=False) -> Optional[Tensor]:
self.hidet_params[name] = None
else:
torch_param: torch.Tensor = self.torch_params[name]
if steal:
del self.torch_params[name]
setattr(self.mod, name, None)
self.hidet_params[name] = tensor_from_torch(torch_param)
del torch_param
torch.cuda.empty_cache()
return self.hidet_params[name]


Expand Down
5 changes: 4 additions & 1 deletion python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hidet.graph.tensor import Tensor
from .interpreter import HidetModule, register_module
from . import register_functions as regs
from .dynamo_config import dynamo_config


@register_module(torch.nn.Conv1d)
Expand Down Expand Up @@ -159,7 +160,9 @@ def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
from hidet import ops

self.transposed_weight = ops.transpose(self.param('weight'), [1, 0])
steal = dynamo_config['steal_weights']

self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0])

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
Expand Down