In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
import torch
from torch.nn import BCEWithLogitsLoss, BCELoss
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report, confusion_matrix, multilabel_confusion_matrix, f1_score, accuracy_score
import pickle
from transformers import *
from tqdm import tqdm, trange

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)
torch.cuda.empty_cache()

# Load and Preprocessing Training Data

Dataset will be tokenized then split into training and validation sets. The validation set will be used to monitor training. For testing a separate test set will be loaded for analysis.

In [3]:
train_df = pd.read_csv('data/toxic_comments/train.csv')
train_df.head(3)

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0


In [6]:
df = train_df.sample(frac=1).reset_index(drop=True) #shuffle rows

In [7]:
cols = df.columns
label_cols = list(cols[2:])
num_labels = len(label_cols)
print('Label columns: ', label_cols)

Label columns:  ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']


In [8]:
df['one_hot_labels'] = list(df[label_cols].values)
labels = list(df.one_hot_labels.values)
comments = list(df.comment_text.values)

Load the pretrained tokenizer that corresponds to your choice in model. e.g.,

```
BERT:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 

XLNet:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=False) 

RoBERTa:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=False)
```


In [9]:
max_length = 100
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # tokenizer
encodings = tokenizer.batch_encode_plus(comments,max_length=max_length,pad_to_max_length=True) # tokenizer's encoding method
print('tokenizer outputs: ', encodings.keys())

tokenizer outputs:  dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])


In [10]:
input_ids = encodings['input_ids'] # tokenized and encoded sentences
token_type_ids = encodings['token_type_ids'] # token type ids
attention_masks = encodings['attention_mask'] # attention masks

In [11]:
# Identifying indices of 'one_hot_labels' entries that only occur once - this will allow us to stratify split our training data later
label_counts = df.one_hot_labels.astype(str).value_counts()
one_freq = label_counts[label_counts==1].keys()
one_freq_idxs = sorted(list(df[df.one_hot_labels.astype(str).isin(one_freq)].index), reverse=True)
print('df label indices with only one instance: ', one_freq_idxs)

df label indices with only one instance:  [28818, 18252]


In [12]:
# Gathering single instance inputs to force into the training set after stratified split
one_freq_input_ids = [input_ids.pop(i) for i in one_freq_idxs]
one_freq_token_types = [token_type_ids.pop(i) for i in one_freq_idxs]
one_freq_attention_masks = [attention_masks.pop(i) for i in one_freq_idxs]
one_freq_labels = [labels.pop(i) for i in one_freq_idxs]

In [13]:
# Use train_test_split to split our data into train and validation sets

train_inputs, validation_inputs, train_labels, validation_labels, train_token_types, validation_token_types, train_masks, validation_masks = train_test_split(input_ids, labels, token_type_ids,attention_masks,
                                                            random_state=2020, test_size=0.10, stratify = labels)

# Add one frequency data to train data
train_inputs.extend(one_freq_input_ids)
train_labels.extend(one_freq_labels)
train_masks.extend(one_freq_attention_masks)
train_token_types.extend(one_freq_token_types)

# Convert all of our data into torch tensors, the required datatype for our model
train_inputs = torch.tensor(train_inputs)
train_labels = torch.tensor(train_labels)
train_masks = torch.tensor(train_masks)
train_token_types = torch.tensor(train_token_types)

validation_inputs = torch.tensor(validation_inputs)
validation_labels = torch.tensor(validation_labels)
validation_masks = torch.tensor(validation_masks)
validation_token_types = torch.tensor(validation_token_types)

In [14]:
# Select a batch size for training.
batch_size = 16

# Create an iterator of our data with torch DataLoader. This helps save on memory during training because, unlike a for loop, 
# with an iterator the entire dataset does not need to be loaded into memory

train_data = TensorDataset(train_inputs, train_masks, train_labels, train_token_types)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels, validation_token_types)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

In [15]:
torch.save(validation_dataloader,'models/validation_data_loader')
torch.save(train_dataloader,'models/train_data_loader')

