## 1. Preparation

In [13]:
long_ner_text = """
The Old Zurich War was a conflict between the canton of Zurich and the other seven cantons of the Old Swiss Confederacy over the succession to the Count of Toggenburg.
In 1436, Count Friedrich VII of Toggenburg died, leaving neither heir nor will.
The canton of Zurich, led by burgomaster Rudolf Stüssi, claimed the Toggenburg lands; the cantons of Schwyz and Glarus made counter-claims, backed by the other cantons.
In 1438, Zurich occupied the disputed area and cut off grain supplies to Schwyz and Glarus. In 1440, the other cantons expelled Zurich from the confederation and declared war.
Zurich retaliated by making an alliance with Frederick III, Holy Roman Emperor of the house of Habsburg.
"""

small_ner_text = "John Doe was here."

In [14]:
%pip install fvcore torch

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

def bert_input(token_embeddings):
    """
    Helper function to prepare BERT input embeddings. As per the model's forward method:
    'self', 'input_ids', 'attention_mask', 'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds', 'labels', 'output_attentions', 'output_hidden_states', 'return_dict', 'outputs', 'sequence_output', 'logits', 'loss', 'loss_fct', 'output'
    
    Args:
        token_embeddings: Dictionary containing 'input_ids', 'attention_mask', and optionally 'token_type_ids'.
    
    Returns:
        Tuple of tensors (input_ids, attention_mask, token_type_ids).
    """
    # Convert to tuple of tensors (in order of model forward args)
    input_ids = token_embeddings["input_ids"]
    attention_mask = token_embeddings["attention_mask"]
    token_type_ids = token_embeddings.get("token_type_ids", None)

    # Create a tuple of positional arguments matching model.forward signature
    args = (input_ids, attention_mask, token_type_ids)
    if token_type_ids is not None:
        args += (token_type_ids,)
    
    return args

def measure_flops(model, input):
    """
    Helper function to measure FLOPS for any given function execution.
    
    Args:
        model: The model to evaluate
        token_embeddings: The token embeddings for the model

    Returns:
        FlopCountAnalysis: An object containing the FLOPS count and other details
    """

    # Use fvcore to count FLOPS
    with torch.no_grad():
        flops = FlopCountAnalysis(model, input)
    return flops

## 1.1. Evaluating average token lengths for ontonotes test split

In [29]:
%pip install datasets

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [None]:
from datasets import load_dataset

