Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LayerThinningModifier #623

Closed
vjsrinivas opened this issue Mar 18, 2022 · 10 comments
Closed

LayerThinningModifier #623

vjsrinivas opened this issue Mar 18, 2022 · 10 comments
Assignees

Comments

@vjsrinivas
Copy link

vjsrinivas commented Mar 18, 2022

Reference:

Hey @vjsrinivas, we do have this enabled for models like ResNet-50 to automatically thin the network. Specifically, we have a LayerThinningModifier that can be used.

I've attached an example implementation of that for ResNet-50. Let us know if you need any support or run into any issues (the dependency graph generation can be tricky). Generally, we've seen 40% filter pruning being at the upper limit for ResNet-50 before it starts degrading in accuracy.

resnet50-structured.yaml.zip

Originally posted by @markurtz in #489 (comment)


Since the other issue is more about TensorRT, I don't want to pollute it with structured pruning related discussion. I ran the YAML for ResNet50 with LayerThinningModifier, but I would get the following error:

Traceback (most recent call last):
  File "classfication_channel.py", line 340, in <module>
    prune_train(model, input_size)
  File "classfication_channel.py", line 103, in prune_train
    model, train_loader, criterion, device, train=True, optimizer=optimizer
  File "classfication_channel.py", line 42, in run_model_one_epoch
    loss.backward()
  File "/home/vijay/anaconda3/envs/radar/lib/python3.7/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/vijay/anaconda3/envs/radar/lib/python3.7/site-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: Function CudnnConvolutionBackward returned an invalid gradient at index 1 - got [2048, 332, 1, 1] but expected shape compatible with [2048, 512, 1, 1]

It seems like the zeroed weights are removed but the layer objects themselves are still expecting the old weight and bias shapes. I changed the YAML file to run LayerThinningModifier at the end of training, and it successfully trained and removed the weights. I had to recreate the network when loading the pruned weights though.

@vjsrinivas
Copy link
Author

Here is the file that will reproduce the error: simple_classification_structured.zip
It's mostly copy-pasted from the classification.ipynb tutorial.

Here are my relevant specs:

  • Python 3.7.10
  • PyTorch 1.8
  • CUDA 10.2

@bfineran
Copy link
Member

thanks for filing the issue @vjsrinivas. We're looking into potential causes for this. In the meantime, using our image classification training script (integrations/pytorch/train.py) I was able to successfully run the recipe on resnet50 with imagenette with no errors. Could you try running the recipe there?

python integrations/pytorch/train.py \
  --recipe-path resnet50-structured.yaml \
  --arch-key resnet50 \
  --dataset imagenette \
  --dataset-path <Path to dataset>  \
  --pretrained True

@vjsrinivas
Copy link
Author

Hi @bfineran, I was able to run the recipe with train.py successfully, but I had some trouble with it initially. I ran python integrations/pytorch/train.py --recipe-path ../resnet50-structured.yaml --arch-key resnet50 --dataset imagenette --dataset-path ../../imagenette2 --pretrained True --train-batch-size 16 --test-batch-size 16, and it would train correctly until the last epoch, where it would error out with:

AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f5fabd0cd40>
Traceback (most recent call last):
  File "/home/vijay/anaconda3/envs/radar/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/home/vijay/anaconda3/envs/radar/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1279, in _shutdown_workers
    self._pin_memory_thread.join()
  File "/home/vijay/anaconda3/envs/radar/lib/python3.7/threading.py", line 1041, in join
    raise RuntimeError("cannot join current thread")
RuntimeError: cannot join current thread

From my searches, it seems related to num_workers in the DataLoaders of _create_train_dataset_and_loader and _create_val_dataset_and_loader. I set num_workers to zero and it worked all the way through.
I get the following accuracy at the end of the last epoch:

Saving model for epoch 9 and top-1 accuracy 94.5999984741211 to /home/vijay/Documents/devmk4/radar-cnn/misc/pruning_example/resnet50/test/sparseml/pytorch_vision/resnet50_imagenette__02 for model

But the layer sparisty stats are all zero:

2022-03-21 21:58:08 __main__     INFO     layer sparsities:
2022-03-21 21:58:08 __main__     INFO     layer sparsities:
2022-03-21 21:58:08 __main__     INFO     input.conv.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     input.conv.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.conv1.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.conv1.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.conv2.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.conv2.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.conv3.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.conv3.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.identity.conv.weight: 0.0000
2022-03-21 21:58:08 __main__     INFO     sections.0.0.identity.conv.weight: 0.0000
...

