<a href="https://colab.research.google.com/github/melodyjansen/AML-local-attention/blob/main/local_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
!pip install transformers datasets --quiet

In [17]:
import torch
from transformers import BertModel, BertTokenizer

In [18]:
# Load pretrained BERT and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

In [19]:
from torch import nn
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertSelfAttention, BertLayer, BertEncoder, BertModel

# Step 1: Custom attention class
class CustomBertSelfAttention(BertSelfAttention):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        custom_scores_mask=None,
    ):
        # Standard BERT attention
        mixed_query_layer = self.query(hidden_states)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / (self.attention_head_size ** 0.5)

        # Inject custom attention scores mask
        if custom_scores_mask is not None:
            attention_scores += custom_scores_mask

        if attention_mask is not None:
            # Expand attention_mask: (batch_size, 1, 1, seq_len)
            extended_mask = attention_mask[:, None, None, :]  # add head and query dims
            extended_mask = (1.0 - extended_mask) * -10000.0  # 0 --> 0, 1 --> -10000
            attention_scores = attention_scores + extended_mask


        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        if head_mask is not None:
            attention_probs = attention_probs * head_mask


        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context = context_layer.view(context_layer.size(0), -1, self.all_head_size)

        return (new_context,)


In [20]:
from transformers.models.bert.modeling_bert import BertAttention

class CustomBertAttention(BertAttention):
    def __init__(self, config):
        super().__init__(config)
        self.self = CustomBertSelfAttention(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        custom_scores_mask=None,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            custom_scores_mask=custom_scores_mask,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        return attention_output


In [21]:

# Step 2: Patch model to use custom attention
from transformers.models.bert.modeling_bert import BertLayer

class CustomBertLayer(BertLayer):
    def __init__(self, config):
        super().__init__(config)
        self.attention = CustomBertAttention(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        custom_scores_mask=None,
    ):
        attention_output = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            custom_scores_mask=custom_scores_mask
        )
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


from transformers.models.bert.modeling_bert import BertEncoder

class CustomBertEncoder(BertEncoder):
    def __init__(self, config):
        super().__init__(config)
        self.layer = nn.ModuleList([CustomBertLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        custom_scores_mask=None,
    ):
        for i, layer_module in enumerate(self.layer):
            layer_head_mask = head_mask[i] if head_mask is not None else None
            hidden_states = layer_module(
                hidden_states,
                attention_mask,
                layer_head_mask,
                custom_scores_mask=custom_scores_mask
            )
        return hidden_states


from transformers.models.bert.modeling_bert import BertModel

class CustomBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)
        self.encoder = CustomBertEncoder(config)

    def forward(self, input_ids, attention_mask=None, custom_scores_mask=None):
        embedding_output = self.embeddings(input_ids=input_ids)
        encoder_output = self.encoder(
            embedding_output,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )
        return encoder_output


# Train for sentiment analysis

In [23]:
import torch.nn as nn

class CustomBertForSequenceClassification(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.bert = CustomBertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None, custom_scores_mask=None, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )
        # Grab [CLS] token (first token's output)
        cls_output = outputs[:, 0]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return loss, logits
        return logits


# Dataset stuff

In [24]:
!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xzf aclImdb_v1.tar.gz


--2025-07-06 13:52:45--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz.1’


2025-07-06 13:52:53 (10.3 MB/s) - ‘aclImdb_v1.tar.gz.1’ saved [84125825/84125825]



In [25]:
import os

def load_imdb_data(data_dir):
    texts = []
    labels = []
    for label_type in ['pos', 'neg']:
        dir_name = os.path.join(data_dir, label_type)
        for fname in os.listdir(dir_name):
            if fname.endswith('.txt'):
                with open(os.path.join(dir_name, fname), encoding='utf-8') as f:
                    texts.append(f.read())
                labels.append(1 if label_type == 'pos' else 0)
    return texts, labels

train_texts, train_labels = load_imdb_data('aclImdb/train')
test_texts, test_labels = load_imdb_data('aclImdb/test')

print(f'Training samples: {len(train_texts)}')
print(f'Testing samples: {len(test_texts)}')
print(train_texts[0], train_labels[0])


Training samples: 25000
Testing samples: 25000
There are enough sad stories about women and their oppression by religious, political and societal means. Not to diminish the films and stories about genital mutilation and reproductive rights, as well as wage inequality, and marginalization in society, all in the name of Allah or God or some other ridiculous justification, but sometimes it is helpful to just take another approach and shed some light on the subject.<br /><br />The setting is the 2006 match between Iran and Bahrain to qualify for the World Cup. Passions are high and several women try to disguise themselves as men to get into the match.<br /><br />The women who were caught (Played by Sima Mobarak-Shahi, Shayesteh Irani, Ayda Sadeqi, Golnaz Farmani, and Mahnaz Zabihi) and detained for prosecution provided a funny and illuminating glimpse into the customs of this country and, most likely, all Muslim countries. Their interaction with the Iranian soldiers who were guarding and t

In [27]:
from torch.utils.data import DataLoader, TensorDataset
import torch
from sklearn.utils import shuffle

def encode_texts(texts, labels, tokenizer, max_length=64):
    encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
    input_ids = torch.tensor(encodings['input_ids'])
    attention_mask = torch.tensor(encodings['attention_mask'])
    labels = torch.tensor(labels)
    return TensorDataset(input_ids, attention_mask, labels)

# Shuffle before taking subset
train_texts, train_labels = shuffle(train_texts, train_labels, random_state=42)
test_texts, test_labels = shuffle(test_texts, test_labels, random_state=42)

# Subset
train_texts = train_texts[:1000]
train_labels = train_labels[:1000]
test_texts = test_texts[:200]
test_labels = test_labels[:200]

train_dataset = encode_texts(train_texts, train_labels, tokenizer)
test_dataset = encode_texts(test_texts, test_labels, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)


In [28]:
def create_local_attention_mask(seq_len, window_size, heads):
    mask = torch.full((1, heads, seq_len, seq_len), float('-inf'))
    for i in range(seq_len):
        mask[0, :, i, max(0, i-window_size):i+window_size+1] = 0
    return mask

In [30]:
import torch.optim as optim
import torch.nn as nn

# instantiate custom model with local attention
config = BertConfig.from_pretrained(model_name)
model = CustomBertForSequenceClassification(config, num_labels=2)  # binary classification
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer
from tqdm import tqdm
import torch

# Parameters
epochs = 3
batch_size = 8
window_size = 3  # for local attention
num_heads = 12

model.train()
for epoch in range(epochs):
    total_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        # Generate local attention mask per batch
        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask,
            labels=labels,
        )

        loss, logits = outputs
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} finished. Average loss: {avg_loss:.4f}")



100%|██████████| 125/125 [12:43<00:00,  6.11s/it]


Epoch 1 finished. Average loss: 0.7759


100%|██████████| 125/125 [12:25<00:00,  5.97s/it]


Epoch 2 finished. Average loss: 0.7318


100%|██████████| 125/125 [12:15<00:00,  5.88s/it]

Epoch 3 finished. Average loss: 0.7074





In [31]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        seq_len = input_ids.size(1)
        custom_scores_mask = create_local_attention_mask(seq_len, window_size, num_heads).to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            custom_scores_mask=custom_scores_mask
        )

        logits = outputs
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Metrics
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=["Negative", "Positive"]))

print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Classification Report:
              precision    recall  f1-score   support

    Negative       0.00      0.00      0.00        95
    Positive       0.53      1.00      0.69       105

    accuracy                           0.53       200
   macro avg       0.26      0.50      0.34       200
weighted avg       0.28      0.53      0.36       200

Confusion Matrix:
[[  0  95]
 [  0 105]]



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [32]:
print("Unique labels:", np.unique(all_labels))
print("Unique predictions:", np.unique(all_preds))


Unique labels: [0 1]
Unique predictions: [1]


In [33]:
print("Train labels distribution:", np.bincount(train_labels))
print("Test labels distribution:", np.bincount(test_labels))


Train labels distribution: [489 511]
Test labels distribution: [ 95 105]
