<a href="https://colab.research.google.com/github/edenlum/DynamicTransformer/blob/main/Dynamic_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Hypothesis 1:

All layers work on roughly the same space, that is, their inputs and outputs are tensors from the same distribution.

If that is true, we can change the order of the layers and they might still make sense. We can also skip some layers and that might make sense.

## Hypothesis 2:

Not all layers are used on all inputs. In other words, there are inputs for which we can skip some of the layers, and the output will not change by much. This is supported by the "circuits" theory where on some tasks you can find a circuit inside the transformer that is made out of a subset of the transformer layers.

In [1]:
!pip install -q datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 w

In [2]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = DistilBertForSequenceClassification.from_pretrained(model_name)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [3]:
from datasets import load_dataset

dataset = load_dataset("glue", "sst2", split="validation")
texts = dataset["sentence"]
labels = dataset["label"]

Downloading readme:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [4]:
import torch
import random
from torch.utils.data import DataLoader, TensorDataset
import copy

def encode_texts(texts, tokenizer, max_length=512):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=max_length)
    return inputs.input_ids, inputs.attention_mask

def prepare_dataloader(texts, labels, tokenizer, batch_size=32):
    input_ids, attention_mask = encode_texts(texts, tokenizer)
    dataset = TensorDataset(input_ids, attention_mask, torch.tensor(labels))
    return DataLoader(dataset, batch_size=batch_size)

def evaluate_sample(model, input_id, attention_mask, device):
    model.to(device)
    model.eval()
    with torch.no_grad():
        output = model(input_id.unsqueeze(0).to(device), attention_mask=attention_mask.unsqueeze(0).to(device))
    return torch.argmax(output.logits, dim=-1).item()

In [5]:
import torch.nn as nn

def remove_layer(model, layers_to_remove):
    """Removes the specified layers from the model."""
    if not isinstance(layers_to_remove, list):
        layers_to_remove = [layers_to_remove]
    modified_model = copy.deepcopy(model)
    modified_model.distilbert.transformer.layer = nn.ModuleList(
        [layer for i, layer in enumerate(modified_model.distilbert.transformer.layer) if i not in layers_to_remove]
    )
    return modified_model

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = prepare_dataloader(texts, labels, tokenizer)

def test_hypothesis(model, dataloader, device, layer_to_remove):
    stable_count = 0
    total_count = 0
    correct_count = 0
    correct_count_modified = 0
    unstable_indices = []

    for batch_idx, batch in enumerate(dataloader):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        for i in range(input_ids.size(0)):
            original_output = evaluate_sample(model, input_ids[i], attention_mask[i], device)
            if original_output == labels[i].item():
                correct_count += 1

            modified_model = remove_layer(model, layer_to_remove)
            modified_output = evaluate_sample(modified_model, input_ids[i], attention_mask[i], device)
            if modified_output == labels[i].item():
                correct_count_modified += 1

            if original_output == modified_output:
                stable_count += 1
            else:
                unstable_indices.append(batch_idx * dataloader.batch_size + i)

            total_count += 1

    print(f"Stable samples: {stable_count}")
    print(f"Total samples: {total_count}")
    print(f"Stability rate: {stable_count / total_count:.4f}")
    print(f"Correct samples: {correct_count}")
    print(f"Correct samples modified: {correct_count_modified}")

    return unstable_indices


In [7]:
layer_wrong_idx_dict = {}
for layer_to_remove in range(len(model.distilbert.transformer.layer)):
    print(f"Removing layer {layer_to_remove}")
    layer_wrong_idx_dict[layer_to_remove] = test_hypothesis(model, dataloader, device, [layer_to_remove])
    print(layer_wrong_idx_dict[layer_to_remove])

