-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
convert actgan params to pure scalars for model creation
* convert actgan params to pure scalars for model creation * add modified torch unpickling to actgan directly GitOrigin-RevId: ed4687ef4682200a311ea9d2563c8e0f42174fc2
- Loading branch information
1 parent
805da13
commit cc933ea
Showing
6 changed files
with
56 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Utils for PyTorch | ||
""" | ||
import pickle | ||
|
||
from io import BytesIO | ||
from typing import BinaryIO | ||
|
||
import torch | ||
|
||
|
||
def determine_device() -> str: | ||
""" | ||
Returns device on which generation should run. | ||
""" | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
return device | ||
|
||
|
||
def patched_torch_unpickle(file_handle: BinaryIO, device: str) -> object: | ||
# https://github.com/pytorch/pytorch/issues/16797#issuecomment-777059657 | ||
|
||
unpickler = _PyTorchPatchedUnpickler(file_handle, map_location=device) | ||
return unpickler.load() | ||
|
||
|
||
class _PyTorchPatchedUnpickler(pickle.Unpickler): | ||
def __init__(self, *args, map_location: str, **kwargs): | ||
self._map_location = map_location | ||
super().__init__(*args, **kwargs) | ||
|
||
def find_class(self, module, name): | ||
if module == "torch.storage" and name == "_load_from_bytes": | ||
return _load_with_map_location(self._map_location) | ||
else: | ||
return super().find_class(module, name) | ||
|
||
|
||
def _load_with_map_location(map_location: str) -> callable: | ||
return lambda b: torch.load(BytesIO(b), map_location=map_location) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
faker==15.3.3 | ||
faker==4.1.1 | ||
flake8==4.0.1 | ||
numpy>=1.18.0 | ||
pandas>=1.1.0 | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters