In [1]:
from transformers import BertTokenizer, BertLMHeadModel, AdamW

In [2]:
from torch.utils.data import Dataset, DataLoader

In [3]:
import torch

In [25]:
import torch.nn.functional as F

In [13]:
import sys
sys.path.append("..")

In [46]:
import os

In [14]:
from stats import AverageMeterSet, StatTracker

### Create Dataset

In [4]:
class CHILDESDataset(Dataset):
    def __init__(self, file_path):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
        
        self.sentences = []
        with open(file_path, "r") as f:
            for line in f:
                text = line.strip()
                self.sentences.append(text)
        encoded_data = self.tokenizer(self.sentences, return_tensors='pt', padding=True, truncation=True)
        
        self.input_ids = encoded_data['input_ids']
        self.token_type_ids = encoded_data['token_type_ids']
        self.attention_mask = encoded_data['attention_mask']

        
    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        batch_dict = {'input_ids': self.input_ids[index], 
                      'token_type_ids': self.token_type_ids[index], 
                      'attention_mask': self.attention_mask[index],
                      'labels': self.input_ids[index]}

        return batch_dict

### Create Dataloaders

In [5]:
file_path = "../../../Data/model-sets/toy_train.txt"
batch_size = 10

In [6]:
train_dataset = CHILDESDataset(file_path)
train_dl = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)

### Train function

In [61]:
def finetune_model(model, train_dataloader, val_dataloader, device, stat_tracker, n_epochs=30, lr=5e-5):
    optimizer = AdamW(model.parameters(), lr)
    for epoch in range(n_epochs):
        model.train()
        epoch_stats = AverageMeterSet()
        for batch in train_dataloader:
            optimizer.zero_grad()
            for key in batch:
                batch[key] = batch[key].to(device)
            outputs = model(**batch)
            loss = outputs[0]
            loss.backward()
            optimizer.step()
            epoch_stats.update('loss', loss, n=1)
        val_accuracy = test_finetuned_model(model, val_dataloader, device, stat_tracker, epoch, prefix="val")
        stat_tracker.record_stats(epoch_stats.averages(epoch, prefix="train"))
        print("val acc: "+ str(val_accuracy))
        print("loss :"+ str(epoch_stats.avgs['loss']))
    return model

### Test function

In [44]:
def test_finetuned_model(model, dataloader, device, stat_tracker, epoch=1, prefix='test'):
    model.eval()
    test_stats = AverageMeterSet()
    batch_size = dataloader.batch_size
    for batch in dataloader:
        for key in batch:
            batch[key] = batch[key].to(device)
        labels = batch['labels']
        outputs = model(**batch)
        distributions = F.log_softmax(outputs.logits, -1)
        predictions = torch.argmax(distributions, dim = -1)
        n_matches = torch.eq(predictions, labels).int()
        max_sequence_len = list(n_matches.size())[1]
        avg_sequence_accs = torch.sum(n_matches, 1)/max_sequence_len
        batch_accuracy = (torch.sum(avg_sequence_accs, 0)/avg_sequence_accs.size()[0]).item()
        test_stats.update('accuracy', batch_accuracy, n=1)
    stat_tracker.record_stats(test_stats.averages(epoch, prefix=prefix))
    
    return test_stats.avgs['accuracy']


### Finetune model

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
model = BertLMHeadModel.from_pretrained("bert-base-multilingual-uncased", return_dict=True, is_decoder = True)

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
model.to(device)

BertLMHeadModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(105879, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [55]:
stat_tracker = StatTracker(log_dir=os.path.join(".","tensorboard-log"))

log_dir: ./tensorboard-log


In [56]:
n_epochs = 3

In [63]:
lr = 1e-4

In [60]:
model = finetune_model(model, train_dl, train_dl, device, stat_tracker, n_epochs)

val acc: 0.6761111497879029
val acc: 0.6761111497879029
loss :1.8540790557861329
val acc: 0.6752020418643951
val acc: 0.6752020418643951
loss :1.8576019287109375
val acc: 0.6755050837993621
val acc: 0.6755050837993621
loss :1.8463512420654298


### Extract surprisal of words

In [None]:
class TokensDataset(Dataset):
    def __init__(self, data_file_path, words_file_path):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
        self.ids_to_words = {}
        self.ids_to_sentences = []
        with open(file_path, "r") as f:
            for line in f:
                text = line.strip()
                self.sentences.append(text)
        encoded_data = self.tokenizer(self.sentences, return_tensors='pt', padding=True, truncation=True)
        
        self.input_ids = encoded_data['input_ids']
        self.token_type_ids = encoded_data['token_type_ids']
        self.attention_mask = encoded_data['attention_mask']

        
    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        batch_dict = {'input_ids': self.input_ids[index], 
                      'token_type_ids': self.token_type_ids[index], 
                      'attention_mask': self.attention_mask[index],
                      'labels': self.input_ids[index]}

        return batch_dict