In [11]:
from transformers import AutoTokenizer               # <-- You forgot this import

from evaluation.evaluator import EarlyExitEvaluator
from strategies.confidence_exit import ConfidenceExit
from models.gpt2_wrapper import GPT2WithEarlyExit
from evaluation.dataset_loaders.sst2 import load_sst2

tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [6]:
strategy = ConfidenceExit(threshold=0.8, allowed_layers=[3,6,9])
model = GPT2WithEarlyExit("gpt2", strategy, tokenizer)

In [7]:
dataset = load_sst2(fraction=0.10)   # <-- use 10% data

evaluator = EarlyExitEvaluator(tokenizer)

result = evaluator.evaluate(
    model=model,
    strategy=strategy,
    dataset=dataset,
    task_type="classification",
)

print(result)

Evaluating: 100%|███████████████████████████████| 87/87 [00:01<00:00, 54.04it/s]

{'metric': 'accuracy', 'score': np.float64(0.0), 'avg_latency_sec': np.float64(0.018404911304342336), 'tokens_per_sec': 54.33332350610494, 'avg_layers_used': np.float64(5.114942528735632), 'num_samples': 87}





In [9]:
from evaluation.dataset_loaders.agnews import load_agnews
from evaluation.dataset_loaders.cnn_dm import load_cnndm
from evaluation.dataset_loaders.squad import load_squad
from evaluation.dataset_loaders.wmt_en_fr import load_wmt_enfr

datasets = [
    ("sst2", load_sst2, "classification"),
    ("agnews", load_agnews, "classification"),
    ("cnn_dm", load_cnndm, "summarization"),
    ("wmt14_enfr", load_wmt_enfr, "translation"),
    ("squad", load_squad, "qa"),
]

for name, loader, task in datasets:
    print(f"Testing {name}...")

    dataset = loader(fraction=0.10)

    result = evaluator.evaluate(
        model=model,
        strategy=strategy,
        dataset=dataset,
        task_type=task,
    )

    print(name, result)

Testing sst2...


Evaluating: 100%|███████████████████████████████| 87/87 [00:01<00:00, 53.20it/s]


sst2 {'metric': 'accuracy', 'score': np.float64(0.0), 'avg_latency_sec': np.float64(0.018698522414284192), 'tokens_per_sec': 53.48016157876085, 'avg_layers_used': np.float64(5.114942528735632), 'num_samples': 87}
Testing agnews...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Map:   0%|          | 0/760 [00:00<?, ? examples/s]

Evaluating: 100%|█████████████████████████████| 760/760 [00:18<00:00, 41.48it/s]


agnews {'metric': 'accuracy', 'score': np.float64(0.0), 'avg_latency_sec': np.float64(0.02400514608935306), 'tokens_per_sec': 41.65773439902236, 'avg_layers_used': np.float64(5.859210526315789), 'num_samples': 760}
Testing cnn_dm...


README.md: 0.00B [00:00, ?B/s]

3.0.0/train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

3.0.0/validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

3.0.0/test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Map:   0%|          | 0/1336 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1156 > 1024). Running this sequence through the model will result in indexing errors
Evaluating: 100%|███████████████████████████| 1336/1336 [03:40<00:00,  6.05it/s]


cnn_dm {'metric': 'rougeL', 'score': np.float64(0.025857347737891605), 'avg_latency_sec': np.float64(0.1649550244122922), 'tokens_per_sec': 6.062258506904149, 'avg_layers_used': np.float64(6.017964071856287), 'num_samples': 1336}
Testing wmt14_enfr...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/30 [00:00<?, ?files/s]

fr-en/train-00000-of-00030.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]

fr-en/train-00001-of-00030.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

fr-en/train-00002-of-00030.parquet:   0%|          | 0.00/243M [00:00<?, ?B/s]

fr-en/train-00003-of-00030.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

fr-en/train-00004-of-00030.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

fr-en/train-00005-of-00030.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

fr-en/train-00006-of-00030.parquet:   0%|          | 0.00/240M [00:00<?, ?B/s]

fr-en/train-00007-of-00030.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

fr-en/train-00008-of-00030.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

fr-en/train-00009-of-00030.parquet:   0%|          | 0.00/239M [00:00<?, ?B/s]

fr-en/train-00010-of-00030.parquet:   0%|          | 0.00/239M [00:00<?, ?B/s]

fr-en/train-00011-of-00030.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

fr-en/train-00012-of-00030.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

fr-en/train-00013-of-00030.parquet:   0%|          | 0.00/230M [00:00<?, ?B/s]

fr-en/train-00014-of-00030.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

fr-en/train-00015-of-00030.parquet:   0%|          | 0.00/231M [00:00<?, ?B/s]

fr-en/train-00016-of-00030.parquet:   0%|          | 0.00/227M [00:00<?, ?B/s]

fr-en/train-00017-of-00030.parquet:   0%|          | 0.00/226M [00:00<?, ?B/s]

