## 1. Preparation

In [1]:
long_ner_text = """
The Old Zurich War was a conflict between the canton of Zurich and the other seven cantons of the Old Swiss Confederacy over the succession to the Count of Toggenburg.
In 1436, Count Friedrich VII of Toggenburg died, leaving neither heir nor will.
The canton of Zurich, led by burgomaster Rudolf Stüssi, claimed the Toggenburg lands; the cantons of Schwyz and Glarus made counter-claims, backed by the other cantons.
In 1438, Zurich occupied the disputed area and cut off grain supplies to Schwyz and Glarus. In 1440, the other cantons expelled Zurich from the confederation and declared war.
Zurich retaliated by making an alliance with Frederick III, Holy Roman Emperor of the house of Habsburg.

The forces of Zurich were defeated in the Battle of St. Jakob an der Sihl on 22 July 1443 and Zurich was besieged.
Frederick appealed to Charles VII of France to attack the confederates and the latter sent a force of about 30,000 Armagnac mercenaries under the command of the Dauphin via Basel to relieve the city.
In the Battle of St. Jakob an der Birs near Basel on 26 August 1444, a blocking force of roughly 1,600 Swiss confederates was defeated, but inflicted such heavy losses on the French (2,000 killed) that the Dauphin decided to retreat.
The confederacy and the Dauphin concluded a peace in October 1444, and his mercenary army withdrew from the war altogether.

In May 1444, the confederacy laid siege to Greifensee, and captured the town after four weeks, on May 27, beheading all but two of the 64 defenders the next day, including their leader, Wildhans von Breitenlandenberg, the so-called Murder of Greifensee.
Even in this time of war, such a mass execution was widely considered a cruel and unjust deed.
"""

small_ner_text = "John Doe was here."

In [2]:
%pip install fvcore torch

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

def bert_input(token_embeddings):
    """
    Helper function to prepare BERT input embeddings for fvcore.
    
    Args:
        token_embeddings: Dictionary containing 'input_ids', 'attention_mask', and optionally 'token_type_ids'.
    
    Returns:
        Tuple of tensors that fvcore can handle.
    """
    # Convert BatchEncoding to tuple of tensors
    input_ids = token_embeddings["input_ids"]
    attention_mask = token_embeddings["attention_mask"]
    
    # Check if token_type_ids exists, if not create zeros
    if "token_type_ids" in token_embeddings and token_embeddings["token_type_ids"] is not None:
        token_type_ids = token_embeddings["token_type_ids"]
    else:
        token_type_ids = torch.zeros_like(input_ids)
    
    return (input_ids, attention_mask, token_type_ids)

def measure_flops(model, input):
    """
    Helper function to measure FLOPS for any given function execution.
    
    Args:
        model: The model to evaluate
        input: The input for the model (tuple of tensors)

    Returns:
        FlopCountAnalysis: An object containing the FLOPS count and other details
    """
    # Use fvcore to count FLOPS
    with torch.no_grad():
        flops = FlopCountAnalysis(model, input)
    return flops

## 1.1. Evaluating average token lengths for ontonotes test split

In [4]:
%pip install datasets

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [5]:
from datasets import load_dataset

# Load the English portion of OntoNotes 5.0
ontonotes = load_dataset(
    "conll2012_ontonotesv5",
    "english_v12",
    cache_dir="./dataset/ontonotes",
)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Evaluating number of prefixes and average token length per prefix
def evaluate_prefixes(dataset):
    """
    Evaluates the number of prefixes and average token length per prefix in the dataset.
    
    Args:
        dataset: List of documents (dictionary) containing 'sentences' (list of dictionaries) containing 'words' (list of strings).

    Returns:
        Tuple of (number of prefixes, average token length)
    """
    num_prefixes = 0
    total_prefix_token_length = 0
    for doc in dataset:
        for sent in doc['sentences']:
            num_tokens = len(sent['words'])
            num_prefixes += num_tokens
            total_prefix_token_length += num_tokens * (num_tokens + 1) // 2
    avg_prefix_length = total_prefix_token_length / num_prefixes if num_prefixes > 0 else 0
    return num_prefixes, avg_prefix_length

