In [1]:
# First install the dependencies
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
!pip install evaluate
!pip install xformers flash-attn
!pip install -q apex


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m82.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for peft (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for accelerate (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [2]:
!pip install triton

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [45]:
import torch
import triton
import triton.language as tl

# Triton kernel for Flash Attention
@triton.jit
def flash_attention_kernel(
    Q, K, V, sm_scale, attention_mask,
    Out, seq_len, head_dim,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)
    col_idx = tl.program_id(1)

    q_ptr = Q + row_idx * seq_len * head_dim
    k_ptr = K + col_idx * seq_len * head_dim
    v_ptr = V + col_idx * seq_len * head_dim

    # Make head_dim a compile-time constant within the kernel
    acc = tl.zeros((BLOCK_SIZE, tl.constexpr(head_dim)), dtype=tl.float32)
    for i in range(0, seq_len, BLOCK_SIZE):
        k = tl.load(k_ptr + i * head_dim, mask=i + tl.arange(0, BLOCK_SIZE) < seq_len, other=0.0)
        v = tl.load(v_ptr + i * head_dim, mask=i + tl.arange(0, BLOCK_SIZE) < seq_len, other=0.0)

        # Compute scaled dot-product attention scores
        q = tl.load(q_ptr, mask=True, other=0.0)
        attn = tl.dot(q, k.transpose(0, 1)) * sm_scale

        if attention_mask:
            attn = attn + attention_mask

        # Compute softmax over scores
        max_attn = tl.max(attn, axis=1)
        attn = tl.exp(attn - max_attn[:, None])
        attn_sum = tl.sum(attn, axis=1)
        attn = attn / attn_sum[:, None]

        # Aggregate attention-weighted values
        acc += tl.dot(attn, v)

    # Store results back to Out
    out_ptr = Out + row_idx * seq_len * head_dim
    tl.store(out_ptr, acc)

class TritonFlashAttention(torch.nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = head_dim ** -0.5

    def forward(self, query, key, value, attention_mask=None):
        batch_size, seq_len, hidden_dim = query.shape
        BLOCK_SIZE = 128

        # Flatten for Triton kernel compatibility
        query = query.view(batch_size * self.num_heads, seq_len, self.head_dim)
        key = key.view(batch_size * self.num_heads, seq_len, self.head_dim)
        value = value.view(batch_size * self.num_heads, seq_len, self.head_dim)

        # Allocate output tensor
        output = torch.empty_like(query)

        # Launch Triton kernel
        grid = (batch_size * self.num_heads, seq_len // BLOCK_SIZE + 1)
        flash_attention_kernel[grid](
            query.to(torch.float16),
            key.to(torch.float16),
            value.to(torch.float16),
            self.scale,
            attention_mask,
            output,
            seq_len,
            self.head_dim,
            BLOCK_SIZE=BLOCK_SIZE
        )
        return output.view(batch_size, seq_len, hidden_dim)



In [46]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader
from peft import PeftModel


class CustomModelWithTritonAttention(AutoModelForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.flash_attention = TritonFlashAttention(self.num_heads, self.head_dim)

    def forward(self, **inputs):
        # Forward inputs through the model
        outputs = super().forward(**inputs)

        # Replace attention logic with TritonFlashAttention
        hidden_states = outputs.hidden_states
        attention_mask = inputs.get("attention_mask", None)

        for i, layer in enumerate(self.bert.encoder.layer):
            hidden_states = layer.attention.self.forward(
                hidden_states,
                attention_mask=attention_mask,
                flash_attention=self.flash_attention,
            )

        return outputs


In [7]:
# Load GLUE dataset (MRPC task)
from datasets import load_dataset  # Import the load_dataset function
import evaluate


task = "mrpc"  # Set the task to "mrpc"
dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/649k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/75.7k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/308k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Downloading builder script:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

In [9]:
# run this cell if you saved model files as zip
!unzip fine-tuned-model.zip -d fine-tuned-model

Archive:  fine-tuned-model.zip
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/tokenizer_config.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/tokenizer.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/special_tokens_map.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/vocab.txt  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/training_args.bin  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/trainer_state.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/adapter_config.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/adapter_model.safetensors  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/README.md  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/rng_state.pth  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/scheduler.pt  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/optimizer.pt  


In [10]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig
# Assuming your model is already loaded and fine-tuned
model = AutoModelForSequenceClassification.from_pretrained("fine-tuned-model/model-files")

# Save model weights, config, and tokenizer to a directory
save_directory = "./saved_model"  # Choose the directory where you want to save
model.save_pretrained(save_directory)

# Save tokenizer as well
tokenizer = AutoTokenizer.from_pretrained("fine-tuned-model/model-files")
tokenizer.save_pretrained(save_directory)

config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/62.7M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


('./saved_model/tokenizer_config.json',
 './saved_model/special_tokens_map.json',
 './saved_model/vocab.txt',
 './saved_model/added_tokens.json',
 './saved_model/tokenizer.json')

In [16]:
from datasets import load_dataset

test_dataset = load_dataset("glue", "mrpc", split="test")

In [17]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("./saved_model", padding=True, truncation=True)

def tokenize_function(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length")

tokenized_dataset = test_dataset.map(tokenize_function, batched=True)

In [18]:
tokenized_dataset = tokenized_dataset.remove_columns(["sentence1", "sentence2", "idx"])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset.set_format("torch")

In [47]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)



In [48]:
def measure_latency(model, dataloader, warmup_runs=10, measure_runs=50):
    model.eval()
    latencies = []

    # Warmup phase
    with torch.no_grad():
        for _ in range(warmup_runs):
            for batch in dataloader:
                batch = {k: v.to("cuda") for k, v in batch.items()}
                _ = model(**batch)
                torch.cuda.synchronize()  # Ensure all GPU operations are complete

    # Measurement phase
    with torch.no_grad():
        for _ in range(measure_runs):
            for batch in dataloader:
                batch = {k: v.to("cuda") for k, v in batch.items()}

                # Time a single batch pass
                start_time = time.time()
                _ = model(**batch)
                torch.cuda.synchronize()
                end_time = time.time()

                latencies.append(end_time - start_time)

    # Calculate average latency
    average_latency = sum(latencies) / len(latencies)
    return average_latency * 1000  # Convert to milliseconds

In [49]:
from transformers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Wrap the base model with Triton attention
custom_model = CustomModelWithTritonAttention.from_pretrained(
    "./saved_model",
    quantization_config=double_quant_config
)
tokenizer = AutoTokenizer.from_pretrained("./saved_model")
custom_model.to("cuda")
custom_model.eval()


`low_cpu_mem_usage` was None, now default to True since model is quantized.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=312, out_features=312, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=312, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(i

In [55]:
batch_sizes = [16, 32, 64, 128, 256, 512, 1024, 2048]

In [50]:
for batch_size in batch_sizes:
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    latency = measure_latency(custom_model, dataloader)
    print(f"Average inference latency (Batch Size {batch_size}): {latency:.2f} ms")

Average inference latency (Batch Size 16): 13.74 ms
Average inference latency (Batch Size 32): 14.10 ms
Average inference latency (Batch Size 64): 14.94 ms
Average inference latency (Batch Size 128): 14.16 ms
Average inference latency (Batch Size 256): 22.20 ms
Average inference latency (Batch Size 512): 34.49 ms
Average inference latency (Batch Size 1024): 64.08 ms


In [51]:
import time

# Wrap the base model with Triton attention
noFA_model = AutoModelForSequenceClassification.from_pretrained(
    "./saved_model",
    quantization_config=double_quant_config
)
tokenizer = AutoTokenizer.from_pretrained("./saved_model")
noFA_model.to("cuda")
noFA_model.eval()


`low_cpu_mem_usage` was None, now default to True since model is quantized.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=312, out_features=312, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=312, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(i

In [52]:
batch_sizes = [16, 32, 64, 128, 256, 512, 1024]
for batch_size in batch_sizes:
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    latency = measure_latency(noFA_model, dataloader)
    print(f"Average inference latency (Batch Size {batch_size}): {latency:.2f} ms")

Average inference latency (Batch Size 16): 14.20 ms
Average inference latency (Batch Size 32): 13.72 ms
Average inference latency (Batch Size 64): 13.90 ms
Average inference latency (Batch Size 128): 14.35 ms
Average inference latency (Batch Size 256): 22.64 ms
Average inference latency (Batch Size 512): 36.84 ms
Average inference latency (Batch Size 1024): 66.12 ms


In [53]:
import time

# Wrap the base model with Triton attention
original_model = AutoModelForSequenceClassification.from_pretrained(
    "huawei-noah/TinyBERT_General_4L_312D",

)
tokenizer = AutoTokenizer.from_pretrained("huawei-noah/TinyBERT_General_4L_312D")
original_model.to("cuda")
original_model.eval()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=312, out_features=312, bias=True)
              (key): Linear(in_features=312, out_features=312, bias=True)
              (value): Linear(in_features=312, out_features=312, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=312, out_features=312, bias=True)
              (LayerNorm): LayerNorm((312,), eps=1e-1

In [54]:
batch_sizes = [16, 32, 64, 128, 256, 512, 1024]
for batch_size in batch_sizes:
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    latency = measure_latency(original_model, dataloader)
    print(f"Average inference latency (Batch Size {batch_size}): {latency:.2f} ms")

Average inference latency (Batch Size 16): 4.86 ms
Average inference latency (Batch Size 32): 5.13 ms
Average inference latency (Batch Size 64): 8.29 ms
Average inference latency (Batch Size 128): 14.60 ms
Average inference latency (Batch Size 256): 26.27 ms
Average inference latency (Batch Size 512): 40.60 ms
Average inference latency (Batch Size 1024): 79.27 ms
