In [1]:
import torch
from torch import nn, optim
from transformers import *
from torch.utils.data import Dataset, DataLoader
from time import time
import numpy as np
import torch
from torch import nn
import math
from pprint import pprint
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

%matplotlib inline

In [2]:
position_encoding_size = 256

def position_embeddings(max_pos, size):
    embeddings = np.zeros((max_pos, size))
    w = 1 / (10000 ** (2*np.arange(size // 2 )/size))
    for pos in range(max_pos):
        embeddings[pos,0::2] = np.sin(w*pos)
        embeddings[pos,1::2] = np.cos(w*pos)
    return torch.Tensor(embeddings)
    
pos_embed = position_embeddings(10000, position_encoding_size)
pos_embed.shape

torch.Size([10000, 256])

In [3]:
bert = BertModel.from_pretrained('bert-base-cased')

bert_hidden_size = 768

In [4]:
class Classifier(nn.Module):
    def __init__(self, bert_size, position_size):
        super().__init__()

        self.attention = nn.Linear(bert_size + position_size, 1)
        self.softmax = nn.Softmax(1)

        self.prediction = nn.Sequential(
            nn.Linear(bert_size, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 6),
            nn.Sigmoid()
        )

    '''
    embeddings: shape (segment count, 512, bert_hidden_size)
    position_encodings:  shape (segment count, 512, position_encoding_size)
    comment_bounds: Array of tuples of the form [(start, end)]. comment_bounds[i] = (a, b) indicates that comment i's embeddings can be extracted as embeddings[a:b]
    '''
    def forward(self, embeddings, position_encodings, comment_bounds = None):
        attention_input = torch.cat([embeddings, position_encodings], dim=2) # (batch, 512, position_size + bert_hidden_size)

        # (batch, 512, 1)
        attentions = self.attention(attention_input)
        if comment_bounds is None:
            attentions = self.softmax(attentions) # (batch, 512, 1)
            vecs = torch.sum(attentions * embeddings, dim=1) # (batch, bert_hidden_size)
            return self.prediction(vecs) # (batch, 1)

        vecs = []
        for (a,b) in comment_bounds:
            comment_embeddings = embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            comment_attentions = attentions[a:b] # (segment_count, 512, 1)
            attention_weights = self.softmax(comment_attentions) # (segment_count, 512, 1)
            weighted_embeddings = attention_weights * embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            vec = torch.sum(weighted_embeddings.view(-1, weighted_embeddings.shape[-1]), dim=0, keepdim=True) # (segment_count, bert_hidden_size)
            vecs.append(vec)
        return self.prediction(torch.cat(vecs))

In [5]:
class MyDataset(Dataset):
    def __init__(self, file_format, normalize):
        super().__init__()

        # Load the data from files
        input_ids = torch.load(file_format.format("input_ids"))
        positions = torch.load(file_format.format("positions"))
        comment_ids = torch.load(file_format.format("ids"))
        targets = torch.load(file_format.format("targets"))

        # Treat the targets as binary to separate the possible outputs
        target_ids = torch.sum(torch.Tensor([32, 16, 8, 4, 2, 1]) * targets, axis=1)

        # Store the data according to the target. Useful for normalization
        self.data = [[] for i in range(64)]

        # Load the data into the array
        curr_id = 0
        start_index = 0
        for i in range(comment_ids.shape[0]):
            if comment_ids[i] != curr_id:
                target_id = int(target_ids[curr_id].item())
                data = (input_ids[start_index:i], positions[start_index:i], targets[curr_id])
                self.data[target_id].append(data)

                curr_id = comment_ids[i]
                start_index = i

        target_id = int(target_ids[curr_id].item())
        data = (input_ids[start_index:i], positions[start_index:i], targets[curr_id])
        self.data[target_id].append(data)

        n_nontoxic = len(self.data[0])

        n_of_each = n_nontoxic // (len(self.data)-1)

        # Remove the empty arrays from the data
        self.data = [data for data in self.data if data]

        n_copies = np.array([1]+[n_of_each // len(self.data[i]) for i in range(1,len(self.data))])
        
        if not normalize:
            n_copies = np.ones_like(n_copies)
        
        self.data_length = np.array([len(data) for data in self.data])

        segment_lengths = n_copies*self.data_length

        self.length = int(np.sum(segment_lengths))

        self.boundaries = np.zeros(segment_lengths.shape[0]+1)
        self.boundaries[1:] = np.cumsum(segment_lengths)


    def __len__(self):
        return self.length

    def __getitem__(self, index):
        for i in range(self.boundaries.shape[0] - 1):
            if index >= self.boundaries[i] and index < self.boundaries[i+1]:
                inner_index = int((index - self.boundaries[i]) % self.data_length[i])
                return self.data[i][inner_index]
        print("Index: ", index)

In [6]:
train_dataset = MyDataset("train_{}.pt", True)
test_dataset = MyDataset("test_{}.pt", False)
len(train_dataset), len(test_dataset)

(226493, 153164)

In [7]:
def iterable(obj):
    try:
        iter(obj)
    except Exception:
        return False
    else:
        return True

def collate_samples(batch):
    for i, item in enumerate(batch):
        if not iterable(item):
            print(i, item)
    split_comments, positions, targets = zip(*batch)
    input_ids = []
    comment_bounds = []
    start = 0
    for comment in split_comments:
        input_ids += comment
        comment_bounds.append((start, start+len(comment)))
        start += len(comment)
    input_ids = torch.stack(input_ids, dim=0)
    encoded_positions = torch.cat([
                          # Use the position array as indices into the position embedding
                          pos_embed[position_arr]
                          # For each comment in the batch
                          for position_arr in positions                     
                      ])
  
    targets = [target.view(1, -1) for target in targets]
    targets = torch.cat(targets, dim=0)
    comment_bounds = np.array(comment_bounds, dtype=np.int32)
    return input_ids, encoded_positions, comment_bounds, targets

batch_size = 72

train = DataLoader(train_dataset, 
                   batch_size = batch_size,
                   shuffle=True,
                   collate_fn = collate_samples)

test = DataLoader(test_dataset, 
                   batch_size = batch_size,
                   shuffle=True,
                   collate_fn = collate_samples)

In [8]:
for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(train):
    print(input_ids.shape, encoded_position.shape, comment_bounds.shape, target.shape)
    break

torch.Size([74, 512]) torch.Size([74, 512, 256]) (72, 2) torch.Size([72, 6])


In [9]:
bert.eval()
device = torch.device("cuda:0")
bert = bert.to(device)

In [10]:
epochs = 10
learning_rate = 3e-5

model = Classifier(bert_hidden_size, position_encoding_size)
model = model.to(device)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [35]:
def init_weights(model, stdv):
    for param in model.parameters():
        if len(param.shape) >= 2:
            nn.init.kaiming_normal_(param.data)
        else:
            param.data.normal_(0.0, stdv)

init_weights(model, 0.02)

In [12]:
model_path = "./model/model_{}.pt"
adam_path = "./adam/adam_{}.pt"
train_loss_path = "./train_loss.pt"
test_loss_path = "./train_loss.pt"

In [13]:
load = None

def load_model(tag):
    model.load_state_dict(torch.load(model_path.format(tag)))
    optimizer.load_state_dict(torch.load(adam_path.format(tag)))
    
def save_model(tag):
    torch.save(model.state_dict(), model_path.format(tag))
    torch.save(optimizer.state_dict(), adam_path.format(tag))

if load is not None:
    load_model(load)

In [31]:
def calc_auc(pred, target):
    result = []
    for i in range(pred.shape[1]):
        if len(np.unique(target[:,i])) == 2:
            result.append(roc_auc_score(target[:,i], pred[:,i], labels=[0,1]))
        else:
            extra = np.array([1-target[0,i]])
            target_i = np.concatenate((target[:,i], extra))
            pred_i = np.concatenate((pred[:,i], extra))
            result.append(roc_auc_score(target_i, pred_i, labels=[0,1]))
            
    
    return result

In [None]:
log_frequency = 1
save_frequency = 10

# Maximum batches per epoch
batches_per_train_epoch = 100
batches_per_test_epoch = 10


batches_per_train_epoch = min(batches_per_train_epoch, len(train))
batches_per_test_epoch = min(batches_per_test_epoch, len(test))

epoch_times = []
batch_times = []
train_losses = []
test_losses = []
for epoch in range(epochs):
    epoch_start = time()
    predicted = None
    true = None
    n_correct = 0
    n_processed = 0
    print("Training")
    for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(train):
        batch_start = time()
        # Get the BERT output
        with torch.no_grad():
            encoded_comments = bert(input_ids.to(device))[0]
    
        # Get the outputs from the network
        output = model(encoded_comments, encoded_position.to(device), comment_bounds)

        #print(output.shape, target.shape)
        # Gradient descent
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()
        batch_times.append(time() - batch_start)
        batch_times = batch_times[-100:]
        pred = torch.round(output.cpu().detach())
        if predicted is None:
            predicted = pred.clone()
            true = target.clone()
        else:
            predicted = torch.cat((predicted, pred), dim=0)
            true = torch.cat((true, target), dim=0)

        n_correct += torch.sum(pred == target, axis=0).numpy()
        n_processed += pred.shape[0]

        predicted = predicted[-1000:]
        true = true[-1000:]
    
        test_losses.append(loss.item())
    
        if i % log_frequency == 0:
            print("Epoch {}: {:05d}/{} ({:7.01f}s remaining)\t BCE Loss: {:.04f}".format(epoch, i, batches_per_train_epoch, np.mean(batch_times)*(batches_per_train_epoch - i), loss.item()))
            auc = calc_auc(predicted.numpy(), true.numpy())
            print("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
            acc = n_correct / n_processed
            print("\tAccuracy: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
        if i % save_frequency == 0:
            save_model("epoch{}_batch{}_bce{:.04f}".format(epoch, i, loss.item()))
            torch.save(torch.Tensor(test_losses), train_loss_path)
        if i % batches_per_train_epoch == 0 and i != 0:
            break
        
    with torch.no_grad():
        print("Testing")
        predicted = None
        for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(test):
            batch_start = time()
            # Get the BERT output
            encoded_comments = bert(input_ids.to(device))[0]

            # Get the outputs from the network
            output = model(encoded_comments, encoded_position.to(device), comment_bounds)

            #print(output[0], target[0])
            # Gradient descent
            pred = torch.round(output.cpu().detach())
            if predicted is None:
                predicted = pred.clone()
                true = target.clone()
            else:
                predicted = torch.cat((predicted, pred), dim=0)
                true = torch.cat((true, target), dim=0)

            n_correct += torch.sum(pred == target, axis=0).numpy()
            n_processed += pred.shape[0]

            predicted = predicted[-1000:]
            true = true[-1000:]

            if i % log_frequency == 0:
                print("Epoch {}: {:05d}/{} ({:7.01f}s remaining)\t BCE Loss: {:.04f}".format(epoch, i, batches_per_test_epoch, np.mean(batch_times)*(batches_per_test_epoch - i), loss.item()))
                auc = calc_auc(predicted.numpy(), true.numpy())
                print("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
                acc = n_correct / n_processed
                print("\tAccuracy: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
            if i % batches_per_test_epoch == 0:
                break
        print("Epoch {} Test Values:".format(epoch))
        auc = calc_auc(predicted.numpy(), true.numpy())
        print("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
        acc = n_correct / n_processed
        print("\tAccuracy:  Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
            
    epoch_times.append(time() - epoch)

Training
Epoch 0: 00000/100 (  484.3s remaining)	 BCE Loss: 0.6299
	AUC: Toxic: 0.7037 Severe Toxic: 0.4750 Obscene: 0.5000 Threat: 0.4844 Insult: 0.4754 Identity-based Hate: 0.5000
	Accuracy: Toxic: 0.8056 Severe Toxic: 0.6806 Obscene: 0.8333 Threat: 0.8611 Insult: 0.2083 Identity-based Hate: 0.8889
Epoch 0: 00001/100 (  464.6s remaining)	 BCE Loss: 0.6322
	AUC: Toxic: 0.6797 Severe Toxic: 0.4643 Obscene: 0.5000 Threat: 0.4877 Insult: 0.4829 Identity-based Hate: 0.5000
	Accuracy: Toxic: 0.8056 Severe Toxic: 0.7292 Obscene: 0.8333 Threat: 0.8264 Insult: 0.2153 Identity-based Hate: 0.8542
Epoch 0: 00002/100 (  479.1s remaining)	 BCE Loss: 0.6446
	AUC: Toxic: 0.6750 Severe Toxic: 0.5132 Obscene: 0.5000 Threat: 0.4890 Insult: 0.4920 Identity-based Hate: 0.5303
	Accuracy: Toxic: 0.8056 Severe Toxic: 0.7593 Obscene: 0.8102 Threat: 0.8194 Insult: 0.2222 Identity-based Hate: 0.8565
Epoch 0: 00003/100 (  480.9s remaining)	 BCE Loss: 0.6191
	AUC: Toxic: 0.6510 Severe Toxic: 0.5126 Obscene: 0.49

Epoch 0: 00028/100 (  390.5s remaining)	 BCE Loss: 0.5511
	AUC: Toxic: 0.5296 Severe Toxic: 0.5069 Obscene: 0.4983 Threat: 0.4975 Insult: 0.5025 Identity-based Hate: 0.5000
	Accuracy: Toxic: 0.7476 Severe Toxic: 0.8372 Obscene: 0.8070 Threat: 0.8209 Insult: 0.6844 Identity-based Hate: 0.8290
Epoch 0: 00029/100 (  384.6s remaining)	 BCE Loss: 0.5082
	AUC: Toxic: 0.5295 Severe Toxic: 0.5071 Obscene: 0.4990 Threat: 0.4982 Insult: 0.5000 Identity-based Hate: 0.5000
	Accuracy: Toxic: 0.7491 Severe Toxic: 0.8389 Obscene: 0.8083 Threat: 0.8222 Insult: 0.6894 Identity-based Hate: 0.8296
Epoch 0: 00030/100 (  386.3s remaining)	 BCE Loss: 0.6144
	AUC: Toxic: 0.5364 Severe Toxic: 0.5168 Obscene: 0.5021 Threat: 0.4988 Insult: 0.5000 Identity-based Hate: 0.5000
	Accuracy: Toxic: 0.7487 Severe Toxic: 0.8387 Obscene: 0.8091 Threat: 0.8226 Insult: 0.6931 Identity-based Hate: 0.8306
Epoch 0: 00031/100 (  380.6s remaining)	 BCE Loss: 0.6499
	AUC: Toxic: 0.5423 Severe Toxic: 0.5230 Obscene: 0.5070 Threat

In [None]:
import os
print(os.listdir("./"))