<a href="https://colab.research.google.com/github/kumarutkarsh99/AdaptivFloat-Implementation/blob/main/script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, ResNet18_Weights
from functools import partial

# Import our custom files
import quantizer
import utils

# Set up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# A simple CNN model designed for MNIST (1 input channel)
class MnistCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1) # 1 input channel
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128) # 64 * 12 * 12 = 9216
        self.fc2 = nn.Linear(128, 10) # 10 output classes for digits 0-9

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [None]:
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transforms for MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Standard MNIST mean/std
    ])

# Download and load the datasets
trainset = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
testset = datasets.MNIST('./data', train=False,
                   transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=1000, shuffle=False)

In [None]:
import copy
from tqdm import tqdm

model_fp32 = MnistCNN().to(device)
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
num_epochs = 10 # More epochs for high accuracy

print("Training a new FP32 baseline model on MNIST...")
model_fp32.train()
for epoch in range(num_epochs):
    for data, target in tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model_fp32(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} complete.")
print("Finished Training.")

In [None]:
from functools import partial

# --- 1. Evaluate our new FP32 baseline ---
# We reuse the `evaluate_resnet` function, but change the printout
fp32_accuracy = utils.evaluate_resnet(model_fp32, testloader, device)

# --- 2. Create a deep copy for quantization ---
model_af_8bit = copy.deepcopy(model_fp32).to(device)

# --- 3. Define and apply 8-bit AdaptivFloat ---
quant_8bit_af_func = partial(quantizer.quantize_to_adaptivfloat,
                             total_bits=8,
                             exponent_bits=3)
quant_8bit_af_func.__name__ = "AdaptivFloat_8bit"
utils.apply_quantization_to_model(model_af_8bit, quant_8bit_af_func)

# --- 4. Evaluate the quantized model ---
af_8bit_accuracy = utils.evaluate_resnet(model_af_8bit, testloader, device)

# --- 5. Print the final, report-ready comparison ---
print("\n--- MNIST Sanity Check Complete ---")
print(f"Baseline FP32 Accuracy (MnistCNN):   {fp32_accuracy:.2f}%")
print(f"AdaptivFloat 8-bit Accuracy (MnistCNN): {af_8bit_accuracy:.2f}%")
print(f"Accuracy Drop: {fp32_accuracy - af_8bit_accuracy:.4f}%")

In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load BERT tokenizer and model
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# The test sentence. The correct answer is "paris"
sentence = "The capital of France is [MASK]."

In [None]:
# Load a fresh FP32 model
model_bert_fp32 = AutoModelForMaskedLM.from_pretrained(model_name).to(device)

# Run the evaluation
predicted_token = utils.evaluate_bert(model_bert_fp32, tokenizer, sentence, device)
print(f"\n[FP32 Baseline BERT] Prediction for '[MASK]': '{predicted_token}'")

In [None]:
# Load a fresh model
model_bert_int8 = AutoModelForMaskedLM.from_pretrained(model_name).to(device)

# Create the "bad" INT8 quantizer function
quant_int8_func = partial(utils.simple_int8_quantizer, total_bits=8)
quant_int8_func.__name__ = "Simple_INT8"

# Apply this bad quantization
utils.apply_quantization_to_model(model_bert_int8, quant_int8_func)

# Evaluate
predicted_token_int8 = utils.evaluate_bert(model_bert_int8, tokenizer, sentence, device)
print(f"\n[Simple INT8 BERT] Prediction for '[MASK]': '{predicted_token_int8}'")

In [None]:
# Load a fresh model
model_bert_af8 = AutoModelForMaskedLM.from_pretrained(model_name).to(device)

# Apply our GOOD AdaptivFloat 8-bit quantizer
# We re-use the function from the ResNet test
utils.apply_quantization_to_model(model_bert_af8, quant_8bit_af_func)

# Evaluate
predicted_token_af8 = utils.evaluate_bert(model_bert_af8, tokenizer, sentence, device)
print(f"\n[AdaptivFloat 8-bit BERT] Prediction for '[MASK]': '{predicted_token_af8}'")

In [None]:
print("\n--- Advanced Test: Going to lower bit-widths ---")

# --- 6-bit (3 exp bits) ---
model_bert_af6 = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
quant_6bit_af_func = partial(quantizer.quantize_to_adaptivfloat, total_bits=6, exponent_bits=3)
quant_6bit_af_func.__name__ = "AdaptivFloat_6bit_3exp"
utils.apply_quantization_to_model(model_bert_af6, quant_6bit_af_func)
predicted_token_af6 = utils.evaluate_bert(model_bert_af6, tokenizer, sentence, device)
print(f"[AdaptivFloat 6-bit BERT] Prediction for '[MASK]': '{predicted_token_af6}'")

# --- 4-bit (2 exp bits) ---
model_bert_af4 = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
quant_4bit_af_func = partial(quantizer.quantize_to_adaptivfloat, total_bits=4, exponent_bits=2)
quant_4bit_af_func.__name__ = "AdaptivFloat_4bit_2exp"
utils.apply_quantization_to_model(model_bert_af4, quant_4bit_af_func)
predicted_token_af4 = utils.evaluate_bert(model_bert_af4, tokenizer, sentence, device)
print(f"[AdaptivFloat 4-bit BERT] Prediction for '[MASK]': '{predicted_token_af4}'")