In [16]:
import os
import pdb
print('working dir:', os.getcwd())

from tqdm.notebook import tqdm
import datasets
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer

import numpy as np
import evaluate
from sklearn import metrics

working dir: /work/fairness-privacy/notebooks


In [17]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_path = "../models/roberta-no-priv-epochs_1"
model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)

base_model = "FacebookAI/roberta-base"
tokenizer = AutoTokenizer.from_pretrained(base_model)

In [18]:
MAXLEN = 128

def tokenize(batch, tokenizer, maxlen=MAXLEN):
    tokenized = tokenizer(batch['text'], truncation=True, padding="max_length", max_length=maxlen)    
    return {**tokenized}

In [19]:
BATCH_SIZE = 64
data_path = "../twitteraae-sentiment-data-split/"

dataset = datasets.load_from_disk(data_path)
val_data_all = dataset['validation'].map(tokenize, num_proc=3, batched=True, fn_kwargs={"tokenizer": tokenizer}).with_format("torch")
val_dataloader = DataLoader(val_data_all, batch_size=BATCH_SIZE, shuffle=False)

# separate AAE and SAE
val_data_aae = val_data_all.filter(lambda p: p['dialect'] == 'AAE')
aae_dataloader = DataLoader(val_data_aae, batch_size=BATCH_SIZE, shuffle=False)
print(f"AAE validation points: {len(aae_dataloader.dataset):,}")

val_data_sae = val_data_all.filter(lambda p: p['dialect'] == 'SAE')
sae_dataloader = DataLoader(val_data_sae, batch_size=BATCH_SIZE, shuffle=False)
print(f"SAE validation points: {len(sae_dataloader.dataset):,}")

Loading cached processed dataset at /work/fairness-privacy/twitteraae-sentiment-data-split/validation/cache-5733705bd941f10c_*_of_00003.arrow
  table = cls._concat_blocks(blocks, axis=0)
Loading cached processed dataset at /work/fairness-privacy/twitteraae-sentiment-data-split/validation/cache-b61fb64ecbaa2dbd.arrow
Loading cached processed dataset at /work/fairness-privacy/twitteraae-sentiment-data-split/validation/cache-684c0bb6a74dc08f.arrow


AAE validation points: 10,588
SAE validation points: 192,886


In [32]:
def evaluate_accuracy(model, dataloader, batch_size=BATCH_SIZE):
    metric = evaluate.load('accuracy')
    
    model.eval()  # switch to eval mode
    print('Evaluating...')
    for batch in tqdm(dataloader):
        batch_topass = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }
        with torch.no_grad():
            outputs = model(**batch_topass)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=1)
        metric.add_batch(predictions=predictions, references=batch['label'])
    
    return metric.compute()

In [27]:
aae_accuracy = evaluate_accuracy(model, aae_dataloader)
sae_accuracy = evaluate_accuracy(model, sae_dataloader)

Evaluating...


  0%|          | 0/166 [00:00<?, ?it/s]

Evaluating...


  0%|          | 0/3014 [00:00<?, ?it/s]

In [33]:
print(f"Accuracy---SAE: {sae_accuracy['accuracy']:.3}; AAE: {aae_accuracy['accuracy']:.3}")

Accuracy---SAE: 0.865; AAE: 0.747


In [34]:
overall_accuracy = evaluate_accuracy(model, val_dataloader)
print(f"Overall accuracy: {overall_accuracy['accuracy']:.3}")

Evaluating...


  0%|          | 0/3180 [00:00<?, ?it/s]

Overall accuracy: 0.859
