In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
import torch.nn.functional as F
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "EleutherAI/gpt-neo-1.3B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16
).to(device)
model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


In [None]:
def forward_with_skipped_layers(model, input_ids, attention_mask, skip_layers):
    """
    Performs a forward pass through GPT-Neo while skipping the transformer
    layers specified in skip_layers.
    """
    # Embeddings
    hidden_states = model.transformer.wte(input_ids)
    position_ids = torch.arange(input_ids.shape[1], device=device).unsqueeze(0)
    hidden_states = hidden_states + model.transformer.wpe(position_ids)

    seq_len = attention_mask.shape[1]
    batch_size = input_ids.shape[0]

    causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(
        1, 1, seq_len, seq_len
    )
    attention_mask_4d = attention_mask.view(batch_size, 1, 1, seq_len)
    combined_mask = causal_mask * attention_mask_4d
    combined_mask = (1.0 - combined_mask) * torch.finfo(torch.float16).min

    # Transformer layers
    for idx, layer in enumerate(model.transformer.h):
        if idx in skip_layers:
            continue
        hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

    # Final layer norm + LM head
    hidden_states = model.transformer.ln_f(hidden_states)
    logits = model.lm_head(hidden_states)
    return logits


def winogrande_prompt(sentence, option):
    """
    Insert the option into the blank (_) in the Winogrande sentence.
    """
    return sentence.replace("_", option)


def evaluate_winogrande(model, tokenizer, skip_layers, sample_size=256):
    """
    Evaluates GPT-Neo on Winogrande using perplexity comparison.
    """
    dataset = load_dataset("winogrande", "winogrande_xl", split="validation")
    dataset = dataset.select(range(sample_size))

    correct = 0

    for example in dataset:
        sentence = example["sentence"]
        option1 = example["option1"]
        option2 = example["option2"]
        label = example["answer"]  # "1" or "2"

        # Make two filled-in sentences
        s1 = winogrande_prompt(sentence, option1)
        s2 = winogrande_prompt(sentence, option2)

        # Tokenize
        inputs1 = tokenizer(s1, return_tensors="pt").to(device)
        inputs2 = tokenizer(s2, return_tensors="pt").to(device)

        with torch.no_grad():
            logits1 = forward_with_skipped_layers(
                model, inputs1.input_ids, inputs1.attention_mask, skip_layers
            )
            logits2 = forward_with_skipped_layers(
                model, inputs2.input_ids, inputs2.attention_mask, skip_layers
            )

        # Compute sentence losses (negative log-likelihood)
        def compute_loss(logits, labels):
            shift_logits = logits[:, :-1].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                reduction="mean",
            )
            return loss.item()

        loss1 = compute_loss(logits1, inputs1.input_ids)
        loss2 = compute_loss(logits2, inputs2.input_ids)

        # Model chooses the option with *lower* loss
        pred = "1" if loss1 < loss2 else "2"
        # print()
        # print(f"Option 1: {s1}")
        # print(f"Option 2: {s2}")
        # print(f"Option {label} is correct. The model found option {pred} more likely.")
        if pred == label:
            correct += 1

    accuracy = correct / sample_size
    return accuracy

In [None]:
import json

num_layers = len(model.transformer.h)
max_layers_to_remove = 6
removed_layers = set()
removal_history = []

full_test_log = []
num_samples = 128

baseline_acc = evaluate_winogrande(model, tokenizer, skip_layers=[], sample_size=num_samples)
print(f"Baseline accuracy: {baseline_acc:.2%}")

removal_history.append({
    'iteration': 0,
    'removed_layers': [],
    'accuracy': baseline_acc
})

for iteration in range(1, max_layers_to_remove + 1):
    print(f"\nIteration {iteration}: finding best layer to remove...")
    best_acc = -1.0
    best_layer = None

    candidate_layers = [l for l in range(num_layers) if l not in removed_layers]
    iteration_log = []

    for layer in tqdm(candidate_layers):
        test_skip_layers = removed_layers | {layer}
        acc = evaluate_winogrande(model, tokenizer, skip_layers=list(test_skip_layers), sample_size=num_samples)

        # Log every layer tested
        iteration_log.append({
            'tested_layer': layer,
            'skip_layers': sorted(list(test_skip_layers)),
            'accuracy': acc
        })

        # Pick the layer whose removal hurts accuracy the least
        if acc >= best_acc:
            best_acc = acc
            best_layer = layer

    # Permanently remove the best layer
    removed_layers.add(best_layer)
    removal_history.append({
        'iteration': iteration,
        'layer_removed': best_layer,
        'removed_layers': sorted(list(removed_layers)),
        'accuracy': best_acc
    })

    # Append iteration log to full log
    full_test_log.append({
        'iteration': iteration,
        'tested_candidates': iteration_log,
        'selected_layer': best_layer,
        'accuracy_after_removal': best_acc
    })

    print(f"Removed layer {best_layer}. New removed set: {sorted(list(removed_layers))}")
    print(f"Accuracy after removal: {best_acc:.2%}")

