# SST Transformers #

In [1]:
%matplotlib inline
import os
import pandas as pd
import torch
from transformers import AdamW, BertForSequenceClassification, BertTokenizer, WarmupLinearSchedule

import sst

### SST specific setup

In [2]:
SST_HOME = os.path.join('data', 'trees')

In [3]:
def text_from_tree(tree):
    sentence = ' '.join(tree.leaves())
    replacements = [
        (" 's", "'s"),
        (' .', '.'),
        (' ,', ','),
        ("`` ", "'"),
        (" ''", "'"),
        (" 'm", "'m"),
        (" 've", "'ve"),
        (" 't", "'t"),
        (" 're", "'re")
    ]
    
    for from_, to in replacements:
        sentence = sentence.replace(from_, to)

    return sentence

##### end SST specific setup

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

'Tesla K80'

In [5]:
# Build the train and dev datasets
def get_texts_and_labels(reader):
    """
    Args:
    - reader: sst.dev_reader, sst.train_reader, or sst.test_reader
    Returns a pair:
    * texts: a list of strings 
    * labels: a list of strings
    """
    d = sst.build_dataset(SST_HOME, reader, phi=lambda x: None, class_func=sst.ternary_class_func, vectorize=False)
    texts = [text_from_tree(tree) for tree in d['raw_examples']]
    labels = d['y']
    return texts, labels

def _bert_tokenize(text, max_length=128):
    tokenized = bert_tokenizer.encode_plus(
        text,
        max_length=max_length,
        add_special_tokens=True
    )
    
    token_ids = tokenized['input_ids']
    special_tokens_mask = tokenized['token_type_ids']
    token_type_ids = tokenized['token_type_ids']
    while len(token_ids) < max_length:
        token_ids.append(0)
    return token_ids


_y_mapper = {'negative': 0, 'neutral': 1, 'positive': 2}

def label_stoi(label_string):
    return _y_mapper[label_string]

def label_itos(label_int):
    for label_string, label_int_ in _y_mapper.items():
        if label_int == label_int_:
            return label_string
    assert False


def make_attention_masks(token_ids_list):
    attention_masks = []
    for token_ids in token_ids_list:
        mask = [float(token_id > 0) for token_id in token_ids]
        attention_masks.append(mask)
    
    return torch.tensor(attention_masks)
    

def build_data_loader(reader, batch_size):
    texts, labels = get_texts_and_labels(reader)
    
    labels_vector = torch.tensor([label_stoi(label) for label in labels])
    token_ids = torch.tensor([_bert_tokenize(text) for text in texts])
    attention_masks = make_attention_masks(token_ids)

    dataset = torch.utils.data.TensorDataset(token_ids, attention_masks, labels_vector)
    return torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)


In [35]:
AdamW?

In [36]:

def train_bert(model, train_loader, epochs=1):
    lr = 1e-3
    max_grad_norm = 1.0
    num_training_steps = 1000
    num_warmup_steps = 100
    warmup_proportion = float(num_warmup_steps) / float(num_training_steps)  # 0.1

    # To reproduce BertAdam specific behavior set correct_bias=False
    optimizer = AdamW(bert_classifier.parameters()) 

    # PyTorch scheduler
#     scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps)
    
    model.train()
    for epoch in range(epochs):
        print(f'Epoch {epoch} of {epochs}')
        
        epoch_loss = 0

        num_batches = len(train_loader)
        for batch_number, batch in enumerate(train_loader):
            batch_token_ids, batch_attention_mask, batch_labels = batch
            batch_token_ids = batch_token_ids.to(device)
            batch_attention_mask = batch_attention_mask.to(device)
            batch_labels = batch_labels.to(device)
            model.train()
            loss, logits = model(batch_token_ids, attention_mask=batch_attention_mask, labels=batch_labels)
            epoch_loss += loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
            optimizer.step()
#             scheduler.step()
            optimizer.zero_grad()
    
            if batch_number % 10 == 0:
                print(f'    Batch {batch_number} of {num_batches}. Loss: {loss}')

        print('Epoch loss:', epoch_loss)
        
transformer_model_name = 'bert-base-uncased'
bert_classifier = BertForSequenceClassification.from_pretrained(transformer_model_name, num_labels=3).to(device)
bert_tokenizer = BertTokenizer.from_pretrained(transformer_model_name)

In [7]:
BATCH_SIZE = 32

In [29]:
train_loader = build_data_loader(sst.train_reader, BATCH_SIZE)

In [37]:
%%time

train_bert(bert_classifier, train_loader)

Epoch 0 of 1


RuntimeError: CUDA out of memory. Tried to allocate 48.00 MiB (GPU 0; 11.17 GiB total capacity; 10.71 GiB already allocated; 8.81 MiB free; 139.52 MiB cached)

In [10]:
dev_loader = build_data_loader(sst.dev_reader, BATCH_SIZE)

In [23]:
def eval_bert(bert_classifier, dev_loader):
    bert_classifier.eval()
    total = correct = 0
    for batch in dev_loader:
        batch_token_ids, batch_attention_mask, batch_labels = batch
        batch_token_ids = batch_token_ids.to(device)
        batch_attention_mask = batch_attention_mask.to(device)
        batch_labels = batch_labels.to(device)
        with torch.no_grad():
            logits, = bert_classifier(batch_token_ids, attention_mask=batch_attention_mask)
        predictions = logits.argmax(axis=1)
        is_correct = batch_labels == predictions
        total += len(is_correct)
        correct += sum(is_correct)
        return batch_labels.cpu().numpy(), logits.cpu().numpy()
    
    print('{} of {} correct ({:.2f}%)'.format(correct, total, correct / total * 100))
    return correct, total

In [14]:
eval_bert(bert_classifier, dev_loader)

229 of 1101 correct (0.00%)


(tensor(229, device='cuda:0'), 1101)

In [25]:
labels, logits = eval_bert(bert_classifier, dev_loader)

In [26]:
labels

array([0, 2, 2, 1, 2, 1, 2, 1, 1, 0, 0, 1, 2, 2, 2, 0, 1, 0, 1, 2, 0, 2,
       2, 0, 0, 1, 0, 2, 0, 0, 0, 1])

In [27]:
logits

array([[ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.2562312 ],
       [ 0.30612442,  0.36633143, -0.256

## Problem: the outputs for every example are the same