## Probing Bert part-of-speech tagging

In [5]:
import urllib.request

TRAIN_FILE = "en_ewt-ud-train.conllu"
EVAL_FILE = "en_ewt-ud-dev.conllu"
TEST_FILE = "en_ewt-ud-test.conllu"

for filename in [ TRAIN_FILE, EVAL_FILE, TEST_FILE ]:
  urllib.request.urlretrieve('https://raw.githubusercontent.com/UniversalDependencies/UD_English-EWT/master/' + filename, filename)

In [6]:
import conllu

def load_conllu(filename):
  with open(filename, encoding="utf-8") as fp:
    data = conllu.parse(fp.read())
  sentences = [[token['form'] for token in sentence] for sentence in data]
  taggings = [[token['xpos'] for token in sentence] for sentence in data]
  return sentences, taggings

train_sentences, train_labels = load_conllu(TRAIN_FILE)
eval_sentences, eval_labels = load_conllu(EVAL_FILE)
test_sentences, test_labels = load_conllu(TEST_FILE)
print(list(zip(train_sentences[0], train_labels[0])))

[('Al', 'NNP'), ('-', 'HYPH'), ('Zaman', 'NNP'), (':', ':'), ('American', 'JJ'), ('forces', 'NNS'), ('killed', 'VBD'), ('Shaikh', 'NNP'), ('Abdullah', 'NNP'), ('al', 'NNP'), ('-', 'HYPH'), ('Ani', 'NNP'), (',', ','), ('the', 'DT'), ('preacher', 'NN'), ('at', 'IN'), ('the', 'DT'), ('mosque', 'NN'), ('in', 'IN'), ('the', 'DT'), ('town', 'NN'), ('of', 'IN'), ('Qaim', 'NNP'), (',', ','), ('near', 'IN'), ('the', 'DT'), ('Syrian', 'JJ'), ('border', 'NN'), ('.', '.')]


In [7]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
tokenizer.tokenize("This is interesting for suuuuuuuuuuuuuuuuuure")

['This',
 'is',
 'interesting',
 'for',
 'su',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##u',
 '##ure']

In [8]:
import re

def align_to_bert_tokenization(sentences, labels):
    tokenized_sentences = []
    aligned_labels = []

    for s, l in zip(sentences, labels):
        tokenized_sentence = tokenizer.tokenize(' '.join(s))
        aligned_label = []
        current_word = ''
        i = 0
        for token in tokenized_sentence:
            current_word += re.sub(r'^##', '', token)
            s[i] = s[i].replace('\xad', '')
            
            assert token == '[UNK]' or s[i].startswith(current_word)

            if token == '[UNK]' or s[i] == current_word:
                current_word = ''
                aligned_label.append(l[i])
                i += 1
            else:
                aligned_label.append('<pad>')
        
        assert len(tokenized_sentence) == len(aligned_label)

        tokenized_sentences.append(tokenized_sentence)
        aligned_labels.append(aligned_label)
    
    return tokenized_sentences, aligned_labels


train_bert_tokenized_sentences, train_aligned_taggings = align_to_bert_tokenization(train_sentences, train_labels)
eval_bert_tokenized_sentences, valid_aligned_taggings = align_to_bert_tokenization(eval_sentences, eval_labels)
test_bert_tokenized_sentences, test_aligned_taggings = align_to_bert_tokenization(test_sentences, test_labels)

print(train_bert_tokenized_sentences[42])
print(train_aligned_taggings[42])

['There', 'has', 'been', 'talk', 'that', 'the', 'night', 'cu', '##rf', '##ew', 'might', 'be', 'implemented', 'again', '.']
['EX', 'VBZ', 'VBN', 'NN', 'IN', 'DT', 'NN', '<pad>', '<pad>', 'NN', 'MD', 'VB', 'VBN', 'RB', '.']


In [10]:
import torch
device = torch.device('cpu')

import collections

label_vocab = collections.defaultdict(lambda: len(label_vocab))
label_vocab['<pad>'] = 0

def convert_to_ids(sentences, taggings):
  sentences_ids = []
  taggings_ids = []
  for sentence, tagging in zip(sentences, taggings):
    sentence_tensor = torch.tensor(tokenizer.convert_tokens_to_ids(['[CLS]'] + sentence + ['SEP'])).long()
    tagging_tensor = torch.tensor([0] + [label_vocab[tag] for tag in tagging] + [0]).long()

    sentences_ids.append(sentence_tensor.to(device))
    taggings_ids.append(tagging_tensor.to(device))
  return sentences_ids, taggings_ids

train_sentences_ids, train_taggings_ids = convert_to_ids(train_bert_tokenized_sentences, train_aligned_taggings)
eval_sentences_ids, eval_taggings_ids = convert_to_ids(eval_bert_tokenized_sentences, valid_aligned_taggings)
test_sentences_ids, test_taggings_ids = convert_to_ids(test_bert_tokenized_sentences, test_aligned_taggings)

print(train_sentences_ids[42])
print(train_taggings_ids[42])
print('num labels:', len(label_vocab))

tensor([  101,  1247,  1144,  1151,  2037,  1115,  1103,  1480, 16408, 11931,
         5773,  1547,  1129,  7042,  1254,   119,   100])
tensor([ 0, 30, 22, 19,  9, 10,  8,  9,  0,  0,  9, 13, 14, 19, 23, 11,  0])
num labels: 52