# Print final summary
print("\nGreedy Layer Pruning Summary:")
for record in removal_history:
    if record['iteration'] == 0:
        print(f"Baseline: Accuracy={record['accuracy']:.2%}")
    else:
        print(f"After removing {record['iteration']} layer(s) ({record['removed_layers']}): Accuracy={record['accuracy']:.2%}")

with open("winogrande_pruning_log.json", "w") as f:
    json.dump(full_test_log, f, indent=4)


Baseline accuracy: 55.47%

Iteration 1: finding best layer to remove...


100%|██████████| 24/24 [03:11<00:00,  7.96s/it]


Removed layer 22. New removed set: [22]
Accuracy after removal: 58.59%

Iteration 2: finding best layer to remove...


100%|██████████| 23/23 [02:56<00:00,  7.68s/it]


Removed layer 18. New removed set: [18, 22]
Accuracy after removal: 62.50%

Iteration 3: finding best layer to remove...


100%|██████████| 22/22 [02:42<00:00,  7.37s/it]


Removed layer 4. New removed set: [4, 18, 22]
Accuracy after removal: 60.94%

Iteration 4: finding best layer to remove...


100%|██████████| 21/21 [02:29<00:00,  7.10s/it]


Removed layer 19. New removed set: [4, 18, 19, 22]
Accuracy after removal: 57.81%

Iteration 5: finding best layer to remove...


100%|██████████| 20/20 [02:17<00:00,  6.86s/it]


Removed layer 16. New removed set: [4, 16, 18, 19, 22]
Accuracy after removal: 58.59%

Iteration 6: finding best layer to remove...


100%|██████████| 19/19 [02:05<00:00,  6.58s/it]

Removed layer 20. New removed set: [4, 16, 18, 19, 20, 22]
Accuracy after removal: 57.81%

Greedy Layer Pruning Summary:
Baseline: Accuracy=55.47%
After removing 1 layer(s) ([22]): Accuracy=58.59%
After removing 2 layer(s) ([18, 22]): Accuracy=62.50%
After removing 3 layer(s) ([4, 18, 22]): Accuracy=60.94%
After removing 4 layer(s) ([4, 18, 19, 22]): Accuracy=57.81%
After removing 5 layer(s) ([4, 16, 18, 19, 22]): Accuracy=58.59%
After removing 6 layer(s) ([4, 16, 18, 19, 20, 22]): Accuracy=57.81%





In [None]:
def hellaswag_prompt(context, ending):
    """
    Combine context + ending for HellaSwag.
    """
    # You can customize how you join context + ending; simplest:
    return context + " " + ending


def evaluate_hellaswag(model, tokenizer, skip_layers, sample_size=256):
    """
    Evaluates GPT-Neo on HellaSwag using perplexity comparison.
    """
    dataset = load_dataset("hellaswag", split="validation")
    dataset = dataset.select(range(sample_size))

    correct = 0

    for example in dataset:
        context = example["ctx"]
        endings = example["endings"]  # list of 4 endings
        label = int(example["label"])     # integer 0..3

        losses = []
        for ending in endings:
            sentence = hellaswag_prompt(context, ending)
            # print(sentence)
            inputs = tokenizer(sentence, return_tensors="pt").to(device)

            with torch.no_grad():
                logits = forward_with_skipped_layers(
                    model, inputs.input_ids, inputs.attention_mask, skip_layers
                )

            # Compute negative log-likelihood
            shift_logits = logits[:, :-1].contiguous()
            shift_labels = inputs.input_ids[:, 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                reduction="mean",
            )
            losses.append(loss.item())

        # Model chooses the ending with the lowest loss
        pred = int(torch.argmin(torch.tensor(losses)))
        if pred == label:
            correct += 1

    accuracy = correct / sample_size
    return accuracy

