In [1]:
import torch
import torch.nn as nn
import pandas
import re, string
import emoji
from collections import Counter, defaultdict
import torchtext

from sklearn.metrics import f1_score, accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset, SequentialSampler, random_split)
import tqdm
import numpy as np

from torch import cuda
RANDOM_STATE = 42
device = 'cuda' if cuda.is_available() else 'cpu'
device

2022-11-29 00:26:54.818184: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


'cuda'

In [2]:
data = pandas.read_csv('cyberbullying_tweets.csv')

In [3]:
data

Unnamed: 0,tweet_text,cyberbullying_type
0,"In other words #katandandre, your food was cra...",not_cyberbullying
1,Why is #aussietv so white? #MKR #theblock #ImA...,not_cyberbullying
2,@XochitlSuckkks a classy whore? Or more red ve...,not_cyberbullying
3,"@Jason_Gio meh. :P thanks for the heads up, b...",not_cyberbullying
4,@RudhoeEnglish This is an ISIS account pretend...,not_cyberbullying
...,...,...
47687,"Black ppl aren't expected to do anything, depe...",ethnicity
47688,Turner did not withhold his disappointment. Tu...,ethnicity
47689,I swear to God. This dumb nigger bitch. I have...,ethnicity
47690,Yea fuck you RT @therealexel: IF YOURE A NIGGE...,ethnicity


In [4]:
# data.drop(data[data['cyberbullying_type'] == 'other_cyberbullying'].index, inplace=True)

In [5]:
labels = list(data.cyberbullying_type.unique())
labels2id = {label: i for i, label in enumerate(labels)}
print(labels2id)

{'not_cyberbullying': 0, 'gender': 1, 'religion': 2, 'other_cyberbullying': 3, 'age': 4, 'ethnicity': 5}


In [6]:
# map labels to ints
data['label'] = data.cyberbullying_type.map(labels2id)

In [7]:
data

Unnamed: 0,tweet_text,cyberbullying_type,label
0,"In other words #katandandre, your food was cra...",not_cyberbullying,0
1,Why is #aussietv so white? #MKR #theblock #ImA...,not_cyberbullying,0
2,@XochitlSuckkks a classy whore? Or more red ve...,not_cyberbullying,0
3,"@Jason_Gio meh. :P thanks for the heads up, b...",not_cyberbullying,0
4,@RudhoeEnglish This is an ISIS account pretend...,not_cyberbullying,0
...,...,...,...
47687,"Black ppl aren't expected to do anything, depe...",ethnicity,5
47688,Turner did not withhold his disappointment. Tu...,ethnicity,5
47689,I swear to God. This dumb nigger bitch. I have...,ethnicity,5
47690,Yea fuck you RT @therealexel: IF YOURE A NIGGE...,ethnicity,5


In [8]:
#Clean emojis from text
def strip_emoji(text):
    return emoji.replace_emoji(text)
    # return re.sub(emoji.get_emoji_regexp(), r"", text) #remove emoji
    
def strip_all_entities(text): 
    text = text.replace('\r', '').replace('\n', ' ').lower() #remove \n and \r and lowercase
    text = re.sub(r"(?:\@|https?\://)\S+", "", text) #remove links and mentions
    text = re.sub(r'[^\x00-\x7f]',r'', text) #remove non utf8/ascii characters such as '\x9a\x91\x97\x9a\x97'
    banned_list= string.punctuation
    table = str.maketrans('', '', banned_list)
    text = text.translate(table)
    text = [word for word in text.split()]
    text = ' '.join(text)
    return text

