/
_pytorch_doc.py
46 lines (40 loc) · 1.54 KB
/
_pytorch_doc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99_pytorch_doc.ipynb.
# %% ../nbs/99_pytorch_doc.ipynb 5
from __future__ import annotations
from types import ModuleType
# %% auto 0
__all__ = ['PYTORCH_URL', 'pytorch_doc_link']
# %% ../nbs/99_pytorch_doc.ipynb 6
PYTORCH_URL = 'https://pytorch.org/docs/stable/'
# %% ../nbs/99_pytorch_doc.ipynb 7
def _mod2page(
mod:ModuleType, # A PyTorch module
) -> str:
"Get the webpage name for a PyTorch module"
if mod == Tensor: return 'tensors.html'
name = mod.__name__
name = name.replace('torch.', '').replace('utils.', '')
if name.startswith('nn.modules'): return 'nn.html'
return f'{name}.html'
# %% ../nbs/99_pytorch_doc.ipynb 9
import importlib
# %% ../nbs/99_pytorch_doc.ipynb 10
def pytorch_doc_link(
name:str # Name of a PyTorch module, class or function
) -> (str, None):
"Get the URL to the documentation of a PyTorch module, class or function"
if name.startswith('F'): name = 'torch.nn.functional' + name[1:]
if not name.startswith('torch.'): name = 'torch.' + name
if name == 'torch.Tensor': return f'{PYTORCH_URL}tensors.html'
try:
mod = importlib.import_module(name)
return f'{PYTORCH_URL}{_mod2page(mod)}'
except: pass
splits = name.split('.')
mod_name,fname = '.'.join(splits[:-1]),splits[-1]
if mod_name == 'torch.Tensor': return f'{PYTORCH_URL}tensors.html#{name}'
try:
mod = importlib.import_module(mod_name)
page = _mod2page(mod)
return f'{PYTORCH_URL}{page}#{name}'
except: return None