# Load Model & Set Params

Load the appropriate model below, each model already contains a single dense layer for classification on top.



```
BERT:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)

XLNet:
model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased", num_labels=num_labels)

RoBERTa:
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=num_labels)
```



In [16]:
# Load model, the pretrained model will include a single linear classification layer on top for classification. 
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
model.cuda()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

Setting custom optimization parameters for the AdamW optimizer 

https://huggingface.co/transformers/main_classes/optimizer_schedules.html

In [17]:
# setting custom optimization parameters. You may implement a scheduler here as well.
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]

In [18]:
optimizer = AdamW(optimizer_grouped_parameters,lr=2e-5,correct_bias=True)
# optimizer = AdamW(model.parameters(),lr=2e-5)  # Default optimization

# Train Model

In [19]:
# Store our loss and accuracy for plotting
train_loss_set = []

# Number of training epochs (authors recommend between 2 and 4)
epochs = 3

# trange is a tqdm wrapper around the normal python range
for _ in trange(epochs, desc='Epochs'):

    # Training

    # Set our model to training mode (as opposed to evaluation mode)
    model.train()

    # Tracking variables
    tr_loss = 0 #running loss
    nb_tr_examples, nb_tr_steps = 0, 0

    # Train the data for one epoch
    for step, batch in enumerate(train_dataloader):
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)
        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels, b_token_types = batch
        # Clear out the gradients (by default they accumulate)
        optimizer.zero_grad()

        # # Forward pass for multiclass classification
        # outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        # loss = outputs[0]
        # logits = outputs[1]

        # Forward pass for multilabel classification
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
        logits = outputs[0]
        loss_func = BCEWithLogitsLoss() 
        loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels)) #convert labels to float for calculation
        # loss_func = BCELoss() 
        # loss = loss_func(torch.sigmoid(logits.view(-1,num_labels)),b_labels.type_as(logits).view(-1,num_labels)) #convert labels to float for calculation
        train_loss_set.append(loss.item())    

        # Backward pass
        loss.backward()
        # Update parameters and take a step using the computed gradient
        optimizer.step()
        # scheduler.step()
        # Update tracking variables
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1
        if nb_tr_steps%100 == 0:
            print("Batch number: {} loss: {} ".format(nb_tr_steps, tr_loss/nb_tr_steps))

    print("Train loss: {}".format(tr_loss/nb_tr_steps))

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  exp_avg.mul_(beta1).add_(1.0 - beta1, grad)


Batch number: 1
Batch number: 2
Batch number: 3
Batch number: 4
Batch number: 5
Batch number: 6
Batch number: 7
Batch number: 8
Batch number: 9
Batch number: 10
Batch number: 11
Batch number: 12
Batch number: 13
Batch number: 14
Batch number: 15
Batch number: 16
Batch number: 17
Batch number: 18
Batch number: 19
Batch number: 20
Batch number: 21
Batch number: 22
Batch number: 23
Batch number: 24
Batch number: 25
Batch number: 26
Batch number: 27
Batch number: 28
Batch number: 29
Batch number: 30
Batch number: 31
Batch number: 32
Batch number: 33
Batch number: 34
Batch number: 35
Batch number: 36
Batch number: 37
Batch number: 38
Batch number: 39
Batch number: 40
Batch number: 41
Batch number: 42
Batch number: 43
Batch number: 44
Batch number: 45
Batch number: 46
Batch number: 47
Batch number: 48
Batch number: 49
Batch number: 50
Batch number: 51
Batch number: 52
Batch number: 53
Batch number: 54
Batch number: 55
Batch number: 56
Batch number: 57
Batch number: 58
Batch number: 59
Batch 

