In [1]:
import torch
from torch import nn, optim
from transformers import *
from torch.utils.data import Dataset, DataLoader
from time import time
import numpy as np
import torch
from torch import nn
import math
from pprint import pprint
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from IPython.display import clear_output

%matplotlib inline

Set up the position embedding matrix such that `pos_embed[i]` is the embedding for position `i`.

In [2]:
position_encoding_size = 256

def position_embeddings(max_pos, size):
    embeddings = np.zeros((max_pos, size))
    w = 1 / (10000 ** (2*np.arange(size // 2 )/size))
    for pos in range(max_pos):
        embeddings[pos,0::2] = np.sin(w*pos)
        embeddings[pos,1::2] = np.cos(w*pos)
    return torch.Tensor(embeddings)
    
pos_embed = position_embeddings(10000, position_encoding_size)
pos_embed.shape

torch.Size([10000, 256])

The model used on top of BERT

In [3]:
class Classifier(nn.Module):
    def __init__(self, bert_size, position_size):
        super().__init__()

        # Calculates the attention value
        self.attention = nn.Linear(bert_size + position_size, 1)
        self.softmax = nn.Softmax(1)

        # Makes the prediction
        self.prediction = nn.Sequential(
            nn.Linear(bert_size, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 6),
            nn.Sigmoid()
        )

    '''
    embeddings: shape (segment count, 512, bert_hidden_size)
        The output of BERT
    position_encodings:  shape (segment count, 512, position_encoding_size)
        The position encodings of the tokens
    comment_bounds: Array of tuples of the form [(start, end)]. 
        comment_bounds[i] = (a, b) indicates that comment i's embeddings can be extracted as embeddings[a:b]
        if not provided, all embeddings and positions are assumed to be a single comment
    '''
    def forward(self, embeddings, position_encodings, comment_bounds = None):
        # Concatenate each BERT output with its position encoding
        attention_input = torch.cat([embeddings, position_encodings], dim=2) # (batch, 512, position_size + bert_hidden_size)

        # Get the attention weights for each concatenated vector
        attentions = self.attention(attention_input) #  (batch size, 512, 1)
        
        # If no bounds are probided, assume the input is all one comment
        if comment_bounds is None:
            # Softmax over attentions
            attentions = self.softmax(attentions) # (batch, 512, 1)
            
            # Calculate the total embedding as a weighted sum of embeddings without the positional encodings
            vecs = torch.sum(attentions * embeddings, dim=1) # (batch, bert_hidden_size)
            return self.prediction(vecs) # (batch, 1)

        # Otherwise, get the outputs for each comment
        vecs = []
        for (a,b) in comment_bounds:
            # Get the embeddings and attentions for the current comment
            comment_embeddings = embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            comment_attentions = attentions[a:b] # (segment_count, 512, 1)
            
            # softmax over the attentions for the comment
            attention_weights = self.softmax(comment_attentions) # (segment_count, 512, 1)
            
            # Calculate the total embedding as a weighted sum over the embeddings of the comment
            weighted_embeddings = attention_weights * embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            vec = torch.sum(weighted_embeddings.view(-1, weighted_embeddings.shape[-1]), dim=0, keepdim=True) # (segment_count, bert_hidden_size)
            
            vecs.append(vec)
            
        # Stack the total embedding vectors and give them to the prediction network to calculate the output
        return self.prediction(torch.cat(vecs))

Dataset that does the data normalization

In [4]:
class MyDataset(Dataset):
    # file_format is a Python format string with a single variable to be inserted. It's used to get the paths of the files
    # normalize indicates whether to perform data normalization procedure
    def __init__(self, file_format, normalize):
        super().__init__()

        # Load the data from files
        input_ids = torch.load(file_format.format("input_ids"))
        positions = torch.load(file_format.format("positions"))
        comment_ids = torch.load(file_format.format("ids"))
        targets = torch.load(file_format.format("targets"))

        # Treat the targets as binary to separate the possible outputs
        target_ids = torch.sum(torch.Tensor([32, 16, 8, 4, 2, 1]) * targets, axis=1)

        # Store the data according to the target. Useful for normalization
        self.data = [[] for i in range(64)]

        # Load the data into the array
        start_index = 0
        end_index = 0
        n_comments = comment_ids.shape[0]
        # Group the items by which comment they're part of
        while start_index < n_comments:
            # Get the current comment id
            curr_id = comment_ids[start_index]
            
            # Find end_index such that input_ids[end_index-1] is the last segment of the comment
            while end_index < n_comments and comment_ids[end_index] == comment_ids[start_index]:
                end_index += 1
            
            # Get the number with a binary representation that's the same as the comment's true labels
            target_id = int(target_ids[curr_id].item())
            
            # Get the comment as a tuple containing (token ids, positions, true labels)
            data = (input_ids[start_index:end_index], positions[start_index:end_index], targets[curr_id])
            self.data[target_id].append(data)

            start_index = end_index

        # Remove the empty arrays from the data
        self.data = [data for data in self.data if data]
        
        # Calculate how many comments are nontoxic
        n_nontoxic = len(self.data[0])
        
        # Calculate how many comments there need to be with each combination of labels
        # The goal is for there to be as many toxic comments as nontoxic comments
        # Also, each combination of labels should be present an equal number of times
        n_of_each = n_nontoxic // (len(self.data)-1)
        
        # Calculate how many copies need to be made from each combination of labels
        n_copies = np.array([1]+[n_of_each // len(self.data[i]) for i in range(1,len(self.data))])
        
        # If normalization shouldn't be done, just replace n_copies with a bunch of ones so there is one copy of each comment
        if not normalize:
            n_copies = np.ones_like(n_copies)
        
        # The number of comments with each combination of labels
        self.data_length = np.array([len(data) for data in self.data])

        # The data is organized into segments each of which contains some number of copies of the comments with a specific combination of labels
        segment_lengths = n_copies*self.data_length

        # The total length of the normalized dataset
        self.length = int(np.sum(segment_lengths))

        # The indices bounding the segments
        self.boundaries = np.zeros(segment_lengths.shape[0]+1)
        self.boundaries[1:] = np.cumsum(segment_lengths)


    def __len__(self):
        return self.length

    def __getitem__(self, index):
        # Find the segment that the index is in
        for i in range(self.boundaries.shape[0] - 1):
            if index >= self.boundaries[i] and index < self.boundaries[i+1]:
                # index - self.boundaries[i] calculates the index into the segment
                # The segment is a bunch of copies of the same data, but it's inefficient to actually copy the data
                # Therefore, "% self.data.length[i]" is used to convert the index into the segment into the index into the data
                inner_index = int((index - self.boundaries[i]) % self.data_length[i])
                return self.data[i][inner_index]
        print("Index: ", index)

Load the data. 

The test data isn't normalized to give a realistic view of the model's performance

In [5]:
train_dataset = MyDataset("train_{}.pt", True)
test_dataset = MyDataset("test_{}.pt", False)
len(train_dataset), len(test_dataset)

(275845, 153164)

Put the data into DataLoaders

In [6]:
# Needed to format the data in a way that the model can use
# batch is a list of tuples (tokenized_comment, positions, true_labels)
def collate_samples(batch):
    # Split the tuples into three variables
    split_comments, positions, targets = zip(*batch)
    
    # An array of tuples (a,b) so that input_ids[a:b] are all part of a single comment
    comment_bounds = np.zeros(len(split_comments)+1, dtype=np.int32)
    comment_bounds[1:] = np.cumsum(list(map(len, split_comments)))
    comment_bounds = np.array([comment_bounds[:-1], comment_bounds[1:]], dtype=np.int32).transpose()
    
    # For parallelism, stack the inputs into single tensors
    input_ids = torch.cat(split_comments, dim=0)
    
    # Stack the position embeddings as well so they can be easily concatenated with the BERT output
    encoded_positions = torch.cat([
                          # Use the position array as indices into the position embedding
                          pos_embed[position_arr]
                          # For each comment in the batch
                          for position_arr in positions                     
                      ])
  
    # Stack the true labels
    targets = torch.stack(targets)
    return input_ids, encoded_positions, comment_bounds, targets

batch_size = 4

train = DataLoader(train_dataset, 
                   batch_size = batch_size,
                   shuffle=True,
                   collate_fn = collate_samples)

test = DataLoader(test_dataset, 
                   batch_size = batch_size,
                   shuffle=True,
                   collate_fn = collate_samples)

In [7]:
for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(train):
    print(input_ids.shape, encoded_position.shape, comment_bounds.shape, target.shape)
    break

torch.Size([4, 512]) torch.Size([4, 512, 256]) (4, 2) torch.Size([4, 6])


Use Kaiming normal weight initialization when possibe and when not, just use normal initialization

In [8]:
def init_weights(model, stdv):
    for param in model.parameters():
        if len(param.shape) >= 2:
            nn.init.kaiming_normal_(param.data)
        else:
            param.data.normal_(0.0, stdv)

Function to calculate AUC given predicted and true labels

In [9]:
def calc_auc(pred, target):
    result = []
    for i in range(pred.shape[1]):
        # If the true labels only has one value in a column, add a fake item to make AUC a valid operation
        if len(np.unique(target[:,i])) == 2:
            result.append(roc_auc_score(target[:,i], pred[:,i], labels=[0,1]))
        else:
            extra = np.array([1-target[0,i]])
            target_i = np.concatenate((target[:,i], extra))
            pred_i = np.concatenate((pred[:,i], extra))
            result.append(roc_auc_score(target_i, pred_i, labels=[0,1]))
            
    
    return result

Functions to save and load the model. Because ADAM has its own internal state, it's saved as well

In [10]:
def load_model(tag):
    model.load_state_dict(torch.load(model_path.format(tag)))
    optimizer.load_state_dict(torch.load(adam_path.format(tag)))
    
def save_model(tag):
    torch.save(model.state_dict(), model_path.format(tag))
    torch.save(optimizer.state_dict(), adam_path.format(tag))
    
def load_bert(tag):
    load_model(tag)
    bert.load_state_dict(torch.load(bert_path.format(tag)))
    
def save_bert(tag):
    save_model(tag)
    torch.save(bert.state_dict(), bert_path.format(tag))

The format strings used to generate the model paths

In [11]:
model_path = "./model/model_{}.pt"
bert_path = "./bert/bert_{}.pt"
adam_path = "./adam/adam_{}.pt"
train_loss_path = "./train_loss_unfrozen.pt"

Get the pretrained BERT model

In [12]:
device = torch.device("cuda:0")

bert = BertModel.from_pretrained('bert-base-cased')

bert_learning_rate = 1e-5

bert_hidden_size = 768

# Put BERT on the GPU
bert = bert.to(device)

bert_optimizer = optim.Adam(bert.parameters(), lr=bert_learning_rate)

Set up the model

In [13]:
# The epoch to start at
first_epoch = 30

# The last epoch number
epochs = 60
learning_rate = 1e-4

model = Classifier(bert_hidden_size, position_encoding_size)
init_weights(model, 0.2)
model = model.to(device)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Used to load saved models

In [14]:
load = "epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(29, 100, 0.0629)

if load is not None:
    load_model(load)

The training prints too many lines, so this class is used to write them to a file after each epoch

In [15]:
class EpochLogger:
    def __init__(self, path):
        self.file = open(path, "w")
    
    def log(self, string):
        print(string)
        self.file.write(string)
        self.file.write("\n")
    
    def close(self):
        self.file.close()

The training loop

In [16]:
# How often to print
log_frequency = 1

# How many batches to run between saving the model
save_frequency = 10
bert_save_frequency = 500

# How many epochs to run before clearing the output
clear_frequency_batches = 300
clear_frequency_epochs = 1

# Maximum batches per epoch
batches_per_train_epoch = 1000
batches_per_test_epoch = 10

batches_per_train_epoch = min(batches_per_train_epoch, len(train))
batches_per_test_epoch = min(batches_per_test_epoch, len(test))

# The format string used to generate paths for the log files
log_path = "./logs/epoch_{}.txt"

# Store the times to help forecast how long training will take
epoch_times = []
batch_times = []

# Store the training loss after each batch
train_losses = []
for epoch in range(first_epoch, epochs):
    # initialize the logger object
    logger = EpochLogger(log_path.format(epoch))
    
    # Record the time at the start of the epoch
    epoch_start = time()
    
    # Predictions and true labels used to calculate the AUC
    predicted = None
    true = None
    
    # Recording how many were correct and the total number of predictions to calculate the accuracy
    n_correct = 0
    n_processed = 0
    print("Training")
    for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(train):
        batch_start = time()
        # Get the BERT output
        encoded_comments = bert(input_ids.to(device))[0]
    
        # Get the outputs from the network
        output = model(encoded_comments, encoded_position.to(device), comment_bounds)

        # Gradient descent
        optimizer.zero_grad()
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()
        bert_optimizer.step()
        
        # Store how long this batch took to run
        batch_times.append(time() - batch_start)
        
        # Only use the last 100 batches to estimate the time remaining
        batch_times = batch_times[-100:]
        
        # Calculate the predicted labels by rounding to zero or 1
        pred = torch.round(output.cpu().detach())
        
        # Add the predicted and true labels to the arrays
        if predicted is None:
            predicted = pred.clone()
            true = target.clone()
        else:
            predicted = torch.cat((predicted, pred), dim=0)
            true = torch.cat((true, target), dim=0)

        # Calculate the number that were correct
        n_correct += torch.sum(pred == target, axis=0).numpy()
        n_processed += pred.shape[0]

        train_losses.append(loss.item())
    
        # Print debugging information
        if i % log_frequency == 0:
            logger.log("Epoch {}: {:05d}/{} ({:7.01f}s remaining)\t BCE Loss: {:.04f}".format(epoch, i, batches_per_train_epoch, np.mean(batch_times)*(batches_per_train_epoch - i), loss.item()))
            auc = calc_auc(predicted.numpy(), true.numpy())
            logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
            acc = n_correct / n_processed
            logger.log("\tAccuracy: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
        # Save the model
        if i % save_frequency == 0:
            save_model("epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(epoch, i, loss.item()))
            torch.save(torch.Tensor(train_losses), train_loss_path)
        if i % bert_save_frequency == 0:
            save_bert("epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(epoch, i, loss.item()))
            torch.save(torch.Tensor(train_losses), train_loss_path)
        if i % clear_frequency_batches == 0 and i != 0:
            clear_output()
        # Break early 
        if i % batches_per_train_epoch == 0 and i != 0:
            break
    
    # Make sure that the model is saved at the end of the epoch
    save_model("epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(epoch, i, loss.item()))
    
    # Save the training losses
    torch.save(torch.Tensor(train_losses), train_loss_path)
    
    # Test the model
    with torch.no_grad():
        print("Testing")
        predicted = None
        for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(test):
            batch_start = time()
            # Get the BERT output
            encoded_comments = bert(input_ids.to(device))[0]

            # Get the outputs from the network
            output = model(encoded_comments, encoded_position.to(device), comment_bounds)

            #print(output[0], target[0])
            # Gradient descent
            pred = torch.round(output.cpu().detach())
            if predicted is None:
                predicted = pred.clone()
                true = target.clone()
            else:
                predicted = torch.cat((predicted, pred), dim=0)
                true = torch.cat((true, target), dim=0)

            n_correct += torch.sum(pred == target, axis=0).numpy()
            n_processed += pred.shape[0]

            predicted = predicted[-1000:]
            true = true[-1000:]

            if i % log_frequency == 0:
                logger.log("Epoch {}: {:05d}/{} ({:7.01f}s remaining)\t BCE Loss: {:.04f}".format(epoch, i, batches_per_test_epoch, np.mean(batch_times)*(batches_per_test_epoch - i), loss.item()))
                auc = calc_auc(predicted.numpy(), true.numpy())
                logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
                acc = n_correct / n_processed
                logger.log("\tAccuracy: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
            if i % batches_per_test_epoch == 0 and i != 0:
                break
                
        auc = calc_auc(predicted.numpy(), true.numpy())
        acc = n_correct / n_processed
        epoch_time = time() - epoch_start
        logger.log("Epoch {} took {:7.01f}s. Test Values:".format(epoch, epoch_time))
        
        logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
        
        logger.log("\tAccuracy:  Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
    logger.close()
    if epoch % clear_frequency_epochs == 0 and epoch != 0:
        clear_output()
    epoch_times.append(epoch_time)

Training
Epoch 30: 00000/1000 (  965.1s remaining)	 BCE Loss: 0.1778
	AUC: Toxic: 0.5000 Severe Toxic: 1.0000 Obscene: 1.0000 Threat: 1.0000 Insult: 1.0000 Identity-based Hate: 1.0000
	Accuracy: Toxic: 0.7500 Severe Toxic: 1.0000 Obscene: 1.0000 Threat: 1.0000 Insult: 1.0000 Identity-based Hate: 1.0000
Epoch 30: 00001/1000 (  846.0s remaining)	 BCE Loss: 0.0001
	AUC: Toxic: 0.5000 Severe Toxic: 1.0000 Obscene: 1.0000 Threat: 1.0000 Insult: 1.0000 Identity-based Hate: 1.0000
	Accuracy: Toxic: 0.8750 Severe Toxic: 1.0000 Obscene: 1.0000 Threat: 1.0000 Insult: 1.0000 Identity-based Hate: 1.0000
Epoch 30: 00002/1000 (  805.7s remaining)	 BCE Loss: 0.8617
	AUC: Toxic: 0.5000 Severe Toxic: 0.5000 Obscene: 0.8333 Threat: 1.0000 Insult: 0.7500 Identity-based Hate: 0.5000
	Accuracy: Toxic: 0.8333 Severe Toxic: 0.9167 Obscene: 0.9167 Threat: 1.0000 Insult: 0.9167 Identity-based Hate: 0.9167
Epoch 30: 00003/1000 (  783.5s remaining)	 BCE Loss: 0.1222
	AUC: Toxic: 0.6282 Severe Toxic: 0.7500 Obsce

RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 11.17 GiB total capacity; 10.60 GiB already allocated; 27.44 MiB free; 10.86 GiB reserved in total by PyTorch)