Removing layer 0
Stable samples: 751
Total samples: 872
Stability rate: 0.8612
Correct samples: 794
Correct samples modified: 725
[33, 39, 46, 62, 64, 69, 76, 83, 84, 88, 97, 98, 102, 107, 115, 118, 122, 142, 149, 154, 157, 161, 172, 183, 184, 186, 193, 195, 196, 203, 215, 220, 221, 224, 235, 236, 243, 249, 272, 274, 276, 279, 285, 292, 315, 322, 323, 332, 354, 356, 394, 395, 400, 411, 420, 422, 428, 434, 435, 445, 447, 448, 454, 456, 462, 467, 477, 485, 490, 517, 519, 520, 524, 528, 554, 558, 579, 604, 606, 612, 617, 624, 626, 632, 633, 634, 643, 652, 667, 671, 678, 684, 691, 735, 741, 742, 753, 756, 760, 765, 766, 770, 771, 782, 784, 787, 790, 793, 801, 812, 823, 824, 830, 831, 832, 843, 847, 850, 862, 863, 864]
Removing layer 1
Stable samples: 800
Total samples: 872
Stability rate: 0.9174
Correct samples: 794
Correct samples modified: 768
[13, 22, 33, 44, 62, 64, 66, 73, 83, 106, 118, 135, 139, 172, 184, 186, 192, 194, 196, 200, 201, 205, 213, 218, 219, 243, 249, 267, 323, 326, 354,

## More tests

In [8]:
# find indices that in the intersection of all lists
indices_to_remove = set(layer_wrong_idx_dict[0])
for layer_wrong_idx in list(layer_wrong_idx_dict.values()):
    indices_to_remove &= set(layer_wrong_idx)
len(indices_to_remove)

4

In [9]:
for _ in range(5):
    layers_to_remove = random.sample(range(len(model.distilbert.transformer.layer)), 2)
    print(f"Removing layers {layers_to_remove}")
    test_hypothesis(model, dataloader, device, layers_to_remove)

Removing layers [1, 4]
Stable samples: 782
Total samples: 872
Stability rate: 0.8968
Correct samples: 794
Correct samples modified: 746
Removing layers [3, 5]
Stable samples: 734
Total samples: 872
Stability rate: 0.8417
Correct samples: 794
Correct samples modified: 724
Removing layers [2, 1]
Stable samples: 754
Total samples: 872
Stability rate: 0.8647
Correct samples: 794
Correct samples modified: 736
Removing layers [4, 3]
Stable samples: 747
Total samples: 872
Stability rate: 0.8567
Correct samples: 794
Correct samples modified: 735
Removing layers [1, 0]
Stable samples: 719
Total samples: 872
Stability rate: 0.8245
Correct samples: 794
Correct samples modified: 707


In [10]:
for _ in range(5):
    layers_to_remove = random.sample(range(len(model.distilbert.transformer.layer)), 3)
    print(f"Removing layers {layers_to_remove}")
    test_hypothesis(model, dataloader, device, layers_to_remove)

Removing layers [5, 4, 3]
Stable samples: 645
Total samples: 872
Stability rate: 0.7397
Correct samples: 794
Correct samples modified: 645
Removing layers [0, 4, 1]
Stable samples: 726
Total samples: 872
Stability rate: 0.8326
Correct samples: 794
Correct samples modified: 708
Removing layers [4, 3, 2]
Stable samples: 692
Total samples: 872
Stability rate: 0.7936
Correct samples: 794
Correct samples modified: 668
Removing layers [5, 4, 1]
Stable samples: 678
Total samples: 872
Stability rate: 0.7775
Correct samples: 794
Correct samples modified: 654
Removing layers [4, 5, 0]
Stable samples: 559
Total samples: 872
Stability rate: 0.6411
Correct samples: 794
Correct samples modified: 571


In [11]:
for _ in range(5):
    layers_to_remove = random.sample(range(len(model.distilbert.transformer.layer)), 6)
    print(f"Removing layers {layers_to_remove}")
    test_hypothesis(model, dataloader, device, layers_to_remove)

Removing layers [1, 2, 4, 3, 5, 0]
Stable samples: 460
Total samples: 872
Stability rate: 0.5275
Correct samples: 794
Correct samples modified: 444
Removing layers [2, 3, 5, 4, 1, 0]
Stable samples: 460
Total samples: 872
Stability rate: 0.5275
Correct samples: 794
Correct samples modified: 444
Removing layers [4, 3, 2, 0, 5, 1]
Stable samples: 460
Total samples: 872
Stability rate: 0.5275
Correct samples: 794
Correct samples modified: 444
Removing layers [5, 3, 2, 1, 0, 4]
Stable samples: 460
Total samples: 872
Stability rate: 0.5275
Correct samples: 794
Correct samples modified: 444
Removing layers [5, 2, 0, 3, 4, 1]
Stable samples: 460
Total samples: 872
Stability rate: 0.5275
Correct samples: 794
Correct samples modified: 444


In [12]:
remove_layer(model, random.sample(range(len(model.distilbert.transformer.layer)), 6))

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList()
    )
  )
  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [13]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

