Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv #979

Closed
AlphaBetaGamma96 opened this issue Jul 22, 2022 · 24 comments

Comments

@AlphaBetaGamma96
Copy link

AlphaBetaGamma96 commented Jul 22, 2022

Hi All,

I was running an older version of PyTorch ( - built from source) with FuncTorch ( - built from source), and somehow I've broken the older version of functorch. When I import functorch I get the following error,

import functorch
#returns ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

The version I had of functorch was 0.2.0a0+9d6ee76, is there a way to perhaps re-install to fix this ImportError? I do have the latest version of PyTorch/FuncTorch in a separate conda environment but I wanted to check how it compares to the older version in this 'older' conda environment PyTorch/Functorch were versions ,1.12.0a0+git7c2103a and 0.2.0a0+9d6ee76 respectively.

Is there a way to download a specific version of functorch with https://github.com/pytorch/functorch.git ? Or another way to fix this issue?

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

@AlphaBetaGamma96 you can check the date for your functorch commit and install pytorch nightly for the previous day (if it exists) and it should be working.

@AlphaBetaGamma96
Copy link
Author

Ah, I forgot to mention above the pytorch version I have is built from source because it had a particular formula for slogdet_backward which has been fixed now in the nightly. So, if I want to keep my current pytorch install and install an older version of functorch is that not possible? Apologizes for the silly question, I'm don't 100% follow! 😅

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

Thanks for confiriming, I was started thinking about that when reread your issue several times more. I think it could be complicated to compile and run old functorch on recent pytorch ...

@AlphaBetaGamma96
Copy link
Author

AlphaBetaGamma96 commented Jul 22, 2022

The funny thing is, it was working fine yesterday I've somehow broken it. I assume it's only possible to install the most recent version of FuncTorch and not previous versions like with PyTorch?

EDIT: I've got a fresh conda environment installed with the latest of both PyTorch (nightly) and FuncTorch. I did just notice that my code now runs significantly slower (~5x - 10x slower) and I wanted to confirm this but the old conda env is somehow broken. I did try re-installing functorch via pip but it returns a invalid version error.

ERROR: Could not find a version that satisfies the requirement torch>=1.13.0.dev (from functorch) (from versions: 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0)

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

I've got a fresh conda environment installed with the latest of both PyTorch (nightly) and FuncTorch. I did just notice that my code now runs significantly slower (~5x - 10x slower)

oh, that's bad if it confirms.

I did try re-installing functorch via pip but it returns a invalid version error.

How do you exactly install pytorch ? Maybe, --pre is missing for pip ? Probably, you have to install pytorch first and next functorch such requirements are satisfied...

I assume it's only possible to install the most recent version of FuncTorch and not previous versions like with PyTorch?

Maybe, you can manually revert the commit that does old -> new in Functorch and thus you can tests it with your locally updated recent pytorch ?

@AlphaBetaGamma96
Copy link
Author

AlphaBetaGamma96 commented Jul 22, 2022

oh, that's bad if it confirms.

I do get quite a lot of spam from PyTorch/FuncTorch saying batching rules don't exist for these functions: aten::_linalg_slogdet, aten::_linalg_solve_ex, and aten::linalg_solve (which might explain the slowdown?). Also, torch.lu throws a deprecation error which I assume comes from torch.linalg.slogdet as that's solved via an LU-decomposition.

~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/functional.py:1643: UserWarning: torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be removed in a future PyTorch release.
LU, pivots = torch.lu(A, compute_pivots)
should be replaced with
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
and
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
should be replaced with
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) (Triggered internally at /opt/conda/conda-bld/pytorch_1658387591415/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2091.)
  return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))

How do you exactly install pytorch ? Maybe, --pre is missing for pip ? Probably, you have to install pytorch first and next functorch such requirements are satisfied...

So checking how installed it was as follows (from a clean conda env),

conda create --name pytorch_nightly
conda activate pytorch_nightly
conda install pytorch torchvision torchaudio cuda=11.6 -c pytorch-nightly -c nvidia
pip install ninja
pip install --user "git+https://github.com/pytorch/functorch.git"

So perhaps missing the --pre is problematic? What exactly does that do?

Maybe, you can manually revert the commit that does old -> new in Functorch and thus you can tests it with your locally updated recent pytorch ?

