In [1]:
import time
import math
import random
import torch
import torch.nn as nn
import transformers
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

# ====================================================
# Constants and Configuration
# ====================================================

# Model and dataset configuration
MODEL_NAME = "facebook/opt-1.3b"  # Name of the OPT model to load
DATASET_NAME = "wikitext2"  # Calibration dataset ("wikitext2" or "ptb-new")

# Quantization parameters
SEED = 0  # Random seed for reproducibility
NUM_SAMPLES = 128  # Number of calibration samples
PERCENT_DAMPENING = 0.01  # Dampening percentage during quantization
BITS = 4  # Number of bits for quantization
GROUP_SIZE = -1  # Group size for quantization (-1 means no grouping)
USE_SYMMETRIC = True  # Use symmetric quantization
USE_ACT_ORDER = False  # Use activation order during quantization
USE_STATIC_GROUPS = False  # Use static groups during quantization

# Device configuration
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Debugging flag
DEBUG = False  # Set to True for debugging output

# Disable TensorFloat32 for matmul and cuDNN to ensure deterministic results
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# ====================================================
# Quantization Functions and Classes
# ====================================================

def quantize_tensor(tensor, scale, zero_point, max_quant):
    """
    Quantize the input tensor using the provided scale and zero point.
    """
    if max_quant < 0:
        # Special case for ternary quantization
        return (tensor > scale / 2).float() * scale + (tensor < zero_point / 2).float() * zero_point
    quantized = torch.clamp(torch.round(tensor / scale) + zero_point, 0, max_quant)
    return scale * (quantized - zero_point)

