<a href="https://colab.research.google.com/github/melodyjansen/AML-local-attention/blob/main/local_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers datasets --quiet

In [2]:
import torch
from transformers import BertModel, BertTokenizer, BertForSequenceClassification, BertConfig
from torch import nn
from transformers.models.bert.modeling_bert import BertSelfAttention, BertLayer, BertEncoder, BertModel, BertAttention
from torch.utils.data import DataLoader, TensorDataset
from sklearn.utils import shuffle
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

In [3]:
# Load pretrained BERT and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [4]:
# Step 1: Custom attention class
class CustomBertSelfAttention(BertSelfAttention):
    def transpose_for_scores(self, x):
      new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
      x = x.view(new_x_shape)
      return x.permute(0, 2, 1, 3)
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        custom_scores_mask=None,
    ):
        # Standard BERT attention
        mixed_query_layer = self.query(hidden_states)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / (self.attention_head_size ** 0.5)

        # Inject custom attention scores mask
        if custom_scores_mask is not None:
            attention_scores += custom_scores_mask

        if attention_mask is not None:
            # Expand attention_mask: (batch_size, 1, 1, seq_len)
            extended_mask = attention_mask[:, None, None, :]  # add head and query dims
            extended_mask = (1.0 - extended_mask) * -10000.0  # 0 --> 0, 1 --> -10000
            attention_scores = attention_scores + extended_mask


        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        if head_mask is not None:
            attention_probs = attention_probs * head_mask


        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context = context_layer.view(context_layer.size(0), -1, self.all_head_size)

        return (new_context,)


In [5]:
class CustomBertAttention(BertAttention):
    def __init__(self, config):
        super().__init__(config)
        self.self = CustomBertSelfAttention(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        custom_scores_mask=None,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            custom_scores_mask=custom_scores_mask,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        return attention_output


In [6]:
# Step 2: Patch model to use custom attention

class CustomBertLayer(BertLayer):
    def __init__(self, config):
        super().__init__(config)
        self.attention = CustomBertAttention(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        custom_scores_mask=None,
    ):
        attention_output = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            custom_scores_mask=custom_scores_mask
        )
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


from transformers.models.bert.modeling_bert import BertEncoder

class CustomBertEncoder(BertEncoder):
    def __init__(self, config):
        super().__init__(config)
        self.layer = nn.ModuleList([CustomBertLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        custom_scores_mask=None,
    ):
        for i, layer_module in enumerate(self.layer):
            layer_head_mask = head_mask[i] if head_mask is not None else None
            hidden_states = layer_module(
                hidden_states,
                attention_mask,
                layer_head_mask,
                custom_scores_mask=custom_scores_mask
            )
        return hidden_states

class CustomBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)
        self.encoder = CustomBertEncoder(config)

    def forward(self, input_ids, attention_mask=None, custom_scores_mask=None):
        embedding_output = self.embeddings(input_ids=input_ids)
        encoder_output = self.encoder(
            embedding_output,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )
        return encoder_output


# Train for sentiment analysis

In [7]:
class CustomBertForSequenceClassification(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.bert = CustomBertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None, custom_scores_mask=None, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )
        # Grab [CLS] token (first token's output)
        cls_output = outputs[:, 0]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return loss, logits
        return logits


# Dataset stuff

In [8]:
!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xzf aclImdb_v1.tar.gz


--2025-08-11 18:27:30--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz’


2025-08-11 18:28:02 (2.50 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]



In [9]:
import os

def load_imdb_data(data_dir):
    texts = []
    labels = []
    for label_type in ['pos', 'neg']:
        dir_name = os.path.join(data_dir, label_type)
        for fname in os.listdir(dir_name):
            if fname.endswith('.txt'):
                with open(os.path.join(dir_name, fname), encoding='utf-8') as f:
                    texts.append(f.read())
                labels.append(1 if label_type == 'pos' else 0)
    return texts, labels

train_texts, train_labels = load_imdb_data('aclImdb/train')
test_texts, test_labels = load_imdb_data('aclImdb/test')

print(f'Training samples: {len(train_texts)}')
print(f'Testing samples: {len(test_texts)}')
print(train_texts[0], train_labels[0])


Training samples: 25000
Testing samples: 25000
If Dick Tracy was in black and white, the pope wouldn't be religious. Giving a new sense to the concept of color in a movie, we are offered an unique experience throughout a comic-strip world, and it's one of the few movies which succeeded in doing so, thanks to a serious script, good direction, great performances (Al Pacino is astonishing) and most importantly a powerful mix of cinematography, art direction and costume design. Using only primary colors, the experience is quite different from anything we have seen before. And there is also a quite successful hommage to all the gangster-movie genre, pratically extinct from modern cinema. Overall, I see this movie as a fresh attempt and a touch of originality to a cinema which relies more and more on the old and already-seen formulas. 7 out of 10. 1


In [10]:
def encode_texts(texts, labels, tokenizer, max_length=64):
    encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
    input_ids = torch.tensor(encodings['input_ids'])
    attention_mask = torch.tensor(encodings['attention_mask'])
    labels = torch.tensor(labels)
    return TensorDataset(input_ids, attention_mask, labels)

# Shuffle before taking subset
train_texts, train_labels = shuffle(train_texts, train_labels, random_state=42)
test_texts, test_labels = shuffle(test_texts, test_labels, random_state=42)

# Subset
train_texts = train_texts[:1000]
train_labels = train_labels[:1000]
test_texts = test_texts[:200]
test_labels = test_labels[:200]

train_dataset = encode_texts(train_texts, train_labels, tokenizer)
test_dataset = encode_texts(test_texts, test_labels, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)


In [11]:
def create_local_attention_mask(seq_len, window_size, heads):
    mask = torch.full((1, heads, seq_len, seq_len), float('-inf'))
    for i in range(seq_len):
        mask[0, :, i, max(0, i-window_size):i+window_size+1] = 0
    return mask

## Window size = 3

In [12]:
# instantiate custom model with local attention
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = BertConfig.from_pretrained(model_name)
model = CustomBertForSequenceClassification(config, num_labels=2)  # binary classification
model.to(device)

# Load pretrained BERT weights
pretrained_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Copy weights to custom model
model = CustomBertForSequenceClassification(config, num_labels=2)
model.bert.embeddings.load_state_dict(pretrained_model.bert.embeddings.state_dict())
model.classifier.load_state_dict(pretrained_model.classifier.state_dict())

optimizer = optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

# Parameters
epochs = 3
batch_size = 8
window_size = 3  # for local attention
num_heads = 12

model.train()
for epoch in range(epochs):
    total_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        # Generate local attention mask per batch
        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask,
            labels=labels,
        )

        loss, logits = outputs
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} finished. Average loss: {avg_loss:.4f}")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 125/125 [13:20<00:00,  6.40s/it]


Epoch 1 finished. Average loss: 0.7849


100%|██████████| 125/125 [13:08<00:00,  6.31s/it]


Epoch 2 finished. Average loss: 0.6828


100%|██████████| 125/125 [12:52<00:00,  6.18s/it]

Epoch 3 finished. Average loss: 0.6498





In [14]:
import numpy as np
print("Train labels distribution:", np.bincount(train_labels))
print("Test labels distribution:", np.bincount(test_labels))


Train labels distribution: [489 511]
Test labels distribution: [ 95 105]


In [15]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )

        logits = outputs
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Metrics
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=["Negative", "Positive"]))

print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))


100%|██████████| 25/25 [01:03<00:00,  2.52s/it]


Classification Report:
              precision    recall  f1-score   support

    Negative       0.48      0.96      0.64        95
    Positive       0.64      0.07      0.12       105

    accuracy                           0.49       200
   macro avg       0.56      0.51      0.38       200
