In [16]:
import torch
from torch import nn, optim
from transformers import *
from torch.utils.data import Dataset, DataLoader
from time import time
import numpy as np
import torch
from torch import nn
import math
from pprint import pprint
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from IPython.display import clear_output

%matplotlib inline

In [17]:
position_encoding_size = 256

def position_embeddings(max_pos, size):
    embeddings = np.zeros((max_pos, size))
    w = 1 / (10000 ** (2*np.arange(size // 2 )/size))
    for pos in range(max_pos):
        embeddings[pos,0::2] = np.sin(w*pos)
        embeddings[pos,1::2] = np.cos(w*pos)
    return torch.Tensor(embeddings)
    
pos_embed = position_embeddings(10000, position_encoding_size)
pos_embed.shape

torch.Size([10000, 256])

In [18]:
class Classifier(nn.Module):
    def __init__(self, bert_size, position_size):
        super().__init__()

        self.attention = nn.Linear(bert_size + position_size, 1)
        self.softmax = nn.Softmax(1)

        self.prediction = nn.Sequential(
            nn.Linear(bert_size, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 6),
            nn.Sigmoid()
        )

    '''
    embeddings: shape (segment count, 512, bert_hidden_size)
    position_encodings:  shape (segment count, 512, position_encoding_size)
    comment_bounds: Array of tuples of the form [(start, end)]. comment_bounds[i] = (a, b) indicates that comment i's embeddings can be extracted as embeddings[a:b]
    '''
    def forward(self, embeddings, position_encodings, comment_bounds = None):
        attention_input = torch.cat([embeddings, position_encodings], dim=2) # (batch, 512, position_size + bert_hidden_size)

        # (batch, 512, 1)
        attentions = self.attention(attention_input)
        if comment_bounds is None:
            attentions = self.softmax(attentions) # (batch, 512, 1)
            vecs = torch.sum(attentions * embeddings, dim=1) # (batch, bert_hidden_size)
            return self.prediction(vecs) # (batch, 1)

        vecs = []
        for (a,b) in comment_bounds:
            comment_embeddings = embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            comment_attentions = attentions[a:b] # (segment_count, 512, 1)
            attention_weights = self.softmax(comment_attentions) # (segment_count, 512, 1)
            weighted_embeddings = attention_weights * embeddings[a:b] # (segment_count, 512, bert_hidden_size)
            vec = torch.sum(weighted_embeddings.view(-1, weighted_embeddings.shape[-1]), dim=0, keepdim=True) # (segment_count, bert_hidden_size)
            vecs.append(vec)
        return self.prediction(torch.cat(vecs))

In [19]:
class MyDataset(Dataset):
    def __init__(self, file_format, normalize):
        super().__init__()

        # Load the data from files
        input_ids = torch.load(file_format.format("input_ids"))
        positions = torch.load(file_format.format("positions"))
        comment_ids = torch.load(file_format.format("ids"))
        targets = torch.load(file_format.format("targets"))

        # Treat the targets as binary to separate the possible outputs
        target_ids = torch.sum(torch.Tensor([32, 16, 8, 4, 2, 1]) * targets, axis=1)

        # Store the data according to the target. Useful for normalization
        self.data = [[] for i in range(64)]

        # Load the data into the array
        curr_id = 0
        start_index = 0
        for i in range(comment_ids.shape[0]):
            if comment_ids[i] != curr_id:
                target_id = int(target_ids[curr_id].item())
                data = (input_ids[start_index:i], positions[start_index:i], targets[curr_id])
                self.data[target_id].append(data)

                curr_id = comment_ids[i]
                start_index = i

        target_id = int(target_ids[curr_id].item())
        data = (input_ids[start_index:i], positions[start_index:i], targets[curr_id])
        self.data[target_id].append(data)

        n_nontoxic = len(self.data[0])

        n_of_each = n_nontoxic // (len(self.data)-1)

        # Remove the empty arrays from the data
        self.data = [data for data in self.data if data]

        n_copies = np.array([1]+[n_of_each // len(self.data[i]) for i in range(1,len(self.data))])
        
        if not normalize:
            n_copies = np.ones_like(n_copies)
        
        self.data_length = np.array([len(data) for data in self.data])

        segment_lengths = n_copies*self.data_length

        self.length = int(np.sum(segment_lengths))

        self.boundaries = np.zeros(segment_lengths.shape[0]+1)
        self.boundaries[1:] = np.cumsum(segment_lengths)


    def __len__(self):
        return self.length

    def __getitem__(self, index):
        for i in range(self.boundaries.shape[0] - 1):
            if index >= self.boundaries[i] and index < self.boundaries[i+1]:
                inner_index = int((index - self.boundaries[i]) % self.data_length[i])
                return self.data[i][inner_index]
        print("Index: ", index)

In [20]:
test_dataset = MyDataset("test_{}.pt", False)
len(test_dataset)

63978

In [21]:
def iterable(obj):
    try:
        iter(obj)
    except Exception:
        return False
    else:
        return True

def collate_samples(batch):
    for i, item in enumerate(batch):
        if not iterable(item):
            print(i, item)
    split_comments, positions, targets = zip(*batch)
    input_ids = []
    comment_bounds = []
    start = 0
    for comment in split_comments:
        input_ids += comment
        comment_bounds.append((start, start+len(comment)))
        start += len(comment)
    input_ids = torch.stack(input_ids, dim=0)
    encoded_positions = torch.cat([
                          # Use the position array as indices into the position embedding
                          pos_embed[position_arr]
                          # For each comment in the batch
                          for position_arr in positions                     
                      ])
  
    targets = [target.view(1, -1) for target in targets]
    targets = torch.cat(targets, dim=0)
    comment_bounds = np.array(comment_bounds, dtype=np.int32)
    return input_ids, encoded_positions, comment_bounds, targets

batch_size = 128

test = DataLoader(test_dataset, 
                   batch_size = batch_size,
                   shuffle=True,
                   collate_fn = collate_samples)

In [22]:
for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(test):
    print(input_ids.shape, encoded_position.shape, comment_bounds.shape, target.shape)
    break

torch.Size([129, 512]) torch.Size([129, 512, 256]) (128, 2) torch.Size([128, 6])


In [23]:
def calc_auc(pred, target):
    result = []
    for i in range(pred.shape[1]):
        if len(np.unique(target[:,i])) == 2:
            result.append(roc_auc_score(target[:,i], pred[:,i], labels=[0,1]))
        else:
            extra = np.array([1-target[0,i]])
            target_i = np.concatenate((target[:,i], extra))
            pred_i = np.concatenate((pred[:,i], extra))
            result.append(roc_auc_score(target_i, pred_i, labels=[0,1]))
            
    
    return result

In [117]:
model_path = "./model/model_{}.pt"
bert_path = "./bert/bert_{}.pt"

In [25]:
bert = BertModel.from_pretrained('bert-base-cased')

bert_hidden_size = 768

In [26]:
bert.eval()
device = torch.device("cuda:0")
bert = bert.to(device)

In [27]:
model = Classifier(bert_hidden_size, position_encoding_size)
model = model.to(device)

model.eval()

Classifier(
  (attention): Linear(in_features=1024, out_features=1, bias=True)
  (softmax): Softmax(dim=1)
  (prediction): Sequential(
    (0): Linear(in_features=768, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=1024, out_features=6, bias=True)
    (5): Sigmoid()
  )
)

In [118]:
load = "epoch_{:04d}_batch_{:04d}_bce_{:.04f}".format(44, 1200, 0.0141)
load_bert = True

if load is not None:
    model.load_state_dict(torch.load(model_path.format(load)))
    if load_bert:
        bert.load_state_dict(torch.load(bert_path.format(load)))

In [29]:
class EpochLogger:
    def __init__(self, path):
        self.file = open(path, "w")
    
    def log(self, string):
        print(string)
        self.file.write(string)
        self.file.write("\n")
    
    def close(self):
        self.file.close()

In [120]:
log_frequency = 1

clear_frequency = 100

log_path = "./logs/val_output_unfrozen.txt"

epoch_times = []
batch_times = []
train_losses = []
test_losses = []
logger = EpochLogger(log_path)
predicted = None
true = None
n_correct = 0
n_processed = 0

with torch.no_grad():
    print("Testing")
    predicted = None
    for i, (input_ids, encoded_position, comment_bounds, target) in enumerate(test):
        batch_start = time()
        # Get the BERT output
        encoded_comments = bert(input_ids.to(device))[0]

        # Get the outputs from the network
        output = model(encoded_comments, encoded_position.to(device), comment_bounds)

        #print(output[0], target[0])
        # Gradient descent
        pred = output.cpu().detach()
        if predicted is None:
            predicted = pred.clone()
            true = target.clone()
        else:
            predicted = torch.cat((predicted, pred), dim=0)
            true = torch.cat((true, target), dim=0)

        n_correct += torch.sum(torch.round(pred) == target, axis=0).numpy()
        n_processed += pred.shape[0]
        
        batch_times.append(time() - batch_start)

        if i % log_frequency == 0:
            logger.log("Batch {:05d}/{} ({:7.01f}s remaining)\t ".format(i, len(test), np.mean(batch_times)*(len(test) - i)))
            auc = calc_auc(predicted.numpy(), true.numpy())
            logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))
            acc = n_correct / n_processed
            logger.log("\tAccuracy: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))

        if i % clear_frequency == 0 and i != 0:
            clear_output()

    auc = calc_auc(predicted.numpy(), true.numpy())
    acc = n_correct / n_processed
    logger.log("Test Values:")

    logger.log("\tAUC: Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*auc))

    logger.log("\tAccuracy:  Toxic: {:.04f} Severe Toxic: {:.04f} Obscene: {:.04f} Threat: {:.04f} Insult: {:.04f} Identity-based Hate: {:.04f}".format(*acc))
logger.close()

Batch 00401/500 (  890.1s remaining)	 
	AUC: Toxic: 0.9463 Severe Toxic: 0.9686 Obscene: 0.9531 Threat: 0.9797 Insult: 0.9430 Identity-based Hate: 0.9641
	Accuracy: Toxic: 0.9258 Severe Toxic: 0.9798 Obscene: 0.9370 Threat: 0.9936 Insult: 0.9222 Identity-based Hate: 0.9676
Batch 00402/500 (  881.0s remaining)	 
	AUC: Toxic: 0.9462 Severe Toxic: 0.9687 Obscene: 0.9531 Threat: 0.9798 Insult: 0.9429 Identity-based Hate: 0.9641
	Accuracy: Toxic: 0.9257 Severe Toxic: 0.9798 Obscene: 0.9370 Threat: 0.9936 Insult: 0.9222 Identity-based Hate: 0.9675
Batch 00403/500 (  872.1s remaining)	 
	AUC: Toxic: 0.9461 Severe Toxic: 0.9687 Obscene: 0.9530 Threat: 0.9798 Insult: 0.9429 Identity-based Hate: 0.9637
	Accuracy: Toxic: 0.9258 Severe Toxic: 0.9798 Obscene: 0.9370 Threat: 0.9936 Insult: 0.9223 Identity-based Hate: 0.9675
Batch 00404/500 (  863.1s remaining)	 
	AUC: Toxic: 0.9461 Severe Toxic: 0.9677 Obscene: 0.9531 Threat: 0.9798 Insult: 0.9429 Identity-based Hate: 0.9636
	Accuracy: Toxic: 0.9257

Batch 00431/500 (  620.1s remaining)	 
	AUC: Toxic: 0.9463 Severe Toxic: 0.9690 Obscene: 0.9528 Threat: 0.9803 Insult: 0.9432 Identity-based Hate: 0.9635
	Accuracy: Toxic: 0.9254 Severe Toxic: 0.9797 Obscene: 0.9369 Threat: 0.9937 Insult: 0.9223 Identity-based Hate: 0.9673
Batch 00432/500 (  611.1s remaining)	 
	AUC: Toxic: 0.9463 Severe Toxic: 0.9690 Obscene: 0.9529 Threat: 0.9803 Insult: 0.9432 Identity-based Hate: 0.9635
	Accuracy: Toxic: 0.9254 Severe Toxic: 0.9797 Obscene: 0.9369 Threat: 0.9936 Insult: 0.9223 Identity-based Hate: 0.9673
Batch 00433/500 (  602.1s remaining)	 
	AUC: Toxic: 0.9464 Severe Toxic: 0.9690 Obscene: 0.9528 Threat: 0.9802 Insult: 0.9432 Identity-based Hate: 0.9635
	Accuracy: Toxic: 0.9255 Severe Toxic: 0.9797 Obscene: 0.9368 Threat: 0.9936 Insult: 0.9222 Identity-based Hate: 0.9673
Batch 00434/500 (  593.1s remaining)	 
	AUC: Toxic: 0.9463 Severe Toxic: 0.9690 Obscene: 0.9528 Threat: 0.9802 Insult: 0.9432 Identity-based Hate: 0.9634
	Accuracy: Toxic: 0.9255

Batch 00461/500 (  350.4s remaining)	 
	AUC: Toxic: 0.9467 Severe Toxic: 0.9697 Obscene: 0.9524 Threat: 0.9755 Insult: 0.9436 Identity-based Hate: 0.9639
	Accuracy: Toxic: 0.9255 Severe Toxic: 0.9798 Obscene: 0.9366 Threat: 0.9936 Insult: 0.9222 Identity-based Hate: 0.9672
Batch 00462/500 (  341.4s remaining)	 
	AUC: Toxic: 0.9467 Severe Toxic: 0.9697 Obscene: 0.9524 Threat: 0.9755 Insult: 0.9437 Identity-based Hate: 0.9640
	Accuracy: Toxic: 0.9255 Severe Toxic: 0.9798 Obscene: 0.9367 Threat: 0.9936 Insult: 0.9222 Identity-based Hate: 0.9672
Batch 00463/500 (  332.4s remaining)	 
	AUC: Toxic: 0.9468 Severe Toxic: 0.9697 Obscene: 0.9523 Threat: 0.9755 Insult: 0.9437 Identity-based Hate: 0.9641
	Accuracy: Toxic: 0.9255 Severe Toxic: 0.9798 Obscene: 0.9367 Threat: 0.9936 Insult: 0.9222 Identity-based Hate: 0.9672
Batch 00464/500 (  323.4s remaining)	 
	AUC: Toxic: 0.9468 Severe Toxic: 0.9698 Obscene: 0.9523 Threat: 0.9756 Insult: 0.9437 Identity-based Hate: 0.9637
	Accuracy: Toxic: 0.9256

Batch 00491/500 (   80.9s remaining)	 
	AUC: Toxic: 0.9471 Severe Toxic: 0.9703 Obscene: 0.9523 Threat: 0.9764 Insult: 0.9438 Identity-based Hate: 0.9639
	Accuracy: Toxic: 0.9257 Severe Toxic: 0.9796 Obscene: 0.9368 Threat: 0.9937 Insult: 0.9221 Identity-based Hate: 0.9671
Batch 00492/500 (   71.9s remaining)	 
	AUC: Toxic: 0.9472 Severe Toxic: 0.9704 Obscene: 0.9523 Threat: 0.9763 Insult: 0.9438 Identity-based Hate: 0.9640
	Accuracy: Toxic: 0.9258 Severe Toxic: 0.9797 Obscene: 0.9369 Threat: 0.9937 Insult: 0.9221 Identity-based Hate: 0.9670
Batch 00493/500 (   62.9s remaining)	 
	AUC: Toxic: 0.9472 Severe Toxic: 0.9704 Obscene: 0.9524 Threat: 0.9764 Insult: 0.9439 Identity-based Hate: 0.9640
	Accuracy: Toxic: 0.9258 Severe Toxic: 0.9796 Obscene: 0.9369 Threat: 0.9937 Insult: 0.9221 Identity-based Hate: 0.9671
Batch 00494/500 (   53.9s remaining)	 
	AUC: Toxic: 0.9472 Severe Toxic: 0.9704 Obscene: 0.9525 Threat: 0.9764 Insult: 0.9439 Identity-based Hate: 0.9639
	Accuracy: Toxic: 0.9258

In [110]:
test_labels = pd.read_csv("./test_labels.csv")
test_labels

Unnamed: 0,id,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,00001cee341fdb12,-1,-1,-1,-1,-1,-1
1,0000247867823ef7,-1,-1,-1,-1,-1,-1
2,00013b17ad220c46,-1,-1,-1,-1,-1,-1
3,00017563c3f7919a,-1,-1,-1,-1,-1,-1
4,00017695ad8997eb,-1,-1,-1,-1,-1,-1
...,...,...,...,...,...,...,...
153159,fffcd0960ee309b5,-1,-1,-1,-1,-1,-1
153160,fffd7a9a6eb32c16,-1,-1,-1,-1,-1,-1
153161,fffda9e8d6fafa9e,-1,-1,-1,-1,-1,-1
153162,fffe8f1340a79fc2,-1,-1,-1,-1,-1,-1


In [111]:
score_indices = test_labels.toxic != -1
columns = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

In [112]:
masked_pred = np.zeros((len(test_labels), 6))
masked_pred[score_indices] = predicted.numpy()

In [113]:
test_labels[columns] = masked_pred

In [114]:
test_labels[score_indices]

Unnamed: 0,id,toxic,severe_toxic,obscene,threat,insult,identity_hate
5,0001ea8717f6de06,0.000009,0.000017,0.000079,6.557048e-08,0.000125,0.000001
7,000247e83dcc1211,0.032808,0.001187,0.012107,4.137385e-04,0.039888,0.050892
11,0002f87b16116a7f,0.947678,0.061841,0.965423,9.469392e-01,0.108079,0.023580
13,0003e1cccfd5a40a,0.996707,0.152949,0.773939,2.561683e-03,0.731999,0.999454
14,00059ace3e3e9a53,0.000248,0.000021,0.003412,5.981569e-04,0.002161,0.000008
...,...,...,...,...,...,...,...
153150,fff8f64043129fa2,0.000081,0.000007,0.002478,1.547969e-04,0.002450,0.000035
153151,fff9d70fe0722906,0.010101,0.000959,0.014651,2.412718e-06,0.092024,0.007202
153154,fffa8a11c4378854,0.000020,0.000019,0.000014,7.742472e-07,0.000042,0.000005
153155,fffac2a094c8e0e2,0.000020,0.000002,0.000737,1.342758e-05,0.000976,0.000012


In [115]:
test_labels.to_csv("results_frozen.csv", index=False)