In [None]:
!pip install pytreebank
!pip install loguru
!pip install transformers



In [None]:
"""This module defines a configurable SSTDataset class."""

import pytreebank
import torch
from loguru import logger
from transformers import BertTokenizer
from torch.utils.data import Dataset
torch.cuda.empty_cache()

logger.info("Loading the tokenizer")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

logger.info("Loading SST")
sst = pytreebank.load_sst()


def rpad(array, n=70):
    """Right padding."""
    current_len = len(array)
    if current_len > n:
        return array[: n - 1]
    extra = n - current_len
    return array + ([0] * extra)


def get_binary_label(label):
    """Convert fine-grained label to binary label."""
    if label < 2:
        return 0
    if label > 2:
        return 1
    raise ValueError("Invalid label")


class SSTDataset(Dataset):
    """Configurable SST Dataset.
    
    Things we can configure:
        - split (train / val / test)
        - root / all nodes
        - binary / fine-grained
    """

    def __init__(self, split="train", root=True, binary=True):
        """Initializes the dataset with given configuration.

        Args:
            split: str
                Dataset split, one of [train, val, test]
            root: bool
                If true, only use root nodes. Else, use all nodes.
            binary: bool
                If true, use binary labels. Else, use fine-grained.
        """
        logger.info(f"Loading SST {split} set")
        self.sst = sst[split]

        logger.info("Tokenizing")
        if root and binary:
            self.data = [
                (
                    rpad(
                        tokenizer.encode("[CLS] " + tree.to_lines()[0] + " [SEP]"), n=66
                    ),
                    get_binary_label(tree.label),
                )
                for tree in self.sst
                if tree.label != 2
            ]
        elif root and not binary:
            self.data = [
                (
                    rpad(
                        tokenizer.encode("[CLS] " + tree.to_lines()[0] + " [SEP]"), n=66
                    ),
                    tree.label,
                )
                for tree in self.sst
            ]
        elif not root and not binary:
            self.data = [
                (rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66), label)
                for tree in self.sst
                for label, line in tree.to_labeled_lines()
            ]
        else:
            self.data = [
                (
                    rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66),
                    get_binary_label(label),
                )
                for tree in self.sst
                for label, line in tree.to_labeled_lines()
                if label != 2
            ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        X, y = self.data[index]
        X = torch.tensor(X)
        return X, y


2020-05-08 04:47:06.386 | INFO     | __main__:<module>:10 - Loading the tokenizer
2020-05-08 04:47:06.614 | INFO     | __main__:<module>:13 - Loading SST


In [None]:
import os

import torch
from loguru import logger
from transformers import BertConfig, BertForSequenceClassification
from tqdm import tqdm

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_one_epoch(model, lossfn, optimizer, dataset, batch_size=8):
    generator = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )
    model.train()
    train_loss, train_acc = 0.0, 0.0
    for batch, labels in tqdm(generator):
        batch, labels = batch.to(device), labels.to(device)
        optimizer.zero_grad()
        loss, logits = model(batch, labels=labels)
        err = lossfn(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pred_labels = torch.argmax(logits, axis=1)
        train_acc += (pred_labels == labels).sum().item()
    train_loss /= len(dataset)
    train_acc /= len(dataset)
    return train_loss, train_acc


def evaluate_one_epoch(model, lossfn, optimizer, dataset, batch_size=8):
    generator = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )
    model.eval()
    loss, acc = 0.0, 0.0
    with torch.no_grad():
        for batch, labels in tqdm(generator):
            batch, labels = batch.to(device), labels.to(device)
            logits = model(batch)[0]
            error = lossfn(logits, labels)
            loss += error.item()
            pred_labels = torch.argmax(logits, axis=1)
            acc += (pred_labels == labels).sum().item()
    loss /= len(dataset)
    acc /= len(dataset)
    return loss, acc