# Evaluate prefixes in the OntoNotes dataset
num_prefixes, avg_prefix_length = evaluate_prefixes(ontonotes['test'])
print(f"Number of prefixes: {num_prefixes:,}; Average token length: {avg_prefix_length:.2f}")


Number of prefixes: 230,118; Average token length: 14.64


In [7]:
# Evaluate number of windows and average window length
# See precompute_fixed_spans_and_labels.ipynb for the code that counts the windows
num_windows, avg_window_length = 224128, 6
print(f"Number of windows: {num_windows:,}; Average window length: {avg_window_length:.2f}")

Number of windows: 224,128; Average window length: 6.00


## 2. Evaluating FLOPS of the baseline BERT model

In [8]:
%pip install transformers

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [9]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

# Load pre-trained NER model
model_name = "dslim/bert-base-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.eval()  # Set to evaluation mode
print(model.forward.__code__.co_varnames)

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('self', 'input_ids', 'attention_mask', 'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds', 'labels', 'output_attentions', 'output_hidden_states', 'return_dict', 'outputs', 'sequence_output', 'logits', 'loss', 'loss_fct', 'output')


In [10]:
# Tokenize both texts
long_tokens = tokenizer(long_ner_text, return_tensors="pt", padding=False, truncation=False)
small_tokens = tokenizer(small_ner_text, return_tensors="pt", padding=False, truncation=False)
print(f"Long text tokens: {long_tokens['input_ids'].shape}")
print(f"Small text tokens: {small_tokens['input_ids'].shape}")

# Evaluate FLOPS for long text given it's input_ids and attention_mask
long_flops = measure_flops(model, bert_input(long_tokens))

# Evaluate FLOPS for small text
small_flops = measure_flops(model, bert_input(small_tokens))

print(f"\nFLOPS Results:")
print(f"Long text FLOPS: {long_flops.total():,}")
print(f"Long text FLOPS/token: {long_flops.total() / long_tokens['input_ids'].numel():,.2f}")
print("="*60)
print(f"Small text FLOPS: {small_flops.total():,}")
print(f"Small text FLOPS/token: {small_flops.total() / small_tokens['input_ids'].numel():,.2f}")
print("="*60)

Long text tokens: torch.Size([1, 418])
Small text tokens: torch.Size([1, 8])

FLOPS Results:


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

Long text FLOPS: 35,545,703,424
Long text FLOPS/token: 85,037,568.00
Small text FLOPS: 680,300,544
Small text FLOPS/token: 85,037,568.00


In [11]:
import numpy as np
data = np.load('data/ontonotes_embeddings_test.npz')
bert_X = data['X']

# Get average length of embeddings in X
# average_length = np.mean([x.shape[0] for x in X])
print(f"Average length of embeddings in X: {bert_X.shape}")

Average length of embeddings in X: (72860, 768)


In [12]:
# Evaluate a more comprehensive regression model for the flops
bert_lengths_to_sample = [4, 8, 12, 16, 24, 32, 48, 64, 96, 128, 192, 256]
bert_tokens = tokenizer(long_ner_text, return_tensors="pt", padding=False, truncation=False)
bert_flops_per_length = {}

print("Measuring BERT FLOPS for different sequence lengths...")
for L in bert_lengths_to_sample:
    # Create tokens of exact length L
    tokens = {k: v[:, :L] if v.shape[1] > L else v for k, v in bert_tokens.items()}
    
    # Pad if necessary
    if tokens["input_ids"].shape[1] < L:
        pad_length = L - tokens["input_ids"].shape[1]
        tokens["input_ids"] = torch.cat([tokens["input_ids"], torch.zeros(1, pad_length, dtype=torch.long)], dim=1)
        tokens["attention_mask"] = torch.cat([tokens["attention_mask"], torch.zeros(1, pad_length, dtype=torch.long)], dim=1)
        if "token_type_ids" in tokens:
            tokens["token_type_ids"] = torch.cat([tokens["token_type_ids"], torch.zeros(1, pad_length, dtype=torch.long)], dim=1)
    
    assert tokens["input_ids"].shape[1] == L, f"Expected {L} tokens, got {tokens['input_ids'].shape[1]}"
    
    flops = FlopCountAnalysis(model, bert_input(tokens)).total()
    bert_flops_per_length[L] = flops
    print(f"L={L}: {flops:,} FLOPS")