Batch number: 463
Batch number: 464
Batch number: 465
Batch number: 466
Batch number: 467
Batch number: 468
Batch number: 469
Batch number: 470
Batch number: 471
Batch number: 472
Batch number: 473
Batch number: 474
Batch number: 475
Batch number: 476
Batch number: 477
Batch number: 478
Batch number: 479
Batch number: 480
Batch number: 481
Batch number: 482
Batch number: 483
Batch number: 484
Batch number: 485
Batch number: 486
Batch number: 487
Batch number: 488
Batch number: 489
Batch number: 490
Batch number: 491
Batch number: 492
Batch number: 493
Batch number: 494
Batch number: 495
Batch number: 496
Batch number: 497
Batch number: 498
Batch number: 499
Batch number: 500
Batch number: 501
Batch number: 502
Batch number: 503
Batch number: 504
Batch number: 505
Batch number: 506
Batch number: 507
Batch number: 508
Batch number: 509
Batch number: 510
Batch number: 511
Batch number: 512
Batch number: 513
Batch number: 514
Batch number: 515
Batch number: 516
Batch number: 517
Batch numb

Batch number: 919
Batch number: 920
Batch number: 921
Batch number: 922
Batch number: 923
Batch number: 924
Batch number: 925
Batch number: 926
Batch number: 927
Batch number: 928
Batch number: 929
Batch number: 930
Batch number: 931
Batch number: 932
Batch number: 933
Batch number: 934
Batch number: 935
Batch number: 936
Batch number: 937
Batch number: 938
Batch number: 939
Batch number: 940
Batch number: 941
Batch number: 942
Batch number: 943
Batch number: 944
Batch number: 945
Batch number: 946
Batch number: 947
Batch number: 948
Batch number: 949
Batch number: 950
Batch number: 951
Batch number: 952
Batch number: 953
Batch number: 954
Batch number: 955
Batch number: 956
Batch number: 957
Batch number: 958
Batch number: 959
Batch number: 960
Batch number: 961
Batch number: 962
Batch number: 963
Batch number: 964
Batch number: 965
Batch number: 966
Batch number: 967
Batch number: 968
Batch number: 969
Batch number: 970
Batch number: 971
Batch number: 972
Batch number: 973
Batch numb

Batch number: 1355
Batch number: 1356
Batch number: 1357
Batch number: 1358
Batch number: 1359
Batch number: 1360
Batch number: 1361
Batch number: 1362
Batch number: 1363
Batch number: 1364
Batch number: 1365
Batch number: 1366
Batch number: 1367
Batch number: 1368
Batch number: 1369
Batch number: 1370
Batch number: 1371
Batch number: 1372
Batch number: 1373
Batch number: 1374
Batch number: 1375
Batch number: 1376
Batch number: 1377
Batch number: 1378
Batch number: 1379
Batch number: 1380
Batch number: 1381
Batch number: 1382
Batch number: 1383
Batch number: 1384
Batch number: 1385
Batch number: 1386
Batch number: 1387
Batch number: 1388
Batch number: 1389
Batch number: 1390
Batch number: 1391
Batch number: 1392
Batch number: 1393
Batch number: 1394
Batch number: 1395
Batch number: 1396
Batch number: 1397
Batch number: 1398
Batch number: 1399
Batch number: 1400
Batch number: 1401
Batch number: 1402
Batch number: 1403
Batch number: 1404
Batch number: 1405
Batch number: 1406
Batch number

Batch number: 1787
Batch number: 1788
Batch number: 1789
Batch number: 1790
Batch number: 1791
Batch number: 1792
Batch number: 1793
Batch number: 1794
Batch number: 1795
Batch number: 1796
Batch number: 1797
Batch number: 1798
Batch number: 1799
Batch number: 1800
Batch number: 1801
Batch number: 1802
Batch number: 1803
Batch number: 1804
Batch number: 1805
Batch number: 1806
Batch number: 1807
Batch number: 1808
Batch number: 1809
Batch number: 1810
Batch number: 1811
Batch number: 1812
Batch number: 1813
Batch number: 1814
Batch number: 1815
Batch number: 1816
Batch number: 1817
Batch number: 1818
Batch number: 1819
Batch number: 1820
Batch number: 1821
Batch number: 1822
Batch number: 1823
Batch number: 1824
Batch number: 1825
Batch number: 1826
Batch number: 1827
Batch number: 1828
Batch number: 1829
Batch number: 1830
Batch number: 1831
Batch number: 1832
Batch number: 1833
Batch number: 1834
Batch number: 1835
Batch number: 1836
Batch number: 1837
Batch number: 1838
Batch number