#remove contractions
def decontract(text):
    text = re.sub(r"can\'t", "can not", text)
    text = re.sub(r"won\'t", "will not", text)
    text = re.sub(r"n\'t", " not", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"\'s", " is", text)
    text = re.sub(r"\'d", " would", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'t", " not", text)
    text = re.sub(r"\'ve", " have", text)
    text = re.sub(r"\'m", " am", text)
    return text

def clean_hashtags(tweet):
    # new_tweet = " ".join(word.strip() for word in re.split('#(?!(?:hashtag)\b)[\w-]+(?=(?:\s+#[\w-]+)*\s*$)', tweet)) #remove last hashtags
    new_tweet2 = " ".join(word.strip() for word in re.split('#|_', tweet)) #remove hashtags symbol from words in the middle of the sentence
    return new_tweet2

#Filter special characters such as "&" and "$" present in some words
def filter_chars(a):
    sent = []
    for word in a.split(' '):
        if ('$' in word) | ('&' in word):
            sent.append('')
        else:
            sent.append(word)
    return ' '.join(sent)

#Remove multiple sequential spaces
def remove_mult_spaces(text):
    return re.sub("\s\s+" , " ", text)


#Then we apply all the defined functions in the following order
def deep_clean(text):
    text = strip_emoji(text)
    text = decontract(text)
    text = strip_all_entities(text)
    text = clean_hashtags(text)
    text = filter_chars(text)
    text = remove_mult_spaces(text)
    return text

In [9]:
PADDING_VALUE = 256
data['cleaned_text'] =  [deep_clean(x).split()[:256] for x in data.tweet_text]

In [10]:
MIN_FREQ = 5
word_counts = Counter([item for sublist in data['cleaned_text'] for item in sublist])
vocab = torchtext.vocab.vocab(word_counts,
                              min_freq = MIN_FREQ,
                              specials = ['<pad>', '<unk>', ])
vocab.set_default_index(1)

VOCAB_SIZE = len(vocab)

In [11]:
try: 
    with open('trainvaltest_idx.npy', 'rb') as f:
        train_idx = np.load(f)
        val_idx = np.load(f)
        test_idx = np.load(f)
except:
    # stratified 80/10/10 train val test split
    TRAIN_SIZE=0.8
    indicies = np.arange(0,len(labels),1)

    train_idx, valtest_idx = train_test_split(indicies, train_size=TRAIN_SIZE, random_state=RANDOM_STATE, shuffle=True, stratify=labels)
    val_idx, test_idx = train_test_split(valtest_idx, train_size=0.5, random_state=RANDOM_STATE, shuffle=True, stratify=labels[valtest_idx])
    with open('trainvaltest_idx.npy', 'wb') as f:
        np.save(f, train_idx)
        np.save(f, val_idx)
        np.save(f, test_idx)

# y train-val-test
labels = torch.tensor(data.label, dtype = torch.int64)
ytrain = labels[train_idx]
yval = labels[val_idx]
ytest = labels[test_idx]

# remove other cyberbullying
train_mask = ytrain != labels2id['other_cyberbullying']
val_mask = yval != labels2id['other_cyberbullying']
test_mask = ytest != labels2id['other_cyberbullying']
ytrain = ytrain[train_mask]
yval = yval[val_mask]
ytest = ytest[test_mask]

tokenized_data = [torch.tensor(vocab(x), dtype = torch.int64) for x in data.cleaned_text]

# x train-val-test for cnn
# pad tokenized_data with 0s to same length as the longest sequence
tokenized_data = nn.utils.rnn.pad_sequence(tokenized_data, batch_first=True)

xtrain = tokenized_data[train_idx][train_mask]
xval = tokenized_data[val_idx][val_mask]
xtest = tokenized_data[test_idx][test_mask]

train = torch.utils.data.TensorDataset(xtrain, ytrain)
val = torch.utils.data.TensorDataset(xval, yval)
test = torch.utils.data.TensorDataset(xtest, ytest)
print(len(train), len(val), len(test))

# x train-val-test for bag of words (frequency of token)
bow_tokenized_data = torch.zeros(tokenized_data.shape[0], VOCAB_SIZE)
for i in range(tokenized_data.shape[0]):
  tweet = tokenized_data[i,tokenized_data[i].nonzero().flatten()]
  bow_tokenized_data[i].put_(tweet, torch.ones_like(tweet).float(), accumulate=True)

bow_tokenized_data = nn.utils.rnn.pad_sequence(bow_tokenized_data, batch_first=True)

bow_xtrain = bow_tokenized_data[train_idx][train_mask]
bow_xval = bow_tokenized_data[val_idx][val_mask]
bow_xtest = bow_tokenized_data[test_idx][test_mask]

bow_train = torch.utils.data.TensorDataset(bow_xtrain, ytrain)
bow_val = torch.utils.data.TensorDataset(bow_xval, yval)
bow_test = torch.utils.data.TensorDataset(bow_xtest, ytest)
print(len(bow_train), len(bow_val), len(bow_test))

31895 3986 3988
31895 3986 3988


In [12]:
BATCH_SIZE = 64
train_loader = torch.utils.data.DataLoader(train, BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val, BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test, BATCH_SIZE)

bow_train_loader = torch.utils.data.DataLoader(bow_train, BATCH_SIZE)
bow_val_loader = torch.utils.data.DataLoader(bow_val, BATCH_SIZE)
bow_test_loader = torch.utils.data.DataLoader(bow_test, BATCH_SIZE)

In [13]:
# augment data


# BOW

In [22]:
class MLP(nn.Module):
  def __init__(self, hidden_dim, dropout, device, outputs):
    super(MLP, self).__init__()

    self.hidden_dim = hidden_dim
    self.dropout = dropout
    self.device = device
    self.outputs = outputs
    
    self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
    self.a1 = nn.ReLU().to(self.device)
    self.d1 = nn.Dropout().to(self.device)
    self.l2 = nn.LazyLinear(self.outputs).to(self.device)

    self.loss = nn.CrossEntropyLoss()
  
  def forward(self, x, y):
    x = self.l1(x).to(self.device)
    x = self.a1(x).to(self.device)
    x = self.d1(x).to(self.device)
    x = self.l2(x).to(self.device)

    loss = self.loss(x, y).to(device)

    return x, loss

  def get_predictions(self, loader):
    with torch.no_grad():
      preds_all = []
      labels_all = []
      accuracies = []

      for tweets, labels in loader:
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        self.eval()

        logits, loss = self.forward(tweets, labels)

        _, predictions = torch.max(logits, axis=1)

        preds_all += list(predictions.cpu())
        labels_all += list(labels.cpu())


        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

    return preds_all, labels_all
  
  def train_model(self, train_loader, val_loader, test_loader, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        tweets, labels = data
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()

        logits, loss = self.forward(tweets, labels)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 499 == 0:
          pred, label = self.get_predictions(val_loader)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)
  


In [23]:
hidden_dim = 128
dropout = 0.
learning_rate = 1e-04
epochs = 15
outputs = len(labels2id.items())

mlp = MLP(hidden_dim, dropout, device, outputs)
optimizer = torch.optim.Adam(mlp.parameters(), lr=learning_rate)

for i in range(epochs):
  mlp.train_model(bow_train_loader, bow_val_loader, bow_test_loader, optimizer)



  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.8426994480682388
0.8426994480682388
              precision    recall  f1-score   support

           0     0.4887    0.7984    0.6063       486
           1     0.8130    0.8416    0.8271       770
           2     0.9700    0.7992    0.8763       971
           4     0.9762    0.8351    0.9002       934
           5     0.9636    0.9297    0.9463       825

    accuracy                         0.8427      3986
   macro avg     0.8423    0.8408    0.8312      3986
weighted avg     0.8811    0.8427    0.8540      3986

########## TEST ###########
0.8400200601805417
0.8400200601805417
              precision    recall  f1-score   support

           0     0.4755    0.8363    0.6063       452
           1     0.8195    0.8185    0.8190       799
           2     0.9738    0.7837    0.8685       994
           4     0.9787    0.8454    0.9072       925
           5     0.9510    0.9254    0.9380       818

    accuracy                         0.8400      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9004014049172102
0.9004014049172102
              precision    recall  f1-score   support

           0     0.7997    0.7678    0.7835       827
           1     0.7992    0.9381    0.8631       679
           2     0.9525    0.9159    0.9338       832
           4     0.9775    0.9221    0.9490       847
           5     0.9724    0.9663    0.9693       801

    accuracy                         0.9004      3986
   macro avg     0.9003    0.9020    0.8997      3986
weighted avg     0.9040    0.9004    0.9009      3986

########## TEST ###########
0.9077231695085256
0.9077231695085256
              precision    recall  f1-score   support

           0     0.8239    0.7873    0.8052       832
           1     0.8008    0.9207    0.8566       694
           2     0.9600    0.9264    0.9429       829
           4     0.9825    0.9435    0.9626       832
           5     0.9711    0.9650    0.9681       801

    accuracy                         0.9077      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9162067235323633
0.9162067235323633
              precision    recall  f1-score   support

           0     0.8841    0.7723    0.8244       909
           1     0.8043    0.9553    0.8733       671
           2     0.9450    0.9450    0.9450       800
           4     0.9762    0.9559    0.9659       816
           5     0.9711    0.9785    0.9748       790

    accuracy                         0.9162      3986
   macro avg     0.9161    0.9214    0.9167      3986
weighted avg     0.9190    0.9162    0.9156      3986

########## TEST ###########
0.9157472417251755
0.9157472417251755
              precision    recall  f1-score   support

           0     0.8818    0.7746    0.8247       905
           1     0.7920    0.9322    0.8564       678
           2     0.9500    0.9536    0.9518       797
           4     0.9800    0.9619    0.9709       814
           5     0.9749    0.9773    0.9761       794

    accuracy                         0.9157      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9199698946312093
0.9199698946312093
              precision    recall  f1-score   support

           0     0.8992    0.7702    0.8298       927
           1     0.8118    0.9585    0.8791       675
           2     0.9400    0.9567    0.9483       786
           4     0.9750    0.9665    0.9707       806
           5     0.9736    0.9785    0.9761       792

    accuracy                         0.9200      3986
   macro avg     0.9199    0.9261    0.9208      3986
weighted avg     0.9226    0.9200    0.9191      3986

########## TEST ###########
0.9197592778335005
0.9197592778335005
              precision    recall  f1-score   support

           0     0.8969    0.7733    0.8305       922
           1     0.8008    0.9383    0.8641       681
           2     0.9437    0.9618    0.9527       785
           4     0.9812    0.9667    0.9739       811
           5     0.9761    0.9848    0.9804       789

    accuracy                         0.9198      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.924234821876568
0.9242348218765679
              precision    recall  f1-score   support

           0     0.9131    0.7746    0.8382       936
           1     0.8231    0.9591    0.8859       684
           2     0.9375    0.9628    0.9500       779
           4     0.9750    0.9738    0.9744       800
           5     0.9724    0.9835    0.9779       787

    accuracy                         0.9242      3986
   macro avg     0.9242    0.9307    0.9253      3986
weighted avg     0.9265    0.9242    0.9231      3986

########## TEST ###########
0.9220160481444333
0.9220160481444333
              precision    recall  f1-score   support

           0     0.9107    0.7719    0.8355       938
           1     0.8045    0.9400    0.8670       683
           2     0.9400    0.9703    0.9549       775
           4     0.9787    0.9726    0.9757       804
           5     0.9761    0.9860    0.9811       788

    accuracy                         0.9220      3988
 

  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9264927245358756
0.9264927245358756
              precision    recall  f1-score   support

           0     0.9093    0.7839    0.8420       921
           1     0.8294    0.9524    0.8867       694
           2     0.9463    0.9582    0.9522       790
           4     0.9725    0.9786    0.9755       794
           5     0.9749    0.9860    0.9804       787

    accuracy                         0.9265      3986
   macro avg     0.9265    0.9318    0.9274      3986
weighted avg     0.9282    0.9265    0.9255      3986

########## TEST ###########
0.9252758274824473
0.9252758274824473
              precision    recall  f1-score   support

           0     0.9157    0.7803    0.8426       933
           1     0.8120    0.9391    0.8710       690
           2     0.9425    0.9704    0.9562       777
           4     0.9800    0.9788    0.9794       800
           5     0.9761    0.9860    0.9811       788

    accuracy                         0.9253      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9272453587556447
0.9272453587556447
              precision    recall  f1-score   support

           0     0.9144    0.7815    0.8427       929
           1     0.8306    0.9539    0.8880       694
           2     0.9437    0.9606    0.9521       786
           4     0.9725    0.9811    0.9767       792
           5     0.9749    0.9885    0.9817       785

    accuracy                         0.9272      3986
   macro avg     0.9272    0.9331    0.9282      3986
weighted avg     0.9290    0.9272    0.9262      3986

########## TEST ###########
0.9277833500501504
0.9277833500501504
              precision    recall  f1-score   support

           0     0.9170    0.7839    0.8452       930
           1     0.8258    0.9374    0.8781       703
           2     0.9400    0.9741    0.9567       772
           4     0.9800    0.9812    0.9806       798
           5     0.9761    0.9898    0.9829       785

    accuracy                         0.9278      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9300050175614651
0.9300050175614651
              precision    recall  f1-score   support

           0     0.9194    0.7875    0.8483       927
           1     0.8381    0.9570    0.8936       698
           2     0.9450    0.9618    0.9533       786
           4     0.9725    0.9823    0.9774       791
           5     0.9749    0.9898    0.9823       784

    accuracy                         0.9300      3986
   macro avg     0.9300    0.9357    0.9310      3986
weighted avg     0.9317    0.9300    0.9289      3986

########## TEST ###########
0.9267803410230692
0.9267803410230692
              precision    recall  f1-score   support

           0     0.9195    0.7785    0.8431       939
           1     0.8246    0.9400    0.8785       700
           2     0.9387    0.9741    0.9561       771
           4     0.9762    0.9824    0.9793       794
           5     0.9749    0.9898    0.9823       784

    accuracy                         0.9268      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9290015052684395
0.9290015052684395
              precision    recall  f1-score   support

           0     0.9131    0.7898    0.8470       918
           1     0.8419    0.9518    0.8935       705
           2     0.9437    0.9569    0.9503       789
           4     0.9725    0.9835    0.9780       790
           5     0.9736    0.9885    0.9810       784

    accuracy                         0.9290      3986
   macro avg     0.9290    0.9341    0.9299      3986
weighted avg     0.9302    0.9290    0.9280      3986

########## TEST ###########
0.9292878635907723
0.9292878635907723
              precision    recall  f1-score   support

           0     0.9157    0.7896    0.8480       922
           1     0.8371    0.9356    0.8836       714
           2     0.9425    0.9729    0.9575       775
           4     0.9775    0.9836    0.9805       794
           5     0.9736    0.9898    0.9816       783

    accuracy                         0.9293      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9297541394882087
0.9297541394882087
              precision    recall  f1-score   support

           0     0.9106    0.7945    0.8486       910
           1     0.8457    0.9480    0.8939       711
           2     0.9463    0.9570    0.9516       791
           4     0.9712    0.9835    0.9773       789
           5     0.9749    0.9885    0.9817       785

    accuracy                         0.9298      3986
   macro avg     0.9297    0.9343    0.9306      3986
weighted avg     0.9307    0.9298    0.9288      3986

########## TEST ###########
0.9292878635907723
0.9292878635907723
              precision    recall  f1-score   support

           0     0.9182    0.7892    0.8488       925
           1     0.8358    0.9368    0.8834       712
           2     0.9413    0.9729    0.9568       774
           4     0.9775    0.9836    0.9805       794
           5     0.9736    0.9898    0.9816       783

    accuracy                         0.9293      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9310085298544907
0.9310085298544907
              precision    recall  f1-score   support

           0     0.9144    0.7961    0.8511       912
           1     0.8482    0.9428    0.8930       717
           2     0.9475    0.9619    0.9547       788
           4     0.9712    0.9848    0.9779       788
           5     0.9736    0.9923    0.9829       781

    accuracy                         0.9310      3986
   macro avg     0.9310    0.9356    0.9319      3986
weighted avg     0.9319    0.9310    0.9300      3986

########## TEST ###########
0.9292878635907723
0.9292878635907723
              precision    recall  f1-score   support

           0     0.9157    0.7904    0.8485       921
           1     0.8383    0.9344    0.8838       716
           2     0.9413    0.9716    0.9562       775
           4     0.9775    0.9836    0.9805       794
           5     0.9736    0.9910    0.9823       782

    accuracy                         0.9293      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9310085298544907
0.9310085298544907
              precision    recall  f1-score   support

           0     0.9081    0.8011    0.8512       900
           1     0.8494    0.9416    0.8931       719
           2     0.9525    0.9585    0.9555       795
           4     0.9712    0.9823    0.9767       790
           5     0.9736    0.9910    0.9823       782

    accuracy                         0.9310      3986
   macro avg     0.9310    0.9349    0.9318      3986
weighted avg     0.9317    0.9310    0.9302      3986

########## TEST ###########
0.9292878635907723
0.9292878635907723
              precision    recall  f1-score   support

           0     0.9082    0.7969    0.8489       906
           1     0.8421    0.9282    0.8830       724
           2     0.9425    0.9692    0.9556       778
           4     0.9787    0.9824    0.9806       796
           5     0.9749    0.9898    0.9823       784

    accuracy                         0.9293      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9302558956347216
0.9302558956347216
              precision    recall  f1-score   support

           0     0.9106    0.7989    0.8511       905
           1     0.8469    0.9427    0.8923       716
           2     0.9500    0.9596    0.9548       792
           4     0.9700    0.9823    0.9761       789
           5     0.9736    0.9885    0.9810       784

    accuracy                         0.9303      3986
   macro avg     0.9302    0.9344    0.9310      3986
weighted avg     0.9311    0.9303    0.9294      3986

########## TEST ###########
0.9310431293881645
0.9310431293881645
              precision    recall  f1-score   support

           0     0.9170    0.7967    0.8526       915
           1     0.8409    0.9398    0.8876       714
           2     0.9400    0.9691    0.9543       776
           4     0.9787    0.9824    0.9806       796
           5     0.9786    0.9898    0.9842       787

    accuracy                         0.9310      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.930506773707978
0.930506773707978
              precision    recall  f1-score   support

           0     0.9043    0.8040    0.8512       893
           1     0.8507    0.9378    0.8921       723
           2     0.9525    0.9561    0.9543       797
           4     0.9700    0.9835    0.9767       788
           5     0.9749    0.9885    0.9817       785

    accuracy                         0.9305      3986
   macro avg     0.9305    0.9340    0.9312      3986
weighted avg     0.9311    0.9305    0.9297      3986

########## TEST ###########
0.9295386158475426
0.9295386158475426
              precision    recall  f1-score   support

           0     0.9057    0.8000    0.8496       900
           1     0.8459    0.9310    0.8864       725
           2     0.9400    0.9666    0.9531       778
           4     0.9775    0.9799    0.9787       797
           5     0.9786    0.9886    0.9836       788

    accuracy                         0.9295      3988
  

  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9310085298544907
0.9310085298544907
              precision    recall  f1-score   support

           0     0.9043    0.8049    0.8517       892
           1     0.8570    0.9343    0.8940       731
           2     0.9487    0.9595    0.9541       791
           4     0.9700    0.9835    0.9767       788
           5     0.9749    0.9898    0.9823       784

    accuracy                         0.9310      3986
   macro avg     0.9310    0.9344    0.9318      3986
weighted avg     0.9313    0.9310    0.9302      3986

########## TEST ###########
0.9292878635907723
0.9292878635907723
              precision    recall  f1-score   support

           0     0.9019    0.8011    0.8485       895
           1     0.8484    0.9261    0.8855       731
           2     0.9413    0.9666    0.9538       779
           4     0.9762    0.9811    0.9787       795
           5     0.9786    0.9886    0.9836       788

    accuracy                         0.9293      3988


# CNN

In [26]:
from torch.nn.modules.conv import Conv1d
class Conv1dBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=2, padding=0, max_pool=2, dropout=0.1):
    super(Conv1dBlock, self).__init__()

    self.max_pool = max_pool
    self.l1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding = padding)
    self.l2 = nn.ReLU()
    self.l3 = nn.MaxPool1d(kernel_size=self.max_pool, stride=1, padding=0)
    self.l4 = nn.Dropout(dropout)
  
  def __call__(self, x):
    x = self.l1(x)
    x = self.l2(x)
    if self.max_pool > 0:
      x = self.l3(x)
    x = self.l4(x)
    return x