In [None]:
import json

num_layers = len(model.transformer.h)
max_layers_to_remove = 6
removed_layers = set()
removal_history = []

full_test_log = []
num_samples = 64

baseline_acc = evaluate_hellaswag(model, tokenizer, skip_layers=[], sample_size=num_samples)
print(f"Baseline accuracy: {baseline_acc:.2%}")

removal_history.append({
    'iteration': 0,
    'removed_layers': [],
    'accuracy': baseline_acc
})

for iteration in range(1, max_layers_to_remove + 1):
    print(f"\nIteration {iteration}: finding best layer to remove...")
    best_acc = -1.0
    best_layer = None

    candidate_layers = [l for l in range(num_layers) if l not in removed_layers]
    iteration_log = []

    for layer in tqdm(candidate_layers):
        test_skip_layers = removed_layers | {layer}
        acc = evaluate_hellaswag(model, tokenizer, skip_layers=list(test_skip_layers), sample_size=num_samples)

        # Log every layer tested
        iteration_log.append({
            'tested_layer': layer,
            'skip_layers': sorted(list(test_skip_layers)),
            'accuracy': acc
        })

        # Pick the layer whose removal hurts accuracy the least
        if acc >= best_acc:
            best_acc = acc
            best_layer = layer

    # Permanently remove the best layer
    removed_layers.add(best_layer)
    removal_history.append({
        'iteration': iteration,
        'layer_removed': best_layer,
        'removed_layers': sorted(list(removed_layers)),
        'accuracy': best_acc
    })

    # Append iteration log to full log
    full_test_log.append({
        'iteration': iteration,
        'tested_candidates': iteration_log,
        'selected_layer': best_layer,
        'accuracy_after_removal': best_acc
    })

    print(f"Removed layer {best_layer}. New removed set: {sorted(list(removed_layers))}")
    print(f"Accuracy after removal: {best_acc:.2%}")

# Print final summary
print("\nGreedy Layer Pruning Summary:")
for record in removal_history:
    if record['iteration'] == 0:
        print(f"Baseline: Accuracy={record['accuracy']:.2%}")
    else:
        print(f"After removing {record['iteration']} layer(s) ({record['removed_layers']}): Accuracy={record['accuracy']:.2%}")

with open("hellaswag_pruning_log.json", "w") as f:
    json.dump(full_test_log, f, indent=4)


Baseline accuracy: 34.38%

Iteration 1: finding best layer to remove...


100%|██████████| 24/24 [03:14<00:00,  8.09s/it]


Removed layer 20. New removed set: [20]
Accuracy after removal: 37.50%

Iteration 2: finding best layer to remove...


100%|██████████| 23/23 [02:59<00:00,  7.79s/it]


Removed layer 10. New removed set: [10, 20]
Accuracy after removal: 39.06%

Iteration 3: finding best layer to remove...


100%|██████████| 22/22 [02:45<00:00,  7.54s/it]


Removed layer 14. New removed set: [10, 14, 20]
Accuracy after removal: 42.19%

Iteration 4: finding best layer to remove...


100%|██████████| 21/21 [02:32<00:00,  7.25s/it]


Removed layer 22. New removed set: [10, 14, 20, 22]
Accuracy after removal: 42.19%

Iteration 5: finding best layer to remove...


100%|██████████| 20/20 [02:20<00:00,  7.01s/it]


Removed layer 16. New removed set: [10, 14, 16, 20, 22]
Accuracy after removal: 39.06%

Iteration 6: finding best layer to remove...


100%|██████████| 19/19 [02:07<00:00,  6.70s/it]

Removed layer 19. New removed set: [10, 14, 16, 19, 20, 22]
Accuracy after removal: 37.50%

Greedy Layer Pruning Summary:
Baseline: Accuracy=34.38%
After removing 1 layer(s) ([20]): Accuracy=37.50%
After removing 2 layer(s) ([10, 20]): Accuracy=39.06%
After removing 3 layer(s) ([10, 14, 20]): Accuracy=42.19%
After removing 4 layer(s) ([10, 14, 20, 22]): Accuracy=42.19%
After removing 5 layer(s) ([10, 14, 16, 20, 22]): Accuracy=39.06%
After removing 6 layer(s) ([10, 14, 16, 19, 20, 22]): Accuracy=37.50%



