In [1]:
from transformers import AutoTokenizer, AutoModel
import senteval
import torch
import torch.nn as nn
import json
import os

fs = os.listdir("..")
if not "SentEval" in fs:
    %cd ./..
    !git clone git@github.com:facebookresearch/SentEval.git
    %cd notebooks
    %mkdir data

PATH_TO_DATA = "../SentEval/data/probing/"
DEVICE = "cuda:0"

Cloning into 'SentEval'...
remote: Enumerating objects: 691, done.[K
remote: Counting objects: 100% (2/2), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 691 (delta 0), reused 2 (delta 0), pack-reused 689[K
Receiving objects: 100% (691/691), 33.25 MiB | 11.32 MiB/s, done.
Resolving deltas: 100% (434/434), done.


# ELECTRA

In [2]:
def batch_to_device(d, device):
    return {k: v.to(device) for k, v in d.items()}
    
class MeanPooling(nn.Module):
    def __init__(self, starting_state):
        super().__init__()
        self.starting_state = starting_state

    def forward(self, x, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.size()).float()
        emb_sum = torch.sum(x * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) # denominator
        emb_mean = emb_sum / sum_mask
        return emb_mean

class Electra:

    def __init__(self, starting_state=12, path=None):
        self.tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
        if path is None:
            self.model = AutoModel.from_pretrained("google/electra-base-discriminator").to(DEVICE)
        else:
            self.model = torch.load(path).to(DEVICE)
        self.pooling = MeanPooling(starting_state)
    
    def prepare(self, params, samples):
        pass 
    
    @torch.no_grad()
    def batcher(self, params, batch):
        tokenized_batch = self.tokenizer(
            batch, truncation=True, padding=True, return_tensors="pt", is_split_into_words=True
        )
        batch_device = batch_to_device(tokenized_batch, DEVICE)
        out = self.model(
            **batch_device, output_hidden_states=True
        ).hidden_states[self.pooling.starting_state]
        out_mean = self.pooling(out, batch_device["attention_mask"])
        return out_mean.cpu()

params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, 'epoch_size': 4}

electra = Electra(starting_state=12, path="../output/google-electra-base-discriminator/mean/12_to_13/model_2024_01_01_03_16.pkl")
se = senteval.engine.SE(params, electra.batcher, electra.prepare)

transfer_tasks = [
    'Length', 
    'WordContent', 
    'Depth', 
    'TopConstituents', 
    'BigramShift', 
    'Tense', 
    'SubjNumber', 
    'ObjNumber', 
    'OddManOut', 
    'CoordinationInversion'
]

results = se.eval(transfer_tasks)
print(results)


2024-01-08 23:59:35.905490: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-08 23:59:35.905525: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-08 23:59:35.905553: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-08 23:59:35.911319: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


{'Length': {'devacc': 84.24, 'acc': 82.69, 'ndev': 9996, 'ntest': 9996}, 'WordContent': {'devacc': 29.11, 'acc': 28.59, 'ndev': 10000, 'ntest': 10000}, 'Depth': {'devacc': 36.23, 'acc': 36.04, 'ndev': 10000, 'ntest': 10000}, 'TopConstituents': {'devacc': 70.32, 'acc': 69.59, 'ndev': 10000, 'ntest': 10000}, 'BigramShift': {'devacc': 92.9, 'acc': 92.41, 'ndev': 10000, 'ntest': 10000}, 'Tense': {'devacc': 86.67, 'acc': 85.26, 'ndev': 10000, 'ntest': 10000}, 'SubjNumber': {'devacc': 81.76, 'acc': 81.41, 'ndev': 10000, 'ntest': 10000}, 'ObjNumber': {'devacc': 79.34, 'acc': 80.09, 'ndev': 10000, 'ntest': 10000}, 'OddManOut': {'devacc': 72.79, 'acc': 72.28, 'ndev': 10000, 'ntest': 10000}, 'CoordinationInversion': {'devacc': 74.98, 'acc': 74.19, 'ndev': 10002, 'ntest': 10002}}


# Finetuned state 12

In [16]:
print(json.dumps({f"{k}_acc": v["acc"] for k,v in results.items()}, indent=4))

{
    "Length_acc": 82.69,
    "WordContent_acc": 28.59,
    "Depth_acc": 36.04,
    "TopConstituents_acc": 69.59,
    "BigramShift_acc": 92.41,
    "Tense_acc": 85.26,
    "SubjNumber_acc": 81.41,
    "ObjNumber_acc": 80.09,
    "OddManOut_acc": 72.28,
    "CoordinationInversion_acc": 74.19
}


# Pretrained (no finetuning)

In [20]:
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, 'epoch_size': 4}

electra2 = Electra(starting_state=12, path=None)
se2 = senteval.engine.SE(params, electra2.batcher, electra2.prepare)

results2 = se2.eval(transfer_tasks)
print(results2)

{'Length': {'devacc': 88.48, 'acc': 88.16, 'ndev': 9996, 'ntest': 9996}, 'WordContent': {'devacc': 29.67, 'acc': 30.49, 'ndev': 10000, 'ntest': 10000}, 'Depth': {'devacc': 41.52, 'acc': 41.3, 'ndev': 10000, 'ntest': 10000}, 'TopConstituents': {'devacc': 78.03, 'acc': 77.56, 'ndev': 10000, 'ntest': 10000}, 'BigramShift': {'devacc': 95.91, 'acc': 95.65, 'ndev': 10000, 'ntest': 10000}, 'Tense': {'devacc': 89.69, 'acc': 88.04, 'ndev': 10000, 'ntest': 10000}, 'SubjNumber': {'devacc': 83.41, 'acc': 82.18, 'ndev': 10000, 'ntest': 10000}, 'ObjNumber': {'devacc': 81.42, 'acc': 81.43, 'ndev': 10000, 'ntest': 10000}, 'OddManOut': {'devacc': 76.23, 'acc': 75.37, 'ndev': 10000, 'ntest': 10000}, 'CoordinationInversion': {'devacc': 78.81, 'acc': 78.27, 'ndev': 10002, 'ntest': 10002}}


In [21]:
print(json.dumps({f"{k}_acc": v["acc"] for k,v in results2.items()}, indent=4))

{
    "Length_acc": 88.16,
    "WordContent_acc": 30.49,
    "Depth_acc": 41.3,
    "TopConstituents_acc": 77.56,
    "BigramShift_acc": 95.65,
    "Tense_acc": 88.04,
    "SubjNumber_acc": 82.18,
    "ObjNumber_acc": 81.43,
    "OddManOut_acc": 75.37,
    "CoordinationInversion_acc": 78.27
}