class ConvNet(nn.Module):
  def __init__(self, hidden_dim, dropout, device, outputs, embedding_dim, kernel_size=2, max_pool=2):
    super(ConvNet, self).__init__()

    self.conv_dim = conv_dim
    self.hidden_dim = hidden_dim
    self.dropout = dropout
    self.device = device
    self.outputs = outputs
    self.embedding_dim = embedding_dim
    self.max_pool = max_pool
    self.kernel_size = kernel_size

    self.e = nn.Embedding(num_embeddings = VOCAB_SIZE, embedding_dim=self.embedding_dim).to(self.device)

    self.c1 = Conv1dBlock(in_channels=self.embedding_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)
    self.c2 = Conv1dBlock(in_channels=self.conv_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)
    self.c3 = Conv1dBlock(in_channels=self.conv_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)

    self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
    self.a1 = nn.ReLU().to(self.device)
    self.d1 = nn.Dropout().to(self.device)
    self.l2 = nn.LazyLinear(self.outputs).to(self.device)

    self.loss = nn.CrossEntropyLoss()
  
  def forward(self, x, y):
    embeddings = self.e(x)
    C, L, I = embeddings.shape
    x = embeddings.reshape(C, I, L)
    x = self.c1(x).to(self.device)
    x = self.c2(x).to(self.device)
    x = self.c3(x).to(self.device)
    x = x.flatten(1)
    x = self.l1(x).to(self.device)
    x = self.a1(x).to(self.device)
    x = self.d1(x).to(self.device)
    x = self.l2(x).to(self.device)
    # x = x[:,0,:]
    loss = self.loss(x, y).to(device)

    return x, loss

  def get_predictions(self, loader):
    with torch.no_grad():
      preds_all = []
      labels_all = []
      accuracies = []

      for tweets, labels in loader:
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        self.eval()

        logits, loss = self.forward(tweets, labels)

        _, predictions = torch.max(logits, axis=1)

        preds_all += list(predictions.cpu())
        labels_all += list(labels.cpu())


        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

    return preds_all, labels_all
  
  def train_model(self, train_loader, val_loader, test_loader, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        tweets, labels = data
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()


        logits, loss = self.forward(tweets, labels)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 499 == 0:
          pred, label = self.get_predictions(val_loader)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)
  