The thing is, I'm not 100% sure what I did that caused it to break because I only changed the definition of slogdet_backward in FunctionManuals.cpp and then did python setup.py install so I'm not entirely sure. I assume the error message is stating something's missing from a future version of PyTorch in FuncTorch?

EDIT: I could always do a new conda environment with the --pre flag and see if that changes anything?

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

So perhaps missing the --pre is problematic? What exactly does that do?

--pre is for pip to install pytorch nightly: pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116

Can you please give steps to repro your error ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv ?
So, I assume I have to install pytorch 1.12.0a0+git7c2103a and functorch 0.2.0a0+9d6ee76 ?

@AlphaBetaGamma96
Copy link
Author

--pre is for pip to install pytorch nightly: pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116

Ah ok, I followed the nightly install for conda from the getting started page. Is that ok? Or should it all be install via pip?

Can you please give steps to repro your error ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

I get that error when I import functorch.

Python 3.9.7 (default, Sep 16 2021, 13:09:58) 
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import functorch
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/.local/lib/python3.9/site-packages/functorch/__init__.py", line 7, in <module>
    from . import _C
ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

So, I assume I have to install pytorch 1.12.0a0+git7c2103a and functorch 0.2.0a0+9d6ee76 ?

My version are as follows,
torch: '1.13.0.dev20220721'
functorch: '0.3.0a0+b567f78'

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

I installed pytorch and functorch following your commands above and

(pytorch_nightly) root@user:/tmp# conda list | grep torch
# packages in environment at /opt/conda/envs/pytorch_nightly:
functorch                 0.3.0a0+e8a68f4          pypi_0    pypi
pytorch                   1.13.0.dev20220722 py3.9_cuda11.6_cudnn8.3.2_0    pytorch-nightly
pytorch-mutex             1.0                        cuda    pytorch-nightly
torchvision               0.14.0.dev20220722      py39_cu116    pytorch-nightly
(pytorch_nightly) root@user:/tmp# 
(pytorch_nightly) root@user:/tmp# 
(pytorch_nightly) root@user:/tmp# python -c "import torch; import functorch"
(pytorch_nightly) root@user:/tmp# 

Looks like it works with dev20220722

@AlphaBetaGamma96
Copy link
Author

Ah sorry, give me one sec I've got confused between the two conda environments. What I stated above is for the new nightly conda env which runs (but is 10x slower), let me find out the versions for the other pytorch install which fails for functorch. Sorry!

@AlphaBetaGamma96
Copy link
Author

AlphaBetaGamma96 commented Jul 22, 2022

@vfdev-5
Here were the version for torch and functorch but functorch still has the aforementioned issue!

PyTorch version:    1.12.0a0+git7c2103a
CUDA version:       11.6
FuncTorch version:  0.2.0a0+9d6ee76

The installation was done from source, and I installed on the 4th of March from what I can find.

EDIT: It seems that the new version of PyTorch/FuncTorch is 5x to 10x slower than an older version I had. I have some old datafile which have the walltime per epoch and it does seem to confirm my initial thoughts that its significantly slower by about a factor of 5x to 10x.

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

It seems that the new version of PyTorch/FuncTorch is 5x to 10x slower than an older version I had. I have some old datafile which have the walltime per epoch and it does seem to confirm my initial thoughts that its significantly slower by about a factor of 5x to 10x.

Let's open a new issue for that. Can you point out which ops exactly are slower, LU related ?

As for PyTorch version: 1.12.0a0+git7c2103a and FuncTorch version: 0.2.0a0+9d6ee76, they seem to be in 5 days range: torch is Mar 19 and functorch is Mar 25.

EDIT:
I confirm that pytorch 1.12.0a0+git7c2103a and functorch 0.2.0a0+9d6ee76 are working:

root@qgpu1:/tmp# pip list | grep torch
functorch          0.2.0a0+9d6ee76     /tmp/functorch
torch              1.12.0a0+git7c2103a /pytorch      
torchvision        0.13.0a0+01b0a00    /tmp/vision   
:/tmp# python -c "import torch; import functorch"
[W OperatorEntry.cpp:127] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)
    registered at aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: FuncTorchBatched
  previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10310
       new kernel: registered at /tmp/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)

So, you said that you updated pytorch locally, slogdet_backward and after that functorch is not working anymore ?

@AlphaBetaGamma96
Copy link
Author