# Convert to arrays
bert_X = torch.tensor(list(bert_flops_per_length.keys()), dtype=torch.float32).unsqueeze(1)
bert_y = torch.tensor(list(bert_flops_per_length.values()), dtype=torch.float32)

# Fit more comprehensive polynomial: FLOPs = a * L^2 + b * L + c
bert_X_poly = torch.cat([bert_X**2, bert_X, torch.ones_like(bert_X)], dim=1)
bert_coeffs = torch.linalg.lstsq(bert_X_poly, bert_y).solution  # returns [a, b, c]
print(f"\nBERT FLOPS = {bert_coeffs[0].item():.2f} * L^2 + {bert_coeffs[1].item():.2f} * L + {bert_coeffs[2].item():.2f}")

# Calculate R-squared for validation
bert_y_pred = bert_X_poly @ bert_coeffs
bert_ss_res = torch.sum((bert_y - bert_y_pred) ** 2)
bert_ss_tot = torch.sum((bert_y - torch.mean(bert_y)) ** 2)
bert_r_squared = 1 - bert_ss_res / bert_ss_tot
print(f"R-squared: {bert_r_squared.item():.4f}")

Measuring BERT FLOPS for different sequence lengths...


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=4: 340,150,272 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=8: 680,300,544 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=12: 1,020,450,816 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=16: 1,360,601,088 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=24: 2,040,901,632 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=32: 2,721,202,176 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=48: 4,081,803,264 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=64: 5,442,404,352 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=96: 8,163,606,528 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=128: 10,884,808,704 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=192: 16,327,213,056 FLOPS


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bert.encoder.layer.0.attention.self.dropout, bert.encoder.layer.1.attention.self.dropout, bert.encoder.layer.10.attention.self.dropout, bert.encoder.layer.11.attention.self.dropout, bert.encoder.layer.2.attention.self.dropout, bert.encoder.layer.3.attention.self.dropout, bert.encoder.layer.4.attention.self.dropout, bert.e

L=256: 21,769,617,408 FLOPS

BERT FLOPS = -0.03 * L^2 + 85037576.00 * L + 16.60
R-squared: 1.0000


## 3. Evaluating FLOPS for our prefix evaluating MLP (approach 1)

In [13]:
import sys
sys.path.append("./src/")
from confidence_model import confidence_model
from transformers import AutoModel

In [14]:
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")
bert_model.eval()
confidence_model = confidence_model()
confidence_model.eval()

confidence_model(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [15]:
# Evaluate FLOPS for long text given it's input_ids and attention_mask
moodel_1_long_flops_prepare = measure_flops(bert_model, bert_input(long_tokens))
moodel_1_long_tokens_cls = bert_model(**long_tokens).last_hidden_state[:, 0, :]
moodel_1_long_flops_run = measure_flops(confidence_model, moodel_1_long_tokens_cls)
moodel_1_long_flops = moodel_1_long_flops_prepare.total() + moodel_1_long_flops_run.total()

# Evaluate FLOPS for small text
moodel_1_small_flops_prepare = measure_flops(bert_model, bert_input(small_tokens))
moodel_1_small_tokens_cls = bert_model(**small_tokens).last_hidden_state[:, 0, :]
moodel_1_small_flops_run = measure_flops(confidence_model, moodel_1_small_tokens_cls)
moodel_1_small_flops = moodel_1_small_flops_prepare.total() + moodel_1_small_flops_run.total()

print(f"Long text tokens: {moodel_1_long_tokens_cls.shape}")
print(f"Small text tokens: {moodel_1_small_tokens_cls.shape}")

print(f"\nFLOPS Results:")
print(f"Long text FLOPS CLS: {moodel_1_long_flops_prepare.total():,}")
print(f"Long text FLOPS Model: {moodel_1_long_flops_run.total():,}")
print(f"Long text FLOPS Total: {moodel_1_long_flops:,}")
print(f"Long text FLOPS/token: {moodel_1_long_flops / moodel_1_long_tokens_cls.numel():,.2f}")
print("="*60)
print(f"Small text FLOPS CLS: {moodel_1_small_flops_prepare.total():,}")
print(f"Small text FLOPS Model: {moodel_1_small_flops_run.total():,}")
print(f"Small text FLOPS Total: {moodel_1_small_flops:,}")
print(f"Small text FLOPS/token: {moodel_1_small_flops / moodel_1_small_tokens_cls.numel():,.2f}")
print("="*60)

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

Long text tokens: torch.Size([1, 768])
Small text tokens: torch.Size([1, 768])

FLOPS Results:
Long text FLOPS CLS: 35,543,404,032
Long text FLOPS Model: 196,864
Long text FLOPS Total: 35,543,600,896
Long text FLOPS/token: 46,280,730.33
Small text FLOPS CLS: 680,835,072
Small text FLOPS Model: 196,864
Small text FLOPS Total: 681,031,936
Small text FLOPS/token: 886,760.33


In [16]:
# Evaluate a more comprehensive regression model for Model 1
model_1_lengths_to_sample = [4, 8, 12, 16, 24, 32, 48, 64, 96, 128, 192, 256]
model_1_tokens = tokenizer(long_ner_text, return_tensors="pt", padding=False, truncation=False)
model_1_flops_per_length = {}

print("Measuring Model 1 FLOPS for different sequence lengths...")
for L in model_1_lengths_to_sample:
    # Create tokens of exact length L
    tokens = {k: v[:, :L] if v.shape[1] > L else v for k, v in model_1_tokens.items()}
    
    # Pad if necessary
    if tokens["input_ids"].shape[1] < L:
        pad_length = L - tokens["input_ids"].shape[1]
        tokens["input_ids"] = torch.cat([tokens["input_ids"], torch.zeros(1, pad_length, dtype=torch.long)], dim=1)
        tokens["attention_mask"] = torch.cat([tokens["attention_mask"], torch.zeros(1, pad_length, dtype=torch.long)], dim=1)
        if "token_type_ids" in tokens:
            tokens["token_type_ids"] = torch.cat([tokens["token_type_ids"], torch.zeros(1, pad_length, dtype=torch.long)], dim=1)
    
    # Measure BERT FLOPS
    with torch.no_grad():
        cls_token = bert_model(**tokens).last_hidden_state[:, 0, :]
    flops_prepare = measure_flops(bert_model, bert_input(tokens)).total()
    flops_run = measure_flops(confidence_model, cls_token).total()
    flops = flops_prepare + flops_run
    model_1_flops_per_length[L] = flops
    print(f"L={L}: {flops:,} FLOPS (BERT: {flops_prepare:,}, MLP: {flops_run:,})")

# Convert to arrays
model_1_X = torch.tensor(list(model_1_flops_per_length.keys()), dtype=torch.float32).unsqueeze(1)
model_1_y = torch.tensor(list(model_1_flops_per_length.values()), dtype=torch.float32)

# Fit comprehensive polynomial: FLOPs = a * L^2 + b * L + c
model_1_X_poly = torch.cat([model_1_X**2, model_1_X, torch.ones_like(model_1_X)], dim=1)
model_1_coeffs = torch.linalg.lstsq(model_1_X_poly, model_1_y).solution  # returns [a, b, c]
print(f"\nModel 1 FLOPS = {model_1_coeffs[0].item():.2f} * L^2 + {model_1_coeffs[1].item():.2f} * L + {model_1_coeffs[2].item():.2f}")

# Calculate R-squared for validation
model_1_y_pred = model_1_X_poly @ model_1_coeffs
model_1_ss_res = torch.sum((model_1_y - model_1_y_pred) ** 2)
model_1_ss_tot = torch.sum((model_1_y - torch.mean(model_1_y)) ** 2)
model_1_r_squared = 1 - model_1_ss_res / model_1_ss_tot
print(f"R-squared: {model_1_r_squared.item():.4f}")

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

Measuring Model 1 FLOPS for different sequence lengths...
L=4: 340,909,312 FLOPS (BERT: 340,712,448, MLP: 196,864)


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

L=8: 681,031,936 FLOPS (BERT: 680,835,072, MLP: 196,864)
L=12: 1,021,154,560 FLOPS (BERT: 1,020,957,696, MLP: 196,864)


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

L=16: 1,361,277,184 FLOPS (BERT: 1,361,080,320, MLP: 196,864)
L=24: 2,041,522,432 FLOPS (BERT: 2,041,325,568, MLP: 196,864)


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

L=32: 2,721,767,680 FLOPS (BERT: 2,721,570,816, MLP: 196,864)
L=48: 4,082,258,176 FLOPS (BERT: 4,082,061,312, MLP: 196,864)


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

L=64: 5,442,748,672 FLOPS (BERT: 5,442,551,808, MLP: 196,864)
L=96: 8,163,729,664 FLOPS (BERT: 8,163,532,800, MLP: 196,864)


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

L=128: 10,884,710,656 FLOPS (BERT: 10,884,513,792, MLP: 196,864)
L=192: 16,326,672,640 FLOPS (BERT: 16,326,475,776, MLP: 196,864)


Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

L=256: 21,768,634,624 FLOPS (BERT: 21,768,437,760, MLP: 196,864)

Model 1 FLOPS = 0.03 * L^2 + 85030656.00 * L + 786886.00
R-squared: 1.0000


## 4. Evaluating FLOPS for our sliding window evaluating MLP (approach 2)

In [17]:
from window_slide_model import WindowSlideModel
window_model = WindowSlideModel(input_dim=768)

In [18]:
model_2_text = " ".join(long_ner_text.split()[0:4]) # Using a window size of 6 as per the problem statement
model_2_tokens = tokenizer(model_2_text, return_tensors="pt", padding=False, truncation=False)
assert model_2_tokens["input_ids"].shape[1] == 6, f"Expected 6 tokens, got {model_2_tokens['input_ids'].shape[1]}"
model_2_long_flops_prepare = measure_flops(bert_model, bert_input(model_2_tokens))
model_2_long_tokens_cls = bert_model(**model_2_tokens).last_hidden_state[:, 0, :]
model_2_long_flops_run = measure_flops(window_model, model_2_long_tokens_cls)
model_2_long_flops = model_2_long_flops_prepare.total() + model_2_long_flops_run.total()
model_2_long_flops_per_token = model_2_long_flops / model_2_tokens["input_ids"].numel()
print(f"Long text tokens: {model_2_long_tokens_cls.shape}")

print(f"\nFLOPS Results:")
print(f"Model 2 FLOPS CLS: {model_2_long_flops_prepare.total():,}")
print(f"Model 2 FLOPS Model: {model_2_long_flops_run.total():,}")
print(f"Model 2 FLOPS Total: {model_2_long_flops:,}")
print(f"Model 2 FLOPS/token: {model_2_long_flops_per_token:,.2f}")

Unsupported operator aten::add encountered 26 time(s)
Unsupported operator aten::embedding encountered 3 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::rsub encountered 1 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::tanh encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.s

Long text tokens: torch.Size([1, 768])

FLOPS Results:
Model 2 FLOPS CLS: 510,773,760
Model 2 FLOPS Model: 98,432
Model 2 FLOPS Total: 510,872,192
Model 2 FLOPS/token: 85,145,365.33


## 5. Verification

In [22]:
import pickle

# Save the coefficients for later use
with open('baselines/flops_coefficients.pkl', 'wb') as f:
    pickle.dump({
        'bert_coeffs': bert_coeffs,
        'model_1_coeffs': model_1_coeffs,
        'model_2_flops_per_token': model_2_long_flops_per_token
    }, f)

In [25]:
from src.flops_calculator import FlopsCalculator
flops_calculator = FlopsCalculator("baselines/flops_coefficients.pkl")

In [27]:
long_text_len = long_tokens["input_ids"].shape[1]
long_text_window_len = model_2_tokens["input_ids"].shape[1]

# Get actual measured values for comparison
actual_ner_flops = long_flops.total()
actual_model1_flops = moodel_1_long_flops
actual_model2_flops = model_2_long_flops

print(f"=== FLOPS Comparison ===")
print(f"Sequence length: {long_text_len} tokens")
print(f"Window length: {long_text_window_len} tokens")
print()

estimated_ner = flops_calculator.calculate_flops("ner", long_text_len)
print(f"NER - Estimated: {estimated_ner:,}")
print(f"NER - Actual: {actual_ner_flops:,}")
print(f"NER - Error: {abs(estimated_ner - actual_ner_flops) / actual_ner_flops * 100:.1f}%")
print()

estimated_model1 = flops_calculator.calculate_flops("model_1", long_text_len)
print(f"Model 1 - Estimated: {estimated_model1:,}")
print(f"Model 1 - Actual: {actual_model1_flops:,}")
print(f"Model 1 - Error: {abs(estimated_model1 - actual_model1_flops) / actual_model1_flops * 100:.1f}%")
print()

estimated_model2 = flops_calculator.calculate_flops("model_2", long_text_window_len)
print(f"Model 2 - Estimated: {estimated_model2:,}")
print(f"Model 2 - Actual: {actual_model2_flops:,}")
print(f"Model 2 - Error: {abs(estimated_model2 - actual_model2_flops) / actual_model2_flops * 100:.1f}%")

=== FLOPS Comparison ===
Sequence length: 418 tokens
Window length: 6 tokens

NER - Estimated: 35,545,702,400
NER - Actual: 35,545,703,424
NER - Error: 0.0%

Model 1 - Estimated: 35,543,605,248
Model 1 - Actual: 35,543,600,896
Model 1 - Error: 0.0%

Model 2 - Estimated: 510,872,192
Model 2 - Actual: 510,872,192
Model 2 - Error: 0.0%


In [28]:
small_text_len = small_tokens["input_ids"].shape[1]

# Get actual measured values for comparison
actual_ner_flops = small_flops.total()
actual_model1_flops = moodel_1_small_flops

print(f"=== FLOPS Comparison (Small Text) ===")
print(f"Sequence length: {small_text_len} tokens")
print()

estimated_ner = flops_calculator.calculate_flops("ner", small_text_len)
print(f"NER - Estimated: {estimated_ner:,}")
print(f"NER - Actual: {actual_ner_flops:,}")
print(f"NER - Error: {abs(estimated_ner - actual_ner_flops) / actual_ner_flops * 100:.1f}%")
print()

estimated_model1 = flops_calculator.calculate_flops("model_1", small_text_len)
print(f"Model 1 - Estimated: {estimated_model1:,}")
print(f"Model 1 - Actual: {actual_model1_flops:,}")
print(f"Model 1 - Error: {abs(estimated_model1 - actual_model1_flops) / actual_model1_flops * 100:.1f}%")
print()

=== FLOPS Comparison (Small Text) ===
Sequence length: 8 tokens

NER - Estimated: 680,300,608
NER - Actual: 680,300,544
NER - Error: 0.0%

Model 1 - Estimated: 681,032,128
Model 1 - Actual: 681,031,936
Model 1 - Error: 0.0%