In [27]:

conv_dim = 32
hidden_dim = 128
dropout = 0.1
learning_rate = 5e-04
epochs = 15
outputs = len(labels2id.items())
embedding_dim = 128

convnet = ConvNet(hidden_dim, dropout, device, outputs, embedding_dim, kernel_size=8, max_pool=2)
optimizer = torch.optim.Adam(convnet.parameters(), lr=learning_rate)

for i in range(epochs):
  convnet.train_model(train_loader, val_loader, test_loader, optimizer)



  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.7837431008529855
0.7837431008529855
              precision    recall  f1-score   support

           0     0.8249    0.6244    0.7108      1049
           1     0.6110    0.8455    0.7094       576
           2     0.8125    0.7303    0.7692       890
           4     0.9074    0.8984    0.9029       807
           5     0.7626    0.9142    0.8315       664

    accuracy                         0.7837      3986
   macro avg     0.7837    0.8026    0.7848      3986
weighted avg     0.7976    0.7837    0.7826      3986

########## TEST ###########
0.7941323971915747
0.7941323971915747
              precision    recall  f1-score   support

           0     0.8491    0.6416    0.7309      1052
           1     0.6291    0.8367    0.7182       600
           2     0.8087    0.7463    0.7762       867
           4     0.9099    0.9145    0.9122       795
           5     0.7739    0.9139    0.8381       674

    accuracy                         0.7941      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.8502257902659307