weighted avg       0.56      0.49      0.37       200

Confusion Matrix:
[[91  4]
 [98  7]]





## Window size = 15

In [16]:
# instantiate custom model with local attention
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = BertConfig.from_pretrained(model_name)
model = CustomBertForSequenceClassification(config, num_labels=2)  # binary classification
model.to(device)

# Load pretrained BERT weights
pretrained_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Copy weights to custom model
model = CustomBertForSequenceClassification(config, num_labels=2)
model.bert.embeddings.load_state_dict(pretrained_model.bert.embeddings.state_dict())
model.classifier.load_state_dict(pretrained_model.classifier.state_dict())

optimizer = optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

# Parameters
epochs = 3
batch_size = 8
window_size = 15  # for local attention
num_heads = 12

model.train()
for epoch in range(epochs):
    total_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        # Generate local attention mask per batch
        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask,
            labels=labels,
        )

        loss, logits = outputs
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} finished. Average loss: {avg_loss:.4f}")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 125/125 [12:32<00:00,  6.02s/it]


Epoch 1 finished. Average loss: 0.7649


100%|██████████| 125/125 [12:21<00:00,  5.93s/it]


Epoch 2 finished. Average loss: 0.6820


100%|██████████| 125/125 [12:21<00:00,  5.93s/it]

Epoch 3 finished. Average loss: 0.5464





In [17]:
print("Train labels distribution:", np.bincount(train_labels))
print("Test labels distribution:", np.bincount(test_labels))


Train labels distribution: [489 511]
Test labels distribution: [ 95 105]


In [18]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )

        logits = outputs
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Metrics
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=["Negative", "Positive"]))

print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Classification Report:
              precision    recall  f1-score   support

    Negative       0.49      0.96      0.65        95
    Positive       0.73      0.10      0.18       105

    accuracy                           0.51       200
   macro avg       0.61      0.53      0.42       200
weighted avg       0.62      0.51      0.41       200

Confusion Matrix:
[[91  4]
 [94 11]]





## Window size = 30

In [19]:
# instantiate custom model with local attention
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = BertConfig.from_pretrained(model_name)
model = CustomBertForSequenceClassification(config, num_labels=2)  # binary classification
model.to(device)

# Load pretrained BERT weights
pretrained_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Copy weights to custom model
model = CustomBertForSequenceClassification(config, num_labels=2)
model.bert.embeddings.load_state_dict(pretrained_model.bert.embeddings.state_dict())
model.classifier.load_state_dict(pretrained_model.classifier.state_dict())

optimizer = optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

# Parameters
epochs = 3
batch_size = 8
window_size = 30  # for local attention
num_heads = 12

model.train()
for epoch in range(epochs):
    total_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        # Generate local attention mask per batch
        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask,
            labels=labels,
        )

        loss, logits = outputs
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} finished. Average loss: {avg_loss:.4f}")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 125/125 [12:26<00:00,  5.97s/it]


Epoch 1 finished. Average loss: 0.8093


100%|██████████| 125/125 [12:16<00:00,  5.90s/it]


Epoch 2 finished. Average loss: 0.7093


100%|██████████| 125/125 [12:17<00:00,  5.90s/it]

Epoch 3 finished. Average loss: 0.6180





In [20]:
print("Train labels distribution:", np.bincount(train_labels))
print("Test labels distribution:", np.bincount(test_labels))


Train labels distribution: [489 511]
Test labels distribution: [ 95 105]


In [21]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )

        logits = outputs
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Metrics
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=["Negative", "Positive"]))

print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Classification Report:
              precision    recall  f1-score   support

    Negative       0.74      0.45      0.56        95
    Positive       0.63      0.86      0.73       105

    accuracy                           0.67       200
   macro avg       0.69      0.65      0.65       200
weighted avg       0.68      0.67      0.65       200

Confusion Matrix:
[[43 52]
 [15 90]]