Batch number: 2219
Batch number: 2220
Batch number: 2221
Batch number: 2222
Batch number: 2223
Batch number: 2224
Batch number: 2225
Batch number: 2226
Batch number: 2227
Batch number: 2228
Batch number: 2229
Batch number: 2230
Batch number: 2231
Batch number: 2232
Batch number: 2233
Batch number: 2234
Batch number: 2235
Batch number: 2236
Batch number: 2237
Batch number: 2238
Batch number: 2239
Batch number: 2240
Batch number: 2241
Batch number: 2242
Batch number: 2243
Batch number: 2244
Batch number: 2245
Batch number: 2246
Batch number: 2247
Batch number: 2248
Batch number: 2249
Batch number: 2250
Batch number: 2251
Batch number: 2252
Batch number: 2253
Batch number: 2254
Batch number: 2255
Batch number: 2256
Batch number: 2257
Batch number: 2258
Batch number: 2259
Batch number: 2260
Batch number: 2261
Batch number: 2262
Batch number: 2263
Batch number: 2264
Batch number: 2265
Batch number: 2266
Batch number: 2267
Batch number: 2268
Batch number: 2269
Batch number: 2270
Batch number

Batch number: 2651
Batch number: 2652
Batch number: 2653
Batch number: 2654
Batch number: 2655
Batch number: 2656
Batch number: 2657
Batch number: 2658
Batch number: 2659
Batch number: 2660
Batch number: 2661
Batch number: 2662
Batch number: 2663
Batch number: 2664
Batch number: 2665
Batch number: 2666
Batch number: 2667
Batch number: 2668
Batch number: 2669
Batch number: 2670
Batch number: 2671
Batch number: 2672
Batch number: 2673
Batch number: 2674
Batch number: 2675
Batch number: 2676
Batch number: 2677
Batch number: 2678
Batch number: 2679
Batch number: 2680
Batch number: 2681
Batch number: 2682
Batch number: 2683
Batch number: 2684
Batch number: 2685
Batch number: 2686
Batch number: 2687
Batch number: 2688
Batch number: 2689
Batch number: 2690
Batch number: 2691
Batch number: 2692
Batch number: 2693
Batch number: 2694
Batch number: 2695
Batch number: 2696
Batch number: 2697
Batch number: 2698
Batch number: 2699
Batch number: 2700
Batch number: 2701
Batch number: 2702
Batch number

Batch number: 3083
Batch number: 3084
Batch number: 3085
Batch number: 3086
Batch number: 3087
Batch number: 3088
Batch number: 3089
Batch number: 3090
Batch number: 3091
Batch number: 3092
Batch number: 3093
Batch number: 3094
Batch number: 3095
Batch number: 3096
Batch number: 3097
Batch number: 3098
Batch number: 3099
Batch number: 3100
Batch number: 3101
Batch number: 3102
Batch number: 3103
Batch number: 3104
Batch number: 3105
Batch number: 3106
Batch number: 3107
Batch number: 3108
Batch number: 3109
Batch number: 3110
Batch number: 3111
Batch number: 3112
Batch number: 3113
Batch number: 3114
Batch number: 3115
Batch number: 3116
Batch number: 3117
Batch number: 3118
Batch number: 3119
Batch number: 3120
Batch number: 3121
Batch number: 3122
Batch number: 3123
Batch number: 3124
Batch number: 3125
Batch number: 3126
Batch number: 3127
Batch number: 3128
Batch number: 3129
Batch number: 3130
Batch number: 3131
Batch number: 3132
Batch number: 3133
Batch number: 3134
Batch number

