Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ __This branch corresponds to the ongoing 2024 course. If you want to see full ma
- [__Week 8:__](./week08_inference_software) __LLM inference optimizations and software__
- Lecture: Inference speed metrics. KV caching, batch inference, continuous batching. FlashAttention with its modifications and PagedAttention. Overview of popular LLM serving frameworks.
- Seminar: Basics of the Triton language. Layer fusion in PyTorch and Triton. Implementation of KV caching, FlashAttention in practice.
- __Week 9:__ __Efficient model inference__
- [__Week 9:__](./week09_compression) __Efficient model inference__
- Lecture: Hardware utilization metrics for deep learning. Knowledge distillation, quantization, LLM.int8(), SmoothQuant, GPTQ. Efficient model architectures. Speculative decoding.
- Seminar: Measuring Memory Bandwidth Utilization in practice. Data-free quantization, GPTq, and SmoothQuant in PyTorch.
- __Week 10:__ __Guest lecture__

## Grading
Expand Down
58 changes: 58 additions & 0 deletions week09_compression/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Week 9: Efficient model inference

* Lecture: [slides](./lecture.pdf)
* Seminar: [notebook](./practice.ipynb)
* Homework: see [homework/README.md](homework/README.md)

### Setup for the seminar notebook
You can use [conda](https://docs.anaconda.com/free/miniconda/), [mamba](https://mamba.readthedocs.io/en/latest/user_guide/mamba.html) or [micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) to create the environment.

```
conda create -n inference \
python=3.10 \
pytorch=2.2.1 \
torchvision=0.17.1 \
torchaudio=2.2.1 \
pytorch-cuda=11.8 \
matplotlib=3.8.0 \
seaborn=0.12.2 \
numpy=1.26.4 \
ipywidgets=8.1.2 \
jupyterlab_widgets=3.0.10 \
jupyterlab=4.0.11 \
tqdm=4.65.0 \
-c pytorch -c nvidia -y

conda activate inference

# To run part with auto-gptq
pip install auto-gptq==0.7.1 accelerate==0.28.0
pip install --upgrade git+https://github.com/huggingface/transformers.git

# To run part with Smoothquant
cd ~
git clone git@github.com:mit-han-lab/smoothquant.git
cd smoothquant
python setup.py install
cd path/to/efficient-dl-systems/week09_compression

# Finally, running notebook
jupyter lab --no-browser
```

## Further reading

### Knowledge distillation
* https://arxiv.org/abs/2106.05237
* https://arxiv.org/abs/1910.01108
* https://arxiv.org/abs/1909.10351

### Pruning
* https://arxiv.org/abs/2302.04089
* https://arxiv.org/abs/2301.00774

### Quantization
* https://arxiv.org/abs/2206.09557
* https://arxiv.org/abs/2208.07339
* https://huggingface.co/blog/hf-bitsandbytes-integration
* https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
51 changes: 51 additions & 0 deletions week09_compression/homework/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Week 9 home assignment

## Submission format
Implement models, training procedures and benchmarks in `.py` files, run all code in a Jupyter notebook and convert it to the PDF format.
Include your implementations and the report file into a `.zip` archive and submit it.


## Task 1: knowledge distillation for image classification (6 points)

0. Finetune ResNet101 on CIFAR10: change only the classification linear layer [*] and don't freeze other weights (**0 points**)

Then take untrained ResNet101 model, remove the `layer3` (except one conv block that creates correct number of channels for the 4-th layer) block out of it and implement 3 training setups:
1. Train the model on input data only (**1 point**)
2. Train the model on data and add soft cross-entropy between the student (truncated ResNet101) and the teacher (finetuned full ResNet101) (**2 points**)
3. Train the model as in the previous subtask, but also add the MSE loss between corresponding `layer1`, `layer2` and `layer4` features of the student and the teacher (**3 points**)
4. Report test accuracy for each of the models

[\*] Vanilla ResNet is not very well suited for CIFAR: it downsamples the image by x32, while images in CIFAR are 32x32 pixels. So you can:
- upsample images (easiest to implement, but you will perform more computations)
- slightly change the first layers (e.g. make `model.conv1` a 3x3 convolution with stride 1 and remove `model.maxpool`)

Feel free to use dataset and model implementation from PyTorch.
For losses in 2nd and 3rd subtasks use the simple average of all inputs.
For the 3rd subtask, you will need to return not only the model's outputs but also intermediate feature maps.

### Training setup
- Use the standard Adam optimizer without scheduler.
- Use any suitable batch size from 128 to 512.
- Training stopping criterion: accuracy (measured from 0 to 1) stabilizes in the second digit after decimal during at least 2 epochs on test set.
That means that you must satisfy condition `torch.abs(acc - acc_prev) < 0.01` for at least two epochs in a row.

## Task 2: use `deepsparse` to prune & quantize your model (4 points)

0. Please read the whole task description before starting it.
1. Install `deepsparse==1.7.0` and `sparseml==1.7.0`. Note: they might not work smoothly with last PyTorch versions. If so, you can downgrade to `torch==1.12.1`.
2. Take your best trained model from subtasks 1.1-1.3 and run pruning + quantization-aware-training, adapting the following [example](./example_train_sparse_and_quantize.py). You will need to change/implement what is marked by #TODO and report the test accuracy of both models. (**3 points**)
3. Take `onnx` baseline (best trained model from subtask 1.1 - 1.3) and pruned-quantized version and benchmark both models on the CPU using `deepsparse.benchmark` at batch sizes 1 and 32. (**1 point**)

For task 2.3, you may find [this page](https://web.archive.org/web/20240319095504/https://docs.neuralmagic.com/user-guides/deepsparse-engine/benchmarking/) helpful.

You should not use training stopping criterion in this part, since the sparsification recipe relies on having certain amount of epochs.

### Tips:
- Debug your code with resnet18 to iterate faster
- Don't forget `model.eval()` before onnx export
- Don't forget `convert_qat=True` in `sparseml.pytorch.utils.export_onnx` after you trained the model with quantization
- To visualize ONNX models, you can use [netron](https://netron.app/)
- Explicitly set the amount of cores in `deepsparse.benchmark`
- If you are desperate and don't have time to train bigger models, submit this part with resnet18

Good luck and have 59 funs!
93 changes: 93 additions & 0 deletions week09_compression/homework/example_train_sparse_and_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from pathlib import Path
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision.models import resnet18, ResNet18_Weights
from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import export_onnx

def save_onnx(model, export_path, convert_qat):
# It is important to call torch_model.eval() or torch_model.train(False) before exporting the model, to turn the model to inference mode.
# This is required since operators like dropout or batchnorm behave differently in inference and training mode.
model.eval()
sample_batch = torch.randn((1, 3, 224, 224))
export_onnx(model, sample_batch, export_path, convert_qat=convert_qat)


def main():
# TODO: add argparse/hydra/... to manage hyperparameters like batch_size, path to pretrained model, etc

# Sparsification recipe -- yaml file with instructions on how to sparsify the model
recipe_path = "recipe.yaml"
assert Path(recipe_path).exists(), "Didn't find sparsification recipe!"

checkpoints_path = Path("checkpoints")
checkpoints_path.mkdir(exist_ok=True)

# Model creation
# TODO: change to your best model from subtasks 1.1 - 1.3
NUM_CLASSES = 10 # number of Imagenette classes
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

save_onnx(model, checkpoints_path / "baseline_resnet.onnx", convert_qat=False)

# Dataset creation
# TODO: change to CIFAR10, add test dataset
batch_size = 64
train_dataset = ImagenetteDataset(train=True, dataset_size=ImagenetteSize.s320, image_size=224)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=True, num_workers=8)

# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)

# Loss setup
criterion = nn.CrossEntropyLoss()
# Note that learning rate is being modified in `recipe.yaml`
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# SparseML Integration
manager = ScheduledModifierManager.from_yaml(recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

# Training Loop
model.train()

# TODO: implement `train_one_epoch` function to structure the code better
pbar = tqdm(range(manager.max_epochs), desc="epoch")
for epoch in pbar:
running_loss = 0.0
running_corrects = 0.0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()

with torch.set_grad_enabled(True):
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
loss.backward()
optimizer.step()

running_loss += loss * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)

epoch_loss = running_loss.item() / len(train_loader.dataset)
epoch_acc = running_corrects.double().item() / len(train_loader.dataset)
pbar.set_description(f"Training loss: {epoch_loss:.3f} Accuracy: {epoch_acc:.3f}")

# TODO: implement `evaluate` function to measure accuracy on the test set

manager.finalize(model)

# Saving model
save_onnx(model, checkpoints_path / "pruned_quantized_resnet.onnx", convert_qat=True)

if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions week09_compression/homework/recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
modifiers:
- !GlobalMagnitudePruningModifier
init_sparsity: 0.05
final_sparsity: 0.8
start_epoch: 0.0
end_epoch: 30.0
update_frequency: 1.0
params: __ALL_PRUNABLE__

- !SetLearningRateModifier
start_epoch: 0.0
learning_rate: 0.05

- !LearningRateFunctionModifier
start_epoch: 30.0
end_epoch: 50.0
lr_func: cosine
init_lr: 0.05
final_lr: 0.001

- !QuantizationModifier
start_epoch: 50.0
freeze_bn_stats_epoch: 53.0

- !SetLearningRateModifier
start_epoch: 50.0
learning_rate: 10e-6

- !EpochRangeModifier
start_epoch: 0.0
end_epoch: 55.0
Binary file added week09_compression/lecture.pdf
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading