<a href="https://colab.research.google.com/github/digwit678/Can-Language-Models-Follow-Discussions/blob/main/Probing_POS_Example_with_Training_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers



In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel
import spacy

In [None]:
# Load pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Training

In [None]:
# Load spaCy model
nlp = spacy.load("en_core_web_sm")

# extract part-of-speech labels from text using spaCy
def extract_pos_labels(text):
    doc = nlp(text)
    labels = []
    for token in doc:
        labels.append(token.pos_)
    return labels

# basic probing classifier, that will be trained or POS-Task
class PosClassifier(torch.nn.Module):
    def __init__(self, input_size, num_labels):
        super().__init__()
        self.fc = torch.nn.Linear(input_size, num_labels)

    def forward(self, x):
        return self.fc(x)

# prepare data for training the probing classifier
def prepare_data(sentences, labels, batch_size):

    # Convert train sentences to BERT input format
    inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)

    # Convert part-of-speech labels to tensor
    label2id = {label: i for i, label in enumerate(set(label for labels_ in labels for label in labels_))}
    labels = [torch.tensor([label2id[label] for label in labels_]) for labels_ in labels]
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-1)

    # Create data loader
    data = list(zip(inputs['input_ids'], inputs['attention_mask'], labels))
    data_loader = DataLoader(data, batch_size=batch_size, shuffle=True)

    return data_loader, len(label2id)

# training loop for the probing classifier
def train_probing_classifier(probing_classifier, data_loader, optimizer, num_epochs):
    for epoch in range(num_epochs):
        for batch in data_loader:

            # Get the input sentences and labels
            input_ids, attention_mask, labels = batch

            # Get the BERT output
            with torch.no_grad():
                bert_output = model(input_ids, attention_mask).last_hidden_state[:, 0]

            # Get the probing classifier predictions
            logits = probing_classifier(bert_output)
            loss = torch.nn.functional.cross_entropy(logits, labels[:, 0], ignore_index=-1)

            # Update POS-classifier
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

# Training Example
sentences = ["The quick brown fox jumps over the lazy dog.", "I love eating pizza."]
pos_labels = [extract_pos_labels(sentence) for sentence in sentences]
data_loader, num_labels = prepare_data(sentences, pos_labels, batch_size=2)
probing_classifier = PosClassifier(model.config.hidden_size, num_labels)
optimizer = torch.optim.Adam(probing_classifier.parameters())
train_probing_classifier(probing_classifier, data_loader, optimizer, num_epochs=10)


In [None]:
# Training Example

sentences = ["The quick brown fox jumps over the lazy dog.", "I love eating pizza."]
pos_labels = [extract_pos_labels(sentence) for sentence in sentences]
data_loader, num_labels = prepare_data(sentences, pos_labels, batch_size=2)
probing_classifier = PosClassifier(model.config.hidden_size, num_labels)
optimizer = torch.optim.Adam(probing_classifier.parameters())
train_probing_classifier(probing_classifier, data_loader, optimizer, num_epochs=10)

print("Starting Sentence: \n", sentences)
print("Sentence embedded: \n", pos_labels)

Starting Sentence: 
 ['The quick brown fox jumps over the lazy dog.', 'I love eating pizza.']
Sentence embedded: 
 [['DET', 'ADJ', 'ADJ', 'NOUN', 'VERB', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT'], ['PRON', 'VERB', 'VERB', 'NOUN', 'PUNCT']]


In [None]:
print(probing_classifier)

PosClassifier(
  (fc): Linear(in_features=768, out_features=7, bias=True)
)


In [None]:
print(optimizer)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)


# Testing

In [None]:
# Define a test-loop for the probing classifier: can he predict the POS labels of this test data if he uses BERTs learned embeddings ?

def test_probing_classifier(probing_classifier, test_data_loader):
  # input is the classifier trained on bert and the True values in the test data loader
    correct = 0
    total = 0
    with torch.no_grad(): # no learning for probing
        for batch in test_data_loader:
            input_ids, attention_mask, labels = batch
            encoder_output = model(input_ids, attention_mask).last_hidden_state[:, 0] # convert data to BERTs encoders representation
            logits = probing_classifier(encoder_output) # forward pass of the encoded data through the trained classifier
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0) # count of all labelled sentences
            correct += (predicted == labels[:, 0]).sum().item() # POS labels the classifier correctly identified with BERTs encoder

    accuracy = 100 * correct / total
    return accuracy

# Prepare the data to test if the probing classifier can recognize POS with BERTs embeddings
test_sentences = ["The cat sat on the mat.", "She enjoys reading books."]
test_pos_labels = [extract_pos_labels(sentence) for sentence in test_sentences] # extracted with spacy (TRUE)
print("labels extracted for testing", test_pos_labels)
test_data_loader, _ = prepare_data(test_sentences, test_pos_labels, batch_size=2) #

# Test the probing classifier on BERTs encoder
accuracy = test_probing_classifier(probing_classifier, test_data_loader)

# Conclude if the model has learned POS or not
if accuracy > 80:  # threshold according to task (benchmarks?)
    print("The model has learned to recognize POS labels with an accuracy of {:.2f}%.".format(accuracy))
else:
    print("The model has not learned to recognize POS labels effectively. Accuracy: {:.2f}%.".format(accuracy))


labels extracted for testing [['DET', 'NOUN', 'VERB', 'ADP', 'DET', 'NOUN', 'PUNCT'], ['PRON', 'VERB', 'VERB', 'NOUN', 'PUNCT']]
The model has learned to recognize POS labels with an accuracy of 100.00%.


# Whole Process

In [None]:
# Load pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# Prepare sentences and their corresponding POS labels for training
sentences = ["The quick brown fox jumps over the lazy dog.", "I love eating pizza."]
pos_labels = [extract_pos_labels(sentence) for sentence in sentences]
# Initialize the POS Classifier and train it
probing_classifier = PosClassifier(model.config.hidden_size, num_labels)
optimizer = torch.optim.Adam(probing_classifier.parameters())
train_probing_classifier(probing_classifier, data_loader, optimizer, num_epochs=10)
# Prepare test sentences and their corresponding POS labels
test_sentences = ["The cat sat on the mat.", "She enjoys reading books."]
test_pos_labels = [extract_pos_labels(sentence) for sentence in test_sentences]
# Probe the classifier (of the model being probed on POS) and interpret the results
accuracy = test_probing_classifier(probing_classifier, test_data_loader)