Batch number: 3515
Batch number: 3516
Batch number: 3517
Batch number: 3518
Batch number: 3519
Batch number: 3520
Batch number: 3521
Batch number: 3522
Batch number: 3523
Batch number: 3524
Batch number: 3525
Batch number: 3526
Batch number: 3527
Batch number: 3528
Batch number: 3529
Batch number: 3530
Batch number: 3531
Batch number: 3532
Batch number: 3533
Batch number: 3534
Batch number: 3535
Batch number: 3536
Batch number: 3537
Batch number: 3538
Batch number: 3539
Batch number: 3540
Batch number: 3541
Batch number: 3542
Batch number: 3543
Batch number: 3544
Batch number: 3545
Batch number: 3546
Batch number: 3547
Batch number: 3548
Batch number: 3549
Batch number: 3550
Batch number: 3551
Batch number: 3552
Batch number: 3553
Batch number: 3554
Batch number: 3555
Batch number: 3556
Batch number: 3557
Batch number: 3558
Batch number: 3559
Batch number: 3560
Batch number: 3561
Batch number: 3562
Batch number: 3563
Batch number: 3564
Batch number: 3565
Batch number: 3566
Batch number

Batch number: 3947
Batch number: 3948
Batch number: 3949
Batch number: 3950
Batch number: 3951
Batch number: 3952
Batch number: 3953
Batch number: 3954
Batch number: 3955
Batch number: 3956
Batch number: 3957
Batch number: 3958
Batch number: 3959
Batch number: 3960
Batch number: 3961
Batch number: 3962
Batch number: 3963
Batch number: 3964
Batch number: 3965
Batch number: 3966
Batch number: 3967
Batch number: 3968
Batch number: 3969
Batch number: 3970
Batch number: 3971
Batch number: 3972
Batch number: 3973
Batch number: 3974
Batch number: 3975
Batch number: 3976
Batch number: 3977
Batch number: 3978
Batch number: 3979
Batch number: 3980
Batch number: 3981
Batch number: 3982
Batch number: 3983
Batch number: 3984
Batch number: 3985
Batch number: 3986
Batch number: 3987
Batch number: 3988
Batch number: 3989
Batch number: 3990
Batch number: 3991
Batch number: 3992
Batch number: 3993
Batch number: 3994
Batch number: 3995
Batch number: 3996
Batch number: 3997
Batch number: 3998
Batch number

Batch number: 4379
Batch number: 4380
Batch number: 4381
Batch number: 4382
Batch number: 4383
Batch number: 4384
Batch number: 4385
Batch number: 4386
Batch number: 4387
Batch number: 4388
Batch number: 4389
Batch number: 4390
Batch number: 4391
Batch number: 4392
Batch number: 4393
Batch number: 4394
Batch number: 4395
Batch number: 4396
Batch number: 4397
Batch number: 4398
Batch number: 4399
Batch number: 4400
Batch number: 4401
Batch number: 4402
Batch number: 4403
Batch number: 4404
Batch number: 4405
Batch number: 4406
Batch number: 4407
Batch number: 4408
Batch number: 4409
Batch number: 4410
Batch number: 4411
Batch number: 4412
Batch number: 4413
Batch number: 4414
Batch number: 4415
Batch number: 4416
Batch number: 4417
Batch number: 4418
Batch number: 4419
Batch number: 4420
Batch number: 4421
Batch number: 4422
Batch number: 4423
Batch number: 4424
Batch number: 4425
Batch number: 4426
Batch number: 4427
Batch number: 4428
Batch number: 4429
Batch number: 4430
Batch number