class Quantizer(nn.Module):
    """
    Quantizer class to handle quantization parameters and operations.
    """
    def __init__(self, shape=1):
        super(Quantizer, self).__init__()
        self.register_buffer('max_quant', torch.tensor(0))  # Maximum quantization level
        self.register_buffer('scale', torch.zeros(shape))  # Scale for quantization
        self.register_buffer('zero_point', torch.zeros(shape))  # Zero point for quantization

    def configure(self, bits, per_channel=False, symmetric=True, use_mse=False,
                  error_norm=2.4, grid_size=100, max_shrink=0.8, use_ternary=False):
        """
        Configure the quantizer with the specified parameters.
        """
        self.max_quant = torch.tensor(2 ** bits - 1)
        self.per_channel = per_channel  # Whether to quantize per channel
        self.symmetric = symmetric  # Symmetric quantization
        self.use_mse = use_mse  # Use MSE to find optimal scale and zero point
        self.error_norm = error_norm  # Norm to compute quantization error
        self.grid_size = grid_size  # Grid size for scale search
        self.max_shrink = max_shrink  # Maximum shrinkage of quantization range during search
        if use_ternary:
            self.max_quant = torch.tensor(-1)  # Special value for ternary quantization

    def find_params(self, tensor, is_weight=False):
        """
        Find the scale and zero point parameters for quantization based on the input tensor.
        """
        device = tensor.device
        self.max_quant = self.max_quant.to(device)

        shape = tensor.shape
        if self.per_channel:
            # Per-channel quantization
            if is_weight:
                tensor = tensor.flatten(1)
            else:
                if len(shape) == 4:
                    tensor = tensor.permute(1, 0, 2, 3).flatten(1)
                elif len(shape) == 3:
                    tensor = tensor.reshape(-1, shape[-1]).t()
                elif len(shape) == 2:
                    tensor = tensor.t()
        else:
            # Global quantization
            tensor = tensor.flatten().unsqueeze(0)

        zeros = torch.zeros(tensor.shape[0], device=device)
        tensor_min = torch.minimum(tensor.min(dim=1)[0], zeros)
        tensor_max = torch.maximum(tensor.max(dim=1)[0], zeros)

        if self.symmetric:
            tensor_max = torch.maximum(torch.abs(tensor_min), tensor_max)
            negative_mask = tensor_min < 0
            if torch.any(negative_mask):
                tensor_min[negative_mask] = -tensor_max[negative_mask]
        zero_mask = (tensor_min == 0) & (tensor_max == 0)
        tensor_min[zero_mask] = -1
        tensor_max[zero_mask] = 1

        if self.max_quant < 0:
            # Special case for ternary quantization
            self.scale = tensor_max
            self.zero_point = tensor_min
        else:
            self.scale = (tensor_max - tensor_min) / self.max_quant
            if self.symmetric:
                self.zero_point = torch.full_like(self.scale, (self.max_quant + 1) / 2)
            else:
                self.zero_point = torch.round(-tensor_min / self.scale)

        if self.use_mse:
            # Use Mean Squared Error to find optimal scale and zero point
            best_error = torch.full([tensor.shape[0]], float('inf'), device=device)
            for i in range(int(self.max_shrink * self.grid_size)):
                p = 1 - i / self.grid_size
                tensor_min1 = p * tensor_min
                tensor_max1 = p * tensor_max
                scale1 = (tensor_max1 - tensor_min1) / self.max_quant
                if not self.symmetric:
                    zero_point1 = torch.round(-tensor_min1 / scale1)
                else:
                    zero_point1 = self.zero_point
                q = quantize_tensor(tensor, scale1.unsqueeze(1), zero_point1.unsqueeze(1), self.max_quant)
                error = ((q - tensor).abs().pow(self.error_norm)).sum(dim=1)
                better_error_mask = error < best_error
                if torch.any(better_error_mask):
                    best_error[better_error_mask] = error[better_error_mask]
                    self.scale[better_error_mask] = scale1[better_error_mask]
                    self.zero_point[better_error_mask] = zero_point1[better_error_mask]

        if not self.per_channel:
            repeat_times = shape[0] if is_weight else shape[1] if len(shape) != 3 else shape[2]
            self.scale = self.scale.repeat(repeat_times)
            self.zero_point = self.zero_point.repeat(repeat_times)

        if is_weight:
            # Reshape for weight tensors
            new_shape = [-1] + [1] * (len(shape) - 1)
            self.scale = self.scale.reshape(new_shape)
            self.zero_point = self.zero_point.reshape(new_shape)
            return

        # Reshape for activation tensors
        if len(shape) == 4:
            self.scale = self.scale.reshape(1, -1, 1, 1)
            self.zero_point = self.zero_point.reshape(1, -1, 1, 1)
        elif len(shape) == 3:
            self.scale = self.scale.reshape(1, 1, -1)
            self.zero_point = self.zero_point.reshape(1, 1, -1)
        elif len(shape) == 2:
            self.scale = self.scale.unsqueeze(0)
            self.zero_point = self.zero_point.unsqueeze(0)

    def quantize(self, tensor):
        """
        Quantize the input tensor using the stored scale and zero point.
        """
        if self.ready():
            return quantize_tensor(tensor, self.scale, self.zero_point, self.max_quant)
        return tensor

    def enabled(self):
        """
        Check if quantization is enabled (max_quant > 0).
        """
        return self.max_quant > 0

    def ready(self):
        """
        Check if the quantizer is ready (scale is non-zero).
        """
        return torch.all(self.scale != 0)

# ====================================================
# GPTQ Quantization Class
# ====================================================