AlphaBetaGamma96 commented Jul 22, 2022

I modified slogdet_backward what I changed list here: https://github.com/AlphaBetaGamma96/pytorch/blob/custom_slogdet_backward_attempt/torch/csrc/autograd/FunctionsManual.cpp#L3433-3497 it was a naive attempt by myself to get round the .item issue for vmap.

The thing is, pytorch and functorch worked with my modified slogdet_backward that hows I got the walltime for the older version but I can't figure out what I've done to affect functorch like this?

EDIT: Would it be possible to re-install functorch from version 0.2.0a0+9d6ee76 ? Perhaps a simple uninstall and re-install might fix this issue?

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

Would it be possible to re-install functorch from version 0.2.0a0+9d6ee76 ? Perhaps a simple uninstall and re-install might fix this issue?

yes, maybe, please try and keep us updated

@AlphaBetaGamma96
Copy link
Author

Silly question, but how can I install a specific version of functorch?

The basic install states

pip install --user "git+https://github.com/pytorch/functorch.git"

I assume it's similar except functorch.git changes to functorch0.2.0a0+9dd6ee76.git`? (or something similar)

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 22, 2022

You can clone the repository and fetch the commit and then python setup.py develop

If using pip: pip install git+https://github.com/pytorch/functorch.git@9d6ee76

@AlphaBetaGamma96
Copy link
Author

AlphaBetaGamma96 commented Jul 23, 2022

Ok I uninstalled functorch re-installed functorch.git@9d6ee76. However, I noticed that the version I uninstalled was Found existing installation: functorch 0.3.0a0+b567f78. So perhaps I installed a newer version and forgot to change conda envs?

When installing the older version, the error goes away however. I get a new error that I haven't seen before.

Traceback (most recent call last):
  File "~/workdir/run_main.py", line 263, in <module>
    output_dict = load_model(model_path, device, net, optim, sampler)
  File "~/workdir/utils.py", line 56, in load_model
    state_dict = torch.load(f=model_path, map_location=device)
  File "~/anaconda3/envs/pytorch_from_source/lib/python3.9/site-packages/torch/serialization.py", line 712, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "~/anaconda3/envs/pytorch_from_source/lib/python3.9/site-packages/torch/serialization.py", line 1046, in _load
    result = unpickler.load()
  File "~/anaconda3/envs/pytorch_from_source/lib/python3.9/site-packages/torch/serialization.py", line 1039, in find_class
    return super().find_class(mod_name, name)
AttributeError: Can't get attribute '_rebuild_parameter_v2' on <module 'torch._utils' from '~/anaconda3/envs/pytorch_from_source/lib/python3.9/site-packages/torch/_utils.py'>

The function called is as follows. It just loads from a checkpoint and returns a few nn.Module objects. But the main issue seems to be torch.load.

def load_model(model_path: str, device: torch.device, net: nn.Module,
               optim: torch.optim.Optimizer, sampler: nn.Module) -> dict:
  r"""A function to load in an object saved from `torch.save` if the file exists already. The method returns a dict 
  """
  if(os.path.isfile(model_path)):
    print("Model already exists %s - transferring" % (model_path))
    state_dict = torch.load(f=model_path, map_location=device)

    start=state_dict['epoch']+1 #start at next epoch
    net.load_state_dict(state_dict['model_state_dict'])
    optim.load_state_dict(state_dict['optim_state_dict'])
    optim._steps = start        #update epoch in optim too!
    loss = state_dict['loss']
    sampler.chains = state_dict['chains']
    print("Model resuming at epoch %6i with energy %6.4f MeV" % (start, loss))
  else:
    print("Saving model to %s - new" % (model_path))
    start=0
  return {'start':start, 'device':device, 'net':net, 'optim':optim, 'sampler':sampler}

EDIT: I've resolved this issue. It happens if a save file from one version of pytorch is loaded in another. I'll run the code for both versions and see how the walltime compares between versions!

@AlphaBetaGamma96
Copy link
Author

AlphaBetaGamma96 commented Jul 25, 2022

So, I've got both conda env working fine and I've compared running my code in both environments.

The "source" version of pytorch/functorch is
pytorch 1.13.0.dev20220721
functorch 0.3.0a0+e8a68f4

