In [1]:
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer
import format_yelp
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
import evaluate
from torch.optim import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm

In [2]:
random_seed = 42
labelled_size = 10000
valid_size = 1000
batch_size = 16
num_epochs = 20
nstep_eval = 100

In [3]:
dataset = load_dataset("yelp_review_full")

In [4]:
N = len(dataset['train']['label'])

In [5]:
shuffled_dataset = dataset['train'].shuffle(random_seed)
labelled_dataset = shuffled_dataset.select(range(labelled_size))
unlabelled_dataset = shuffled_dataset.select(range(labelled_size,N-valid_size))
valid_dataset = shuffled_dataset.select(range(N-valid_size,N))

In [6]:
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-small", num_labels=5)
max_len = model.config.max_position_embeddings

def tokenize_dataset(dataset, max_len):
    return tokenizer(format_yelp.format_dataset(dataset["text"]), padding="max_length", truncation=True, max_length = max_len)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def prepare_dataset(dataset, max_len):
    result = dataset.map(tokenize_dataset, batched = True, fn_kwargs = {'max_len':max_len})
    result = result.remove_columns(["text"])
    result = result.rename_column("label", "labels")
    result.set_format("torch")
    return result

In [8]:
tk_labelled_dataset = prepare_dataset(labelled_dataset, max_len)
tk_valid_dataset = prepare_dataset(valid_dataset, max_len)

In [9]:
train_dataloader = DataLoader(tk_labelled_dataset, shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(tk_valid_dataset, batch_size=batch_size)

In [10]:
optimizer = AdamW(model.parameters(), lr=2e-5)


num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=num_training_steps//10, num_training_steps=num_training_steps
)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, e

In [11]:
progress_bar = tqdm(range(num_training_steps))

#model.train()
best_acc = 0
accs = []
step = 0
metric = evaluate.load("accuracy")
model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        #model.train()
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
            
        progress_bar.update(1)
        step += 1
        if step == 100:
            model.eval()
            
            for batch in eval_dataloader:
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = model(**batch)
            
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)
                metric.add_batch(predictions=predictions, references=batch["labels"])
            
            acc = metric.compute()['accuracy']
            accs.append(acc)
            if acc > best_acc:
                best_acc = acc
                torch.save(model.state_dict(), './distill_checkpoint_small_from_scratch',_use_new_zipfile_serialization=False)
                print(best_acc)
            model.train()
            step = 0

  0%|          | 0/12500 [00:00<?, ?it/s]

0.22
0.335
0.382
0.417
0.465
0.513
0.536
0.543
0.553
0.556
0.581
0.588


In [12]:
tk_test_dataset = prepare_dataset(dataset['test'], max_len = max_len)
test_dataloader = DataLoader(tk_test_dataset, batch_size=16*batch_size)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [13]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.load_state_dict(torch.load('./distill_checkpoint_small_from_scratch'), strict = True)
model.to(device)
model.eval()
metric = evaluate.load("accuracy")
model_predictions = []
model_logits = []
for batch in tqdm(test_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(input_ids = batch['input_ids'], token_type_ids = batch['token_type_ids'], attention_mask = batch['attention_mask'])
            
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        model_predictions.append(predictions.cpu().numpy())
        model_logits.append(logits.cpu().numpy())
        metric.add_batch(predictions=predictions, references=batch["labels"])
            
acc = metric.compute()['accuracy']

  0%|          | 0/196 [00:00<?, ?it/s]

In [14]:
acc

0.59

In [15]:
def binary_accuracy(logits, labels):
    label_array = np.array(labels)
    binary_index = np.where(label_array != 2)
    probs = np.exp(logits[binary_index])
    probs[:,2] = 0
    probs /= probs.sum(axis=1, keepdims=True)
    binary_pred = (probs[:,3] + probs[:,4])>0.5
    binary_labels = label_array[binary_index]>2
    return np.mean(binary_pred==binary_labels)

In [16]:
binary_accuracy(np.vstack(model_logits), dataset['test']['label'])

0.934825

In [17]:
mp = np.hstack(model_predictions)
polar_mp = mp.copy()
polar_mp[mp>2] = 4
polar_mp[mp<2] = 0
polar_label = np.array(dataset['test']['label'])
polar_label[polar_label>2] = 4
polar_label[polar_label<2] = 0

In [18]:
np.mean(polar_mp == polar_label)

0.77964

In [19]:
np.mean(np.abs(mp-dataset['test']['label']))

0.48936

In [20]:
import pickle
with open('bert_small_from_scratch_accs.pickle', 'wb') as f:
    pickle.dump(accs, f)