Batch number: 4811
Batch number: 4812
Batch number: 4813
Batch number: 4814
Batch number: 4815
Batch number: 4816
Batch number: 4817
Batch number: 4818
Batch number: 4819
Batch number: 4820
Batch number: 4821
Batch number: 4822
Batch number: 4823
Batch number: 4824
Batch number: 4825
Batch number: 4826
Batch number: 4827
Batch number: 4828
Batch number: 4829
Batch number: 4830
Batch number: 4831
Batch number: 4832
Batch number: 4833
Batch number: 4834
Batch number: 4835
Batch number: 4836
Batch number: 4837
Batch number: 4838
Batch number: 4839
Batch number: 4840
Batch number: 4841
Batch number: 4842
Batch number: 4843
Batch number: 4844
Batch number: 4845
Batch number: 4846
Batch number: 4847
Batch number: 4848
Batch number: 4849
Batch number: 4850
Batch number: 4851
Batch number: 4852
Batch number: 4853
Batch number: 4854
Batch number: 4855
Batch number: 4856
Batch number: 4857
Batch number: 4858
Batch number: 4859
Batch number: 4860
Batch number: 4861
Batch number: 4862
Batch number

Batch number: 5243
Batch number: 5244
Batch number: 5245
Batch number: 5246
Batch number: 5247
Batch number: 5248
Batch number: 5249
Batch number: 5250
Batch number: 5251
Batch number: 5252
Batch number: 5253
Batch number: 5254
Batch number: 5255
Batch number: 5256
Batch number: 5257
Batch number: 5258
Batch number: 5259
Batch number: 5260
Batch number: 5261
Batch number: 5262
Batch number: 5263
Batch number: 5264
Batch number: 5265
Batch number: 5266
Batch number: 5267
Batch number: 5268
Batch number: 5269
Batch number: 5270
Batch number: 5271
Batch number: 5272
Batch number: 5273
Batch number: 5274
Batch number: 5275
Batch number: 5276
Batch number: 5277
Batch number: 5278
Batch number: 5279
Batch number: 5280
Batch number: 5281
Batch number: 5282
Batch number: 5283
Batch number: 5284
Batch number: 5285
Batch number: 5286
Batch number: 5287
Batch number: 5288
Batch number: 5289
Batch number: 5290
Batch number: 5291
Batch number: 5292
Batch number: 5293
Batch number: 5294
Batch number

Batch number: 5675
Batch number: 5676
Batch number: 5677
Batch number: 5678
Batch number: 5679
Batch number: 5680
Batch number: 5681
Batch number: 5682
Batch number: 5683
Batch number: 5684
Batch number: 5685
Batch number: 5686
Batch number: 5687
Batch number: 5688
Batch number: 5689
Batch number: 5690
Batch number: 5691
Batch number: 5692
Batch number: 5693
Batch number: 5694
Batch number: 5695
Batch number: 5696
Batch number: 5697
Batch number: 5698
Batch number: 5699
Batch number: 5700
Batch number: 5701
Batch number: 5702
Batch number: 5703
Batch number: 5704
Batch number: 5705
Batch number: 5706
Batch number: 5707
Batch number: 5708
Batch number: 5709
Batch number: 5710
Batch number: 5711
Batch number: 5712
Batch number: 5713
Batch number: 5714
Batch number: 5715
Batch number: 5716
Batch number: 5717
Batch number: 5718
Batch number: 5719
Batch number: 5720
Batch number: 5721
Batch number: 5722
Batch number: 5723
Batch number: 5724
Batch number: 5725
Batch number: 5726
Batch number

Epochs:   0%|          | 0/3 [52:02<?, ?it/s]


KeyboardInterrupt: 

In [None]:
###############################################################################

  # Validation

  # Put model in evaluation mode to evaluate loss on the validation set
  model.eval()

  # Variables to gather full output
  logit_preds,true_labels,pred_labels,tokenized_texts = [],[],[],[]

  # Predict
  for i, batch in enumerate(validation_dataloader):
    batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    b_input_ids, b_input_mask, b_labels, b_token_types = batch
    with torch.no_grad():
      # Forward pass
      outs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
      b_logit_pred = outs[0]
      pred_label = torch.sigmoid(b_logit_pred)

      b_logit_pred = b_logit_pred.detach().cpu().numpy()
      pred_label = pred_label.to('cpu').numpy()
      b_labels = b_labels.to('cpu').numpy()

    tokenized_texts.append(b_input_ids)
    logit_preds.append(b_logit_pred)
    true_labels.append(b_labels)
    pred_labels.append(pred_label)

  # Flatten outputs
  pred_labels = [item for sublist in pred_labels for item in sublist]
  true_labels = [item for sublist in true_labels for item in sublist]

  # Calculate Accuracy
  threshold = 0.50
  pred_bools = [pl>threshold for pl in pred_labels]
  true_bools = [tl==1 for tl in true_labels]
  val_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')*100
  val_flat_accuracy = accuracy_score(true_bools, pred_bools)*100

  print('F1 Validation Accuracy: ', val_f1_accuracy)
  print('Flat Validation Accuracy: ', val_flat_accuracy)

In [None]:
torch.save(model.state_dict(), 'bert_model_toxic')

# Load and Preprocess Test Data

In [14]:
test_df = pd.read_csv('data/toxic_comments/test.csv')
test_labels_df = pd.read_csv('data/toxic_comments/test_labels.csv')
test_df = test_df.merge(test_labels_df, on='id', how='left')
test_label_cols = list(test_df.columns[2:])
print('Null values: ', test_df.isnull().values.any()) #should not be any null sentences or labels
print('Same columns between train and test: ', label_cols == test_label_cols) #columns should be the same
test_df.head()