The "nightly" version of pytorch/functorch is
pytorch 1.12.0a0+git7c2103a
functorch 0.2.0a0+9dd6ee76

Here are some results below. N here is the number of input nodes so I have something like an N x N matrix I take the determinant of and calculate its Hessian. I'll try and get a minimal reproducible example to solve this issue as I have an idea what's causing the issue.

N: 2 | source: 0.7472s | nightly: 4.3208s
N: 3 | source: 0.8020s | nightly: 5.4278s
N: 4 | source: 0.8688s | nightly: 5.7715s
N: 5 | source: 0.9874s | nightly: 6.7180s
N: 6 | source: 1.0704s | nightly: 8.1426s

Also, when I went to check the functorch version for the "source" version of pytorch/functorch. I got the same error again,

>>> import functorch; functorch.__version__
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/.local/lib/python3.9/site-packages/functorch/__init__.py", line 7, in <module>
    from . import _C
ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

I did install functorch via pip install git+https://github.com/pytorch/functorch.git@9d6ee76, however, when checking my version of functorch via pip list I see functorch is version 0.3.0a0+e8a68f4 which is odd because I specified an older version. I'll uninstall it again and re-try and see if the issue persists.

EDIT: I've uninstalled functorch version 0.3.0a0+e8a68f4. However, before I went to install version 0.2.0 somehow the older install functorch was there? It seems that when I uninstalled 0.3.0, it instantly got replaced with 0.2.0. That sounds a bit strange to me?

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 25, 2022

N: 2 | source: 0.7472s | nightly: 4.3208s
N: 3 | source: 0.8020s | nightly: 5.4278s
N: 4 | source: 0.8688s | nightly: 5.7715s
N: 5 | source: 0.9874s | nightly: 6.7180s
N: 6 | source: 1.0704s | nightly: 8.1426s

According to the numbers and labels, "source" is faster then "nightly". Label "source" refers to a more recent functorch/pytorch. So, I'd say it is good that newer versions are faster then older. Am I missing something ?

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 25, 2022

When using pip to install a package, please ensure that pip cmd work with the appropriate conda env (pip --version, it should show the path to the correct python in correct env) and use --upgrade to reinstall new/old package: pip install --upgrage git+https://github.com/pytorch/functorch.git@9d6ee76

@AlphaBetaGamma96
Copy link
Author

AlphaBetaGamma96 commented Jul 25, 2022

According to the numbers and labels, "source" is faster then "nightly". Label "source" refers to a more recent functorch/pytorch. So, I'd say it is good that newer versions are faster then older. Am I missing something ?

source is the older version, and nightly is the newer version. I probably should've labelled the names better, but I have an older version of pytorch from source and the latest version of pytorch nightly

EDIT: So, the latest version of pytorch is significantly slower.

@AlphaBetaGamma96
Copy link
Author

So, I ran the profiler on a single epoch of my code and I've found the issue with why the later version is so slow. I ran my code within this context manager,
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: and print the results below via print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)).

This may be more appropriate in a different thread, if so I'll open a new one.

Here is the source older PyTorch/FuncTorch version

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::linalg_norm         0.01%     252.000us        10.15%     191.767ms       7.990ms       0.000us         0.00%     118.833ms       4.951ms            24  
                              Optimizer.step#Optim.step        24.08%     454.978ms        38.64%     730.137ms     730.137ms       0.000us         0.00%     114.920ms     114.920ms             1  
                                              aten::bmm         0.34%       6.442ms         1.39%      26.357ms      59.363us      68.411ms        32.24%     105.901ms     238.516us           444  
                                           aten::matmul         0.20%       3.782ms        25.20%     476.263ms       1.234ms       0.000us         0.00%      81.857ms     212.065us           386  
                                     aten::nuclear_norm         0.01%     190.000us         7.45%     140.846ms       7.825ms       0.000us         0.00%      73.612ms       4.090ms            18  
                                   aten::linalg_svdvals         0.00%      40.000us         7.39%     139.638ms       7.758ms       0.000us         0.00%      73.494ms       4.083ms            18  
                                      aten::_linalg_svd         0.12%       2.202ms         7.39%     139.598ms       7.755ms      73.428ms        34.60%      73.494ms       4.083ms            18  
                                   volta_dgemm_64x64_nt         0.00%       0.000us         0.00%       0.000us       0.000us      62.215ms        29.32%      62.215ms     609.951us           102  
