diff --git a/README.md b/README.md index 8170249..f944689 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,9 @@ __This branch corresponds to the ongoing 2024 course. If you want to see full ma - [__Week 8:__](./week08_inference_software) __LLM inference optimizations and software__ - Lecture: Inference speed metrics. KV caching, batch inference, continuous batching. FlashAttention with its modifications and PagedAttention. Overview of popular LLM serving frameworks. - Seminar: Basics of the Triton language. Layer fusion in PyTorch and Triton. Implementation of KV caching, FlashAttention in practice. -- __Week 9:__ __Efficient model inference__ +- [__Week 9:__](./week09_compression) __Efficient model inference__ + - Lecture: Hardware utilization metrics for deep learning. Knowledge distillation, quantization, LLM.int8(), SmoothQuant, GPTQ. Efficient model architectures. Speculative decoding. + - Seminar: Measuring Memory Bandwidth Utilization in practice. Data-free quantization, GPTq, and SmoothQuant in PyTorch. - __Week 10:__ __Guest lecture__ ## Grading diff --git a/week09_compression/README.md b/week09_compression/README.md new file mode 100644 index 0000000..f67a5ae --- /dev/null +++ b/week09_compression/README.md @@ -0,0 +1,58 @@ +# Week 9: Efficient model inference + +* Lecture: [slides](./lecture.pdf) +* Seminar: [notebook](./practice.ipynb) +* Homework: see [homework/README.md](homework/README.md) + +### Setup for the seminar notebook +You can use [conda](https://docs.anaconda.com/free/miniconda/), [mamba](https://mamba.readthedocs.io/en/latest/user_guide/mamba.html) or [micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) to create the environment. + +``` +conda create -n inference \ + python=3.10 \ + pytorch=2.2.1 \ + torchvision=0.17.1 \ + torchaudio=2.2.1 \ + pytorch-cuda=11.8 \ + matplotlib=3.8.0 \ + seaborn=0.12.2 \ + numpy=1.26.4 \ + ipywidgets=8.1.2 \ + jupyterlab_widgets=3.0.10 \ + jupyterlab=4.0.11 \ + tqdm=4.65.0 \ + -c pytorch -c nvidia -y + +conda activate inference + +# To run part with auto-gptq +pip install auto-gptq==0.7.1 accelerate==0.28.0 +pip install --upgrade git+https://github.com/huggingface/transformers.git + +# To run part with Smoothquant +cd ~ +git clone git@github.com:mit-han-lab/smoothquant.git +cd smoothquant +python setup.py install +cd path/to/efficient-dl-systems/week09_compression + +# Finally, running notebook +jupyter lab --no-browser +``` + +## Further reading + +### Knowledge distillation +* https://arxiv.org/abs/2106.05237 +* https://arxiv.org/abs/1910.01108 +* https://arxiv.org/abs/1909.10351 + +### Pruning +* https://arxiv.org/abs/2302.04089 +* https://arxiv.org/abs/2301.00774 + +### Quantization +* https://arxiv.org/abs/2206.09557 +* https://arxiv.org/abs/2208.07339 +* https://huggingface.co/blog/hf-bitsandbytes-integration +* https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html \ No newline at end of file diff --git a/week09_compression/homework/README.md b/week09_compression/homework/README.md new file mode 100644 index 0000000..42342db --- /dev/null +++ b/week09_compression/homework/README.md @@ -0,0 +1,51 @@ +# Week 9 home assignment + +## Submission format +Implement models, training procedures and benchmarks in `.py` files, run all code in a Jupyter notebook and convert it to the PDF format. +Include your implementations and the report file into a `.zip` archive and submit it. + + +## Task 1: knowledge distillation for image classification (6 points) + +0. Finetune ResNet101 on CIFAR10: change only the classification linear layer [*] and don't freeze other weights (**0 points**) + +Then take untrained ResNet101 model, remove the `layer3` (except one conv block that creates correct number of channels for the 4-th layer) block out of it and implement 3 training setups: +1. Train the model on input data only (**1 point**) +2. Train the model on data and add soft cross-entropy between the student (truncated ResNet101) and the teacher (finetuned full ResNet101) (**2 points**) +3. Train the model as in the previous subtask, but also add the MSE loss between corresponding `layer1`, `layer2` and `layer4` features of the student and the teacher (**3 points**) +4. Report test accuracy for each of the models + +[\*] Vanilla ResNet is not very well suited for CIFAR: it downsamples the image by x32, while images in CIFAR are 32x32 pixels. So you can: +- upsample images (easiest to implement, but you will perform more computations) +- slightly change the first layers (e.g. make `model.conv1` a 3x3 convolution with stride 1 and remove `model.maxpool`) + +Feel free to use dataset and model implementation from PyTorch. +For losses in 2nd and 3rd subtasks use the simple average of all inputs. +For the 3rd subtask, you will need to return not only the model's outputs but also intermediate feature maps. + +### Training setup +- Use the standard Adam optimizer without scheduler. +- Use any suitable batch size from 128 to 512. +- Training stopping criterion: accuracy (measured from 0 to 1) stabilizes in the second digit after decimal during at least 2 epochs on test set. +That means that you must satisfy condition `torch.abs(acc - acc_prev) < 0.01` for at least two epochs in a row. + +## Task 2: use `deepsparse` to prune & quantize your model (4 points) + +0. Please read the whole task description before starting it. +1. Install `deepsparse==1.7.0` and `sparseml==1.7.0`. Note: they might not work smoothly with last PyTorch versions. If so, you can downgrade to `torch==1.12.1`. +2. Take your best trained model from subtasks 1.1-1.3 and run pruning + quantization-aware-training, adapting the following [example](./example_train_sparse_and_quantize.py). You will need to change/implement what is marked by #TODO and report the test accuracy of both models. (**3 points**) +3. Take `onnx` baseline (best trained model from subtask 1.1 - 1.3) and pruned-quantized version and benchmark both models on the CPU using `deepsparse.benchmark` at batch sizes 1 and 32. (**1 point**) + +For task 2.3, you may find [this page](https://web.archive.org/web/20240319095504/https://docs.neuralmagic.com/user-guides/deepsparse-engine/benchmarking/) helpful. + +You should not use training stopping criterion in this part, since the sparsification recipe relies on having certain amount of epochs. + +### Tips: +- Debug your code with resnet18 to iterate faster +- Don't forget `model.eval()` before onnx export +- Don't forget `convert_qat=True` in `sparseml.pytorch.utils.export_onnx` after you trained the model with quantization +- To visualize ONNX models, you can use [netron](https://netron.app/) +- Explicitly set the amount of cores in `deepsparse.benchmark` +- If you are desperate and don't have time to train bigger models, submit this part with resnet18 + +Good luck and have 59 funs! diff --git a/week09_compression/homework/example_train_sparse_and_quantize.py b/week09_compression/homework/example_train_sparse_and_quantize.py new file mode 100644 index 0000000..5994cc4 --- /dev/null +++ b/week09_compression/homework/example_train_sparse_and_quantize.py @@ -0,0 +1,93 @@ +from pathlib import Path +from tqdm.auto import tqdm + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from torchvision.models import resnet18, ResNet18_Weights +from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize +from sparseml.pytorch.optim import ScheduledModifierManager +from sparseml.pytorch.utils import export_onnx + +def save_onnx(model, export_path, convert_qat): + # It is important to call torch_model.eval() or torch_model.train(False) before exporting the model, to turn the model to inference mode. + # This is required since operators like dropout or batchnorm behave differently in inference and training mode. + model.eval() + sample_batch = torch.randn((1, 3, 224, 224)) + export_onnx(model, sample_batch, export_path, convert_qat=convert_qat) + + +def main(): + # TODO: add argparse/hydra/... to manage hyperparameters like batch_size, path to pretrained model, etc + + # Sparsification recipe -- yaml file with instructions on how to sparsify the model + recipe_path = "recipe.yaml" + assert Path(recipe_path).exists(), "Didn't find sparsification recipe!" + + checkpoints_path = Path("checkpoints") + checkpoints_path.mkdir(exist_ok=True) + + # Model creation + # TODO: change to your best model from subtasks 1.1 - 1.3 + NUM_CLASSES = 10 # number of Imagenette classes + model = resnet18(weights=ResNet18_Weights.DEFAULT) + model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES) + + save_onnx(model, checkpoints_path / "baseline_resnet.onnx", convert_qat=False) + + # Dataset creation + # TODO: change to CIFAR10, add test dataset + batch_size = 64 + train_dataset = ImagenetteDataset(train=True, dataset_size=ImagenetteSize.s320, image_size=224) + train_loader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=True, num_workers=8) + + # Device setup + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + + # Loss setup + criterion = nn.CrossEntropyLoss() + # Note that learning rate is being modified in `recipe.yaml` + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + # SparseML Integration + manager = ScheduledModifierManager.from_yaml(recipe_path) + optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader)) + + # Training Loop + model.train() + + # TODO: implement `train_one_epoch` function to structure the code better + pbar = tqdm(range(manager.max_epochs), desc="epoch") + for epoch in pbar: + running_loss = 0.0 + running_corrects = 0.0 + for inputs, labels in train_loader: + inputs = inputs.to(device) + labels = labels.to(device) + optimizer.zero_grad() + + with torch.set_grad_enabled(True): + outputs = model(inputs) + loss = criterion(outputs, labels) + _, preds = torch.max(outputs, 1) + loss.backward() + optimizer.step() + + running_loss += loss * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + epoch_loss = running_loss.item() / len(train_loader.dataset) + epoch_acc = running_corrects.double().item() / len(train_loader.dataset) + pbar.set_description(f"Training loss: {epoch_loss:.3f} Accuracy: {epoch_acc:.3f}") + + # TODO: implement `evaluate` function to measure accuracy on the test set + + manager.finalize(model) + + # Saving model + save_onnx(model, checkpoints_path / "pruned_quantized_resnet.onnx", convert_qat=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/week09_compression/homework/recipe.yaml b/week09_compression/homework/recipe.yaml new file mode 100644 index 0000000..130c0ef --- /dev/null +++ b/week09_compression/homework/recipe.yaml @@ -0,0 +1,31 @@ +modifiers: + - !GlobalMagnitudePruningModifier + init_sparsity: 0.05 + final_sparsity: 0.8 + start_epoch: 0.0 + end_epoch: 30.0 + update_frequency: 1.0 + params: __ALL_PRUNABLE__ + + - !SetLearningRateModifier + start_epoch: 0.0 + learning_rate: 0.05 + + - !LearningRateFunctionModifier + start_epoch: 30.0 + end_epoch: 50.0 + lr_func: cosine + init_lr: 0.05 + final_lr: 0.001 + + - !QuantizationModifier + start_epoch: 50.0 + freeze_bn_stats_epoch: 53.0 + + - !SetLearningRateModifier + start_epoch: 50.0 + learning_rate: 10e-6 + + - !EpochRangeModifier + start_epoch: 0.0 + end_epoch: 55.0 \ No newline at end of file diff --git a/week09_compression/lecture.pdf b/week09_compression/lecture.pdf new file mode 100644 index 0000000..dfbbe8c Binary files /dev/null and b/week09_compression/lecture.pdf differ diff --git a/week09_compression/memory_bandwidth_utilization.jpg b/week09_compression/memory_bandwidth_utilization.jpg new file mode 100644 index 0000000..cdab8f4 Binary files /dev/null and b/week09_compression/memory_bandwidth_utilization.jpg differ diff --git a/week09_compression/practice.ipynb b/week09_compression/practice.ipynb new file mode 100644 index 0000000..f4d695b --- /dev/null +++ b/week09_compression/practice.ipynb @@ -0,0 +1,1026 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2341fbbf", + "metadata": {}, + "source": [ + "# Week 9: Efficient model inference\n", + "\n", + "As we now know from the lecture, there are many ways to make inference more efficient:\n", + "- Distillation\n", + "- Quantization\n", + "- Changing architecture (e.g. encoder-decoder vs decoder)\n", + "- Speculative decoding\n", + "\n", + "In the seminar we will talk about different kinds of **post-training quantization**.\n", + "\n", + "For more info about quantization, a good starting point is [\"A Survey of Quantization Methods for Efficient Neural Network Inference\"](https://arxiv.org/abs/2103.13630), 2021.\n", + "\n", + "### Plan:\n", + "\n", + "1. Some notes about Memory Bandwidth Utilization\n", + "2. Data-free quantization with T5\n", + "3. Weight-only Quantization with calibration (GPTq)\n", + "4. Weight & Activation Quantization (SmoothQuant)" + ] + }, + { + "cell_type": "markdown", + "id": "c569d5d7-8373-474e-ba7c-579b836758ec", + "metadata": {}, + "source": [ + "## 1: Memory Bandwidth Utilization (MBU)" + ] + }, + { + "cell_type": "markdown", + "id": "d4770d8e-a6bd-49c7-99d4-d5162fd090e5", + "metadata": {}, + "source": [ + "Let's read the following passage from [this post](https://www.databricks.com/blog/llm-inference-performance-engineering-best-practices) by Databricks." + ] + }, + { + "cell_type": "markdown", + "id": "45c9a148-0204-4a7e-9e92-d4c998be3d84", + "metadata": {}, + "source": [ + "> So, how exactly should we think about inference speed?\n", + "Our _[Databricks]_ team uses four key metrics for LLM serving:\n", + "> 1. **Time To First Token (TTFT)**: How quickly users start seeing the model's output after entering their query. Low waiting times for a response are essential in real-time interactions, but less important in offline workloads. This metric is driven by the time required to process the prompt and then generate the first output token.\n", + "> 2. **Time Per Output Token (TPOT)**: Time to generate an output token for each user that is querying our system. This metric corresponds with how each user will perceive the \"speed\" of the model. For example, a TPOT of 100 milliseconds/tok would be 10 tokens per second per user, or ~450 words per minute, which is faster than a typical person can read.\n", + "> 3. **Latency**: The overall time it takes for the model to generate the full response for a user. Overall response latency can be calculated using the previous two metrics: latency = (TTFT) + (TPOT) * (the number of tokens to be generated)\n", + ">4. **Throughput**: The number of output tokens per second an inference server can generate across all users and requests.ests." + ] + }, + { + "cell_type": "markdown", + "id": "7793f603-ab3b-4a5a-9930-f699ad1b3fcc", + "metadata": {}, + "source": [ + "> To measure the underlying hardware's utilization, we introduce a new metric called Model Bandwidth Utilization (MBU). \n", + "> MBU is defined as \n", + "\n", + "$$\\frac{\\text{achieved memory bandwidth}}{\\text{peak memory bandwidth}}$$\n", + "\n", + ">where \n", + "\n", + "$$\n", + "\\text{achieved memory bandwidth} = \\frac{\\text{total model parameter size + KV cache size}}{\\text{TPOT}}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "524efadd-f13f-42ab-babc-228180e270de", + "metadata": {}, + "source": [ + "![](memory_bandwidth_utilization.jpg)" + ] + }, + { + "cell_type": "markdown", + "id": "c7d7506a-e99e-4f6f-8612-1e1331179de6", + "metadata": {}, + "source": [ + "### Example on how to estimate MBU\n", + "\n", + "- For example, if a 7B parameter running with 16-bit precision has TPOT equal to 14ms, then it's moving 14GB of parameters in 14ms translating to 1TB/sec bandwidth usage.\n", + "- A100 can handle up to ~2Tb/sec.\n", + "- So, we are running at an MBU of 50%." + ] + }, + { + "cell_type": "markdown", + "id": "f1c00c6a-40a4-4a52-b757-bd3f7ac61a77", + "metadata": {}, + "source": [ + "## 2: Data-free quantization with t5\n", + "\n", + "First let's try data-free quantization, initially proposed in [\"QLoRA: Efficient Finetuning of Quantized LLMs\"](https://arxiv.org/abs/2305.14314).\n", + "\n", + "(Section is based on this [post](https://huggingface.co/blog/hf-bitsandbytes-integration).)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42d883ed-b571-4ca7-86e8-c8f0c638880d", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# autogptq can be very slow if you don't restrict the amount of CPU cores it is using\n", + "max_cpu_threads = \"16\"\n", + "os.environ[\"OMP_NUM_THREADS\"] = max_cpu_threads\n", + "os.environ[\"OPENBLAS_NUM_THREADS\"] = max_cpu_threads\n", + "os.environ[\"MKL_NUM_THREADS\"] = max_cpu_threads\n", + "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = max_cpu_threads\n", + "os.environ[\"NUMEXPR_NUM_THREADS\"] = max_cpu_threads\n", + "os.environ[\"NUMEXPR_MAX_THREADS\"] = max_cpu_threads" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6c26c41", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d07f8ec", + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"t5-3b-sharded\" # @param [\"t5-11b-sharded\", \"t5-3b-sharded\"]\n", + "\n", + "# T5-3b and T5-11B are supported!\n", + "# We need sharded weights otherwise we get CPU OOM errors\n", + "model_id = f\"ybelkada/{model_name}\"\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", + "model_8bit = AutoModelForSeq2SeqLM.from_pretrained(\n", + " model_id,\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_8bit=True\n", + " ),\n", + " device_map=\"auto\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87b15c0d", + "metadata": {}, + "outputs": [], + "source": [ + "model_8bit.get_memory_footprint() / 1e9" + ] + }, + { + "cell_type": "markdown", + "id": "cd59d870", + "metadata": {}, + "source": [ + "For t5-3b the int8 model is about ~5.3GB! whereas the original model has 11GB. \n", + "\n", + "For t5-11b the int8 model is about ~11GB vs 42GB for the original model. Now let's generate and see the qualitative results of the 8bit model!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a192201c", + "metadata": {}, + "outputs": [], + "source": [ + "max_new_tokens = 50\n", + "\n", + "input_ids = tokenizer(\n", + " \"translate English to German: Hello my name is Younes and I am a Machine Learning Engineer at Hugging Face\",\n", + " return_tensors=\"pt\",\n", + ").input_ids.to(\"cuda:0\")\n", + "\n", + "outputs = model_8bit.generate(input_ids, max_new_tokens=max_new_tokens)\n", + "print(tokenizer.decode(outputs[0], skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b6e63e7-352b-4d26-9c6f-5fd4f3ae7408", + "metadata": {}, + "outputs": [], + "source": [ + "torch.cuda.max_memory_allocated() / 1e9" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94ffe96b-c10a-4fcd-94a3-64678ee89f5d", + "metadata": {}, + "outputs": [], + "source": [ + "del model_8bit, tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc32dd19-d725-420f-9cda-b0f6907d2073", + "metadata": {}, + "outputs": [], + "source": [ + "# https://stackoverflow.com/questions/57858433/how-to-clear-gpu-memory-after-pytorch-model-training-without-restarting-kernel\n", + "import gc\n", + "torch.cuda.empty_cache()\n", + "gc.collect()\n", + "\n", + "torch.cuda.reset_peak_memory_stats()\n", + "torch.cuda.max_memory_allocated() / 1e9" + ] + }, + { + "cell_type": "markdown", + "id": "2e3243af-0599-470a-9f5b-578926e3a5d8", + "metadata": {}, + "source": [ + "## 3: Weight-only quantization with calibration dataset (GPTq)\n", + "\n", + "Data-free quantization usually does something like\n", + "$$\n", + "\\arg\\min \\|W - W_{\\text{quantized}}\\|_{F}\n", + "$$\n", + "It is simple and easy to use. However, this does not acoount for the fact, that we apply our models on a specific distribution of data.\n", + "\n", + "Let's $X$ to be activation from previous layers. Then we might formulate quantization objective as\n", + "$$\n", + "\\arg\\min \\|X \\cdot W - X \\cdot W_{\\text{quantized}}\\|_{F}\n", + "$$\n", + "The intuition is that we want to preserve _the way layer $W$ transforms the inputs_, not its literal weights.\n", + "This is one of the core ideas used in GPTq algorithm.\n", + "\n", + "(Based on [AutoGPTq tutorial](https://github.com/AutoGPTQ/AutoGPTQ/tree/main))" + ] + }, + { + "cell_type": "markdown", + "id": "1f3deac8-78ce-46a5-b9ac-57ddfa0ab479", + "metadata": {}, + "source": [ + "### Setting up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fe6a780-0f0c-4b6e-9d73-62e55c9eef0a", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_name,\n", + " torch_dtype=torch.float16,\n", + " local_files_only=True,\n", + " low_cpu_mem_usage=True, # speeds up loading, if `accelerate` is installed\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "167bd7fa-d5ae-489e-a4b4-e9b4edede93e", + "metadata": {}, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b9d03ec-a3c8-4191-b01e-10b810ac943f", + "metadata": {}, + "outputs": [], + "source": [ + "def count_params(model):\n", + " return sum(p.numel() for p in model.parameters())\n", + "\n", + "print(f\"{count_params(model) // 1e6:4.0f} M parameters\")\n", + "print(f\"{count_params(model.model.embed_tokens) // 1e6:4.0f} M parameters in embedding block\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ca501cb-ca81-46f0-9c60-86dca670200e", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\")\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e26d9020-61d5-44b0-bfd1-1ecce08b3682", + "metadata": {}, + "outputs": [], + "source": [ + "@torch.inference_mode()\n", + "def generate(model, tokenizer, prefix, max_length, device=\"cuda:0\") -> str:\n", + " inputs = tokenizer(prefix, return_tensors=\"pt\").to(device)\n", + " outputs = model.generate(\n", + " **inputs,\n", + " max_length=max_length,\n", + " pad_token_id=tokenizer.eos_token_id,\n", + " do_sample=True,\n", + " repetition_penalty=1.1,\n", + " )\n", + " return tokenizer.decode(outputs[0], skip_special_tokens=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14de5efa-23ab-4541-bf7b-96fac9fc253d", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import time\n", + "from tqdm.auto import tqdm\n", + "\n", + "prompts = [f\"You will never believe this wild conspiracy theory about {topic}:\"\n", + " for topic in (\"bananas\", \"grizzly bears\", \"gummy bears\", \"Python language\", \"Yann LeCun\")]\n", + "\n", + "max_length = 384\n", + "\n", + "start = time.perf_counter()\n", + "answers = [generate(model, tokenizer, prompt, max_length) for prompt in tqdm(prompts)]\n", + "generation_time = time.perf_counter() - start" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98ecfc9f-190c-4850-ac28-217b5db202de", + "metadata": {}, + "outputs": [], + "source": [ + "print(answers[4])" + ] + }, + { + "cell_type": "markdown", + "id": "9884670f-538d-40d8-8219-76c113ab8e30", + "metadata": {}, + "source": [ + "Let's calculate MBU for this model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66617acf-03ac-4297-88cc-a108946b2872", + "metadata": {}, + "outputs": [], + "source": [ + "n_generated_tokens_total = sum([len(answer) - len(prompt)\n", + " for answer, prompt in zip(tokenizer(answers).input_ids, tokenizer(prompts).input_ids)])\n", + "n_generated_tokens_total" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64b8d16e-e669-4348-bded-5e3ad2a5366e", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Generation speed: {n_generated_tokens_total / generation_time:.1f} tokens/sec\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02277a09-a8ff-48a1-9efd-cda94808ae3c", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_model_size_mb(model):\n", + " model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / 1e6\n", + " model_size_mb += sum(b.numel() * b.element_size() for b in model.buffers()) / 1e6\n", + " return model_size_mb\n", + "\n", + "def compute_memory_bandwidth_utilization(model_and_kv_cache_size_mb, max_bandwidth_mb, time_per_output_token):\n", + " return (model_and_kv_cache_size_mb / time_per_output_token) / max_bandwidth_mb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2149797d-7f79-4e0e-a4d2-805d5a607a0a", + "metadata": {}, + "outputs": [], + "source": [ + "model_size_mb = compute_model_size_mb(model)\n", + "\n", + "# 2 * batch_size * sequence_length * n_layers * (n_heads * d_head) * precision\n", + "kv_cache_size_mb = 2 * 1 * max_length * model.config.num_hidden_layers * model.config.hidden_size * 2 / 1e6\n", + "\n", + "a100_max_bandwidth_mb = 2e6\n", + "\n", + "mbu = compute_memory_bandwidth_utilization(\n", + " model_size_mb + kv_cache_size_mb,\n", + " a100_max_bandwidth_mb, \n", + " generation_time / n_generated_tokens_total\n", + ")\n", + "\n", + "print(f\"Memory Bandwidth Utilization is {mbu * 100:.2f} %\") " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17351c00-b791-4839-b798-566e8bd99f45", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Model size: {model_size_mb:.0f} Mb\")\n", + "print(f\"KV cache size: {kv_cache_size_mb:.0f} Mb\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e12ba245-769c-4f47-8935-f42e3fcde5d5", + "metadata": {}, + "outputs": [], + "source": [ + "del model" + ] + }, + { + "cell_type": "markdown", + "id": "11542d1d-f4b3-418f-a147-e4b67d40df89", + "metadata": {}, + "source": [ + "### Run AutoGPTq" + ] + }, + { + "cell_type": "markdown", + "id": "ce7c02d5-9347-4e15-85c9-5946a515876c", + "metadata": {}, + "source": [ + "Let's prepare a calibration dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50dc659d-c98d-4d31-bf6b-ea5871f779f7", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "n_samples = 128\n", + "dataset = load_dataset(\"wikitext\", \"wikitext-2-v1\", split=\"test\")\n", + "\n", + "calibration_set = dataset.filter(lambda example: len(example[\"text\"]) > 100)\n", + "calibration_set = calibration_set.shuffle(seed=59)[:n_samples][\"text\"]\n", + "\n", + "len(calibration_set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df343800-bef9-4813-81e2-d3ccbb0a682f", + "metadata": {}, + "outputs": [], + "source": [ + "calibration_set[:2]" + ] + }, + { + "cell_type": "markdown", + "id": "30ebb973-92a0-455f-951d-19ee63eaecf1", + "metadata": {}, + "source": [ + "Now we can run GPTq." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d000862-7f33-4ea5-9e2f-b98b3a36192e", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, TextGenerationPipeline\n", + "from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig\n", + "import logging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac6db93c-c13b-42c7-b07b-181b6ce03a88", + "metadata": {}, + "outputs": [], + "source": [ + "logging.basicConfig(\n", + " format=\"%(asctime)s %(levelname)s [%(name)s] %(message)s\", level=logging.INFO, datefmt=\"%Y-%m-%d %H:%M:%S\"\n", + ")\n", + "\n", + "quantized_model_dir = model_name + \"_4bit\"\n", + "\n", + "quantize_config = BaseQuantizeConfig(\n", + " bits=4, # quantize model to 4-bit\n", + " group_size=128, # it is recommended to set the value to 128\n", + " desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad\n", + ")\n", + "\n", + "examples = [tokenizer(sample, return_tensors=\"pt\").to(device) for sample in calibration_set]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfe94a01-2277-4e8c-af2c-915ab1c08093", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# load un-quantized model, by default, the model will always be loaded into CPU memory\n", + "model = AutoGPTQForCausalLM.from_pretrained(\n", + " model_name,\n", + " quantize_config,\n", + " local_files_only=True,\n", + " low_cpu_mem_usage=True,\n", + ")\n", + "model.to(device)\n", + "\n", + "# quantize model, the examples should be list of dict whose keys can only be \"input_ids\" and \"attention_mask\"\n", + "model.quantize(examples)" + ] + }, + { + "cell_type": "markdown", + "id": "bfdbef3c-9c88-4c9d-b257-a46113f4a35b", + "metadata": {}, + "source": [ + "### Save quantized model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dbb99f45-e6d7-417c-a0f0-afa299653085", + "metadata": {}, + "outputs": [], + "source": [ + "# save quantized model using safetensors\n", + "model.save_quantized(quantized_model_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "d32c4214-f315-4faf-803c-d0fa51d95f96", + "metadata": {}, + "source": [ + "### Check how quantized model generates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e896be99-b7ad-4b0d-9e88-6814f55cf17c", + "metadata": {}, + "outputs": [], + "source": [ + "# load quantized model to the first GPU\n", + "model = AutoGPTQForCausalLM.from_quantized(\n", + " quantized_model_dir,\n", + " low_cpu_mem_usage=True,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1cb69e00-3633-462a-9d4c-79b35fd20c49", + "metadata": {}, + "source": [ + "What size we should expect before and after quantization?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b1a4b46-e516-476a-9904-365476a91a1b", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Before quantization: {model_size_mb:.0f} Mb\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "000b0cbd-2f1e-4a24-9156-4735a4a01260", + "metadata": {}, + "outputs": [], + "source": [ + "model_size_mb = compute_model_size_mb(model)\n", + "print(f\"After quantization: {model_size_mb:.0f} Mb\")" + ] + }, + { + "cell_type": "markdown", + "id": "646df63b-6bcb-42f5-87e3-768fd7d893ca", + "metadata": {}, + "source": [ + "Quantized model has more than x3 smaller memory footprint. You can almost run it on a toaster now." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c28f3b0-d6cb-45be-8d72-4ea840b1b124", + "metadata": {}, + "outputs": [], + "source": [ + "start = time.perf_counter()\n", + "answers = [generate(model, tokenizer, prompt, max_length) for prompt in tqdm(prompts)]\n", + "generation_time = time.perf_counter() - start" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01e84f2f-8649-42b9-856f-c97ca72559f7", + "metadata": {}, + "outputs": [], + "source": [ + "print(answers[4])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "112dadd4-a894-4b93-9557-305dc5ff0a8c", + "metadata": {}, + "outputs": [], + "source": [ + "n_generated_tokens_total = sum([len(answer) - len(prompt)\n", + " for answer, prompt in zip(tokenizer(answers).input_ids, tokenizer(prompts).input_ids)])\n", + "n_generated_tokens_total" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5389206-cad5-4180-b4f5-822b38ac1fcb", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Generation speed: {n_generated_tokens_total / generation_time:.1f} tokens/sec\")" + ] + }, + { + "cell_type": "markdown", + "id": "1303dfe9", + "metadata": {}, + "source": [ + "Having compressed the model, we might have hoped for speedup. However, memory transfers are not the only bottleneck, and there might be some inefficiencies in implementation, which slow us down.\n", + "\n", + "GPTq still can noticeably drive the memory footprint down, and this is often vital when you work on a small GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90ed9a0e-8357-4019-8126-e9f7dee6640d", + "metadata": {}, + "outputs": [], + "source": [ + "mbu = compute_memory_bandwidth_utilization(\n", + " model_size_mb + kv_cache_size_mb,\n", + " a100_max_bandwidth_mb,\n", + " generation_time / n_generated_tokens_total\n", + ")\n", + "\n", + "print(f\"Memory Bandwidth Utilization is {mbu * 100:.2f} %\") " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcafd052-1e88-48c6-9f68-34ba5cf5146c", + "metadata": {}, + "outputs": [], + "source": [ + "del model, examples, tokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "a2c32ea6-e2d3-4791-a307-7333c1ef12e4", + "metadata": {}, + "source": [ + "## 4: Weight & Activation Quantization (SmoothQuant)\n", + "\n", + "Weight-only quantization helps to improve Memory Bandwidth Utilization. Therefore, it primarily provides speedups at low batch sizes and for autoregressive generation tasks.\n", + "\n", + "To make models faster when you have large batch sizes or don't have to autoregressively generate responces, you can use weight and activation quantization.\n", + "\n", + "By converting weights and activations e.g. from fp16 to int8, we can utilize efficient `GEMM` and `BMM` kernels and theoretically double the throughput.\n", + "\n", + "Current part is a copy of this [example](https://github.com/mit-han-lab/smoothquant/blob/main/examples/smoothquant_llama_demo.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed673b79-3ef3-412e-8e35-9d417dde49cc", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9e81c07-8434-40bb-a811-1b086803222e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from transformers.models.llama.modeling_llama import (\n", + " LlamaAttention,\n", + " LlamaDecoderLayer,\n", + " LlamaForCausalLM,\n", + " LlamaMLP,\n", + ")\n", + "from transformers import LlamaTokenizer\n", + "import smoothquant\n", + "from smoothquant.smooth import smooth_lm\n", + "from smoothquant.fake_quant import quantize_llama_like\n", + "import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "b15fed9a-01c5-4e8d-aaa4-eb13d7954a4a", + "metadata": {}, + "source": [ + "> The following is an evaluator to see the performance of the model. We use a toy dataset (the first 40 examples in the test set of the Wikitext-2 dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "486f004a-d58f-4156-a480-fc0224ce48be", + "metadata": {}, + "outputs": [], + "source": [ + "class Evaluator:\n", + " def __init__(self, dataset, tokenizer, device, n_samples=40):\n", + " self.dataset = dataset\n", + " self.tokenizer = tokenizer\n", + " self.device = device\n", + "\n", + " self.dataset = tokenizer(\n", + " \"\\n\\n\".join(dataset[\"text\"]), return_tensors=\"pt\"\n", + " ).input_ids.to(device)\n", + "\n", + " self.n_samples = n_samples\n", + "\n", + " @torch.no_grad()\n", + " def evaluate(self, model):\n", + " model.eval()\n", + " nlls = []\n", + " for i in tqdm.tqdm(range(self.n_samples), desc=\"Evaluating...\"):\n", + " batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)\n", + " with torch.no_grad():\n", + " lm_logits = model(batch).logits\n", + " shift_logits = lm_logits[:, :-1, :].contiguous().float()\n", + " shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]\n", + " loss_fct = nn.CrossEntropyLoss()\n", + " loss = loss_fct(\n", + " shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)\n", + " )\n", + " neg_log_likelihood = loss.float() * 2048\n", + " nlls.append(neg_log_likelihood)\n", + "\n", + " return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dbc8b9c-18ff-452e-a183-58ee02ad27a4", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "device = \"cuda:0\"\n", + "\n", + "tokenizer = LlamaTokenizer.from_pretrained(model_name)\n", + "dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')\n", + "evaluator = Evaluator(dataset, tokenizer, device)" + ] + }, + { + "cell_type": "markdown", + "id": "45150888-e55d-4805-b475-31edd0dac0b5", + "metadata": {}, + "source": [ + "**FP16 Model Perplexity**\n", + "\n", + "> Let's first check the performance of the original FP16 model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0eeae86a-924f-4d93-82c5-ee1edf25c1dd", + "metadata": {}, + "outputs": [], + "source": [ + "model_fp16 = LlamaForCausalLM.from_pretrained(\n", + " model_name, torch_dtype=torch.float16, device_map=\"auto\", local_files_only=True, low_cpu_mem_usage=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68cc2761-c36d-455f-bd6c-c24a28e690c6", + "metadata": {}, + "outputs": [], + "source": [ + "ppl_fp16 = evaluator.evaluate(model_fp16)\n", + "print(f\"Original model (fp16) perplexity: {ppl_fp16}\")" + ] + }, + { + "cell_type": "markdown", + "id": "781760d8-ce52-4a40-8d90-9ee90be3c92a", + "metadata": {}, + "source": [ + "> We then quantize the model to W8A8 and check the performance." + ] + }, + { + "cell_type": "markdown", + "id": "473e2d45-440d-4d1a-8c20-2d41e1f43005", + "metadata": {}, + "source": [ + "**Naive W8A8 Quantized Model Perplexity**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "299cc074-4272-4a42-9325-c53e79d5778a", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "model_w8a8 = quantize_llama_like(model_fp16)\n", + "print(model_w8a8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0d7097a-8b59-468e-992b-a7a17319e3d7", + "metadata": {}, + "outputs": [], + "source": [ + "ppl_w8a8 = evaluator.evaluate(model_w8a8)\n", + "print(f\"Naive W8A8 quantized model perplexity: {ppl_w8a8}\")" + ] + }, + { + "cell_type": "markdown", + "id": "65e771b0-8422-4f62-b970-db526103ae52", + "metadata": {}, + "source": [ + "> We can see there is a perplexity increase. We then use SmoothQuant to quantize the model and check the performance." + ] + }, + { + "cell_type": "markdown", + "id": "95a244da-e0a4-4000-9f15-88f5d5f20d76", + "metadata": {}, + "source": [ + "**SmoothQuant W8A8 Quantized Model Perplexity**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00e31a64-825a-4e03-a022-b0e3affb97d4", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# We have to load corresponding activation scales:\n", + "#!wget https://huggingface.co/mit-han-lab/smoothquant-scales/resolve/main/llama-2-7b.pt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a31717f-0383-48aa-b368-fc3df010eef5", + "metadata": {}, + "outputs": [], + "source": [ + "model = LlamaForCausalLM.from_pretrained(\n", + " model_name, torch_dtype=torch.float16, device_map=\"auto\"\n", + ")\n", + "act_scales = torch.load(\"llama-2-7b.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3681c508-9144-40e2-8e4a-928cc9479fa3", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "smooth_lm(model, act_scales, 0.85)\n", + "model_smoothquant_w8a8 = quantize_llama_like(model)\n", + "print(model_smoothquant_w8a8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b7b30c2-cd34-43ae-8e03-bd706d55146d", + "metadata": {}, + "outputs": [], + "source": [ + "ppl_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)\n", + "print(f\"SmoothQuant W8A8 quantized model perplexity: {ppl_smoothquant_w8a8}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b4730f11-c8fb-4649-9201-ed6b5e721d1e", + "metadata": {}, + "source": [ + "> We can see the smoothed model has a lower perplexity which is close to the FP16 model's. This is because SmoothQuant smooths the outliers in activations and balances the quantization difficulty of activations and weights." + ] + }, + { + "cell_type": "markdown", + "id": "6a38ccb4-f051-437b-9bc6-cbb9b053ce20", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "- Data-free quantization methods are very fast, and you can often gridsearch optimal quantization hyperparameters on your laptop.\n", + "- Weight-only quantization methods mainly address memory bottlenecks (which mostly occur at low batch sizes).\n", + "- Weight & Activation quantization methods can deal with both memory and computation bottlenecks, achieving speedups e.g. due to using efficient int8 matrix multiplication kernels, but might have slightly inferior quality compared to weight-only methods.\n", + "- Also, the points above are actually too general, there is no silver bullet and the only method to know whether a quantization method fits your application is to actually try it." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}