0.8502257902659307
              precision    recall  f1-score   support

           0     0.8526    0.6663    0.7481      1016
           1     0.6675    0.9366    0.7795       568
           2     0.8750    0.8717    0.8734       803
           4     0.9449    0.9379    0.9414       805
           5     0.9108    0.9131    0.9119       794

    accuracy                         0.8502      3986
   macro avg     0.8502    0.8651    0.8509      3986
weighted avg     0.8610    0.8502    0.8495      3986

########## TEST ###########
0.8565697091273822
0.8565697091273822
              precision    recall  f1-score   support

           0     0.8830    0.6862    0.7723      1023
           1     0.6554    0.9339    0.7703       560
           2     0.8838    0.8772    0.8804       806
           4     0.9512    0.9394    0.9453       809
           5     0.9095    0.9165    0.9130       790

    accuracy                         0.8566      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.8662819869543402
0.8662819869543402
              precision    recall  f1-score   support

           0     0.8741    0.6878    0.7698      1009
           1     0.6826    0.9628    0.7988       565
           2     0.9175    0.8645    0.8902       849
           4     0.9324    0.9790    0.9551       761
           5     0.9246    0.9177    0.9212       802

    accuracy                         0.8663      3986
   macro avg     0.8662    0.8824    0.8670      3986
weighted avg     0.8775    0.8663    0.8654      3986

########## TEST ###########
0.8713640922768305
0.8713640922768305
              precision    recall  f1-score   support

           0     0.8918    0.7076    0.7891      1002
           1     0.6704    0.9469    0.7850       565
           2     0.9275    0.8689    0.8972       854
           4     0.9349    0.9752    0.9546       766
           5     0.9322    0.9263    0.9292       801

    accuracy                         0.8714      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.8823381836427496
0.8823381836427496
              precision    recall  f1-score   support

           0     0.8678    0.7038    0.7772       979
           1     0.7340    0.9574    0.8310       611
           2     0.9200    0.9020    0.9109       816
           4     0.9374    0.9778    0.9572       766
           5     0.9523    0.9312    0.9416       814

    accuracy                         0.8823      3986
   macro avg     0.8823    0.8944    0.8836      3986
weighted avg     0.8886    0.8823    0.8810      3986

########## TEST ###########
0.8876629889669007
0.8876629889669007
              precision    recall  f1-score   support

           0     0.9057    0.7122    0.7973      1011
           1     0.7018    0.9459    0.8058       592
           2     0.9337    0.9154    0.9245       816
           4     0.9387    0.9868    0.9622       760
           5     0.9585    0.9431    0.9508       809

    accuracy                         0.8877      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.8936276969392875
0.8936276969392875
              precision    recall  f1-score   support

           0     0.8804    0.7274    0.7966       961
           1     0.7880    0.9088    0.8441       691
           2     0.9050    0.9330    0.9188       776
           4     0.9499    0.9718    0.9608       781
           5     0.9447    0.9678    0.9561       777

    accuracy                         0.8936      3986
   macro avg     0.8936    0.9018    0.8953      3986
weighted avg     0.8953    0.8936    0.8919      3986

########## TEST ###########
0.8984453360080241
0.8984453360080241
              precision    recall  f1-score   support

           0     0.8994    0.7402    0.8120       966
           1     0.7707    0.8913    0.8266       690
           2     0.9213    0.9389    0.9300       785
           4     0.9587    0.9808    0.9696       781
           5     0.9422    0.9791    0.9603       766

    accuracy                         0.8984      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.903411941796287
0.903411941796287
              precision    recall  f1-score   support

           0     0.8879    0.7484    0.8122       942
           1     0.8055    0.9277    0.8623       692
           2     0.9175    0.9279    0.9227       791
           4     0.9587    0.9696    0.9641       790
           5     0.9472    0.9780    0.9623       771

    accuracy                         0.9034      3986
   macro avg     0.9034    0.9103    0.9047      3986
weighted avg     0.9050    0.9034    0.9020      3986

########## TEST ###########
0.906469408224674
0.906469408224674
              precision    recall  f1-score   support

           0     0.9019    0.7500    0.8190       956
           1     0.7782    0.9092    0.8386       683
           2     0.9300    0.9526    0.9412       781
           4     0.9700    0.9773    0.9736       793
           5     0.9523    0.9781    0.9650       775

    accuracy                         0.9065      3988
   m

  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9004014049172102
0.9004014049172102
              precision    recall  f1-score   support

           0     0.8854    0.7323    0.8016       960
           1     0.7867    0.9414    0.8571       666
           2     0.9213    0.9270    0.9241       795
           4     0.9524    0.9744    0.9633       781
           5     0.9560    0.9707    0.9633       784

    accuracy                         0.9004      3986
   macro avg     0.9004    0.9092    0.9019      3986
weighted avg     0.9031    0.9004    0.8988      3986

########## TEST ###########
0.9014543630892679
0.901454363089268
              precision    recall  f1-score   support

           0     0.9057    0.7385    0.8136       975
           1     0.7594    0.9140    0.8296       663
           2     0.9300    0.9442    0.9370       788
           4     0.9612    0.9808    0.9709       783
           5     0.9510    0.9718    0.9613       779

    accuracy                         0.9015      3988
 

  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9004014049172102
0.9004014049172102
              precision    recall  f1-score   support

           0     0.9194    0.7213    0.8084      1012
           1     0.7629    0.9635    0.8515       631
           2     0.9175    0.9350    0.9262       785
           4     0.9474    0.9768    0.9619       775
           5     0.9548    0.9706    0.9626       783

    accuracy                         0.9004      3986
   macro avg     0.9004    0.9135    0.9021      3986
weighted avg     0.9066    0.9004    0.8986      3986

########## TEST ###########
0.9069709127382146
0.9069709127382146
              precision    recall  f1-score   support

           0     0.9358    0.7287    0.8194      1021
           1     0.7469    0.9460    0.8347       630
           2     0.9225    0.9584    0.9401       770
           4     0.9662    0.9872    0.9766       782
           5     0.9636    0.9771    0.9703       785

    accuracy                         0.9070      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.905920722528851
0.905920722528851
              precision    recall  f1-score   support

           0     0.8841    0.7532    0.8134       932
           1     0.8043    0.9358    0.8650       685
           2     0.9150    0.9373    0.9260       781
           4     0.9637    0.9759    0.9698       789
           5     0.9623    0.9587    0.9605       799

    accuracy                         0.9059      3986
   macro avg     0.9059    0.9122    0.9070      3986
weighted avg     0.9079    0.9059    0.9048      3986

########## TEST ###########
0.9114844533600802
0.9114844533600802
              precision    recall  f1-score   support

           0     0.9019    0.7611    0.8256       942
           1     0.7907    0.9225    0.8516       684
           2     0.9213    0.9659    0.9431       763
           4     0.9712    0.9761    0.9737       795
           5     0.9724    0.9627    0.9675       804

    accuracy                         0.9115      3988
  

  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9101856497742097
0.9101856497742097
              precision    recall  f1-score   support

           0     0.8678    0.7698    0.8159       895
           1     0.8243    0.9267    0.8725       709
           2     0.9375    0.9214    0.9294       814
           4     0.9587    0.9758    0.9672       785
           5     0.9623    0.9783    0.9702       783

    accuracy                         0.9102      3986
   macro avg     0.9101    0.9144    0.9110      3986
