In [2]:
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer
import format_yelp
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
import evaluate
from torch.optim import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm

In [3]:
random_seed = 42
labelled_size = 10000
valid_size = 1000
batch_size = 8
num_epochs = 20
nstep_eval = 100
neval_early_stop = 100

In [4]:
dataset = load_dataset("yelp_review_full")
N = len(dataset['train']['label'])

In [5]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [6]:
def tokenize_dataset(dataset):
    return tokenizer(format_yelp.format_dataset(dataset["text"]), padding="max_length", truncation=True)

In [7]:
def prepare_dataset(dataset):
    result = dataset.map(tokenize_dataset, batched = True)
    result = result.remove_columns(["text"])
    result = result.rename_column("label", "labels")
    result.set_format("torch")
    return result

In [8]:
tk_test_dataset = prepare_dataset(dataset['test'])
test_dataloader = DataLoader(tk_test_dataset, batch_size=batch_size)

In [9]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.load_state_dict(torch.load('./distill_checkpoint'), strict = True)
model.to(device)
model.eval()
metric = evaluate.load("accuracy")
model_predictions = []
model_logits = []
for batch in tqdm(test_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(input_ids = batch['input_ids'], token_type_ids = batch['token_type_ids'], attention_mask = batch['attention_mask'])
            
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        model_predictions.append(predictions.cpu().numpy())
        model_logits.append(logits.cpu().numpy())
        metric.add_batch(predictions=predictions, references=batch["labels"])
            
acc = metric.compute()['accuracy']

  0%|          | 0/6250 [00:00<?, ?it/s]

In [11]:
acc

0.60012

In [12]:
def binary_accuracy(logits, labels):
    label_array = np.array(labels)
    binary_index = np.where(label_array != 2)
    probs = np.exp(logits[binary_index])
    probs[:,2] = 0
    probs /= probs.sum(axis=1, keepdims=True)
    binary_pred = (probs[:,3] + probs[:,4])>0.5
    binary_labels = label_array[binary_index]>2
    return np.mean(binary_pred==binary_labels)

In [13]:
binary_accuracy(np.vstack(model_logits), dataset['test']['label'])

0.949375

In [14]:
mp = np.hstack(model_predictions)
polar_mp = mp.copy()
polar_mp[mp>2] = 4
polar_mp[mp<2] = 0
polar_label = np.array(dataset['test']['label'])
polar_label[polar_label>2] = 4
polar_label[polar_label<2] = 0

In [15]:
np.mean(polar_mp == polar_label)

0.79732

In [16]:
np.mean(np.abs(mp-dataset['test']['label']))

0.4501