Also, I checked on the pth file output and it's the same size as the original, which I'm guessing means that it did not actually remove any weights?

@bfineran bfineran self-assigned this Mar 22, 2022
@bfineran
Copy link
Member

Hi @vjsrinivas glad to hear your were able to get the run working.

  1. Setting non zero number of DataLoader workers may cause issues when pytorch doesn't detect enough shared memory for processes to communicate with. Could you check the amount of shared memory available on your system? df -h | grep shm
  2. The recipe you are using just applies structured pruning and then thins the model to remove the now defunct channels. After those channels are removed, you will have a smaller dense model which is why the final sparsity is detected to be 0. Definitely up for debate if we should add additional logging at the end for structurally thinned models.
  3. What's the file size? And are you looking at the torch checkpoint? The torch checkpoint will contain additional data for items such as optimizer state which will essentially double the memory required for saving. You can try grabbing just the model weights and re-saving with the following:
import torch

PATH_TO_CKPT = "..."
NEW_CKPT_PATH = "..."

checkpoint = torch.load(PATH_TO_CKPT)
torch.save(checkpoint["state_dict"])

The new checkpoint should just contain the model weights.

@vjsrinivas
Copy link
Author

re: 1, the shared memory is 16GB: tmpfs 16G 210M 16G 2% /dev/shm

re: 2, I would strongly be in favor of showing the final max sparsity of the overall training cycle rather than the final epoch.

re: 3, Yes, I was talking about the PyTorch weight. You were right, the weights are pruned.

Now, when loading the pruned weights into original model, it will not fit because the filter and output sizes are different for each layer (expected). Currently, I'm manually editing the model structure with the new weight shapes (gist). It works but is very messy. Is there an easier way?

Thank you for the help, btw!

@bfineran
Copy link
Member

hm, ok maybe something else is causing the dataloader issue, what batch size are you running at and with how many workers were you trying?

that's great news. for re-loading the weights we're currently in the process of rolling out "phased recipes" that will track changes between runs and apply them before weight load. Until then, you can apply the structural changes yourself using the old recipe with the following:

from sparseml.pytorch.optim import ScheduledModifierManager

# model definition
model = ...

# load old recipe and apply to make structural changes
OLD_RECIPE_PATH = ...
ScheduledModifierManager.from_yaml(OLD_RECIPE_PATH).apply(model)
# parameter sizes should now line up and weights can be loaded

@vjsrinivas
Copy link
Author

Batchsize was 16. I did not set a value for num_worker, but the default value was 4. I'll make sure to use --loader-num-workers next time.
The code snippet works perfectly, thanks!

I had a couple of questions about the YAML recipe:

  • The docstring says that param_groups is to group similar layers for pruning (like residual blocks), but it seems like its grouped by corresponding layers between each residual block for each stage. I guess I'm confused on the reasoning behind the param_groups
  • From NeuralMagic's blog post about GMP, it was advised to plot the average, absolute value of each layer weight to determine which layer to prune. This makes sense in unstructured pruning, but if you're grouping layers together for structure pruning, how would you figure out which groups to prune if the average magnitude differs nontrivial for each layer in the group?
  • Is there a way to generate param_groups and param_group_dependency_map without exporting to ONNX?

@bfineran
Copy link
Member

No problem, happy to support.

  • a param group is a set of parameters who share overlapping parameters that are dependent on them for layer thinning. they must be pruned along the same channels (accounting for strides) so that there are no conflicts in connections when removing the corresponding channels from the dependent parameters. (ie in a residual block the same channels need to be pruned so they can feed into the same channels on the first layer after that block)
  • currently we use the l2 of magnitudes across channels to account for differing channel sizes in a group, feel free to experiment with different methods though, the logic is contained in the structured mask creator class
  • no, currently the best supported method is going through sparseml.onnx.optim. get_param_structured_pruning_group_dependencies with an onnx export. This is due to pytorch graphs not being static, we could in theory trace the graph with torch.jit, but the onnx export essentially does this and gives it to us in an easy to parse format

@vjsrinivas
Copy link
Author

@bfineran Thank you for your help! If it is available, do you mind making a YOLOv3 structured pruning recipe available?
I'll close this issue since the main questions have been addressed.

@vjsrinivas
Copy link
Author

I don't want to resuscitate an old issue, but this error: RuntimeError: Function CudnnConvolutionBackward returned an invalid gradient at index 1 - got [2048, 332, 1, 1] but expected shape compatible with [2048, 512, 1, 1] was fixed by re-initializing the optimizer to account for the newly thinned models parameters. Not sure how it'd work with a lr scheduler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants