<a href="https://colab.research.google.com/github/edenlum/DynamicTransformer/blob/main/Dynamic_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Hypothesis 1:

All layers work on roughly the same space, that is, their inputs and outputs are tensors from the same distribution.

If that is true, we can change the order of the layers and they might still make sense. We can also skip some layers and that might make sense.

## Hypothesis 2:

Not all layers are used on all inputs. In other words, there are inputs for which we can skip some of the layers, and the output will not change by much. This is supported by the "circuits" theory where on some tasks you can find a circuit inside the transformer that is made out of a subset of the transformer layers.

In [1]:
!pip install -q datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m40.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 16.1.0 w

In [21]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = DistilBertForSequenceClassification.from_pretrained(model_name)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [22]:
from datasets import load_dataset

dataset = load_dataset("glue", "sst2", split="validation")
texts = dataset["sentence"]
labels = dataset["label"]

Downloading readme:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [23]:
import torch
import random
from torch.utils.data import DataLoader, TensorDataset
import copy

def encode_texts(texts, tokenizer, max_length=512):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=max_length)
    return inputs.input_ids, inputs.attention_mask

def prepare_dataloader(texts, labels, tokenizer, batch_size=32):
    input_ids, attention_mask = encode_texts(texts, tokenizer)
    dataset = TensorDataset(input_ids, attention_mask, torch.tensor(labels))
    return DataLoader(dataset, batch_size=batch_size)

def evaluate_sample(model, input_id, attention_mask, device):
    model.to(device)
    model.eval()
    with torch.no_grad():
        output = model(input_id.unsqueeze(0).to(device), attention_mask=attention_mask.unsqueeze(0).to(device))
    return torch.argmax(output.logits, dim=-1).item()

In [24]:
import torch.nn as nn

def remove_layer(model, layers_to_remove):
    """Removes the specified layers from the model."""
    if not isinstance(layers_to_remove, list):
        layers_to_remove = [layers_to_remove]
    modified_model = copy.deepcopy(model)
    modified_model.distilbert.transformer.layer = nn.ModuleList(
        [layer for i, layer in enumerate(modified_model.distilbert.transformer.layer) if i not in layers_to_remove]
    )
    return modified_model

In [48]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = prepare_dataloader(texts, labels, tokenizer)

def test_hypothesis(model, dataloader, device, layer_to_remove):
    stable_count = 0
    total_count = 0
    correct_count = 0
    correct_count_modified = 0
    unstable_indices = []

    for batch_idx, batch in enumerate(dataloader):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        for i in range(input_ids.size(0)):
            original_output = evaluate_sample(model, input_ids[i], attention_mask[i], device)
            if original_output == labels[i].item():
                correct_count += 1

            modified_model = remove_layer(model, layer_to_remove)
            modified_output = evaluate_sample(modified_model, input_ids[i], attention_mask[i], device)
            if modified_output == labels[i].item():
                correct_count_modified += 1

            if original_output == modified_output:
                stable_count += 1
            else:
                unstable_indices.append(batch_idx * dataloader.batch_size + i)

            total_count += 1

    print(f"Stable samples: {stable_count}")
    print(f"Total samples: {total_count}")
    print(f"Stability rate: {stable_count / total_count:.4f}")
    print(f"Correct samples: {correct_count}")
    print(f"Correct samples modified: {correct_count_modified}")

    return unstable_indices


In [49]:
layer_wrong_idx_dict = {}
for layer_to_remove in range(len(model.distilbert.transformer.layer)):
    print(f"Removing layer {layer_to_remove}")
    layer_wrong_idx_dict[layer_to_remove] = test_hypothesis(model, dataloader, device, [layer_to_remove])
    print(layer_wrong_idx_dict[layer_to_remove])

Removing layer 0
Stable samples: 751
Total samples: 872
Stability rate: 0.8612
Correct samples: 794
Correct samples modified: 725
[33, 39, 46, 62, 64, 69, 76, 83, 84, 88, 97, 98, 102, 107, 115, 118, 122, 142, 149, 154, 157, 161, 172, 183, 184, 186, 193, 195, 196, 203, 215, 220, 221, 224, 235, 236, 243, 249, 272, 274, 276, 279, 285, 292, 315, 322, 323, 332, 354, 356, 394, 395, 400, 411, 420, 422, 428, 434, 435, 445, 447, 448, 454, 456, 462, 467, 477, 485, 490, 517, 519, 520, 524, 528, 554, 558, 579, 604, 606, 612, 617, 624, 626, 632, 633, 634, 643, 652, 667, 671, 678, 684, 691, 735, 741, 742, 753, 756, 760, 765, 766, 770, 771, 782, 784, 787, 790, 793, 801, 812, 823, 824, 830, 831, 832, 843, 847, 850, 862, 863, 864]
Removing layer 1
Stable samples: 800
Total samples: 872
Stability rate: 0.9174
Correct samples: 794
Correct samples modified: 768
[13, 22, 33, 44, 62, 64, 66, 73, 83, 106, 118, 135, 139, 172, 184, 186, 192, 194, 196, 200, 201, 205, 213, 218, 219, 243, 249, 267, 323, 326, 354,

In [56]:
# find indices that in the intersection of all lists
indices_to_remove = set(layer_wrong_idx_dict[0])
for layer_wrong_idx in list(layer_wrong_idx_dict.values()):
    indices_to_remove &= set(layer_wrong_idx)
len(indices_to_remove)

4

set()

In [27]:
for _ in range(5):
    layers_to_remove = random.sample(range(len(model.distilbert.transformer.layer)), 2)
    print(f"Removing layers {layers_to_remove}")
    test_hypothesis(model, dataloader, device, layers_to_remove)

Removing layers [2, 0]
Stable samples: 738
Total samples: 872
Stability rate: 0.8463302752293578
Correct samples: 794
Correct samples modified: 716
Removing layers [2, 3]
Stable samples: 707
Total samples: 872
Stability rate: 0.8107798165137615
Correct samples: 794
Correct samples modified: 675
Removing layers [5, 4]
Stable samples: 675
Total samples: 872
Stability rate: 0.7740825688073395
Correct samples: 794
Correct samples modified: 679
Removing layers [2, 3]
Stable samples: 707
Total samples: 872
Stability rate: 0.8107798165137615
Correct samples: 794
Correct samples modified: 675
Removing layers [4, 2]
Stable samples: 790
Total samples: 872
Stability rate: 0.9059633027522935
Correct samples: 794
Correct samples modified: 752


In [28]:
for _ in range(5):
    layers_to_remove = random.sample(range(len(model.distilbert.transformer.layer)), 3)
    print(f"Removing layers {layers_to_remove}")
    test_hypothesis(model, dataloader, device, layers_to_remove)

Removing layers [4, 3, 0]
Stable samples: 650
Total samples: 872
Stability rate: 0.7454128440366973
Correct samples: 794
Correct samples modified: 646
Removing layers [2, 0, 4]
Stable samples: 726
Total samples: 872
Stability rate: 0.8325688073394495
Correct samples: 794
Correct samples modified: 698
Removing layers [5, 2, 0]
Stable samples: 716
Total samples: 872
Stability rate: 0.8211009174311926
Correct samples: 794
Correct samples modified: 704
Removing layers [2, 5, 3]
Stable samples: 643
Total samples: 872
Stability rate: 0.7373853211009175
Correct samples: 794
Correct samples modified: 619
Removing layers [3, 1, 0]
Stable samples: 695
Total samples: 872
Stability rate: 0.7970183486238532
Correct samples: 794
Correct samples modified: 677


In [29]:
for _ in range(5):
    layers_to_remove = random.sample(range(len(model.distilbert.transformer.layer)), 6)
    print(f"Removing layers {layers_to_remove}")
    test_hypothesis(model, dataloader, device, layers_to_remove)

Removing layers [3, 5, 0, 4, 2, 1]
Stable samples: 460
Total samples: 872
Stability rate: 0.5275229357798165
Correct samples: 794
Correct samples modified: 444
Removing layers [5, 4, 2, 0, 1, 3]


KeyboardInterrupt: 

In [30]:
remove_layer(model, random.sample(range(len(model.distilbert.transformer.layer)), 6))

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList()
    )
  )
  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [31]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 