Null values:  False
Same columns between train and test:  True


Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,00001cee341fdb12,Yo bitch Ja Rule is more succesful then you'll...,-1,-1,-1,-1,-1,-1
1,0000247867823ef7,== From RfC == \n\n The title is fine as it is...,-1,-1,-1,-1,-1,-1
2,00013b17ad220c46,""" \n\n == Sources == \n\n * Zawe Ashton on Lap...",-1,-1,-1,-1,-1,-1
3,00017563c3f7919a,":If you have a look back at the source, the in...",-1,-1,-1,-1,-1,-1
4,00017695ad8997eb,I don't anonymously edit articles at all.,-1,-1,-1,-1,-1,-1


In [None]:
test_df = test_df[~test_df[test_label_cols].eq(-1).any(axis=1)] #remove irrelevant rows/comments with -1 values
test_df['one_hot_labels'] = list(test_df[test_label_cols].values)
test_df.head()

In [None]:
# Gathering input data
test_labels = list(test_df.one_hot_labels.values)
test_comments = list(test_df.comment_text.values)

In [None]:
# Encoding input data
test_encodings = tokenizer.batch_encode_plus(test_comments,max_length=max_length,pad_to_max_length=True)
test_input_ids = test_encodings['input_ids']
test_token_type_ids = test_encodings['token_type_ids']
test_attention_masks = test_encodings['attention_mask']

In [None]:
# Make tensors out of data
test_inputs = torch.tensor(test_input_ids)
test_labels = torch.tensor(test_labels)
test_masks = torch.tensor(test_attention_masks)
test_token_types = torch.tensor(test_token_type_ids)
# Create test dataloader
test_data = TensorDataset(test_inputs, test_masks, test_labels, test_token_types)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
# Save test dataloader
torch.save(test_dataloader,'test_data_loader')

# Prediction and Metrics

In [None]:
# Test

# Put model in evaluation mode to evaluate loss on the validation set
model.eval()

#track variables
logit_preds,true_labels,pred_labels,tokenized_texts = [],[],[],[]

# Predict
for i, batch in enumerate(test_dataloader):
  batch = tuple(t.to(device) for t in batch)
  # Unpack the inputs from our dataloader
  b_input_ids, b_input_mask, b_labels, b_token_types = batch
  with torch.no_grad():
    # Forward pass
    outs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
    b_logit_pred = outs[0]
    pred_label = torch.sigmoid(b_logit_pred)

    b_logit_pred = b_logit_pred.detach().cpu().numpy()
    pred_label = pred_label.to('cpu').numpy()
    b_labels = b_labels.to('cpu').numpy()

  tokenized_texts.append(b_input_ids)
  logit_preds.append(b_logit_pred)
  true_labels.append(b_labels)
  pred_labels.append(pred_label)

# Flatten outputs
tokenized_texts = [item for sublist in tokenized_texts for item in sublist]
pred_labels = [item for sublist in pred_labels for item in sublist]
true_labels = [item for sublist in true_labels for item in sublist]
# Converting flattened binary values to boolean values
true_bools = [tl==1 for tl in true_labels]

In [None]:
pred_bools = [pl>0.60 for pl in pred_labels] #boolean output after thresholding

# Print and save classification report
print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools,average='micro'))
print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools),'\n')
clf_report = classification_report(true_bools,pred_bools,target_names=test_label_cols)
pickle.dump(clf_report, open('classification_report.txt','wb')) #save report
print(clf_report)

# Output dataframe

In [None]:
idx2label = dict(zip(range(6),label_cols))
print(idx2label)

In [None]:
# Getting indices of where boolean one hot vector true_bools is True so we can use idx2label to gather label names
true_label_idxs, pred_label_idxs=[],[]
for vals in true_bools:
  true_label_idxs.append(np.where(vals)[0].flatten().tolist())
for vals in pred_bools:
  pred_label_idxs.append(np.where(vals)[0].flatten().tolist())

In [None]:
# Gathering vectors of label names using idx2label
true_label_texts, pred_label_texts = [], []
for vals in true_label_idxs:
  if vals:
    true_label_texts.append([idx2label[val] for val in vals])
  else:
    true_label_texts.append(vals)

for vals in pred_label_idxs:
  if vals:
    pred_label_texts.append([idx2label[val] for val in vals])
  else:
    pred_label_texts.append(vals)

In [None]:
# Decoding input ids to comment text
comment_texts = [tokenizer.decode(text,skip_special_tokens=True,clean_up_tokenization_spaces=False) for text in tokenized_texts]

In [None]:
# Converting lists to df
comparisons_df = pd.DataFrame({'comment_text': comment_texts, 'true_labels': true_label_texts, 'pred_labels':pred_label_texts})
comparisons_df.to_csv('comparisons.csv')
comparisons_df.head()

# Bonus - Optimizing threshold value for micro F1 score

In [None]:
# Calculate Accuracy - maximize F1 accuracy by tuning threshold values. First with 'macro_thresholds' on the order of e^-1 then with 'micro_thresholds' on the order of e^-2

macro_thresholds = np.array(range(1,10))/10

f1_results, flat_acc_results = [], []
for th in macro_thresholds:
  pred_bools = [pl>th for pl in pred_labels]
  test_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')
  test_flat_accuracy = accuracy_score(true_bools, pred_bools)
  f1_results.append(test_f1_accuracy)
  flat_acc_results.append(test_flat_accuracy)

best_macro_th = macro_thresholds[np.argmax(f1_results)] #best macro threshold value

micro_thresholds = (np.array(range(10))/100)+best_macro_th #calculating micro threshold values

f1_results, flat_acc_results = [], []
for th in micro_thresholds:
  pred_bools = [pl>th for pl in pred_labels]
  test_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')
  test_flat_accuracy = accuracy_score(true_bools, pred_bools)
  f1_results.append(test_f1_accuracy)
  flat_acc_results.append(test_flat_accuracy)

best_f1_idx = np.argmax(f1_results) #best threshold value

# Printing and saving classification report
print('Best Threshold: ', micro_thresholds[best_f1_idx])
print('Test F1 Accuracy: ', f1_results[best_f1_idx])
print('Test Flat Accuracy: ', flat_acc_results[best_f1_idx], '\n')

best_pred_bools = [pl>micro_thresholds[best_f1_idx] for pl in pred_labels]
clf_report_optimized = classification_report(true_bools,best_pred_bools, target_names=label_cols)
pickle.dump(clf_report_optimized, open('classification_report_optimized.txt','wb'))
print(clf_report_optimized)