void gesvdbj_batch_32x16<double, double>(int, int co...         0.00%       0.000us         0.00%       0.000us       0.000us      52.741ms        24.86%      52.741ms     390.674us           135  
                                               aten::mm         0.20%       3.867ms        47.45%     896.618ms       3.915ms      12.993ms         6.12%      32.014ms     139.799us           229  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.890s
Self CUDA time total: 212.189ms

Here is the new nightly PyTorch/FuncTorch results

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 aten::_linalg_solve_ex         7.25%        1.252s        74.24%       12.810s     191.276us       0.000us         0.00%        3.294s      49.189us         66973  
                                     aten::linalg_solve         0.00%     759.000us        54.82%        9.460s     118.255ms       0.000us         0.00%        2.382s      29.780ms            80  
                                  aten::linalg_solve_ex         0.00%     146.000us        38.19%        6.591s     106.304ms       0.000us         0.00%        1.698s      27.379ms            62  
                                  aten::_linalg_slogdet         2.78%     479.166ms        35.95%        6.204s     477.729us       0.000us         0.00%        1.065s      82.026us         12987  
autograd::engine::evaluate_function: LinalgSolveExBa...         0.00%     347.000us        20.58%        3.552s     111.000ms       0.000us         0.00%     895.913ms      27.997ms            32  
                                 LinalgSolveExBackward0         0.00%     215.000us        20.58%        3.551s     110.976ms       0.000us         0.00%     895.885ms      27.996ms            32  
                              aten::linalg_lu_factor_ex         3.86%     666.437ms        13.07%        2.255s      30.562us     749.043ms        30.36%     895.281ms      12.135us         73779  
                                  aten::linalg_lu_solve         4.75%     819.420ms        19.41%        3.350s      54.511us     601.140ms        24.36%     853.543ms      13.890us         61451  
autograd::engine::evaluate_function: LinalgSlogdetBa...         0.00%     620.000us        17.79%        3.070s      99.029ms       0.000us         0.00%     811.436ms      26.175ms            31  
                                 LinalgSlogdetBackward0         0.00%     354.000us        17.78%        3.068s      98.961ms       0.000us         0.00%     811.393ms      26.174ms            31  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 17.256s
Self CUDA time total: 2.467s

As can be seen the bottleneck is these commands aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet. They are extremely slow in comparison to the other function calls. For some reason, they seem to even seem to call the CPU even though I run them exclusively on the GPU? I assume there are some commands that sync with the CPU, is there a way to find which functions do that?

TL;DR - no batching rule for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet results in significantly slowdown for per-sample gradients with torch.linalg.slogdet

Also, I think I figured out why I had different versions of functorch, it's because although I was in a particular conda env I installed functorch via pip install ... However when I installed it via python3 -m pip install ... it seems to have kept each functorch install to its respective conda env. If that makes sense?

@vfdev-5
Copy link
Contributor

vfdev-5 commented Jul 25, 2022

Thanks for detailed benchmark @AlphaBetaGamma96 ! Let's create a separate issue for that. Please provide a repro code if possible. Also, please make sure to provide correct versions for each benchmark and not mix them up.

TL;DR - no batching rule for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet results in significantly slowdown for per-sample gradients with torch.linalg.slogdet

I agree if the issue is with a missing batching rule, this can explain such slowdown. So, you see a warning about missing batching rule and that functorch will do a for-loop.

Also, I think I figured out why I had different versions of functorch, it's because although I was in a particular conda env I installed functorch via pip install ... However when I installed it via python3 -m pip install ... it seems to have kept each functorch install to its respective conda env. If that makes sense?

I can happen in a terminal that just calling pip will call base pip and not the one from the env, so installing a package with it will install the package into the system and not in the env.

@AlphaBetaGamma96
Copy link
Author

Thanks for detailed benchmark @AlphaBetaGamma96 ! Let's create a separate issue for that. Please provide a repro code if possible. Also, please make sure to provide correct versions for each benchmark and not mix them up.

I'll begin writing up a minimal reproducible example for this issue, and make sure to pay attention to the correct versions of PyTorch/FuncTorch.

I can happen in a terminal that just calling pip will call base pip and not the one from the env, so installing a package with it will install the package into the system and not in the env.

That makes sense, I'll make sure to pay attention to that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants