In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch

In [2]:
import numpy as np
from copy import deepcopy

In [3]:
cache_directory = '/scratch/user/nehajm'
api_token = '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'

In [4]:
model_llama_name = 'meta-llama/Llama-2-7b-chat-hf'  #'gpt2'
# model_gpt_name = 'gpt'

In [19]:
model = AutoModelForCausalLM.from_pretrained(model_llama_name, cache_dir=cache_directory,
  use_auth_token=api_token, device_map='auto')




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [14]:
tokenizer = AutoTokenizer.from_pretrained(model_llama_name, cache_dir=cache_directory, use_auth_token=api_token)



In [20]:
#original model weights
weights = [param.data.clone() for param in model.parameters()]

In [6]:
data_files = {"validation": "en/c4-validation.*.json.gz"}
c4_validation = load_dataset("allenai/c4", data_files=data_files, split="validation", cache_dir=cache_directory)

Found cached dataset json (/scratch/user/nehajm/allenai___json/allenai--c4-181ebbe6122ca37f/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


In [7]:
validation_data = c4_validation['text']

In [8]:
MAX_TOTAL_TOKENS =128

ABSMAX QUANTIZATION - for symmetric distributions in tensor values (values ranging from -ve to +ve)

In [21]:
def absmax_quantize(X):
    # Calculate scale
    scale = 127 / torch.max(torch.abs(X))

    # Quantize
    X_quant = (scale * X).round()    #gives you a number in the range of [-127,127]

    # Dequantize
    X_dequant = X_quant / scale

    return X_quant.to(torch.int8), X_dequant

In [22]:
model_abs = AutoModelForCausalLM.from_pretrained(model_llama_name, cache_dir=cache_directory,
  use_auth_token=api_token, device_map='auto')

weights_abs = []
for param in model_abs.parameters():
    _, dequantized = absmax_quantize(param.data)
    param.data = dequantized
    weights_abs.append(dequantized)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

ZEROPOINT QUANTIZATION  - for asymmetric input distribution, like output of a ReLU function 

In [20]:
def zeropoint_quantize(X):
    # Calculate value range (denominator)
    x_range = torch.max(X) - torch.min(X)
    x_range = 1 if x_range == 0 else x_range

    # Calculate scale
    scale = 255 / x_range

    # Shift by zero-point  (to map it to the range of [-128,127])
    zeropoint = (-scale * torch.min(X) - 128).round()

    # Quantize (scale and zeropoint measures are used for the below 2 steps of quantize (8 bit version) and dequantize (get the original tensor value back) )
    # Scale and round the inputs
    X_quant = torch.clip((X * scale + zeropoint).round(), -128, 127)

    # Dequantize
    X_dequant = (X_quant - zeropoint) / scale

    return X_quant.to(torch.int8), X_dequant

In [21]:
# Create model to quantize
model_zp = deepcopy(model)

# Quantize all model weights
weights_zp = []
for param in model_zp.parameters():
    _, dequantized = zeropoint_quantize(param.data)
    param.data = dequantized
    weights_zp.append(dequantized)

INT-8 QUANTIZATION

In [5]:
model_int8 = AutoModelForCausalLM.from_pretrained(model_llama_name,
                                                  cache_dir=cache_directory,
                                             device_map='auto',
                                             load_in_8bit=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

INT-4 QUANTIZATION

In [12]:
model_int4 = AutoModelForCausalLM.from_pretrained(model_llama_name,
                                                  cache_dir=cache_directory,
                                             device_map='auto',
                                             load_in_4bit=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Generating TEXT with the original and quantized models 

In [23]:
def generate_text(model, input_text, max_length=50):
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    output = model.generate(inputs=input_ids,
                            max_length=max_length,
                            do_sample=True,
                            top_k=30,
                            pad_token_id=tokenizer.eos_token_id,
                            attention_mask=input_ids.new_ones(input_ids.shape))
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [25]:
input_text = "explain deep learning"

In [31]:
# Generate text with original and quantized models

original_text = generate_text(model, input_text)
print(f"Original model:\n{original_text}")
print("-" * 50)


Original model:
explain deep learning and its applications

Deep learning (also known as deep structured learning) is part of a broader family of machine learning methods based on artificial neural networks with representation learning
--------------------------------------------------


In [32]:
absmax_text   = generate_text(model_abs, input_text)
print(f"Absmax model:\n{absmax_text}")
print("-" * 50)

Absmax model:
explain deep learning to a 5 year old
 nobody understands deep learning
Deep Learning for Kids
Deep Learning Explained in 5 Minutes
Deep Learning Explained in Simple Terms
Deep Learning
--------------------------------------------------


In [26]:
zp_text = generate_text(model_zp, input_text)
print(f"Zeropoint model:\n{zp_text}")

Zeropoint model:
explain deep learning in simple terms
➖ 1. What is deep learning?
Deep learning is a type of machine learning that uses artificial neural networks to model and solve complex problems. These networks are designed to mimic the structure and


In [29]:
int8_text = generate_text(model_int8, input_text)
print(f"int8 model:\n{int8_text}")

int8 model:
explain deep learning for computer vision

Deep learning for computer vision is a subfield of machine learning that focuses on developing algorithms and models that can be used to analyze and understand visual data from images and videos. The goal of deep learning


In [31]:
int4_text= generate_text(model_int4, input_text)
print(f"int4 model:\n{int4_text}")

int4 model:
explain deep learning in simple terms
 everybody can understand

Deep learning is a type of machine learning that uses artificial neural networks to analyze and learn from data.

Think of a neural network like a map of a city. Each


Generating PERPLEXITY for the Generated Text 

In [27]:
def calculate_perplexity(model, text):
    # Encode the text
    encodings = tokenizer(text, return_tensors='pt')

    # Define input_ids and target_ids
    input_ids = encodings.input_ids
    target_ids = input_ids.clone()

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

    # Loss calculation
    neg_log_likelihood = outputs.loss

    # Perplexity calculation
    ppl = torch.exp(neg_log_likelihood)

    return ppl

In [37]:
#Original model perplexity : 
ppl     = calculate_perplexity(model, original_text)
print(f"Original perplexity:  {ppl.item():.2f}")

Original perplexity:  1.67


In [38]:
#Absmax perplexity 
ppl_abs = calculate_perplexity(model_abs, absmax_text)
print(f"Absmax perplexity:    {ppl_abs.item():.2f}")

Absmax perplexity:    2.65


In [28]:
#Zeropoint perplexity 
ppl_zp = calculate_perplexity(model_zp, zp_text)
print(f"Zeropoint perplexity:    {ppl_zp.item():.2f}")

Absmax perplexity:    2.21


In [30]:
#int8 perplexity
ppl_int8 = calculate_perplexity(model_int8, int8_text)
print(f"Int8 perplexity:    {ppl_int8.item():.2f}")

Int8 perplexity:    1.90


In [32]:
#int4 perplexity
ppl_int4 = calculate_perplexity(model_int4, int4_text)
print(f"Int4 perplexity:    {ppl_int4.item():.2f}")

Int4 perplexity:    2.15


Generate C4 DATASET PERPLEXITIES FOR ALL THE MODELS 

In [10]:
#Generate tokens for the C4 dataset 
def generate_tokens_c4(model):
    # Initialize an empty list to store generated tokens
    generated_tokens = []


    # Set pad_token to eos_token for correct padding
    tokenizer.pad_token = tokenizer.eos_token

    # Define the maximum number of tokens you want to generate
    MAX_TOTAL_TOKENS = 512  # Adjust this value as needed

    # Tokenize the validation data and generate tokens
    for text in validation_data:
    # Calculate the maximum number of tokens to generate for this input
        max_length = min(MAX_TOTAL_TOKENS - len(generated_tokens), MAX_TOTAL_TOKENS)
        input_ids = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        generated_ids = model.generate(input_ids.input_ids, max_length=max_length)
        generated_tokens.extend(tokenizer.decode(generated_ids[0], skip_special_tokens=True).split())
    
        # Check if we've reached the maximum number of tokens
        if len(generated_tokens) >= MAX_TOTAL_TOKENS:
            break

    # Combine generated tokens into a single string
    generated_text = " ".join(generated_tokens)
    return generated_text

In [40]:
#generating tokens for the original model : 
original_gen_text = generate_tokens_c4(model)



In [41]:
#generating tokens for the absmax model : 
absmax_gen_text = generate_tokens_c4(model_abs)



In [33]:
#generating tokens for the absmax model : 
zp_gen_text = generate_tokens_c4(model_zp)



In [35]:
#generating tokens for the int8 model : 
int8_gen_text = generate_tokens_c4(model_int8)



In [36]:
#generating tokens for the int4 model : 
int4_gen_text = generate_tokens_c4(model_int4)



Calculating PERPLEXITIES for the c4 dataset for all the models below 

In [11]:
def calculate_perplexity_from_c4(model, generated_text):
    # Calculate perplexity
    with torch.no_grad():
        # Tokenize the generated text for loss calculation
        generated_input_ids = tokenizer(generated_text, return_tensors="pt").input_ids
        # Calculate the loss using the generated input_ids as labels
        loss = model(generated_input_ids, labels=generated_input_ids).loss
        perplexity = torch.exp(loss)
    return perplexity

In [43]:
#Original model perplexity on the c4 dataset 
orginal_gen_text_ppl = calculate_perplexity_from_c4(model, original_gen_text)
print(f"Original model:  {orginal_gen_text_ppl.item():.2f}")

Original model:  13.67


In [44]:
#Absmax model perplexity on the c4 dataset 
absmax_gen_text_ppl = calculate_perplexity_from_c4(model_abs, absmax_gen_text)
print(f"Absmax model:  {absmax_gen_text_ppl.item():.2f}")

Absmax model:  11.90


In [34]:
#Zeropoint model perplexity on the c4 dataset 
zp_gen_text_ppl = calculate_perplexity_from_c4(model_zp, zp_gen_text)
print(f"Zeropoint model:  {zp_gen_text_ppl.item():.2f}")

Zeropoint model:  12.00


In [37]:
#Int8 model perplexity on the c4 dataset 
int8_gen_text_ppl = calculate_perplexity_from_c4(model_int8, int8_gen_text)
print(f"Int 8 model:  {int8_gen_text_ppl.item():.2f}")

Int 8 model:  12.79


In [38]:
#Int4 model perplexity on the c4 dataset 
int4_gen_text_ppl = calculate_perplexity_from_c4(model_int4, int4_gen_text)
print(f"Int 4 model:  {int4_gen_text_ppl.item():.2f}")

Int 4 model:  12.41
