In [1]:
from transformers import AutoTokenizer, AutoModel
import senteval
import torch
import torch.nn as nn
import json
import os

fs = os.listdir("..")
if not "SentEval" in fs:
    %cd ./..
    !git clone git@github.com:facebookresearch/SentEval.git
    %cd notebooks
    %mkdir data

PATH_TO_DATA = "../SentEval/data/probing/"
DEVICE = "cuda:1"

# BERT

In [2]:
def batch_to_device(d, device):
    return {k: v.to(device) for k, v in d.items()}
    
class MeanPooling(nn.Module):
    def __init__(self, starting_state):
        super().__init__()
        self.starting_state = starting_state

    def forward(self, x, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.size()).float()
        emb_sum = torch.sum(x * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) # denominator
        emb_mean = emb_sum / sum_mask
        return emb_mean

    
class Bert:

    def __init__(self, starting_state=12, path=None):
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        if path is None:
            self.model = AutoModel.from_pretrained("bert-base-cased").to(DEVICE)
        else:
            self.model = torch.load(path).to(DEVICE)
        self.pooling = MeanPooling(starting_state)
    
    def prepare(self, params, samples):
        pass 
    
    @torch.no_grad()
    def batcher(self, params, batch):
        tokenized_batch = self.tokenizer(
            batch, truncation=True, padding=True, return_tensors="pt", is_split_into_words=True
        )
        batch_device = batch_to_device(tokenized_batch, DEVICE)
        out = self.model(
            **batch_device, output_hidden_states=True
        ).hidden_states[self.pooling.starting_state]
        out_mean = self.pooling(out, batch_device["attention_mask"])
        return out_mean.cpu()

params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, 'epoch_size': 4}

bert = Bert(starting_state=12, path="../output/bert-base-cased/mean/12_to_13/model_2024_01_01_03_16.pkl")
se = senteval.engine.SE(params, bert.batcher, bert.prepare)

transfer_tasks = [
    'Length', 
    'WordContent', 
    'Depth', 
    'TopConstituents', 
    'BigramShift', 
    'Tense', 
    'SubjNumber', 
    'ObjNumber', 
    'OddManOut', 
    'CoordinationInversion'
]

results = se.eval(transfer_tasks)
print(results)


2024-01-09 00:18:30.347945: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-09 00:18:30.347977: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-09 00:18:30.348005: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-09 00:18:30.353796: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


{'Length': {'devacc': 69.92, 'acc': 70.21, 'ndev': 9996, 'ntest': 9996}, 'WordContent': {'devacc': 57.66, 'acc': 57.61, 'ndev': 10000, 'ntest': 10000}, 'Depth': {'devacc': 31.21, 'acc': 31.33, 'ndev': 10000, 'ntest': 10000}, 'TopConstituents': {'devacc': 63.47, 'acc': 63.54, 'ndev': 10000, 'ntest': 10000}, 'BigramShift': {'devacc': 83.46, 'acc': 82.65, 'ndev': 10000, 'ntest': 10000}, 'Tense': {'devacc': 88.26, 'acc': 86.16, 'ndev': 10000, 'ntest': 10000}, 'SubjNumber': {'devacc': 82.22, 'acc': 80.76, 'ndev': 10000, 'ntest': 10000}, 'ObjNumber': {'devacc': 77.1, 'acc': 78.45, 'ndev': 10000, 'ntest': 10000}, 'OddManOut': {'devacc': 64.27, 'acc': 64.33, 'ndev': 10000, 'ntest': 10000}, 'CoordinationInversion': {'devacc': 66.7, 'acc': 67.01, 'ndev': 10002, 'ntest': 10002}}


# Finetuned state 12

In [4]:
print(json.dumps({f"{k}_acc": v["acc"] for k,v in results.items()}, indent=4))

{
    "Length_acc": 70.21,
    "WordContent_acc": 57.61,
    "Depth_acc": 31.33,
    "TopConstituents_acc": 63.54,
    "BigramShift_acc": 82.65,
    "Tense_acc": 86.16,
    "SubjNumber_acc": 80.76,
    "ObjNumber_acc": 78.45,
    "OddManOut_acc": 64.33,
    "CoordinationInversion_acc": 67.01
}


# Pretrained (no finetuning)

In [11]:
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, 'epoch_size': 4}

bert2 = Bert(starting_state=12, path=None)
se2 = senteval.engine.SE(params, bert2.batcher, bert2.prepare)

results2 = se2.eval(transfer_tasks)
print(results2)

{'Length': {'devacc': 81.7, 'acc': 82.87, 'ndev': 9996, 'ntest': 9996}, 'WordContent': {'devacc': 59.79, 'acc': 59.3, 'ndev': 10000, 'ntest': 10000}, 'Depth': {'devacc': 37.26, 'acc': 37.9, 'ndev': 10000, 'ntest': 10000}, 'TopConstituents': {'devacc': 74.85, 'acc': 74.71, 'ndev': 10000, 'ntest': 10000}, 'BigramShift': {'devacc': 89.0, 'acc': 88.58, 'ndev': 10000, 'ntest': 10000}, 'Tense': {'devacc': 90.42, 'acc': 88.54, 'ndev': 10000, 'ntest': 10000}, 'SubjNumber': {'devacc': 85.61, 'acc': 84.56, 'ndev': 10000, 'ntest': 10000}, 'ObjNumber': {'devacc': 80.98, 'acc': 82.41, 'ndev': 10000, 'ntest': 10000}, 'OddManOut': {'devacc': 66.17, 'acc': 65.69, 'ndev': 10000, 'ntest': 10000}, 'CoordinationInversion': {'devacc': 69.43, 'acc': 68.98, 'ndev': 10002, 'ntest': 10002}}


In [12]:
print(json.dumps({f"{k}_acc": v["acc"] for k,v in results2.items()}, indent=4))

{
    "Length_acc": 82.87,
    "WordContent_acc": 59.3,
    "Depth_acc": 37.9,
    "TopConstituents_acc": 74.71,
    "BigramShift_acc": 88.58,
    "Tense_acc": 88.54,
    "SubjNumber_acc": 84.56,
    "ObjNumber_acc": 82.41,
    "OddManOut_acc": 65.69,
    "CoordinationInversion_acc": 68.98
}