In [11]:
from torch.utils.data import Dataset

class PosTaggingDataset(Dataset):
  def __init__(self, sentences, taggings):
    assert len(sentences) == len(taggings)
    self.sentences = sentences
    self.taggings = taggings

  def __getitem__(self, i):
    return self.sentences[i], self.taggings[i]

  def __len__(self):
    return len(self.sentences)

In [12]:
from torch.utils.data import DataLoader

def collate_fn(items):
  max_len = max(len(item[0]) for item in items)

  sentences = torch.zeros((len(items), max_len), device=items[0][0].device).long().to(device)
  taggings = torch.zeros((len(items), max_len)).long().to(device)

  for i, (sentence, tagging) in enumerate(items):
    sentences[i][0:len(sentence)] = sentence
    taggings[i][0:len(tagging)] = tagging

  return sentences, taggings


batch_size = 64
train_loader = DataLoader(PosTaggingDataset(train_sentences_ids, train_taggings_ids), batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
eval_loader = DataLoader(PosTaggingDataset(eval_sentences_ids, eval_taggings_ids), batch_size=batch_size, collate_fn=collate_fn)
test_loader = DataLoader(PosTaggingDataset(test_sentences_ids, test_taggings_ids), batch_size=batch_size, collate_fn=collate_fn)

In [13]:
import torch.nn as nn
import torch.nn.functional as F

class LinearProbeRandom(nn.Module):
  def __init__(self, num_labels):
    super().__init__()
    self.embedding = nn.Embedding(tokenizer.vocab_size, 768)
    self.probe = nn.Linear(768, num_labels)
    self.to(device)

  def parameters(self):
    return self.probe.parameters()
  
  def forward(self, sentences):
    with torch.no_grad():
      word_rep = self.embedding(sentences)
    return self.probe(word_rep)

random_model = LinearProbeRandom(len(label_vocab))
with torch.no_grad():
    y = random_model(torch.tensor([[0,1,2],[3,4,5]]).to(device))
print(y.shape)

torch.Size([2, 3, 52])


In [14]:
class LinearProbeBert(nn.Module):
  def __init__(self, num_labels):
    super().__init__()
    self.bert = AutoModel.from_pretrained('bert-base-cased')
    self.probe = nn.Linear(self.bert.config.hidden_size, num_labels)
    self.to(device)

  def parameters(self):
    return self.probe.parameters()
  
  def forward(self, sentences):
    with torch.no_grad():
      word_rep, sentence_rep = self.bert(sentences, return_dict=False)
    return self.probe(word_rep)

bert_model = LinearProbeBert(len(label_vocab))
y = bert_model(torch.tensor([[0, 1, 2], [3, 4, 5]]).to(device))
print(y.shape)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([2, 3, 52])


In [22]:
import torch.optim as optim

def perf(model, loader):
  criterion = nn.CrossEntropyLoss()
  model.eval() # do not apply training-specific steps such as dropout
  total_loss = correct = num_loss = num_perf = 0
  for x, y in loader:
    with torch.no_grad(): # no need to store computation graph for gradients
      # perform inference and compute loss
      y_scores = model(x)
      print(y_scores.shape)
      print(y.shape)
      one = y_scores.view(-1, len(label_vocab))
      two = y.view(-1)
      print(one)
      print(two)
      loss = criterion(y_scores.view(-1, len(label_vocab)), y.view(-1)) # requires tensors of shape (num-instances, num-labels) and (num-instances)

      # gather loss statistics
      total_loss += loss.item()
      num_loss += 1

      # gather accuracy statistics
      y_pred = torch.max(y_scores, 2)[1] # compute highest-scoring tag
      mask = (y != 0) # ignore <pad> tags
      correct += torch.sum((y_pred == y) * mask) # compute number of correct predictions
      num_perf += torch.sum(mask).item()
  return total_loss / num_loss, correct.item() / num_perf

def fit(model, epochs):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=1e-2)
  for epoch in range(epochs):
    model.train()
    total_loss = num = 0
    for x, y in train_loader:
      optimizer.zero_grad() # start accumulating gradients
      y_scores = model(x)
      loss = criterion(y_scores.view(-1, len(label_vocab)), y.view(-1))
      loss.backward() # compute gradients though computation graph
      optimizer.step() # modify model parameters
      total_loss += loss.item()
      num += 1
    print(1 + epoch, total_loss / num, *perf(model, eval_loader))


fit(random_model, 5)
#fit(bert_model, 5)

torch.Size([64, 72, 52])
torch.Size([64, 72])
tensor([[  5.3464,  -1.3886,  -7.3429,  ..., -10.1562,  -8.8742, -12.8704],
        [  3.7534,   1.7354,  -3.7966,  ...,  -2.3750,  -2.5712,  -8.7468],
        [  4.8144,   1.6469,  -4.1465,  ...,  -1.2046,  -6.7123, -11.4561],
        ...,
        [  5.8665,  -3.7160, -11.7380,  ..., -16.1012, -10.4523, -19.0601],
        [  5.8665,  -3.7160, -11.7380,  ..., -16.1012, -10.4523, -19.0601],
        [  5.8665,  -3.7160, -11.7380,  ..., -16.1012, -10.4523, -19.0601]])
tensor([ 0, 10,  8,  ...,  0,  0,  0])


KeyboardInterrupt: 