fr-en/train-00018-of-00030.parquet:   0%|          | 0.00/261M [00:00<?, ?B/s]

fr-en/train-00019-of-00030.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

fr-en/train-00020-of-00030.parquet:   0%|          | 0.00/261M [00:00<?, ?B/s]

fr-en/train-00021-of-00030.parquet:   0%|          | 0.00/264M [00:00<?, ?B/s]

fr-en/train-00022-of-00030.parquet:   0%|          | 0.00/267M [00:00<?, ?B/s]

fr-en/train-00023-of-00030.parquet:   0%|          | 0.00/270M [00:00<?, ?B/s]

fr-en/train-00024-of-00030.parquet:   0%|          | 0.00/274M [00:00<?, ?B/s]

fr-en/train-00025-of-00030.parquet:   0%|          | 0.00/278M [00:00<?, ?B/s]

fr-en/train-00026-of-00030.parquet:   0%|          | 0.00/365M [00:00<?, ?B/s]

fr-en/train-00027-of-00030.parquet:   0%|          | 0.00/322M [00:00<?, ?B/s]

fr-en/train-00028-of-00030.parquet:   0%|          | 0.00/370M [00:00<?, ?B/s]

fr-en/train-00029-of-00030.parquet:   0%|          | 0.00/311M [00:00<?, ?B/s]

fr-en/validation-00000-of-00001.parquet:   0%|          | 0.00/475k [00:00<?, ?B/s]

fr-en/test-00000-of-00001.parquet:   0%|          | 0.00/536k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40836715 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3003 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/30 [00:00<?, ?it/s]

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

Evaluating: 100%|█████████████████████████████| 300/300 [00:07<00:00, 38.39it/s]


wmt14_enfr {'metric': 'bleu', 'score': np.float64(1.4693051022486856e-06), 'avg_latency_sec': np.float64(0.025930917263031004), 'tokens_per_sec': 38.564004113563406, 'avg_layers_used': np.float64(6.303333333333334), 'num_samples': 300}
Testing squad...


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Map:   0%|          | 0/1057 [00:00<?, ? examples/s]

Evaluating: 100%|███████████████████████████| 1057/1057 [00:51<00:00, 20.64it/s]

squad {'metric': 'token_f1', 'score': np.float64(0.016823466137562637), 'avg_latency_sec': np.float64(0.04826024555325396), 'tokens_per_sec': 20.72098864264015, 'avg_layers_used': np.float64(6.811731315042573), 'num_samples': 1057}





In [None]:
### Strategy 2 - Confidence threshold should be (meet) in Continous layers

In [12]:
from strategies.continous_confidence_exit import ContinuousConfidenceExit

strategy = ContinuousConfidenceExit(
    threshold=0.75,
    required_consecutive=2,
    allowed_layers=[3, 6, 9, 11]
)

model = GPT2WithEarlyExit("gpt2", strategy, tokenizer)
evaluator = EarlyExitEvaluator(tokenizer)

In [None]:
datasets = [
    ("sst2", load_sst2, "classification"),
    ("agnews", load_agnews, "classification"),
    ("cnn_dm", load_cnndm, "summarization"),
    ("wmt14_enfr", load_wmt_enfr, "translation"),
    ("squad", load_squad, "qa"),
]

for name, loader, task in datasets:
    print(f"\n========== Testing {name.upper()} ==========\n")

    # Use 2% of dataset
    dataset = loader(fraction=0.02)

    result = evaluator.evaluate(
        model=model,
        strategy=strategy,
        dataset=dataset,
        task_type=task,
    )

    print(name, result)





Map:   0%|          | 0/17 [00:00<?, ? examples/s]

Evaluating: 100%|███████████████████████████████| 17/17 [00:00<00:00, 21.62it/s]


sst2 {'metric': 'accuracy', 'score': np.float64(0.0), 'avg_latency_sec': np.float64(0.04608050514669979), 'tokens_per_sec': 21.70115099251724, 'avg_layers_used': np.float64(12.0), 'num_samples': 17}




Map:   0%|          | 0/152 [00:00<?, ? examples/s]

Evaluating: 100%|█████████████████████████████| 152/152 [00:08<00:00, 18.96it/s]


agnews {'metric': 'accuracy', 'score': np.float64(0.0), 'avg_latency_sec': np.float64(0.052560009454426015), 'tokens_per_sec': 19.025871767909113, 'avg_layers_used': np.float64(12.0), 'num_samples': 152}




Map:   0%|          | 0/267 [00:00<?, ? examples/s]

Evaluating: 100%|█████████████████████████████| 267/267 [01:19<00:00,  3.34it/s]


cnn_dm {'metric': 'rougeL', 'score': np.float64(0.029517060847254528), 'avg_latency_sec': np.float64(0.2987214992108863), 'tokens_per_sec': 3.3475996961773316, 'avg_layers_used': np.float64(12.0), 'num_samples': 267}




Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/30 [00:00<?, ?it/s]