weighted avg     0.9108    0.9102    0.9092      3986

########## TEST ###########
0.9129889669007021
0.9129889669007021
              precision    recall  f1-score   support

           0     0.8780    0.7773    0.8246       898
           1     0.8120    0.9038    0.8554       717
           2     0.9425    0.9401    0.9413       802
           4     0.9650    0.9822    0.9735       785
           5     0.9673    0.9796    0.9735       786

    accuracy                         0.9130      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9129453085800301
0.9129453085800301
              precision    recall  f1-score   support

           0     0.8942    0.7667    0.8256       926
           1     0.8156    0.9353    0.8713       695
           2     0.9350    0.9280    0.9315       806
           4     0.9612    0.9821    0.9715       782
           5     0.9585    0.9820    0.9701       777

    accuracy                         0.9129      3986
   macro avg     0.9129    0.9188    0.9140      3986
weighted avg     0.9144    0.9129    0.9118      3986

########## TEST ###########
0.9137412236710131
0.9137412236710131
              precision    recall  f1-score   support

           0     0.9006    0.7757    0.8335       923
           1     0.8120    0.9101    0.8583       712
           2     0.9363    0.9457    0.9410       792
           4     0.9612    0.9859    0.9734       779
           5     0.9585    0.9757    0.9670       782

    accuracy                         0.9137      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9109382839939789
0.9109382839939789
              precision    recall  f1-score   support

           0     0.8942    0.7585    0.8208       936
           1     0.8256    0.9255    0.8727       711
           2     0.9275    0.9357    0.9316       793
           4     0.9512    0.9845    0.9675       772
           5     0.9560    0.9832    0.9694       774

    accuracy                         0.9109      3986
   macro avg     0.9109    0.9175    0.9124      3986
weighted avg     0.9116    0.9109    0.9094      3986

########## TEST ###########
0.9117352056168505
0.9117352056168505
              precision    recall  f1-score   support

           0     0.9082    0.7640    0.8299       945
           1     0.8158    0.9017    0.8566       722
           2     0.9287    0.9501    0.9393       782
           4     0.9524    0.9896    0.9707       769
           5     0.9535    0.9857    0.9693       770

    accuracy                         0.9117      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9119417962870046
0.9119417962870046
              precision    recall  f1-score   support

           0     0.8992    0.7588    0.8231       941
           1     0.8118    0.9377    0.8702       690
           2     0.9263    0.9380    0.9321       790
           4     0.9625    0.9734    0.9679       790
           5     0.9598    0.9858    0.9726       775

    accuracy                         0.9119      3986
   macro avg     0.9119    0.9187    0.9132      3986
weighted avg     0.9138    0.9119    0.9106      3986

########## TEST ###########
0.9137412236710131
0.9137412236710131
              precision    recall  f1-score   support

           0     0.9069    0.7654    0.8302       942
           1     0.8008    0.9247    0.8583       691
           2     0.9250    0.9548    0.9397       775
           4     0.9700    0.9761    0.9730       794
           5     0.9661    0.9784    0.9722       786

    accuracy                         0.9137      3988


  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.909683893627697
0.909683893627697
              precision    recall  f1-score   support

           0     0.8829    0.7661    0.8204       915
           1     0.8231    0.9073    0.8632       723
           2     0.9137    0.9408    0.9271       777
           4     0.9687    0.9724    0.9705       796
           5     0.9598    0.9858    0.9726       775

    accuracy                         0.9097      3986
   macro avg     0.9096    0.9145    0.9108      3986
weighted avg     0.9101    0.9097    0.9085      3986

########## TEST ###########
0.9157472417251755
0.9157472417251755
              precision    recall  f1-score   support

           0     0.8931    0.7837    0.8348       906
           1     0.8296    0.8946    0.8609       740
           2     0.9250    0.9610    0.9427       770
           4     0.9725    0.9786    0.9755       794
           5     0.9585    0.9807    0.9695       778

    accuracy                         0.9157      3988
  

  0%|          | 0/499 [00:00<?, ?batch/s]

########## VAL ###########
0.9124435524335173
0.9124435524335173
              precision    recall  f1-score   support

           0     0.8866    0.7652    0.8215       920
           1     0.8143    0.9516    0.8776       682
           2     0.9400    0.9148    0.9273       822
           4     0.9650    0.9735    0.9692       792
           5     0.9560    0.9883    0.9719       770

    accuracy                         0.9124      3986
   macro avg     0.9124    0.9187    0.9135      3986
weighted avg     0.9142    0.9124    0.9113      3986

########## TEST ###########
0.9180040120361084
0.9180040120361084
              precision    recall  f1-score   support

           0     0.8981    0.7803    0.8351       915
           1     0.8145    0.9380    0.8719       693
           2     0.9513    0.9303    0.9407       818
           4     0.9737    0.9786    0.9762       795
           5     0.9523    0.9883    0.9699       767

    accuracy                         0.9180      3988


# Transformer Models

In [15]:
BATCH_SIZE = 50
del convnet
del mlp

In [16]:
model_type = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=False)

In [17]:
tokenized_text = list(data['cleaned_text'])
tokenized_text = [' '.join(x) for x in tokenized_text]
tokenized_text = tokenizer(tokenized_text, padding = 'max_length', max_length = PADDING_VALUE, truncation=True, add_special_tokens=True)
x_ids = torch.tensor([f for f in tokenized_text.input_ids], dtype=torch.long)
x_mask = torch.tensor([f for f in tokenized_text.attention_mask], dtype=torch.long)
label_tensor = torch.tensor(data.label)

In [18]:
train_label = label_tensor[train_idx]
val_label = label_tensor[val_idx]
test_label = label_tensor[test_idx]

# masks to get rid of other cyberbullying examples. a little roundabout but I want the train val test split to be consistent across all models
train_label_mask = train_label != labels2id['other_cyberbullying']
val_label_mask = val_label != labels2id['other_cyberbullying']
test_label_mask = test_label != labels2id['other_cyberbullying']

train_label = train_label[train_label_mask]
val_label = val_label[val_label_mask]
test_label = test_label[test_label_mask]


train_x = x_ids[train_idx][train_label_mask]
val_x = x_ids[val_idx][val_label_mask]
test_x = x_ids[test_idx][test_label_mask]

train_mask = x_mask[train_idx][train_label_mask]
val_mask = x_mask[val_idx][val_label_mask]
test_mask = x_mask[test_idx][test_label_mask]

train = TensorDataset(train_x, train_mask, train_label)
val = TensorDataset(val_x, val_mask, val_label)
test = TensorDataset(test_x, test_mask, test_label)
print(len(train), len(val), len(test))

31895 3986 3988


In [19]:
# tensor_dataset = TensorDataset(x_ids, x_mask, label_tensor)

