In [1]:
import torch
from torch import nn, optim
from transformers import *
from torch.utils.data import Dataset, DataLoader
from time import time
import numpy as np
import torch
from torch import nn
import math
from pprint import pprint
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from IPython.display import clear_output

%matplotlib inline

In [2]:
position_encoding_size = 256

def position_embeddings(max_pos, size):
    embeddings = np.zeros((max_pos, size))
    w = 1 / (10000 ** (2*np.arange(size // 2 )/size))
    for pos in range(max_pos):
        embeddings[pos,0::2] = np.sin(w*pos)
        embeddings[pos,1::2] = np.cos(w*pos)
    return torch.Tensor(embeddings)
    
pos_embed = position_embeddings(10000, position_encoding_size)
pos_embed.shape

torch.Size([10000, 256])

In [3]:
bert = BertModel.from_pretrained('bert-base-cased')

bert_hidden_size = 768

In [4]:
class Classifier(nn.Module):
    def __init__(self, bert_size, position_size):
        super().__init__()

        self.attention = nn.Linear(bert_size + position_size, 1)
        self.softmax = nn.Softmax(1)

        self.prediction = nn.Sequential(
            nn.Linear(bert_size, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 6),
            nn.Sigmoid()
        )

    '''
    embeddings: shape (segment count, 512, bert_hidden_size)
    position_encodings:  shape (segment count, 512, position_encoding_size)
    comment_bounds: Array of tuples of the form [(start, end)]. comment_bounds[i] = (a, b) indicates that comment i's embeddings can be extracted as embeddings[a:b]
    '''
    def forward(self, embeddings, position_encodings, comment_bounds = None):
        attention_input = torch.cat([embeddings, position_encodings], dim=2) # (batch, 512, position_size + bert_hidden_size)

        # (batch, 512, 1)
        attentions = self.attention(attention_input)
        if comment_bounds is None:
            attentions = self.softmax(attentions) # (batch, 512, 1)
            vecs = torch.sum(attentions * embeddings, dim=1) # (batch, bert_hidden_size)
            return self.prediction(vecs) # (batch, 1)

        vecs = []
        for (a,b) in comment_bounds:
            comment_embeddings = embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            comment_attentions = attentions[a:b] # (segment_count, 512, 1)
            attention_weights = self.softmax(comment_attentions) # (segment_count, 512, 1)
            weighted_embeddings = attention_weights * embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            vec = torch.sum(weighted_embeddings.view(-1, weighted_embeddings.shape[-1]), dim=0, keepdim=True) # (segment_count, bert_hidden_size)
            vecs.append(vec)
        return self.prediction(torch.cat(vecs))

In [5]:
class MyDataset(Dataset):
    def __init__(self, file_format, normalize):
        super().__init__()

        # Load the data from files
        input_ids = torch.load(file_format.format("input_ids"))
        positions = torch.load(file_format.format("positions"))
        comment_ids = torch.load(file_format.format("ids"))
        targets = torch.load(file_format.format("targets"))

        # Treat the targets as binary to separate the possible outputs
        target_ids = torch.sum(torch.Tensor([32, 16, 8, 4, 2, 1]) * targets, axis=1)

        # Store the data according to the target. Useful for normalization
        self.data = [[] for i in range(64)]

        # Load the data into the array
        curr_id = 0
        start_index = 0
        for i in range(comment_ids.shape[0]):
            if comment_ids[i] != curr_id:
                target_id = int(target_ids[curr_id].item())
                data = (input_ids[start_index:i], positions[start_index:i], targets[curr_id])
                self.data[target_id].append(data)

                curr_id = comment_ids[i]
                start_index = i

        target_id = int(target_ids[curr_id].item())
        data = (input_ids[start_index:i], positions[start_index:i], targets[curr_id])
        self.data[target_id].append(data)

        n_nontoxic = len(self.data[0])

        n_of_each = n_nontoxic // (len(self.data)-1)

        # Remove the empty arrays from the data
        self.data = [data for data in self.data if data]

        n_copies = np.array([1]+[n_of_each // len(self.data[i]) for i in range(1,len(self.data))])
        
        if not normalize:
            n_copies = np.ones_like(n_copies)
        
        self.data_length = np.array([len(data) for data in self.data])

        segment_lengths = n_copies*self.data_length

        self.length = int(np.sum(segment_lengths))

        self.boundaries = np.zeros(segment_lengths.shape[0]+1)
        self.boundaries[1:] = np.cumsum(segment_lengths)


    def __len__(self):
        return self.length

    def __getitem__(self, index):
        for i in range(self.boundaries.shape[0] - 1):
            if index >= self.boundaries[i] and index < self.boundaries[i+1]:
                inner_index = int((index - self.boundaries[i]) % self.data_length[i])
                return self.data[i][inner_index]
        print("Index: ", index)

In [6]:
train_dataset = MyDataset("train_{}.pt", True)
test_dataset = MyDataset("test_{}.pt", False)
len(train_dataset), len(test_dataset)

(226493, 153164)

In [7]:
def iterable(obj):
    try:
        iter(obj)
    except Exception:
        return False
    else:
        return True

def collate_samples(batch):
    for i, item in enumerate(batch):
        if not iterable(item):
            print(i, item)
    split_comments, positions, targets = zip(*batch)
    input_ids = []
    comment_bounds = []
    start = 0
    for comment in split_comments:
        input_ids += comment
        comment_bounds.append((start, start+len(comment)))
        start += len(comment)
    input_ids = torch.stack(input_ids, dim=0)
    encoded_positions = torch.cat([
                          # Use the position array as indices into the position embedding
                          pos_embed[position_arr]
                          # For each comment in the batch
                          for position_arr in positions                     
                      ])
  
    targets = [target.view(1, -1) for target in targets]
    targets = torch.cat(targets, dim=0)
    comment_bounds = np.array(comment_bounds, dtype=np.int32)
    return input_ids, encoded_positions, comment_bounds, targets

batch_size = 72

train = DataLoader(train_dataset, 
                   batch_size = batch_size,
                   shuffle=True,
                   collate_fn = collate_samples)

test = DataLoader(test_dataset, 
                   batch_size = batch_size,
                   shuffle=True,
                   collate_fn = collate_samples)

In [8]:
for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(train):
    print(input_ids.shape, encoded_position.shape, comment_bounds.shape, target.shape)
    break

torch.Size([74, 512]) torch.Size([74, 512, 256]) (72, 2) torch.Size([72, 6])


In [9]:
def init_weights(model, stdv):
    for param in model.parameters():
        if len(param.shape) >= 2:
            nn.init.kaiming_normal_(param.data)
        else:
            param.data.normal_(0.0, stdv)

In [10]:
model_path = "./model/model_{}.pt"
adam_path = "./adam/adam_{}.pt"
train_loss_path = "./train_loss.pt"
test_loss_path = "./train_loss.pt"

In [11]:
def calc_auc(pred, target):
    result = []
    for i in range(pred.shape[1]):
        if len(np.unique(target[:,i])) == 2:
            result.append(roc_auc_score(target[:,i], pred[:,i], labels=[0,1]))
        else:
            extra = np.array([1-target[0,i]])
            target_i = np.concatenate((target[:,i], extra))
            pred_i = np.concatenate((pred[:,i], extra))
            result.append(roc_auc_score(target_i, pred_i, labels=[0,1]))
            
    
    return result

In [12]:
def load_model(tag):
    model.load_state_dict(torch.load(model_path.format(tag)))
    optimizer.load_state_dict(torch.load(adam_path.format(tag)))
    
def save_model(tag):
    torch.save(model.state_dict(), model_path.format(tag))
    torch.save(optimizer.state_dict(), adam_path.format(tag))

In [13]:
bert.eval()
device = torch.device("cuda:0")
bert = bert.to(device)

In [14]:
first_epoch = 0
epochs = 30
learning_rate = 1e-4

model = Classifier(bert_hidden_size, position_encoding_size)
init_weights(model, 0.2)
model = model.to(device)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [15]:

load = None #"epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(5, 100, 0.3493)

if load is not None:
    load_model(load)

In [16]:
class EpochLogger:
    def __init__(self, path):
        self.file = open(path, "w")
    
    def log(self, string):
        print(string)
        self.file.write(string)
        self.file.write("\n")
    
    def close(self):
        self.file.close()

In [17]:
log_frequency = 1
save_frequency = 10

clear_frequency = 3

# Maximum batches per epoch
batches_per_train_epoch = 100
batches_per_test_epoch = 10


batches_per_train_epoch = min(batches_per_train_epoch, len(train))
batches_per_test_epoch = min(batches_per_test_epoch, len(test))

log_path = "./logs/epoch_{}.txt"

epoch_times = []
batch_times = []
train_losses = []
test_losses = []
for epoch in range(first_epoch, epochs):
    logger = EpochLogger(log_path.format(epoch))
    epoch_start = time()
    predicted = None
    true = None
    n_correct = 0
    n_processed = 0
    print("Training")
    for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(train):
        batch_start = time()
        # Get the BERT output
        with torch.no_grad():
            encoded_comments = bert(input_ids.to(device))[0]
    
        # Get the outputs from the network
        output = model(encoded_comments, encoded_position.to(device), comment_bounds)

        #print(output.shape, target.shape)
        # Gradient descent
        optimizer.zero_grad()
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()
        batch_times.append(time() - batch_start)
        batch_times = batch_times[-100:]
        pred = torch.round(output.cpu().detach())
        if predicted is None:
            predicted = pred.clone()
            true = target.clone()
        else:
            predicted = torch.cat((predicted, pred), dim=0)
            true = torch.cat((true, target), dim=0)

        n_correct += torch.sum(pred == target, axis=0).numpy()
        n_processed += pred.shape[0]

        predicted = predicted[-1000:]
        true = true[-1000:]
    
        train_losses.append(loss.item())
    
        if i % log_frequency == 0:
            logger.log("Epoch {}: {:05d}/{} ({:7.01f}s remaining)\t BCE Loss: {:.04f}".format(epoch, i, batches_per_train_epoch, np.mean(batch_times)*(batches_per_train_epoch - i), loss.item()))
            auc = calc_auc(predicted.numpy(), true.numpy())
            logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
            acc = n_correct / n_processed
            logger.log("\tAccuracy: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
        if i % save_frequency == 0:
            save_model("epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(epoch, i, loss.item()))
            torch.save(torch.Tensor(train_losses), train_loss_path)
        if i % batches_per_train_epoch == 0 and i != 0:
            break
            
    save_model("epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(epoch, i, loss.item()))
    torch.save(torch.Tensor(train_losses), train_loss_path)
        
    with torch.no_grad():
        print("Testing")
        predicted = None
        for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(test):
            batch_start = time()
            # Get the BERT output
            encoded_comments = bert(input_ids.to(device))[0]

            # Get the outputs from the network
            output = model(encoded_comments, encoded_position.to(device), comment_bounds)

            #print(output[0], target[0])
            # Gradient descent
            pred = torch.round(output.cpu().detach())
            if predicted is None:
                predicted = pred.clone()
                true = target.clone()
            else:
                predicted = torch.cat((predicted, pred), dim=0)
                true = torch.cat((true, target), dim=0)

            n_correct += torch.sum(pred == target, axis=0).numpy()
            n_processed += pred.shape[0]

            predicted = predicted[-1000:]
            true = true[-1000:]

            if i % log_frequency == 0:
                logger.log("Epoch {}: {:05d}/{} ({:7.01f}s remaining)\t BCE Loss: {:.04f}".format(epoch, i, batches_per_test_epoch, np.mean(batch_times)*(batches_per_test_epoch - i), loss.item()))
                auc = calc_auc(predicted.numpy(), true.numpy())
                logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
                acc = n_correct / n_processed
                logger.log("\tAccuracy: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
            if i % batches_per_test_epoch == 0 and i != 0:
                break
                
        auc = calc_auc(predicted.numpy(), true.numpy())
        acc = n_correct / n_processed
        epoch_time = time() - epoch_start
        logger.log("Epoch {} took {:7.01f}s. Test Values:".format(epoch, epoch_time))
        
        logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
        
        logger.log("\tAccuracy:  Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
    logger.close()
    if epoch % clear_frequency == 0 and epoch != 0:
        clear_output()
    epoch_times.append(epoch_time)

Training
Epoch 28: 00000/100 (  532.5s remaining)	 BCE Loss: 0.0467
	AUC: Toxic: 0.9750 Severe Toxic: 1.0000 Obscene: 0.9911 Threat: 1.0000 Insult: 0.9911 Identity-based Hate: 0.9158
	Accuracy: Toxic: 0.9861 Severe Toxic: 1.0000 Obscene: 0.9861 Threat: 1.0000 Insult: 0.9861 Identity-based Hate: 0.9444
Epoch 28: 00001/100 (  527.4s remaining)	 BCE Loss: 0.3318
	AUC: Toxic: 0.9771 Severe Toxic: 0.9917 Obscene: 0.9958 Threat: 0.9917 Insult: 0.9654 Identity-based Hate: 0.9500
	Accuracy: Toxic: 0.9792 Severe Toxic: 0.9861 Obscene: 0.9931 Threat: 0.9861 Insult: 0.9653 Identity-based Hate: 0.9583
Epoch 28: 00002/100 (  523.1s remaining)	 BCE Loss: 0.0723
	AUC: Toxic: 0.9767 Severe Toxic: 0.9806 Obscene: 0.9944 Threat: 0.9916 Insult: 0.9509 Identity-based Hate: 0.9350
	Accuracy: Toxic: 0.9815 Severe Toxic: 0.9861 Obscene: 0.9907 Threat: 0.9861 Insult: 0.9583 Identity-based Hate: 0.9583
Epoch 28: 00003/100 (  518.2s remaining)	 BCE Loss: 0.0638
	AUC: Toxic: 0.9686 Severe Toxic: 0.9741 Obscene: 

Epoch 28: 00028/100 (  380.3s remaining)	 BCE Loss: 0.1311
	AUC: Toxic: 0.9564 Severe Toxic: 0.9484 Obscene: 0.9360 Threat: 0.9742 Insult: 0.9322 Identity-based Hate: 0.9710
	Accuracy: Toxic: 0.9708 Severe Toxic: 0.9784 Obscene: 0.9761 Threat: 0.9842 Insult: 0.9598 Identity-based Hate: 0.9737
Epoch 28: 00029/100 (  375.2s remaining)	 BCE Loss: 0.0415
	AUC: Toxic: 0.9593 Severe Toxic: 0.9562 Obscene: 0.9410 Threat: 0.9732 Insult: 0.9343 Identity-based Hate: 0.9726
	Accuracy: Toxic: 0.9718 Severe Toxic: 0.9787 Obscene: 0.9764 Threat: 0.9843 Insult: 0.9602 Identity-based Hate: 0.9745
Epoch 28: 00030/100 (  369.0s remaining)	 BCE Loss: 0.0958
	AUC: Toxic: 0.9604 Severe Toxic: 0.9510 Obscene: 0.9434 Threat: 0.9666 Insult: 0.9323 Identity-based Hate: 0.9714
	Accuracy: Toxic: 0.9722 Severe Toxic: 0.9776 Obscene: 0.9763 Threat: 0.9834 Insult: 0.9601 Identity-based Hate: 0.9745
Epoch 28: 00031/100 (  363.4s remaining)	 BCE Loss: 0.0759
	AUC: Toxic: 0.9597 Severe Toxic: 0.9480 Obscene: 0.9458 Th

Epoch 28: 00056/100 (  230.9s remaining)	 BCE Loss: 0.1069
	AUC: Toxic: 0.9621 Severe Toxic: 0.9671 Obscene: 0.9638 Threat: 0.9821 Insult: 0.9279 Identity-based Hate: 0.9540
	Accuracy: Toxic: 0.9730 Severe Toxic: 0.9776 Obscene: 0.9759 Threat: 0.9859 Insult: 0.9600 Identity-based Hate: 0.9754
Epoch 28: 00057/100 (  225.7s remaining)	 BCE Loss: 0.0741
	AUC: Toxic: 0.9626 Severe Toxic: 0.9648 Obscene: 0.9635 Threat: 0.9892 Insult: 0.9286 Identity-based Hate: 0.9462
	Accuracy: Toxic: 0.9729 Severe Toxic: 0.9773 Obscene: 0.9758 Threat: 0.9859 Insult: 0.9600 Identity-based Hate: 0.9751
Epoch 28: 00058/100 (  221.1s remaining)	 BCE Loss: 0.0932
	AUC: Toxic: 0.9614 Severe Toxic: 0.9632 Obscene: 0.9597 Threat: 0.9870 Insult: 0.9279 Identity-based Hate: 0.9433
	Accuracy: Toxic: 0.9729 Severe Toxic: 0.9772 Obscene: 0.9755 Threat: 0.9859 Insult: 0.9597 Identity-based Hate: 0.9746
Epoch 28: 00059/100 (  216.1s remaining)	 BCE Loss: 0.0747
	AUC: Toxic: 0.9628 Severe Toxic: 0.9609 Obscene: 0.9611 Th

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