# Load the English portion of OntoNotes 5.0
ontonotes = load_dataset(
    "conll2012_ontonotesv5",
    "english_v12",
    cache_dir="./dataset/ontonotes",
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [43]:
# Evaluating number of prefixes and average token length per prefix
def evaluate_prefixes(dataset):
    """
    Evaluates the number of prefixes and average token length per prefix in the dataset.
    
    Args:
        dataset: List of documents (dictionary) containing 'sentences' (list of dictionaries) containing 'words' (list of strings).

    Returns:
        Tuple of (number of prefixes, average token length)
    """
    num_prefixes = 0
    total_prefix_token_length = 0
    for doc in dataset:
        for sent in doc['sentences']:
            num_tokens = len(sent['words'])
            num_prefixes += num_tokens
            total_prefix_token_length += num_tokens * (num_tokens + 1) // 2
    avg_prefix_length = total_prefix_token_length / num_prefixes if num_prefixes > 0 else 0
    return num_prefixes, avg_prefix_length

# Evaluate prefixes in the OntoNotes dataset
num_prefixes, avg_prefix_length = evaluate_prefixes(ontonotes['test'])
print(f"Number of prefixes: {num_prefixes:,}; Average token length: {avg_prefix_length:.2f}")


Number of prefixes: 230,118; Average token length: 14.64


In [44]:
# Evaluate number of windows and average window length
# See precompute_fixed_spans_and_labels.ipynb for the code that counts the windows
num_windows, avg_window_length = 224128, 6
print(f"Number of windows: {num_windows:,}; Average window length: {avg_window_length:.2f}")

Number of windows: 224,128; Average window length: 6.00


## 2. Evaluating FLOPS of the baseline BERT model

In [16]:
%pip install transformers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [21]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

# Load pre-trained NER model
model_name = "dslim/bert-base-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.eval()  # Set to evaluation mode
print(model.forward.__code__.co_varnames)

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('self', 'input_ids', 'attention_mask', 'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds', 'labels', 'output_attentions', 'output_hidden_states', 'return_dict', 'outputs', 'sequence_output', 'logits', 'loss', 'loss_fct', 'output')


In [53]:
# Tokenize both texts
long_tokens = tokenizer(long_ner_text, return_tensors="pt")
small_tokens = tokenizer(small_ner_text, return_tensors="pt")
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")
print(f"Long text tokens: {long_tokens['input_ids'].shape}")
print(f"Small text tokens: {small_tokens['input_ids'].shape}")

# Evaluate FLOPS for long text given it's input_ids and attention_mask
long_flops = measure_flops(model, bert_input(long_tokens))

# Evaluate FLOPS for small text
small_flops = measure_flops(model, bert_input(small_tokens))

print(f"\nFLOPS Results:")
print(f"Long text FLOPS: {long_flops.total():,}")
print(f"Long text FLOPS/token: {long_flops.total() / long_tokens['input_ids'].numel():,.2f}")
print("="*60)
print(f"Small text FLOPS: {small_flops.total():,}")
print(f"Small text FLOPS/token: {small_flops.total() / small_tokens['input_ids'].numel():,.2f}")
print("="*60)

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

Long text tokens: torch.Size([1, 167])
Small text tokens: torch.Size([1, 8])

FLOPS Results:
Long text FLOPS: 14,201,273,856
Long text FLOPS/token: 85,037,568.00


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

Small text FLOPS: 680,300,544
Small text FLOPS/token: 85,037,568.00


In [26]:
import numpy as np
data = np.load('data/ontonotes_embeddings_test.npz')
X = data['X']

# Get average length of embeddings in X
# average_length = np.mean([x.shape[0] for x in X])
print(f"Average length of embeddings in X: {X.shape}")

Average length of embeddings in X: (72860, 768)


In [None]:
# Evaluate a simple regression model for the flops
lengths_to_sample = [4, 8, 16, 32, 64, 128, 256]
model_1_flops_per_length = {}

for L in lengths_to_sample:
    tokens = tokenizer(" ".join(["hello"] * L), return_tensors="pt")
    flops = FlopCountAnalysis(model, bert_input(tokens)).total()
    model_1_flops_per_length[L] = flops

# Convert to arrays
X = torch.tensor(list(model_1_flops_per_length.keys()), dtype=torch.float32).unsqueeze(1)
y = torch.tensor(list(model_1_flops_per_length.values()), dtype=torch.float32)

# Fit FLOPs = a * L^2 + b
model_1_X_squared = X ** 2
model_1_X_poly = torch.cat([model_1_X_squared, torch.ones_like(X)], dim=1)
model_1_coeffs = torch.linalg.lstsq(model_1_X_poly, y).solution  # returns [a, b]
print(f"FLOPS = {model_1_coeffs[0].item():.2f} * L^2 + {model_1_coeffs[1].item():.2f}")

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

FLOPS = 312630.66 * L^2 + 2439028736.00


## 3. Evaluating FLOPS for our prefix evaluating MLP (approach 1)

In [49]:
import sys
sys.path.append("./src/")
from confidence_model import confidence_model
from transformers import AutoModel

In [50]:
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")
bert_model.eval()
confidence_model = confidence_model()
confidence_model.eval()

confidence_model(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [52]:
# Evaluate FLOPS for long text given it's input_ids and attention_mask
moodel_1_long_flops_prepare = measure_flops(bert_model, bert_input(long_tokens))
moodel_1_long_tokens_cls = bert_model(**long_tokens).last_hidden_state[:, 0, :]
moodel_1_long_flops_run = measure_flops(confidence_model, moodel_1_long_tokens_cls)
moodel_1_long_flops = moodel_1_long_flops_prepare.total() + moodel_1_long_flops_run.total()

# Evaluate FLOPS for small text
moodel_1_small_flops_prepare = measure_flops(bert_model, bert_input(small_tokens))
moodel_1_small_tokens_cls = bert_model(**small_tokens).last_hidden_state[:, 0, :]
moodel_1_small_flops_run = measure_flops(confidence_model, moodel_1_small_tokens_cls)
moodel_1_small_flops = moodel_1_small_flops_prepare.total() + moodel_1_small_flops_run.total()

print(f"Long text tokens: {moodel_1_long_tokens_cls.shape}")
print(f"Small text tokens: {moodel_1_small_tokens_cls.shape}")

print(f"\nFLOPS Results:")
print(f"Long text FLOPS CLS: {moodel_1_long_flops_prepare.total():,}")
print(f"Long text FLOPS Model: {moodel_1_long_flops_run.total():,}")
print(f"Long text FLOPS Total: {moodel_1_long_flops:,}")
print(f"Long text FLOPS/token: {moodel_1_long_flops / moodel_1_long_tokens_cls.numel():,.2f}")
print("="*60)
print(f"Small text FLOPS CLS: {moodel_1_small_flops_prepare.total():,}")
print(f"Small text FLOPS Model: {moodel_1_small_flops_run.total():,}")
print(f"Small text FLOPS Total: {moodel_1_small_flops:,}")
print(f"Small text FLOPS/token: {moodel_1_small_flops / moodel_1_small_tokens_cls.numel():,.2f}")
print("="*60)

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

Long text tokens: torch.Size([1, 768])
Small text tokens: torch.Size([1, 768])

FLOPS Results:
Long text FLOPS CLS: 14,200,709,376
Long text FLOPS Model: 196,864
Long text FLOPS Total: 14,200,906,240
Long text FLOPS/token: 18,490,763.33
Small text FLOPS CLS: 680,835,072
Small text FLOPS Model: 196,864
Small text FLOPS Total: 681,031,936
Small text FLOPS/token: 886,760.33


In [55]:
# Evaluate a simple regression model for the flops
lengths_to_sample = [4, 8, 16, 32, 64, 128, 256]
model_1_flops_per_length = {}

for L in lengths_to_sample:
    tokens = tokenizer(" ".join(["hello"] * L), return_tensors="pt")
    cls_token = bert_model(**tokens).last_hidden_state[:, 0, :]
    flops_prepare = measure_flops(bert_model, bert_input(tokens)).total()
    flops_run = measure_flops(confidence_model, cls_token).total()
    flops = flops_prepare + flops_run
    model_1_flops_per_length[L] = flops

# Convert to arrays
X = torch.tensor(list(model_1_flops_per_length.keys()), dtype=torch.float32).unsqueeze(1)
y = torch.tensor(list(model_1_flops_per_length.values()), dtype=torch.float32)

# Fit FLOPs = a * L^2 + b
model_1_X_squared = X ** 2
model_1_X_poly = torch.cat([model_1_X_squared, torch.ones_like(X)], dim=1)
model_1_coeffs = torch.linalg.lstsq(model_1_X_poly, y).solution  # returns [a, b]
print(f"FLOPS = {model_1_coeffs[0].item():.2f} * L^2 + {model_1_coeffs[1].item():.2f}")

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

FLOPS = 312605.25 * L^2 + 2439616512.00


## 4. Evaluating FLOPS for our sliding window evaluating MLP (approach 2)

In [57]:
from window_slide_model import WindowSlideModel
window_model = WindowSlideModel(input_dim=768)

In [58]:
model_2_text = " ".join(long_ner_text.split(" ")[0:6])
model_2_tokens = tokenizer(model_2_text, return_tensors="pt")
moodel_2_long_flops_prepare = measure_flops(bert_model, bert_input(model_2_tokens))
moodel_2_long_tokens_cls = bert_model(**model_2_tokens).last_hidden_state[:, 0, :]
moodel_2_long_flops_run = measure_flops(window_model, moodel_2_long_tokens_cls)
moodel_2_long_flops = moodel_2_long_flops_prepare.total() + moodel_2_long_flops_run.total()

print(f"Long text tokens: {moodel_2_long_tokens_cls.shape}")

print(f"\nFLOPS Results:")
print(f"Model 2 FLOPS CLS: {moodel_2_long_flops_prepare.total():,}")
print(f"Model 2 FLOPS Model: {moodel_2_long_flops_run.total():,}")
print(f"Model 2 FLOPS Total: {moodel_2_long_flops:,}")
print(f"Model 2 FLOPS/token: {moodel_2_long_flops / moodel_2_long_tokens_cls.numel():,.2f}")

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

Long text tokens: torch.Size([1, 768])

FLOPS Results:
Model 2 FLOPS CLS: 680,835,072
Model 2 FLOPS Model: 98,432
Model 2 FLOPS Total: 680,933,504
Model 2 FLOPS/token: 886,632.17
