In [None]:
!pip install transformers==4.6.1

In [None]:
!curl https://s3.amazonaws.com/realworldnlpbook/data/stanfordSentimentTreebank/trees/dev.txt --output dev.txt

In [None]:
!curl https://s3.amazonaws.com/realworldnlpbook/data/stanfordSentimentTreebank/trees/train.txt --output train.txt

In [None]:
import re

import torch
from torch import nn, optim
from transformers import AutoTokenizer, AutoModel, AdamW, get_cosine_schedule_with_warmup

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
BERT_MODEL = 'bert-base-cased'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)

In [None]:
class BertClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super(BertClassifier, self).__init__()
        self.bert_model = AutoModel.from_pretrained(model_name)

        self.linear = nn.Linear(self.bert_model.config.hidden_size, num_labels)

        self.loss_function = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, token_type_ids, label=None):
        bert_out = self.bert_model(
          input_ids=input_ids,
          attention_mask=attention_mask,
          token_type_ids=token_type_ids)
        
        logits = self.linear(bert_out.pooler_output)

        loss = None
        if label is not None:
            loss = self.loss_function(logits, label)

        return loss, logits

In [None]:
token_ids = tokenizer.encode('The best movie ever!')

In [None]:
token_ids

In [None]:
tokenizer.decode(token_ids)

In [None]:
result = tokenizer(
    ['The best movie ever!', 'Aweful movie'],
    max_length=10,
    pad_to_max_length=True,
    truncation=True,
    return_tensors='pt')

In [None]:
result

In [None]:
result['input_ids']

In [None]:
result['token_type_ids']

In [None]:
result['attention_mask']

In [None]:
def read_dataset(file_path, batch_size, tokenizer, max_length):
    batches = []
    with open(file_path) as f:
        texts = []
        labels = []
        for line in f:
            text = line.strip()
            label = int(text[1])
            text = re.sub('\)+', '', re.sub('\(\d ', '', text))
            text = text.replace('-LRB-', '(').replace('-RRB-', ')')
            
            texts.append(text)
            labels.append(label)

            if len(texts) == batch_size:
                batch = tokenizer(
                    texts,
                    max_length=max_length,
                    pad_to_max_length=True,
                    truncation=True,
                    return_tensors='pt')
                batch['label'] = torch.tensor(labels)
                batches.append(batch)
                
                texts = []
                labels = []
        
        if texts:
            batch = tokenizer(
                texts,
                max_length=max_length,
                pad_to_max_length=True,
                truncation=True,
                return_tensors='pt')
            batch['label'] = torch.tensor(labels)
            batches.append(batch)

        return batches

In [None]:
train_data = read_dataset('train.txt', batch_size=32, tokenizer=tokenizer, max_length=128)
dev_data = read_dataset('dev.txt', batch_size=32, tokenizer=tokenizer, max_length=128)

In [None]:
len(train_data), len(dev_data)

In [None]:
def move_to(batch, device):
    for key in batch.keys():
        batch[key] = batch[key].to(device)

In [None]:
model = BertClassifier(model_name=BERT_MODEL, num_labels=5).to(device)

In [None]:
move_to(dev_data[0], device)
model(**dev_data[0])

In [None]:
epochs = 30
optimizer = AdamW(model.parameters(), lr=1e-5)
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=1000,
    num_training_steps=len(train_data) * epochs)

In [None]:
for epoch in range(epochs):
    print(f'epoch = {epoch}')
    
    model.train()

    losses = []
    total_instances = 0
    correct_instances = 0
    for batch in train_data:
        batch_size = batch['input_ids'].size(0)
        move_to(batch, device)

        optimizer.zero_grad()
        
        loss, logits = model(**batch)
        loss.backward()
        optimizer.step()
        scheduler.step()
    
        losses.append(loss)
        
        total_instances += batch_size
        correct_instances += torch.sum(torch.argmax(logits, dim=-1) == batch['label']).item()
    
    avr_loss = sum(losses) / len(losses)
    accuracy = correct_instances / total_instances
    print(f'train loss = {avr_loss}, accuracy = {accuracy}')
    
    losses = []
    total_instances = 0
    correct_instances = 0
    
    model.eval()
    for batch in dev_data:
        batch_size = batch['input_ids'].size(0)
        move_to(batch, device)

        with torch.no_grad():
            loss, logits = model(**batch)
        
        losses.append(loss)
        
        total_instances += batch_size
        correct_instances += torch.sum(torch.argmax(logits, dim=-1) == batch['label']).item()

    avr_loss = sum(losses) / len(losses)
    accuracy = correct_instances / total_instances
    
    print(f'dev loss = {avr_loss}, accuracy = {accuracy}')