def train(
    root=True,
    binary=False,
    bert="bert-base-uncased",
    epochs=30,
    batch_size=8,
    save=False,
):
    trainset = SSTDataset("train", root=root, binary=binary)
    devset = SSTDataset("dev", root=root, binary=binary)
    testset = SSTDataset("test", root=root, binary=binary)

    # REMOVE BAD TRAINING DATA
    loop_iter = 0
    while loop_iter < 3:
        for x in trainset.data:
            if len(x[0]) != 66:
                trainset.data.remove(x)

        for x in devset.data:
            if len(x[0]) != 66:
                devset.data.remove(x)

        for x in testset.data:
            if len(x[0]) != 66:
                testset.data.remove(x)
                
        loop_iter += 1

    train_losses = []
    val_losses = []
    test_losses = []
    
    train_accuracies = []
    val_accuracies = []
    test_accuracies = []    
        
    config = BertConfig.from_pretrained(bert)
    if not binary:
        config.num_labels = 5
    model = BertForSequenceClassification.from_pretrained(bert, config=config)

    model = model.to(device)
    lossfn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    for epoch in range(1, epochs):
        train_loss, train_acc = train_one_epoch(
            model, lossfn, optimizer, trainset, batch_size=batch_size
        )
        val_loss, val_acc = evaluate_one_epoch(
            model, lossfn, optimizer, devset, batch_size=batch_size
        )
        test_loss, test_acc = evaluate_one_epoch(
            model, lossfn, optimizer, testset, batch_size=batch_size
        )
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        test_losses.append(test_loss)
        
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        test_accuracies.append(test_acc)
        
        logger.info(f"epoch={epoch}")
        logger.info(
            f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, test_loss={test_loss:.4f}"
        )
        logger.info(
            f"train_acc={train_acc:.3f}, val_acc={val_acc:.3f}, test_acc={test_acc:.3f}"
        )
        if save and epoch%10 == 9:
            label = "binary" if binary else "fine"
            nodes = "root" if root else "all"
            torch.save(model, f"{bert}__{nodes}__{label}__e{epoch}.pickle")
        # if save and test_acc > 0.57:
        #     label = "binary" if binary else "fine"
        #     nodes = "root" if root else "all"
        #     torch.save(model, f"{bert}__{nodes}__{label}__e{epoch}.pickle")


    logger.success("Done!")
    return train_losses, val_losses, test_losses, train_accuracies, val_accuracies, test_accuracies


In [None]:
bert_type = "bert-base-uncased"
train_losses, val_losses, test_losses, train_accuracies, val_accuracies, test_accuracies = train(root=True, 
                                                                                                 binary=False, 
                                                                                                 bert=bert_type, 
                                                                                                 save=True)

2020-05-08 04:47:11.674 | INFO     | __main__:__init__:55 - Loading SST train set
2020-05-08 04:47:11.676 | INFO     | __main__:__init__:58 - Tokenizing
2020-05-08 04:47:16.101 | INFO     | __main__:__init__:55 - Loading SST dev set
2020-05-08 04:47:16.101 | INFO     | __main__:__init__:58 - Tokenizing
2020-05-08 04:47:16.681 | INFO     | __main__:__init__:55 - Loading SST test set
2020-05-08 04:47:16.681 | INFO     | __main__:__init__:58 - Tokenizing


HBox(children=(IntProgress(value=0, description='Downloading', max=433, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=440473133, style=ProgressStyle(description_…




100%|██████████| 1068/1068 [03:21<00:00,  5.30it/s]
100%|██████████| 138/138 [00:05<00:00, 24.44it/s]
100%|██████████| 277/277 [00:11<00:00, 24.62it/s]
2020-05-08 04:51:38.650 | INFO     | __main__:train:116 - epoch=1
2020-05-08 04:51:38.651 | INFO     | __main__:train:118 - train_loss=0.1712, val_loss=0.1438, test_loss=0.1375
2020-05-08 04:51:38.652 | INFO     | __main__:train:121 - train_acc=0.382, val_acc=0.486, test_acc=0.508
100%|██████████| 1068/1068 [03:20<00:00,  5.32it/s]
100%|██████████| 138/138 [00:05<00:00, 24.65it/s]
100%|██████████| 277/277 [00:11<00:00, 24.54it/s]
2020-05-08 04:55:16.197 | INFO     | __main__:train:116 - epoch=2
2020-05-08 04:55:16.198 | INFO     | __main__:train:118 - train_loss=0.1289, val_loss=0.1370, test_loss=0.1332
2020-05-08 04:55:16.199 | INFO     | __main__:train:121 - train_acc=0.546, val_acc=0.510, test_acc=0.528
100%|██████████| 1068/1068 [03:22<00:00,  5.26it/s]
100%|██████████| 138/138 [00:05<00:00, 24.98it/s]
100%|██████████| 277/277 [00:1

In [None]:
import pandas as pd

def save_model_results(train_losses, val_losses, test_losses, train_acc, val_acc, test_acc, bert_type):
    
    # create losses df
    train_losses_df = pd.DataFrame(train_losses)
    val_losses_df = pd.DataFrame(val_losses)
    test_losses_df = pd.DataFrame(test_losses)
    
    # create test df
    train_acc_df = pd.DataFrame(train_acc)
    val_acc_df = pd.DataFrame(val_acc)
    test_acc_df = pd.DataFrame(test_acc)
    
    # save losses to csv 
    train_losses_df.to_csv(bert_type + '_train_loss.csv')
    val_losses_df.to_csv(bert_type + '_val_loss.csv')
    test_losses_df.to_csv(bert_type + '_test_loss.csv')
    
    # save acc to csv
    train_acc_df.to_csv(bert_type + 'train_acc.csv')
    val_acc_df.to_csv(bert_type + '_val_acc.csv')
    test_acc_df.to_csv(bert_type + '_test_acc.csv')
    
    print('finished')

save_model_results(train_losses, val_losses, test_losses, train_accuracies, val_accuracies, test_accuracies, bert_type)

finished