In [20]:
# TRAIN_SPLIT = int(0.8 * len(labels))
# VAL_TEST_SPLIT = len(labels) - TRAIN_SPLIT
# VAL_SPLIT = VAL_TEST_SPLIT//2
# TEST_SPLIT = VAL_TEST_SPLIT - VAL_SPLIT

# train, val_test = random_split(tensor_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT], generator=torch.Generator().manual_seed(13))
# val, test = random_split(val_test, [VAL_SPLIT, TEST_SPLIT], generator=torch.Generator().manual_seed(13))

# print(len(train), len(val), len(test))

In [21]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE)
val_loader = DataLoader(val, batch_size=BATCH_SIZE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE)

In [22]:
class transformer_classifier(nn.Module):
  def __init__(self, model_type, hidden_dim, dropout, device ):
      super(transformer_classifier, self).__init__()

      self.device = device
      self.hidden_dim = hidden_dim
      self.LABELS_NUM = len(labels2id.items())

      self.lm = AutoModel.from_pretrained(model_type).to(self.device)

      config = config = AutoConfig.from_pretrained(model_type, num_labels = self.LABELS_NUM)
      self.lm_sequence_classification = AutoModelForSequenceClassification.from_pretrained(model_type, config=config).to(self.device)

      self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
      self.a1 = nn.GELU().to(self.device)
      self.d1 = nn.Dropout(dropout).to(self.device)
      self.l2 = nn.LazyLinear(self.LABELS_NUM).to(self.device)

      self.loss = nn.CrossEntropyLoss()


  def forward(self, ids, masks, labels, mode):
    if mode:
      output = self.lm_sequence_classification(ids, masks, labels=labels)
      return output.logits.to(self.device), output.loss.to(self.device)
    else:
      x = self.lm(ids, masks).last_hidden_state[:, 0].to(self.device)
      x = self.l1(x).to(self.device)
      x = self.a1(x).to(self.device)
      x = self.d1(x).to(self.device)
      x = self.l2(x).to(self.device)

      loss = self.loss(x, labels).to(self.device)

      return x, loss

  def get_predictions(self, loader, mode):
    with torch.no_grad():
      with tqdm.notebook.tqdm(
      loader,
      unit="batch",
      total=len(loader)) as batch_iterator:

        preds_all = []
        labels_all = []
        accuracies = []

        for iteration, data in enumerate(batch_iterator, start=1):
          ids, masks, labels = data
          ids = ids.to(self.device)
          masks = masks.to(self.device)
          labels = labels.to(self.device)

          self.eval()

          logits, loss = self.forward(ids, masks, labels, mode)

          _, predictions = torch.max(logits, axis=1)

          preds_all += list(predictions.cpu())
          labels_all += list(labels.cpu())


          accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
          accuracies.append(accuracy)

          batch_iterator.set_postfix(accuracy = accuracy, mean_accuracy = np.mean(accuracies))
    return preds_all, labels_all

  def train_model(self, train_loader, val_loader, test_loader, mode, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        ids, masks, labels = data
        ids = ids.to(self.device)
        masks = masks.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()

        logits, loss = self.forward(ids, masks, labels, mode)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 400 == 0:
          pred, label = self.get_predictions(val_loader, mode)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader, mode)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          



In [23]:
hidden_dim = 128
dropout = 0.2
learning_rate = 1e-04
epochs = 5

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, False, optimizer)

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, True, optimizer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerN

  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.937531359759157
0.937531359759157
              precision    recall  f1-score   support

           0     0.9093    0.8223    0.8636       878
           1     0.8683    0.9584    0.9111       722
           2     0.9500    0.9669    0.9584       786
           4     0.9750    0.9750    0.9750       799
           5     0.9849    0.9788    0.9818       801

    accuracy                         0.9375      3986
   macro avg     0.9375    0.9403    0.9380      3986
weighted avg     0.9383    0.9375    0.9370      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.9358074222668004
0.9358074222668004
              precision    recall  f1-score   support

           0     0.8994    0.8266    0.8614       865
           1     0.8596    0.9462    0.9009       725
           2     0.9575    0.9635    0.9605       795
           4     0.9812    0.9788    0.9800       801
           5     0.9812    0.9738    0.9775       802

    accuracy                         0.9358      3988
   macro avg     0.9358    0.9378    0.9361      3988
weighted avg     0.9366    0.9358    0.9355      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9372804816859006
0.9372804816859006
              precision    recall  f1-score   support

           0     0.8627    0.8405    0.8515       815
           1     0.9021    0.8999    0.9010       799
           2     0.9537    0.9695    0.9616       787
           4     0.9775    0.9836    0.9805       794
           5     0.9899    0.9962    0.9931       791

    accuracy                         0.9373      3986
   macro avg     0.9372    0.9379    0.9375      3986
weighted avg     0.9367    0.9373    0.9369      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.9370611835506519
0.9370611835506519
              precision    recall  f1-score   support

           0     0.8566    0.8481    0.8523       803
           1     0.8997    0.8897    0.8947       807
           2     0.9600    0.9673    0.9636       794
           4     0.9875    0.9900    0.9887       797
           5     0.9812    0.9924    0.9867       787

    accuracy                         0.9371      3988
   macro avg     0.9370    0.9375    0.9372      3988
weighted avg     0.9367    0.9371    0.9368      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9284997491219268
0.9284997491219268
              precision    recall  f1-score   support

           0     0.8060    0.8568    0.8306       747
           1     0.9184    0.8725    0.8949       839
           2     0.9550    0.9550    0.9550       800
           4     0.9750    0.9811    0.9780       794
           5     0.9874    0.9752    0.9813       806

    accuracy                         0.9285      3986
   macro avg     0.9284    0.9281    0.9280      3986
weighted avg     0.9299    0.9285    0.9289      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.9348044132397192
0.9348044132397192
              precision    recall  f1-score   support

           0     0.8277    0.8773    0.8518       750
           1     0.9261    0.8746    0.8996       845
           2     0.9563    0.9598    0.9580       797
           4     0.9787    0.9836    0.9812       795
           5     0.9849    0.9788    0.9818       801

    accuracy                         0.9348      3988
   macro avg     0.9347    0.9348    0.9345      3988
weighted avg     0.9359    0.9348    0.9351      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9345208228800803
0.9345208228800803
              precision    recall  f1-score   support

           0     0.8929    0.8225    0.8563       862
           1     0.8582    0.9566    0.9048       715
           2     0.9637    0.9391    0.9513       821
           4     0.9700    0.9848    0.9773       787
           5     0.9874    0.9813    0.9843       801

    accuracy                         0.9345      3986
   macro avg     0.9345    0.9369    0.9348      3986
weighted avg     0.9355    0.9345    0.9342      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.940320962888666
0.940320962888666
              precision    recall  f1-score   support

           0     0.9107    0.8322    0.8697       870
           1     0.8596    0.9528    0.9038       720
           2     0.9688    0.9498    0.9592       816
           4     0.9800    0.9911    0.9855       790
           5     0.9824    0.9874    0.9849       792

    accuracy                         0.9403      3988
   macro avg     0.9403    0.9426    0.9406      3988