## Training Router


In [35]:
# Modify the network and add the router to choose order of layers
# Define the router model

class Router(nn.Module):
    def __init__(self, input_dim, num_layers):
        super(Router, self).__init__()
        self.fc = nn.Linear(input_dim, num_layers)

    def forward(self, x):
        logits = self.fc(x)
        probs = torch.softmax(logits, dim=-1)
        return probs

# Define the custom block
class RouterBlock(nn.Module):
    def __init__(self, original_layers, input_dim, num_layers):
        super(RouterBlock, self).__init__()
        self.router = Router(input_dim, num_layers)
        self.original_layers = original_layers
        self.num_layers = num_layers

        for layer in self.original_layers:
            for param in layer.parameters():
                param.requires_grad = False

    def forward(self, hidden_states, attention_mask, *args, **kwargs):
        router_probs = self.router(hidden_states[:, 0, :])  # Take [CLS] token representation
        max_prob, selected_layer = torch.max(router_probs, dim=-1)

        batch_size = hidden_states.size(0)
        new_hidden_states = torch.zeros_like(hidden_states)

        for i in range(batch_size):
            layer_output = self.original_layers[selected_layer[i]](hidden_states[i].unsqueeze(0), attention_mask[i].unsqueeze(0))[0]
            new_hidden_states[i] = max_prob[i].unsqueeze(0).unsqueeze(0) * layer_output

        return (new_hidden_states,)

# add a router to the model
def add_router(model):
    modified_model = copy.deepcopy(model)
    num_layers = len(modified_model.distilbert.transformer.layer)
    input_dim = modified_model.distilbert.transformer.layer[0].attention.q_lin.in_features
    router_block = RouterBlock(modified_model.distilbert.transformer.layer, input_dim, num_layers)
    modified_model.distilbert.transformer.layer = nn.ModuleList([router_block] * num_layers)
    return modified_model

routed_model = add_router(model)

In [36]:
from torch.utils.data import DataLoader, TensorDataset

# Load the training and validation datasets
train_dataset = load_dataset("glue", "sst2", split="train")
validation_dataset = load_dataset("glue", "sst2", split="validation")

# Encode the texts
train_texts = train_dataset["sentence"]
train_labels = train_dataset["label"]
validation_texts = validation_dataset["sentence"]
validation_labels = validation_dataset["label"]

encoded_train_texts = encode_texts(train_texts, tokenizer)
encoded_validation_texts = encode_texts(validation_texts, tokenizer)

# Create TensorDatasets
train_labels_tensor = torch.tensor(train_labels)
validation_labels_tensor = torch.tensor(validation_labels)

train_dataset = TensorDataset(encoded_train_texts[0], encoded_train_texts[1], train_labels_tensor)
validation_dataset = TensorDataset(encoded_validation_texts[0], encoded_validation_texts[1], validation_labels_tensor)

# Create DataLoaders
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

In [37]:
from tqdm import tqdm

def eval(model, validatoin_dataloader, epoch):
    # Evaluate on the validation set
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for batch in tqdm(validation_dataloader):
            input_ids, attention_mask, labels = [x.to(device) for x in batch]
            outputs = model(input_ids, attention_mask)
            _, predicted = torch.max(outputs.logits, dim=1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_accuracy = val_correct / val_total
    print(f"Validation Accuracy after epoch {epoch + 1}: {val_accuracy:.4f}")


# Define the training function
def train_router_model(model, train_dataloader, validation_dataloader, epochs=3, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0
        total = 0

        for batch in tqdm(train_dataloader):
            input_ids, attention_mask, labels = [x.to(device) for x in batch]

            optimizer.zero_grad()

            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs.logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.logits, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct / total
        avg_loss = total_loss / len(train_dataloader)

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        eval(model, validation_dataloader, epoch)
        model.train()

# Train the router model
train_router_model(routed_model, train_dataloader, validation_dataloader, epochs=3, lr=1e-4)

  4%|▍         | 172/4210 [00:51<19:58,  3.37it/s]


KeyboardInterrupt: 

In [None]:
eval(routed_model, validation_dataloader, 1)
eval(model, validation_dataloader, 1)