

***Project_Title :- Quantization-Aware Training for LLM Pruning***

In [3]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

The QuantizationLayer class implements a dynamic quantization mechanism that adjusts the precision of neurons based on their energy levels. Neurons with energy above a specified threshold are assigned a higher precision (high_precision), while those below the threshold are assigned a lower precision (base_precision). During the forward pass, each neuron's precision is determined dynamically, and its activations are quantized using a scaling factor derived from the assigned precision. This selective quantization ensures that critical neurons retain finer-grained representations, while less active neurons are compressed, optimizing both accuracy and resource utilization.

This aligns with the broader concepts of Adaptive Precision Heterogeneity (APH) and Thermal-Analog Quantization (TAQ). APH dynamically assigns precision across layers and operations based on their performance sensitivity, with critical components like attention layers receiving higher precision and less sensitive layers using lower precision. TAQ, inspired by thermal states, extends this idea to neurons, assigning precision based on their "energy" states. High-energy neurons critical for specific tasks retain higher precision, while low-energy neurons are quantized more coarsely, mimicking energy efficiency principles from physics. Together, these concepts enhance model efficiency while preserving critical task performance

In [4]:
class QuantizationLayer(nn.Module):
    def __init__(self, base_precision, high_precision, threshold):
        super(QuantizationLayer, self).__init__()
        self.base_precision = base_precision
        self.high_precision = high_precision
        self.threshold = threshold

    def quantize(self, x, precision):
        scale_factor = 2 ** (precision - 1)
        return torch.round(x * scale_factor) / scale_factor

    def forward(self, x, neuron_energy):
        quantized_output = torch.zeros_like(x)
        for i in range(x.size(1)):
            precision = self.high_precision if neuron_energy[i] > self.threshold else self.base_precision
            quantized_output[:, i] = self.quantize(x[:, i], precision)
        return quantized_output

Here, the logic used by us in the PruningLayer is based on an energy threshold mechanism, where neurons with low energy values are deactivated to optimize model efficiency. Neuron energy, typically derived from gradient-based metrics such as the mean absolute gradient, is compared against a predefined pruning threshold. Neurons with energy below this threshold are pruned by setting their corresponding values in the input tensor to zero, effectively deactivating them while preserving the tensor's structure. This dynamic pruning process selectively retains only the most significant neurons, reducing computational overhead and potentially enhancing generalization by removing redundant or less impactful neurons during runtime.

In [5]:
class PruningLayer(nn.Module):
    def __init__(self, pruning_threshold):
        super(PruningLayer, self).__init__()
        self.pruning_threshold = pruning_threshold

    def forward(self, x, neuron_energy):
        """
        Prunes neurons based on energy: if the neuron energy is below the threshold, set the neuron to zero.
        """
        # Prune neurons (set to zero) based on neuron energy
        pruned_output = x.clone()
        for i in range(x.size(1)):
            if neuron_energy[i] < self.pruning_threshold:
                pruned_output[:, i] = 0
        return pruned_output


HEGQ combines layer-level adaptability with neuron-specific precision scaling. First, Adaptive Precision Heterogeneity (APH) is applied to set a base precision level for each layer according to its role and sensitivity. Then, within each layer, Thermal-Analog Quantization (TAQ) adjusts the precision further based on the "energy" or activation level of individual neurons during specific tasks. This approach creates a multi-tiered precision framework that maximizes efficiency by allocating computational resources dynamically, optimizing for both critical and non-critical layers and neurons.
HEGQ leverages both coarse-grained (layer) and fine-grained (neuron) quantization, ensuring that high-sensitivity layers and neurons retain high precision, while less critical areas use lower precision, conserving resources without compromising performance

The HEGQModel employs a structured pruning methodology focused on energy-based pruning to enhance computational efficiency while preserving model performance. The model evaluates neuron importance dynamically using gradient-based energy calculations, where neuron energy is determined as the mean absolute value of gradients across samples. Neurons with energy values below a configurable pruning threshold are pruned by setting their activations to zero, effectively removing them from subsequent computations. .

In [6]:

class HEGQModel(nn.Module):
    def __init__(self, model_name, num_classes, base_precision=4, high_precision=8, pruning_threshold=0.1):
        super(HEGQModel, self).__init__()
        self.base_model = AutoModel.from_pretrained(model_name)
        self.num_classes = num_classes
        
        # Number of layers in the transformer model
        num_layers = self.base_model.config.num_hidden_layers

        # Dynamic precision levels and thresholds for quantization
        layer_precisions = [min(base_precision + i, high_precision) for i in range(num_layers)]
        neuron_energy_thresholds = [0.2 + (0.6 * i / (num_layers - 1)) for i in range(num_layers)]

        # Quantization and Pruning layers
        self.quant_layers = nn.ModuleList([
            QuantizationLayer(
                base_precision=base_precision,
                high_precision=layer_precisions[i],
                threshold=neuron_energy_thresholds[i]
            )
            for i in range(num_layers)
        ])
        
        self.pruning_layer = PruningLayer(pruning_threshold)

        # Classification layer
        self.classifier = nn.Linear(self.base_model.config.hidden_size, num_classes)

    def forward(self, x, labels=None):
        # Forward pass through the base model
        outputs = self.base_model(**x)
        last_hidden_state = outputs.last_hidden_state

        # Calculate classification logits
        logits = self.classifier(last_hidden_state[:, 0, :])

        # Compute gradients through the classification loss if labels are provided
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            # Calculate gradients manually
            grads = torch.autograd.grad(loss, last_hidden_state, retain_graph=True)[0]

            # Quantize based on neuron energy and prune neurons
            for i, quant_layer in enumerate(self.quant_layers):
                quantized_output = torch.zeros_like(last_hidden_state)
                for batch_idx in range(last_hidden_state.size(0)):
                    neuron_energy = grads[batch_idx].abs().mean(dim=0)
                    # Apply quantization
                    quantized_output[batch_idx] = quant_layer(last_hidden_state[batch_idx], neuron_energy)
                last_hidden_state = quantized_output

            # Apply pruning
            pruned_output = self.pruning_layer(last_hidden_state, neuron_energy)
            last_hidden_state = pruned_output

            return loss, logits

        return logits

This code demonstrates a classification setup using the HEGQModel applied to the Falcon-7B model. It begins by initializing the tokenizer and setting up the HEGQModel with 10 output classes. If the tokenizer lacks a padding token, the end-of-sequence token is assigned as the padding token to ensure compatibility. The input text is tokenized into a format suitable for model input, with padding and truncation enabled to maintain consistency across different input lengths. A sample label (1) is provided to simulate a classification task. The tokenized inputs and label are passed through the HEGQModel in a forward pass, where the model computes the classification loss and outputs logits for the input text. Finally, the computed loss and logits are printed to evaluate the model's performance for the given input. This workflow highlights the model's ability to handle text classification tasks dynamically using quantization and pruning strategies

In [7]:
model_name = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HEGQModel(model_name, num_classes=10)

# Assign a padding token if it's not already set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Prepare tokenized input
input_text = "Sample text for classification."
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

# Test with tokenized input and label
labels = torch.tensor([1])  # Example label for classification

# Run forward pass
loss, logits = model(inputs, labels=labels)
print("Loss:", loss.item())
print("Logits:", logits)

tokenizer_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.73M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/281 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/17.7k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.48G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loss: 2.196079730987549
Logits: tensor([[ 1.0465,  1.5487,  1.4816,  3.1135, -0.9330, -1.4105,  1.9201, -2.2978,
         -1.4990, -2.7751]], grad_fn=<AddmmBackward0>)