weighted avg     0.9413    0.9403    0.9400      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9264927245358756
0.9264927245358756
              precision    recall  f1-score   support

           0     0.8438    0.8201    0.8318       817
           1     0.9109    0.8715    0.8908       833
           2     0.9363    0.9677    0.9517       774
           4     0.9675    0.9885    0.9779       782
           5     0.9736    0.9936    0.9835       780

    accuracy                         0.9265      3986
   macro avg     0.9264    0.9283    0.9271      3986
weighted avg     0.9254    0.9265    0.9258      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.928284854563691
0.928284854563691
              precision    recall  f1-score   support

           0     0.8415    0.8342    0.8378       802
           1     0.9185    0.8726    0.8950       840
           2     0.9413    0.9617    0.9514       783
           4     0.9750    0.9886    0.9817       788
           5     0.9648    0.9910    0.9777       775

    accuracy                         0.9283      3988
   macro avg     0.9282    0.9296    0.9287      3988
weighted avg     0.9277    0.9283    0.9278      3988



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerN

  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9315102860010035
0.9315102860010035
              precision    recall  f1-score   support

           0     0.8892    0.8238    0.8552       857
           1     0.8645    0.9400    0.9007       733
           2     0.9425    0.9630    0.9526       783
           4     0.9750    0.9786    0.9768       796
           5     0.9862    0.9608    0.9733       817

    accuracy                         0.9315      3986
   macro avg     0.9315    0.9332    0.9317      3986
weighted avg     0.9321    0.9315    0.9312      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.9355566700100301
0.9355566700100301
              precision    recall  f1-score   support

           0     0.8906    0.8349    0.8618       848
           1     0.8672    0.9389    0.9016       737
           2     0.9525    0.9597    0.9561       794
           4     0.9837    0.9813    0.9825       801
           5     0.9837    0.9691    0.9763       808

    accuracy                         0.9356      3988
   macro avg     0.9355    0.9368    0.9357      3988
weighted avg     0.9361    0.9356    0.9354      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9377822378324134
0.9377822378324134
              precision    recall  f1-score   support

           0     0.9106    0.8188    0.8623       883
           1     0.8795    0.9347    0.9063       750
           2     0.9437    0.9667    0.9551       781
           4     0.9712    0.9898    0.9804       784
           5     0.9837    0.9937    0.9886       788

    accuracy                         0.9378      3986
   macro avg     0.9378    0.9407    0.9385      3986
weighted avg     0.9376    0.9378    0.9370      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.941073219658977
0.941073219658977
              precision    recall  f1-score   support

           0     0.9208    0.8281    0.8719       884
           1     0.8810    0.9324    0.9059       754
           2     0.9563    0.9745    0.9653       785
           4     0.9762    0.9949    0.9855       784
           5     0.9711    0.9898    0.9803       781

    accuracy                         0.9411      3988
   macro avg     0.9411    0.9439    0.9418      3988
weighted avg     0.9410    0.9411    0.9403      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9322629202207727
0.9322629202207727
              precision    recall  f1-score   support

           0     0.9030    0.8056    0.8515       890
           1     0.8720    0.9267    0.8985       750
           2     0.9237    0.9685    0.9456       763
           4     0.9725    0.9898    0.9811       785
           5     0.9899    0.9875    0.9887       798

    accuracy                         0.9323      3986
   macro avg     0.9322    0.9356    0.9331      3986
weighted avg     0.9322    0.9323    0.9314      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.9340521564694082
0.9340521564694082
              precision    recall  f1-score   support

           0     0.9019    0.8148    0.8561       880
           1     0.8759    0.9345    0.9043       748
           2     0.9263    0.9611    0.9433       771
           4     0.9787    0.9924    0.9855       788
           5     0.9874    0.9813    0.9843       801

    accuracy                         0.9341      3988
   macro avg     0.9340    0.9368    0.9347      3988
weighted avg     0.9341    0.9341    0.9333      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.9297541394882087
0.9297541394882087
              precision    recall  f1-score   support

           0     0.8627    0.8233    0.8426       832
           1     0.8971    0.9120    0.9045       784
           2     0.9300    0.9650    0.9472       771
           4     0.9675    0.9898    0.9785       781
           5     0.9912    0.9645    0.9777       818

    accuracy                         0.9298      3986
   macro avg     0.9297    0.9309    0.9301      3986
weighted avg     0.9294    0.9298    0.9293      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.9348044132397192
0.9348044132397192
              precision    recall  f1-score   support

           0     0.8805    0.8314    0.8552       842
           1     0.8972    0.9109    0.9040       786
           2     0.9313    0.9688    0.9496       769
           4     0.9750    0.9949    0.9848       783
           5     0.9899    0.9752    0.9825       808

    accuracy                         0.9348      3988
   macro avg     0.9348    0.9362    0.9353      3988
weighted avg     0.9343    0.9348    0.9343      3988



  0%|          | 0/638 [00:00<?, ?batch/s]

  0%|          | 0/80 [00:00<?, ?batch/s]

########## VAL ###########
0.935022579026593
0.935022579026593
              precision    recall  f1-score   support

           0     0.8728    0.8400    0.8561       825
           1     0.8921    0.9222    0.9069       771
           2     0.9475    0.9619    0.9547       788
           4     0.9775    0.9666    0.9720       808
           5     0.9849    0.9874    0.9862       794

    accuracy                         0.9350      3986
   macro avg     0.9350    0.9356    0.9352      3986
weighted avg     0.9349    0.9350    0.9348      3986



  0%|          | 0/80 [00:00<?, ?batch/s]

########## TEST ###########
0.932296890672016
0.932296890672016
              precision    recall  f1-score   support

           0     0.8667    0.8321    0.8490       828
           1     0.8772    0.9067    0.8917       772
           2     0.9575    0.9599    0.9587       798
           4     0.9825    0.9752    0.9788       805
           5     0.9774    0.9911    0.9842       785

    accuracy                         0.9323      3988
   macro avg     0.9322    0.9330    0.9325      3988
weighted avg     0.9321    0.9323    0.9320      3988



In [39]:
model_type

'bert-base-uncased'

In [None]:
model_type = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=False)
tokenized_text = list(data['cleaned_text'])
tokenized_text = [' '.join(x) for x in tokenized_text]
tokenized_text = tokenizer(tokenized_text, padding = 'max_length', max_length = PADDING_VALUE, truncation=True, add_special_tokens=True)
x_ids = torch.tensor([f for f in tokenized_text.input_ids], dtype=torch.long)
x_mask = torch.tensor([f for f in tokenized_text.attention_mask], dtype=torch.long)
label_tensor = torch.tensor(data.label)
tensor_dataset = TensorDataset(x_ids, x_mask, label_tensor)
train, val_test = random_split(tensor_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT], generator=torch.Generator().manual_seed(13))
val, test = random_split(val_test, [VAL_SPLIT, TEST_SPLIT], generator=torch.Generator().manual_seed(13))
train_loader = DataLoader(train, batch_size=BATCH_SIZE)
val_loader = DataLoader(val, batch_size=BATCH_SIZE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE)

In [None]:
classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, False, optimizer)

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, True, optimizer)