class GPTQuantizer:
    """
    GPTQuantizer class to perform quantization of a given model layer using GPTQ algorithm.
    """
    def __init__(self, layer):
        self.layer = layer
        self.device = self.layer.weight.device
        weight_data = layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            weight_data = weight_data.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            weight_data = weight_data.t()
        self.num_rows, self.num_columns = weight_data.shape
        self.hessian_inv = torch.zeros((self.num_columns, self.num_columns), device=self.device)
        self.num_samples = 0  # Number of samples collected
        self.quantizer = Quantizer()

    def add_batch(self, inp, out):
        """
        Add a batch of input and output data to compute the Hessian approximation.
        """
        if DEBUG:
            self.debug_input = inp
            self.debug_output = out
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        batch_size = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, transformers.Conv1D)):
            if len(inp.shape) == 3:
                inp = inp.reshape(-1, inp.shape[-1])
            inp = inp.t()
        elif isinstance(self.layer, nn.Conv2d):
            unfold = nn.Unfold(
                kernel_size=self.layer.kernel_size,
                dilation=self.layer.dilation,
                padding=self.layer.padding,
                stride=self.layer.stride
            )
            inp = unfold(inp)
            inp = inp.permute(1, 0, 2).flatten(1)
        # Update Hessian approximation
        self.hessian_inv *= self.num_samples / (self.num_samples + batch_size)
        self.num_samples += batch_size
        inp = math.sqrt(2 / self.num_samples) * inp.float()
        self.hessian_inv += inp @ inp.t()

    def perform_quantization(self, block_size=128, percent_dampening=0.01, group_size=-1,
                             use_activation_order=False, use_static_groups=False):
        """
        Perform quantization using the collected data and Hessian approximation.
        """
        weight_data = self.layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            weight_data = weight_data.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            weight_data = weight_data.t()
        weight_data = weight_data.float()

        start_time = time.time()

        if not self.quantizer.ready():
            self.quantizer.find_params(weight_data, is_weight=True)

        H = self.hessian_inv
        del self.hessian_inv
        zero_diagonal = torch.diag(H) == 0
        H[zero_diagonal, zero_diagonal] = 1
        weight_data[:, zero_diagonal] = 0

        if use_static_groups:
            import copy
            groups = []
            for i in range(0, self.num_columns, group_size):
                quantizer_copy = copy.deepcopy(self.quantizer)
                quantizer_copy.find_params(weight_data[:, i:i+group_size], is_weight=True)
                groups.append(quantizer_copy)

        if use_activation_order:
            perm = torch.argsort(torch.diag(H), descending=True)
            weight_data = weight_data[:, perm]
            H = H[perm][:, perm]
            inv_perm = torch.argsort(perm)

        losses = torch.zeros_like(weight_data)
        quantized_weight = torch.zeros_like(weight_data)

        dampening = percent_dampening * torch.mean(torch.diag(H))
        diag_indices = torch.arange(self.num_columns, device=self.device)
        H[diag_indices, diag_indices] += dampening
        H = torch.linalg.cholesky(H)
        H = torch.cholesky_inverse(H)
        H = torch.linalg.cholesky(H, upper=True)
        H_inv = H

        for start_idx in range(0, self.num_columns, block_size):
            end_idx = min(start_idx + block_size, self.num_columns)
            block_count = end_idx - start_idx

            W_block = weight_data[:, start_idx:end_idx].clone()
            Q_block = torch.zeros_like(W_block)
            Error_block = torch.zeros_like(W_block)
            Losses_block = torch.zeros_like(W_block)
            H_inv_block = H_inv[start_idx:end_idx, start_idx:end_idx]

            for i in range(block_count):
                w = W_block[:, i]
                d = H_inv_block[i, i]

                if group_size != -1:
                    if not use_static_groups:
                        if (start_idx + i) % group_size == 0:
                            self.quantizer.find_params(weight_data[:, (start_idx + i):(start_idx + i + group_size)], is_weight=True)
                    else:
                        idx = start_idx + i
                        if use_activation_order:
                            idx = perm[idx]
                        self.quantizer = groups[idx // group_size]

                q = quantize_tensor(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero_point, self.quantizer.max_quant).flatten()
                Q_block[:, i] = q
                Losses_block[:, i] = ((w - q) ** 2) / (d ** 2) / 2

                err = (w - q) / d
                W_block[:, i:] -= err.unsqueeze(1) @ H_inv_block[i, i:].unsqueeze(0)
                Error_block[:, i] = err

            quantized_weight[:, start_idx:end_idx] = Q_block
            losses[:, start_idx:end_idx] = Losses_block

            weight_data[:, end_idx:] -= Error_block @ H_inv[start_idx:end_idx, end_idx:]

            if DEBUG:
                self.layer.weight.data[:, :end_idx] = quantized_weight[:, :end_idx]
                self.layer.weight.data[:, end_idx:] = weight_data[:, end_idx:]
                print(torch.sum((self.layer(self.debug_input) - self.debug_output) ** 2))
                print(torch.sum(losses))

        torch.cuda.synchronize()
        print(f"Time for quantization: {time.time() - start_time:.2f} seconds")
        print(f"Total quantization error: {torch.sum(losses).item()}")

        if use_activation_order:
            quantized_weight = quantized_weight[:, inv_perm]

        if isinstance(self.layer, transformers.Conv1D):
            quantized_weight = quantized_weight.t()
        self.layer.weight.data = quantized_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
        if DEBUG:
            print(torch.sum((self.layer(self.debug_input) - self.debug_output) ** 2))

    def free(self):
        """
        Free up memory by deleting large variables.
        """
        if DEBUG:
            self.debug_input = None
            self.debug_output = None
        self.hessian_inv = None
        torch.cuda.empty_cache()

# ====================================================
# Data Loader Functions
# ====================================================

def set_random_seed(seed):
    """
    Set the random seed for reproducibility.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

def load_wikitext2(nsamples, seed, sequence_length, model_name):
    """
    Load the WikiText-2 dataset and prepare calibration and test data loaders.
    """
    from datasets import load_dataset

    # Load the dataset
    train_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    tokenizer.pad_token = tokenizer.eos_token

    # Tokenize the data
    train_encodings = tokenizer("\n\n".join(train_dataset["text"]), return_tensors="pt")
    test_encodings = tokenizer("\n\n".join(test_dataset["text"]), return_tensors="pt")

    # Set random seed
    set_random_seed(seed)

    # Prepare the calibration data loader
    calibration_data = []
    for _ in range(nsamples):
        start_idx = random.randint(0, train_encodings.input_ids.shape[1] - sequence_length - 1)
        end_idx = start_idx + sequence_length
        input_ids = train_encodings.input_ids[:, start_idx:end_idx]
        labels = input_ids.clone()
        labels[:, :-1] = -100  # Mask for language modeling loss
        calibration_data.append((input_ids, labels))
    return calibration_data, test_encodings

def load_ptb_new(nsamples, seed, sequence_length, model_name):
    """
    Load the Penn Treebank (PTB) dataset and prepare calibration and test data loaders.
    """
    from datasets import load_dataset

    # Load the dataset
    train_dataset = load_dataset("ptb_text_only", "penn_treebank", split="train")
    test_dataset = load_dataset("ptb_text_only", "penn_treebank", split="test")

    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    tokenizer.pad_token = tokenizer.eos_token

    # Tokenize the data
    train_encodings = tokenizer(" ".join(train_dataset["sentence"]), return_tensors="pt")
    test_encodings = tokenizer(" ".join(test_dataset["sentence"]), return_tensors="pt")

    # Set random seed
    set_random_seed(seed)

    # Prepare the calibration data loader
    calibration_data = []
    for _ in range(nsamples):
        start_idx = random.randint(0, train_encodings.input_ids.shape[1] - sequence_length - 1)
        end_idx = start_idx + sequence_length
        input_ids = train_encodings.input_ids[:, start_idx:end_idx]
        labels = input_ids.clone()
        labels[:, :-1] = -100  # Mask for language modeling loss
        calibration_data.append((input_ids, labels))
    return calibration_data, test_encodings

def get_data_loaders(dataset_name, nsamples, seed, sequence_length, model_name):
    """
    Get the calibration and test data loaders for the specified dataset.
    """
    if dataset_name == "wikitext2":
        return load_wikitext2(nsamples, seed, sequence_length, model_name)
    elif dataset_name == "ptb-new":
        return load_ptb_new(nsamples, seed, sequence_length, model_name)
    else:
        raise ValueError(f"Dataset {dataset_name} not supported.")

# ====================================================
# Model Utility Functions
# ====================================================

def find_target_layers(module, target_layers=[nn.Conv2d, nn.Linear, transformers.Conv1D], name=''):
    """
    Recursively find all layers of specified types in a model.
    Returns a dictionary mapping layer names to layers.
    """
    if type(module) in target_layers:
        return {name: module}
    found_layers = {}
    for child_name, child_module in module.named_children():
        child_layers = find_target_layers(
            child_module, target_layers=target_layers,
            name=f"{name}.{child_name}" if name else child_name
        )
        found_layers.update(child_layers)
    return found_layers

def load_opt_model(model_name):
    """
    Load and prepare the OPT model for quantization.
    """
    # Disable weight initialization to speed up model loading
    def skip_weight_init(*args, **kwargs):
        pass

    torch.nn.init.kaiming_uniform_ = skip_weight_init
    torch.nn.init.uniform_ = skip_weight_init
    torch.nn.init.normal_ = skip_weight_init

    # Load the model
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype='auto')
    model.sequence_length = model.config.max_position_embeddings
    return model

# ====================================================
# Quantization and Evaluation Functions
# ====================================================

@torch.no_grad()
def quantize_model_sequentially(model, calibration_data, device):
    """
    Perform quantization on the OPT model layers sequentially.
    The quantized model is modified in-place.
    """
    print("Starting quantization...")

    # Disable cache to prevent unexpected behaviors
    use_cache = model.config.use_cache
    model.config.use_cache = False

    layers = model.model.decoder.layers

    # Move embedding layers to the specified device
    model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
    model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
    
    # Prepare for quantization
    dtype = next(iter(model.parameters())).dtype
    inputs = torch.zeros((NUM_SAMPLES, model.sequence_length, model.config.hidden_size), dtype=dtype, device=device)
    cache = {"index": 0, "attention_mask": None}
    
    # Define a module to capture intermediate inputs
    class InputRecorder(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inputs[cache["index"]] = inp
            cache["index"] += 1
            cache["attention_mask"] = kwargs.get("attention_mask")
            raise ValueError  # Interrupt after capturing input

    # Replace the first layer with the recorder to collect inputs
    layers[0] = InputRecorder(layers[0])
    for batch, _ in calibration_data:
        try:
            batch = batch.to(device)
            model(batch)
        except ValueError:
            pass  # Catch the exception to stop after capturing input
    layers[0] = layers[0].module  # Restore the original layer

    # Move layers back to CPU to free up GPU memory
    layers[0] = layers[0].cpu()
    model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
    model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
    torch.cuda.empty_cache()

    outputs = torch.zeros_like(inputs)
    attention_mask = cache["attention_mask"]

    print("Quantizing layers...")
    for layer_index in range(len(layers)):
        layer = layers[layer_index].to(device)
        
        # Prepare for quantization
        target_layers = find_target_layers(layer)
        quantizers = {}
        for name in target_layers:
            quantizers[name] = GPTQuantizer(target_layers[name])
            quantizers[name].quantizer.configure(
                bits=BITS, per_channel=True, symmetric=USE_SYMMETRIC, use_mse=False
            )

        # Define a hook to collect data for quantization
        def collect_activations(name):
            def hook(module, inp, out):
                quantizers[name].add_batch(inp[0].data, out.data)
            return hook

        # Register hooks
        hooks = []
        for name in target_layers:
            hooks.append(target_layers[name].register_forward_hook(collect_activations(name)))

        # Run the model to collect data
        for sample_index in range(NUM_SAMPLES):
            outputs[sample_index] = layer(inputs[sample_index].unsqueeze(0), attention_mask=attention_mask)[0]

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Perform quantization
        for name in target_layers:
            print(f"Quantizing layer {layer_index}, module {name}")
            quantizers[name].perform_quantization(
                percent_dampening=PERCENT_DAMPENING,
                group_size=GROUP_SIZE,
                use_activation_order=USE_ACT_ORDER,
                use_static_groups=USE_STATIC_GROUPS
            )
            # Apply quantization to the layer
            quantized_weight = quantizers[name].quantizer.quantize(target_layers[name].weight.data)
            target_layers[name].weight.data = quantized_weight.to(dtype)
            quantizers[name].free()

        # Update inputs for the next layer
        for sample_index in range(NUM_SAMPLES):
            outputs[sample_index] = layer(inputs[sample_index].unsqueeze(0), attention_mask=attention_mask)[0]

        # Move the quantized layer back to CPU
        layers[layer_index] = layer.cpu()
        del layer
        del quantizers
        torch.cuda.empty_cache()

        # Swap inputs and outputs for the next iteration
        inputs, outputs = outputs, inputs

    model.config.use_cache = use_cache
    print("Quantization complete.")
    return model  # Return the quantized model

@torch.no_grad()
def evaluate_model(model, test_encodings, device):
    """
    Evaluate the quantized OPT model and calculate perplexity.
    """
    print("Evaluating...")

    model.eval()
    model.to(device)

    input_ids = test_encodings.input_ids
    total_tokens = 0
    total_loss = 0.0

    for i in range(0, input_ids.size(1), model.sequence_length):
        batch = input_ids[:, i:i+model.sequence_length].to(device)
        if batch.size(1) < 2:
            continue  # Skip batches that are too small
        outputs = model(batch)
        logits = outputs.logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()

        loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        total_loss += loss.item()
        total_tokens += (shift_labels != -100).sum().item()

    perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
    print(f"Perplexity: {perplexity.item():.2f}")
    return perplexity.item()

# ====================================================
# Main Execution
# ====================================================

# Set random seed
set_random_seed(SEED)

# Load the model
model = load_opt_model(MODEL_NAME)
model.eval()

# Prepare data loaders
calibration_data, test_encodings = get_data_loaders(
    DATASET_NAME,
    nsamples=NUM_SAMPLES,
    seed=SEED,
    sequence_length=model.sequence_length,
    model_name=MODEL_NAME,
)

# Perform quantization if required
if BITS < 16:
    start_time = time.time()
    model = quantize_model_sequentially(model, calibration_data, DEVICE)
    print(f"Quantization time: {time.time() - start_time:.2f} seconds")

# Evaluate the model
datasets = ["wikitext2", "ptb-new"]
for dataset in datasets:
    print(f"Evaluating on {dataset} dataset...")
    _, test_encodings = get_data_loaders(
        dataset, nsamples=NUM_SAMPLES, seed=SEED, sequence_length=model.sequence_length, model_name=MODEL_NAME
    )
    perplexity = evaluate_model(model, test_encodings, DEVICE)
    print(f"Perplexity on {dataset}: {perplexity:.2f}")

# Save the quantized model
model.save_pretrained("quantized_opt_model")
print("Quantized model saved.")



Starting quantization...
Quantizing layers...
Quantizing layer 0, module self_attn.k_proj
Time for quantization: 0.31 seconds
Total quantization error: 22786.833984375
Quantizing layer 0, module self_attn.v_proj
Time for quantization: 0.24 seconds
Total quantization error: 2735.7451171875
Quantizing layer 0, module self_attn.q_proj
Time for quantization: 0.23 seconds
Total quantization error: 20060.662109375
Quantizing layer 0, module self_attn.out_proj
Time for quantization: 0.21 seconds
Total quantization error: 19.52503204345703
Quantizing layer 0, module fc1
Time for quantization: 0.22 seconds
Total quantization error: 15380.837890625
Quantizing layer 0, module fc2
Time for quantization: 0.98 seconds
Total quantization error: 290.2220764160156
Quantizing layer 1, module self_attn.k_proj
Time for quantization: 0.25 seconds
Total quantization error: 8428.318359375
Quantizing layer 1, module self_attn.v_proj
Time for quantization: 0.24 seconds
Total quantization error: 892.70837402343

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Evaluating...
Perplexity: 22.39
Perplexity on ptb-new: 22.39
Quantized model saved.
