In [1]:
# conda env test1

In [2]:
import torch
import torch.nn as nn
import pandas
import re, string
import emoji
from collections import Counter
import torchtext

from sklearn.metrics import f1_score, accuracy_score, classification_report
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset, SequentialSampler, random_split)
import tqdm
import numpy as np

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
data = pandas.read_csv('cyberbullying_tweets.csv')

In [4]:
data

Unnamed: 0,tweet_text,cyberbullying_type
0,"In other words #katandandre, your food was cra...",not_cyberbullying
1,Why is #aussietv so white? #MKR #theblock #ImA...,not_cyberbullying
2,@XochitlSuckkks a classy whore? Or more red ve...,not_cyberbullying
3,"@Jason_Gio meh. :P thanks for the heads up, b...",not_cyberbullying
4,@RudhoeEnglish This is an ISIS account pretend...,not_cyberbullying
...,...,...
47687,"Black ppl aren't expected to do anything, depe...",ethnicity
47688,Turner did not withhold his disappointment. Tu...,ethnicity
47689,I swear to God. This dumb nigger bitch. I have...,ethnicity
47690,Yea fuck you RT @therealexel: IF YOURE A NIGGE...,ethnicity


In [5]:
data.loc[data['cyberbullying_type'] == 'other_cyberbullying']

Unnamed: 0,tweet_text,cyberbullying_type
23916,"@ikralla fyi, it looks like I was caught by it...",other_cyberbullying
23917,I need to just switch to an organization-based...,other_cyberbullying
23918,RMAed my monoprice. Shoddy power bricks on tho...,other_cyberbullying
23919,@murphy_slaw https://t.co/M8w8xnUnDL,other_cyberbullying
23920,@1Life0Continues i've got the code to interpre...,other_cyberbullying
...,...,...
31734,"@kufr666 @blockbot no, that's @oolon",other_cyberbullying
31735,@AriMelber why are you giving these idiots air...,other_cyberbullying
31736,I am right now watching Enforcers defend Chums...,other_cyberbullying
31737,✨✨✨ misandry is not a word iOS can autocomplet...,other_cyberbullying


In [6]:
labels = list(data.cyberbullying_type.unique())
labels2id = {label: i for i, label in enumerate(labels)}
print(labels2id)

{'not_cyberbullying': 0, 'gender': 1, 'religion': 2, 'other_cyberbullying': 3, 'age': 4, 'ethnicity': 5}


In [7]:
data['label'] = data.cyberbullying_type.map(labels2id)

In [8]:
#Clean emojis from text
def strip_emoji(text):
    return emoji.replace_emoji(text)
    # return re.sub(emoji.get_emoji_regexp(), r"", text) #remove emoji
    
def strip_all_entities(text): 
    text = text.replace('\r', '').replace('\n', ' ').lower() #remove \n and \r and lowercase
    text = re.sub(r"(?:\@|https?\://)\S+", "", text) #remove links and mentions
    text = re.sub(r'[^\x00-\x7f]',r'', text) #remove non utf8/ascii characters such as '\x9a\x91\x97\x9a\x97'
    banned_list= string.punctuation
    table = str.maketrans('', '', banned_list)
    text = text.translate(table)
    text = [word for word in text.split()]
    text = ' '.join(text)
    return text

#remove contractions
def decontract(text):
    text = re.sub(r"can\'t", "can not", text)
    text = re.sub(r"won\'t", "will not", text)
    text = re.sub(r"n\'t", " not", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"\'s", " is", text)
    text = re.sub(r"\'d", " would", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'t", " not", text)
    text = re.sub(r"\'ve", " have", text)
    text = re.sub(r"\'m", " am", text)
    return text

def clean_hashtags(tweet):
    # new_tweet = " ".join(word.strip() for word in re.split('#(?!(?:hashtag)\b)[\w-]+(?=(?:\s+#[\w-]+)*\s*$)', tweet)) #remove last hashtags
    new_tweet2 = " ".join(word.strip() for word in re.split('#|_', tweet)) #remove hashtags symbol from words in the middle of the sentence
    return new_tweet2

#Filter special characters such as "&" and "$" present in some words
def filter_chars(a):
    sent = []
    for word in a.split(' '):
        if ('$' in word) | ('&' in word):
            sent.append('')
        else:
            sent.append(word)
    return ' '.join(sent)

#Remove multiple sequential spaces
def remove_mult_spaces(text):
    return re.sub("\s\s+" , " ", text)


#Then we apply all the defined functions in the following order
def deep_clean(text):
    text = strip_emoji(text)
    text = decontract(text)
    text = strip_all_entities(text)
    text = clean_hashtags(text)
    text = filter_chars(text)
    text = remove_mult_spaces(text)
    return text

In [9]:
PADDING_VALUE = 256
data['cleaned_text'] =  [deep_clean(x).split()[:256] for x in data.tweet_text]

In [10]:
word_counts = Counter([item for sublist in data['cleaned_text'] for item in sublist])
vocab = torchtext.vocab.vocab(word_counts,
                              min_freq = 5,
                              specials = ['<pad>', '<unk>', ])
vocab.set_default_index(1)

VOCAB_SIZE = len(vocab)

In [11]:
tokenized_data = [torch.tensor(vocab(x), dtype = torch.int64) for x in data.cleaned_text]

labels = torch.tensor(data.label, dtype = torch.int64)

tokenized_data = nn.utils.rnn.pad_sequence(tokenized_data, batch_first=True)

# bag of words input. frequency of token
bow_tokenized_data = torch.zeros(tokenized_data.shape[0], VOCAB_SIZE)
for i in range(tokenized_data.shape[0]):
  tweet = tokenized_data[i,tokenized_data[i].nonzero().flatten()]
  bow_tokenized_data[i].put_(tweet, torch.ones_like(tweet).float(), accumulate=True)

bow_tokenized_data = nn.utils.rnn.pad_sequence(bow_tokenized_data, batch_first=True)

In [12]:
TRAIN_SPLIT = int(0.8 * len(labels))
VAL_TEST_SPLIT = len(labels) - TRAIN_SPLIT
VAL_SPLIT = VAL_TEST_SPLIT//2
TEST_SPLIT = VAL_TEST_SPLIT - VAL_SPLIT

bow_cnn_dataset = torch.utils.data.TensorDataset(tokenized_data, labels)
bow_dataset = torch.utils.data.TensorDataset(bow_tokenized_data, labels)

In [13]:
train, val_test = torch.utils.data.random_split(bow_cnn_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT])
val, test = torch.utils.data.random_split(val_test, [VAL_SPLIT, TEST_SPLIT])

print(len(train), len(val), len(test))

bow_train, bow_val_test = torch.utils.data.random_split(bow_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT])
bow_val, bow_test = torch.utils.data.random_split(bow_val_test, [VAL_SPLIT, TEST_SPLIT])

print(len(bow_train), len(bow_val), len(bow_test))

38153 4769 4770
38153 4769 4770


In [14]:
BATCH_SIZE = 64
train_loader = torch.utils.data.DataLoader(train, BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val, BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test, BATCH_SIZE)

bow_train_loader = torch.utils.data.DataLoader(bow_train, BATCH_SIZE)
bow_val_loader = torch.utils.data.DataLoader(bow_val, BATCH_SIZE)
bow_test_loader = torch.utils.data.DataLoader(bow_test, BATCH_SIZE)

# BOW

In [15]:
class MLP(nn.Module):
  def __init__(self, hidden_dim, dropout, device, outputs):
    super(MLP, self).__init__()

    self.hidden_dim = hidden_dim
    self.dropout = dropout
    self.device = device
    self.outputs = outputs
    
    self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
    self.a1 = nn.ReLU().to(self.device)
    self.d1 = nn.Dropout().to(self.device)
    self.l2 = nn.LazyLinear(self.outputs).to(self.device)

    self.loss = nn.CrossEntropyLoss()
  
  def forward(self, x, y):
    x = self.l1(x).to(self.device)
    x = self.a1(x).to(self.device)
    x = self.d1(x).to(self.device)
    x = self.l2(x).to(self.device)

    loss = self.loss(x, y).to(device)

    return x, loss

  def get_predictions(self, loader):
    with torch.no_grad():
      preds_all = []
      labels_all = []
      accuracies = []

      for tweets, labels in loader:
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        self.eval()

        logits, loss = self.forward(tweets, labels)

        _, predictions = torch.max(logits, axis=1)

        preds_all += list(predictions.cpu())
        labels_all += list(labels.cpu())


        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

    return preds_all, labels_all
  
  def train_model(self, train_loader, val_loader, test_loader, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        tweets, labels = data
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()

        logits, loss = self.forward(tweets, labels)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 500 == 0:
          pred, label = self.get_predictions(val_loader)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)
  


In [16]:
hidden_dim = 128
dropout = 0.
learning_rate = 1e-04
epochs = 15
outputs = len(labels2id.items())

mlp = MLP(hidden_dim, dropout, device, outputs)
optimizer = torch.optim.Adam(mlp.parameters(), lr=learning_rate)

for i in range(epochs):
  mlp.train_model(bow_train_loader, bow_val_loader, bow_test_loader, optimizer)



  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7372614803942126
0.7372614803942126
              precision    recall  f1-score   support

           0     0.2548    0.7256    0.3771       277
           1     0.8075    0.7497    0.7775       867
           2     0.9782    0.6955    0.8130      1097
           3     0.4358    0.5536    0.4877       625
           4     0.9859    0.7744    0.8675       993
           5     0.9586    0.8648    0.9093       910

    accuracy                         0.7373      4769
   macro avg     0.7368    0.7273    0.7053      4769
weighted avg     0.8319    0.7373    0.7683      4769

########## TEST ###########
0.7268343815513627
0.7268343815513627
              precision    recall  f1-score   support

           0     0.2400    0.7283    0.3611       265
           1     0.8088    0.7335    0.7693       848
           2     0.9690    0.6890    0.8053      1180
           3     0.4021    0.5126    0.4507       597
           4     0.9612    0.7767    0.8592      1021
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8066680645837702
0.80666806458377
              precision    recall  f1-score   support

           0     0.4474    0.6281    0.5226       562
           1     0.7801    0.8870    0.8301       708
           2     0.9500    0.8982    0.9234       825
           3     0.6990    0.5640    0.6243       984
           4     0.9872    0.9200    0.9524       837
           5     0.9744    0.9379    0.9558       853

    accuracy                         0.8067      4769
   macro avg     0.8064    0.8059    0.8014      4769
weighted avg     0.8247    0.8067    0.8115      4769

########## TEST ###########
0.79979035639413
0.79979035639413
              precision    recall  f1-score   support

           0     0.4055    0.6293    0.4932       518
           1     0.7633    0.9017    0.8268       651
           2     0.9523    0.8800    0.9147       908
           3     0.7201    0.5315    0.6116      1031
           4     0.9685    0.9173    0.9422       871
       

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8202977563430489
0.820297756343049
              precision    recall  f1-score   support

           0     0.4892    0.6266    0.5495       616
           1     0.7776    0.9165    0.8414       683
           2     0.9487    0.9367    0.9427       790
           3     0.7456    0.5714    0.6470      1036
           4     0.9872    0.9483    0.9673       812
           5     0.9720    0.9591    0.9655       832

    accuracy                         0.8203      4769
   macro avg     0.8201    0.8265    0.8189      4769
weighted avg     0.8313    0.8203    0.8213      4769

########## TEST ###########
0.8083857442348008
0.8083857442348008
              precision    recall  f1-score   support

           0     0.4378    0.6143    0.5113       573
           1     0.7607    0.9169    0.8316       638
           2     0.9368    0.9140    0.9253       860
           3     0.7608    0.5371    0.6297      1078
           4     0.9709    0.9435    0.9570       849
  

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.825539945481233
0.825539945481233
              precision    recall  f1-score   support

           0     0.4905    0.6303    0.5517       614
           1     0.7925    0.9180    0.8507       695
           2     0.9462    0.9523    0.9492       775
           3     0.7620    0.5718    0.6533      1058
           4     0.9885    0.9554    0.9716       807
           5     0.9720    0.9732    0.9726       820

    accuracy                         0.8255      4769
   macro avg     0.8253    0.8335    0.8249      4769
weighted avg     0.8358    0.8255    0.8258      4769

########## TEST ###########
0.8153039832285115
0.8153039832285115
              precision    recall  f1-score   support

           0     0.4478    0.6250    0.5217       576
           1     0.7724    0.9267    0.8426       641
           2     0.9356    0.9257    0.9306       848
           3     0.7792    0.5450    0.6414      1088
           4     0.9745    0.9459    0.9600       850
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8295240092262529
0.8295240092262529
              precision    recall  f1-score   support

           0     0.4994    0.6304    0.5573       625
           1     0.8025    0.9124    0.8539       708
           2     0.9462    0.9584    0.9523       770
           3     0.7670    0.5783    0.6594      1053
           4     0.9872    0.9637    0.9753       799
           5     0.9732    0.9816    0.9774       814

    accuracy                         0.8295      4769
   macro avg     0.8292    0.8375    0.8293      4769
weighted avg     0.8382    0.8295    0.8294      4769

########## TEST ###########
0.8226415094339623
0.8226415094339623
              precision    recall  f1-score   support

           0     0.4627    0.6447    0.5387       577
           1     0.7880    0.9182    0.8481       660
           2     0.9368    0.9357    0.9363       840
           3     0.7924    0.5578    0.6547      1081
           4     0.9733    0.9503    0.9617       845
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8332983854057454
0.8332983854057454
              precision    recall  f1-score   support

           0     0.5158    0.6349    0.5692       641
           1     0.8112    0.9082    0.8570       719
           2     0.9462    0.9609    0.9535       768
           3     0.7645    0.5870    0.6641      1034
           4     0.9872    0.9673    0.9772       796
           5     0.9732    0.9852    0.9792       811

    accuracy                         0.8333      4769
   macro avg     0.8330    0.8406    0.8334      4769
weighted avg     0.8400    0.8333    0.8329      4769

########## TEST ###########
0.8289308176100629
0.8289308176100629
              precision    recall  f1-score   support

           0     0.4813    0.6439    0.5509       601
           1     0.8062    0.9158    0.8575       677
           2     0.9380    0.9425    0.9403       835
           3     0.7924    0.5726    0.6648      1053
           4     0.9733    0.9582    0.9657       838
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8349758859299643
0.8349758859299643
              precision    recall  f1-score   support

           0     0.5108    0.6346    0.5660       635
           1     0.8186    0.9090    0.8614       725
           2     0.9462    0.9609    0.9535       768
           3     0.7720    0.5917    0.6699      1036
           4     0.9872    0.9698    0.9784       794
           5     0.9732    0.9852    0.9792       811

    accuracy                         0.8350      4769
   macro avg     0.8347    0.8419    0.8347      4769
weighted avg     0.8424    0.8350    0.8348      4769

########## TEST ###########
0.8318658280922432
0.8318658280922432
              precision    recall  f1-score   support

           0     0.4863    0.6528    0.5574       599
           1     0.8101    0.9095    0.8569       685
           2     0.9416    0.9450    0.9433       836
           3     0.7963    0.5782    0.6700      1048
           4     0.9733    0.9605    0.9669       836
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.834766198364437
0.834766198364437
              precision    recall  f1-score   support

           0     0.5158    0.6291    0.5669       647
           1     0.8224    0.9106    0.8642       727
           2     0.9462    0.9609    0.9535       768
           3     0.7607    0.5904    0.6648      1023
           4     0.9872    0.9722    0.9796       792
           5     0.9744    0.9852    0.9798       812

    accuracy                         0.8348      4769
   macro avg     0.8344    0.8414    0.8348      4769
weighted avg     0.8407    0.8348    0.8343      4769

########## TEST ###########
0.8341719077568134
0.8341719077568134
              precision    recall  f1-score   support

           0     0.4975    0.6579    0.5666       608
           1     0.8153    0.9087    0.8595       690
           2     0.9416    0.9472    0.9444       834
           3     0.7937    0.5813    0.6711      1039
           4     0.9721    0.9628    0.9674       833
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8358146361920739
0.8358146361920739
              precision    recall  f1-score   support

           0     0.5095    0.6351    0.5654       633
           1     0.8286    0.9075    0.8662       735
           2     0.9500    0.9524    0.9512       778
           3     0.7645    0.5934    0.6681      1023
           4     0.9846    0.9759    0.9802       787
           5     0.9756    0.9852    0.9804       813

    accuracy                         0.8358      4769
   macro avg     0.8355    0.8416    0.8353      4769
weighted avg     0.8431    0.8358    0.8359      4769

########## TEST ###########
0.8373165618448637
0.8373165618448637
              precision    recall  f1-score   support

           0     0.4950    0.6656    0.5678       598
           1     0.8244    0.9096    0.8649       697
           2     0.9440    0.9462    0.9451       837
           3     0.8016    0.5865    0.6774      1040
           4     0.9721    0.9639    0.9680       832
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8356049486265464
0.8356049486265464
              precision    recall  f1-score   support

           0     0.5158    0.6300    0.5672       646
           1     0.8311    0.9041    0.8660       740
           2     0.9487    0.9548    0.9518       775
           3     0.7557    0.5952    0.6659      1008
           4     0.9846    0.9746    0.9796       788
           5     0.9756    0.9865    0.9810       812

    accuracy                         0.8356      4769
   macro avg     0.8353    0.8409    0.8353      4769
weighted avg     0.8415    0.8356    0.8355      4769

########## TEST ###########
0.8371069182389937
0.8371069182389937
              precision    recall  f1-score   support

           0     0.4975    0.6601    0.5674       606
           1     0.8257    0.9007    0.8616       705
           2     0.9428    0.9496    0.9462       833
           3     0.7963    0.5912    0.6786      1025
           4     0.9733    0.9617    0.9675       835
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8328790102746907
0.8328790102746906
              precision    recall  f1-score   support

           0     0.5146    0.6161    0.5608       659
           1     0.8323    0.9005    0.8651       744
           2     0.9500    0.9549    0.9524       776
           3     0.7368    0.5903    0.6555       991
           4     0.9846    0.9771    0.9808       786
           5     0.9769    0.9865    0.9816       813

    accuracy                         0.8329      4769
   macro avg     0.8325    0.8376    0.8327      4769
weighted avg     0.8374    0.8329    0.8326      4769

########## TEST ###########
0.8373165618448637
0.8373165618448637
              precision    recall  f1-score   support

           0     0.5037    0.6553    0.5696       618
           1     0.8270    0.8996    0.8618       707
           2     0.9428    0.9473    0.9450       835
           3     0.7884    0.5946    0.6780      1009
           4     0.9733    0.9628    0.9681       834
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8339274481023276
0.8339274481023276
              precision    recall  f1-score   support

           0     0.5184    0.6197    0.5645       660
           1     0.8348    0.8960    0.8643       750
           2     0.9500    0.9549    0.9524       776
           3     0.7355    0.5923    0.6562       986
           4     0.9846    0.9808    0.9827       783
           5     0.9781    0.9865    0.9823       814

    accuracy                         0.8339      4769
   macro avg     0.8336    0.8384    0.8337      4769
weighted avg     0.8383    0.8339    0.8337      4769

########## TEST ###########
0.8366876310272536
0.8366876310272536
              precision    recall  f1-score   support

           0     0.5062    0.6543    0.5708       622
           1     0.8283    0.8909    0.8585       715
           2     0.9428    0.9473    0.9450       835
           3     0.7806    0.5958    0.6758       997
           4     0.9721    0.9639    0.9680       832
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8339274481023276
0.8339274481023276
              precision    recall  f1-score   support

           0     0.5158    0.6167    0.5618       660
           1     0.8335    0.8947    0.8630       750
           2     0.9500    0.9549    0.9524       776
           3     0.7380    0.5949    0.6588       985
           4     0.9846    0.9808    0.9827       783
           5     0.9793    0.9865    0.9829       815

    accuracy                         0.8339      4769
   macro avg     0.8336    0.8381    0.8336      4769
weighted avg     0.8385    0.8339    0.8338      4769

########## TEST ###########
0.8377358490566038
0.8377358490566038
              precision    recall  f1-score   support

           0     0.5075    0.6528    0.5710       625
           1     0.8336    0.8890    0.8604       721
           2     0.9452    0.9474    0.9463       837
           3     0.7792    0.5996    0.6777       989
           4     0.9709    0.9662    0.9686       829
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.832459635143636
0.832459635143636
              precision    recall  f1-score   support

           0     0.5171    0.6108    0.5601       668
           1     0.8335    0.8911    0.8614       753
           2     0.9526    0.9550    0.9538       778
           3     0.7280    0.5922    0.6531       976
           4     0.9833    0.9808    0.9821       782
           5     0.9781    0.9889    0.9835       812

    accuracy                         0.8325      4769
   macro avg     0.8321    0.8365    0.8323      4769
weighted avg     0.8362    0.8325    0.8322      4769

########## TEST ###########
0.8358490566037736
0.8358490566037736
              precision    recall  f1-score   support

           0     0.5100    0.6426    0.5687       638
           1     0.8336    0.8854    0.8587       724
           2     0.9452    0.9474    0.9463       837
           3     0.7648    0.5982    0.6713       973
           4     0.9709    0.9662    0.9686       829
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8316208848815265
0.8316208848815265
              precision    recall  f1-score   support

           0     0.5298    0.5989    0.5622       698
           1     0.8323    0.8921    0.8612       751
           2     0.9526    0.9562    0.9544       777
           3     0.7103    0.5943    0.6472       949
           4     0.9833    0.9808    0.9821       782
           5     0.9793    0.9901    0.9847       812

    accuracy                         0.8316      4769
   macro avg     0.8313    0.8354    0.8320      4769
weighted avg     0.8331    0.8316    0.8309      4769

########## TEST ###########
0.8354297693920335
0.8354297693920335
              precision    recall  f1-score   support

           0     0.5199    0.6372    0.5726       656
           1     0.8309    0.8863    0.8577       721
           2     0.9452    0.9474    0.9463       837
           3     0.7543    0.5992    0.6678       958
           4     0.9709    0.9662    0.9686       829
 

# CNN

In [17]:
from torch.nn.modules.conv import Conv1d
class Conv1dBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=2, padding=0, max_pool=2, dropout=0.1):
    super(Conv1dBlock, self).__init__()

    self.max_pool = max_pool
    self.l1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding = padding)
    self.l2 = nn.ReLU()
    self.l3 = nn.MaxPool1d(kernel_size=self.max_pool, stride=1, padding=0)
    self.l4 = nn.Dropout(dropout)
  
  def __call__(self, x):
    x = self.l1(x)
    x = self.l2(x)
    if self.max_pool > 0:
      x = self.l3(x)
    x = self.l4(x)
    return x

class ConvNet(nn.Module):
  def __init__(self, hidden_dim, dropout, device, outputs, embedding_dim, kernel_size=2, max_pool=2):
    super(ConvNet, self).__init__()

    self.conv_dim = conv_dim
    self.hidden_dim = hidden_dim
    self.dropout = dropout
    self.device = device
    self.outputs = outputs
    self.embedding_dim = embedding_dim
    self.max_pool = max_pool
    self.kernel_size = kernel_size

    self.e = nn.Embedding(num_embeddings = VOCAB_SIZE, embedding_dim=self.embedding_dim).to(self.device)

    self.c1 = Conv1dBlock(in_channels=self.embedding_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)
    self.c2 = Conv1dBlock(in_channels=self.conv_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)
    self.c3 = Conv1dBlock(in_channels=self.conv_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)

    self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
    self.a1 = nn.ReLU().to(self.device)
    self.d1 = nn.Dropout().to(self.device)
    self.l2 = nn.LazyLinear(self.outputs).to(self.device)

    self.loss = nn.CrossEntropyLoss()
  
  def forward(self, x, y):
    embeddings = self.e(x)
    C, L, I = embeddings.shape
    x = embeddings.reshape(C, I, L)
    x = self.c1(x).to(self.device)
    x = self.c2(x).to(self.device)
    x = self.c3(x).to(self.device)
    x = x.flatten(1)
    x = self.l1(x).to(self.device)
    x = self.a1(x).to(self.device)
    x = self.d1(x).to(self.device)
    x = self.l2(x).to(self.device)
    # x = x[:,0,:]
    loss = self.loss(x, y).to(device)

    return x, loss

  def get_predictions(self, loader):
    with torch.no_grad():
      preds_all = []
      labels_all = []
      accuracies = []

      for tweets, labels in loader:
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        self.eval()

        logits, loss = self.forward(tweets, labels)

        _, predictions = torch.max(logits, axis=1)

        preds_all += list(predictions.cpu())
        labels_all += list(labels.cpu())


        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

    return preds_all, labels_all
  
  def train_model(self, train_loader, val_loader, test_loader, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        tweets, labels = data
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()


        logits, loss = self.forward(tweets, labels)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 500 == 0:
          pred, label = self.get_predictions(val_loader)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)
  


In [18]:

conv_dim = 32
hidden_dim = 128
dropout = 0.1
learning_rate = 5e-04
epochs = 15
outputs = len(labels2id.items())
embedding_dim = 128

convnet = ConvNet(hidden_dim, dropout, device, outputs, embedding_dim, kernel_size=8, max_pool=2)
optimizer = torch.optim.Adam(convnet.parameters(), lr=learning_rate)

for i in range(epochs):
  convnet.train_model(train_loader, val_loader, test_loader, optimizer)



  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.6756133361291675
0.6756133361291675
              precision    recall  f1-score   support

           0     0.4558    0.4623    0.4591       770
           1     0.7052    0.6932    0.6991       828
           2     0.7545    0.7393    0.7468       794
           3     0.4494    0.4594    0.4544       764
           4     0.8269    0.9253    0.8734       723
           5     0.8499    0.7697    0.8078       890

    accuracy                         0.6756      4769
   macro avg     0.6736    0.6749    0.6734      4769
weighted avg     0.6776    0.6756    0.6758      4769

########## TEST ###########
0.6631027253668763
0.6631027253668763
              precision    recall  f1-score   support

           0     0.4130    0.4661    0.4379       723
           1     0.6676    0.6402    0.6536       781
           2     0.7805    0.6987    0.7373       896
           3     0.4724    0.4643    0.4683       812
           4     0.8170    0.9262    0.8681       718
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.728244915076536
0.7282449150765359
              precision    recall  f1-score   support

           0     0.7324    0.4370    0.5474      1309
           1     0.6978    0.8806    0.7786       645
           2     0.7931    0.9168    0.8504       673
           3     0.2650    0.4882    0.3436       424
           4     0.9184    0.9346    0.9264       795
           5     0.9504    0.8299    0.8861       923

    accuracy                         0.7282      4769
   macro avg     0.7262    0.7478    0.7221      4769
weighted avg     0.7679    0.7282    0.7320      4769

########## TEST ###########
0.7182389937106918
0.7182389937106918
              precision    recall  f1-score   support

           0     0.7243    0.4434    0.5500      1333
           1     0.6542    0.8305    0.7319       590
           2     0.8180    0.8877    0.8514       739
           3     0.2481    0.4681    0.3243       423
           4     0.9201    0.9328    0.9264       803
  

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7498427343258545
0.7498427343258545
              precision    recall  f1-score   support

           0     0.7375    0.4384    0.5499      1314
           1     0.7322    0.8701    0.7952       685
           2     0.9087    0.8570    0.8821       825
           3     0.2292    0.5219    0.3185       343
           4     0.9394    0.9596    0.9494       792
           5     0.9404    0.9358    0.9381       810

    accuracy                         0.7498      4769
   macro avg     0.7479    0.7638    0.7389      4769
weighted avg     0.7978    0.7498    0.7582      4769

########## TEST ###########
0.7379454926624738
0.7379454926624738
              precision    recall  f1-score   support

           0     0.7255    0.4441    0.5510      1333
           1     0.6943    0.8525    0.7653       610
           2     0.9202    0.8405    0.8786       878
           3     0.2193    0.4679    0.2986       374
           4     0.9312    0.9559    0.9434       793
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7624239882574964
0.7624239882574964
              precision    recall  f1-score   support

           0     0.8361    0.4394    0.5761      1486
           1     0.7420    0.9055    0.8157       667
           2     0.8766    0.9459    0.9099       721
           3     0.2036    0.5521    0.2975       288
           4     0.9567    0.9603    0.9585       806
           5     0.9479    0.9538    0.9508       801

    accuracy                         0.7624      4769
   macro avg     0.7605    0.7928    0.7514      4769
weighted avg     0.8300    0.7624    0.7708      4769

########## TEST ###########
0.7450733752620545
0.7450733752620545
              precision    recall  f1-score   support

           0     0.8064    0.4276    0.5588      1539
           1     0.7009    0.8663    0.7749       606
           2     0.8928    0.9323    0.9121       768
           3     0.1742    0.4982    0.2581       279
           4     0.9496    0.9543    0.9520       810
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7657789893059341
0.765778989305934
              precision    recall  f1-score   support

           0     0.7798    0.4545    0.5743      1340
           1     0.7666    0.8914    0.8243       700
           2     0.8946    0.9243    0.9092       753
           3     0.2087    0.5601    0.3041       291
           4     0.9629    0.9558    0.9594       815
           5     0.9690    0.8977    0.9320       870

    accuracy                         0.7658      4769
   macro avg     0.7636    0.7806    0.7505      4769
weighted avg     0.8269    0.7658    0.7784      4769

########## TEST ###########
0.7574423480083857
0.7574423480083857
              precision    recall  f1-score   support

           0     0.7696    0.4499    0.5678      1396
           1     0.7330    0.8659    0.7939       634
           2     0.9177    0.9031    0.9103       815
           3     0.1930    0.5404    0.2844       285
           4     0.9521    0.9580    0.9550       809
  

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7798280561962676
0.7798280561962676
              precision    recall  f1-score   support

           0     0.6338    0.4679    0.5383      1058
           1     0.7592    0.8970    0.8224       689
           2     0.9062    0.9325    0.9192       756
           3     0.4392    0.5368    0.4831       639
           4     0.9654    0.9642    0.9648       810
           5     0.9640    0.9510    0.9575       817

    accuracy                         0.7798      4769
   macro avg     0.7780    0.7916    0.7809      4769
weighted avg     0.7819    0.7798    0.7766      4769

########## TEST ###########
0.7744234800838574
0.7744234800838574
              precision    recall  f1-score   support

           0     0.6262    0.4780    0.5422      1069
           1     0.7370    0.8693    0.7977       635
           2     0.9239    0.9286    0.9263       798
           3     0.4348    0.5195    0.4734       668
           4     0.9558    0.9629    0.9593       808
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7966030614384567
0.7966030614384567
              precision    recall  f1-score   support

           0     0.5531    0.5199    0.5360       831
           1     0.7666    0.9109    0.8326       685
           2     0.9229    0.9229    0.9229       778
           3     0.5992    0.5467    0.5718       856
           4     0.9604    0.9712    0.9658       800
           5     0.9677    0.9524    0.9600       819

    accuracy                         0.7966      4769
   macro avg     0.7950    0.8040    0.7982      4769
weighted avg     0.7919    0.7966    0.7930      4769

########## TEST ###########
0.7832285115303983
0.7832285115303983
              precision    recall  f1-score   support

           0     0.5147    0.5303    0.5224       792
           1     0.7316    0.8712    0.7954       629
           2     0.9302    0.9199    0.9250       811
           3     0.6028    0.5117    0.5535       940
           4     0.9582    0.9714    0.9647       803
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8018452505766408
0.8018452505766407
              precision    recall  f1-score   support

           0     0.4289    0.5877    0.4959       570
           1     0.7826    0.9087    0.8409       701
           2     0.9139    0.9368    0.9252       759
           3     0.7426    0.5316    0.6197      1091
           4     0.9666    0.9525    0.9595       821
           5     0.9665    0.9420    0.9541       827

    accuracy                         0.8018      4769
   macro avg     0.8002    0.8099    0.7992      4769
weighted avg     0.8156    0.8018    0.8025      4769

########## TEST ###########
0.7861635220125787
0.7861635220125787
              precision    recall  f1-score   support

           0     0.3725    0.5769    0.4527       527
           1     0.7490    0.8711    0.8055       644
           2     0.9277    0.9288    0.9283       801
           3     0.7381    0.5009    0.5968      1176
           4     0.9644    0.9573    0.9608       820
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7986999370937303
0.7986999370937302
              precision    recall  f1-score   support

           0     0.6056    0.5192    0.5591       911
           1     0.7838    0.9024    0.8389       707
           2     0.9075    0.9502    0.9283       743
           3     0.5493    0.5571    0.5532       770
           4     0.9617    0.9749    0.9683       798
           5     0.9739    0.9345    0.9538       840

    accuracy                         0.7987      4769
   macro avg     0.7970    0.8064    0.8003      4769
weighted avg     0.7944    0.7987    0.7951      4769

########## TEST ###########
0.7895178197064989
0.7895178197064989
              precision    recall  f1-score   support

           0     0.5846    0.5330    0.5576       895
           1     0.7543    0.8561    0.8020       660
           2     0.9264    0.9453    0.9358       786
           3     0.5464    0.5266    0.5363       828
           4     0.9558    0.9786    0.9671       795
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8075068148458796
0.8075068148458796
              precision    recall  f1-score   support

           0     0.5403    0.5524    0.5463       764
           1     0.7887    0.8954    0.8387       717
           2     0.9242    0.9473    0.9356       759
           3     0.6440    0.5570    0.5974       903
           4     0.9691    0.9715    0.9703       807
           5     0.9690    0.9536    0.9612       819

    accuracy                         0.8075      4769
   macro avg     0.8059    0.8129    0.8082      4769
weighted avg     0.8046    0.8075    0.8049      4769

########## TEST ###########
0.7958071278825996
0.7958071278825996
              precision    recall  f1-score   support

           0     0.5025    0.5474    0.5240       749
           1     0.7757    0.8698    0.8200       668
           2     0.9314    0.9528    0.9420       784
           3     0.6353    0.5227    0.5735       970
           4     0.9631    0.9751    0.9691       804
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8098133780666806
0.8098133780666806
              precision    recall  f1-score   support

           0     0.4904    0.5929    0.5368       646
           1     0.7998    0.8979    0.8460       725
           2     0.9165    0.9236    0.9200       772
           3     0.6927    0.5647    0.6222       958
           4     0.9716    0.9621    0.9668       817
           5     0.9777    0.9260    0.9511       851

    accuracy                         0.8098      4769
   macro avg     0.8081    0.8112    0.8071      4769
weighted avg     0.8164    0.8098    0.8106      4769

########## TEST ###########
0.7960167714884696
0.7960167714884696
              precision    recall  f1-score   support

           0     0.4522    0.5885    0.5114       627
           1     0.7664    0.8723    0.8159       658
           2     0.9327    0.9280    0.9303       806
           3     0.6867    0.5249    0.5950      1044
           4     0.9644    0.9632    0.9638       815
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8125393164185364
0.8125393164185364
              precision    recall  f1-score   support

           0     0.4840    0.5870    0.5305       644
           1     0.8145    0.8876    0.8495       747
           2     0.9216    0.9472    0.9342       757
           3     0.7017    0.5615    0.6238       976
           4     0.9691    0.9549    0.9620       821
           5     0.9739    0.9527    0.9632       824

    accuracy                         0.8125      4769
   macro avg     0.8108    0.8151    0.8105      4769
weighted avg     0.8179    0.8125    0.8127      4769

########## TEST ###########
0.7976939203354297
0.7976939203354297
              precision    recall  f1-score   support

           0     0.4522    0.5642    0.5020       654
           1     0.7757    0.8724    0.8212       666
           2     0.9314    0.9468    0.9390       789
           3     0.6892    0.5304    0.5995      1037
           4     0.9705    0.9587    0.9646       824
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8098133780666806
0.8098133780666806
              precision    recall  f1-score   support

           0     0.5595    0.5483    0.5539       797
           1     0.7789    0.9175    0.8425       691
           2     0.9422    0.9349    0.9385       784
           3     0.6389    0.5664    0.6005       881
           4     0.9629    0.9774    0.9701       797
           5     0.9677    0.9524    0.9600       819

    accuracy                         0.8098      4769
   macro avg     0.8084    0.8162    0.8109      4769
weighted avg     0.8064    0.8098    0.8069      4769

########## TEST ###########
0.7987421383647799
0.7987421383647799
              precision    recall  f1-score   support

           0     0.5233    0.5553    0.5388       769
           1     0.7530    0.9038    0.8216       624
           2     0.9501    0.9338    0.9419       816
           3     0.6366    0.5275    0.5769       963
           4     0.9595    0.9726    0.9660       803
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8096036905011533
0.8096036905011533
              precision    recall  f1-score   support

           0     0.4904    0.5682    0.5265       674
           1     0.7789    0.9337    0.8493       679
           2     0.9319    0.9319    0.9319       778
           3     0.7157    0.5406    0.6160      1034
           4     0.9666    0.9775    0.9720       800
           5     0.9653    0.9677    0.9665       804

    accuracy                         0.8096      4769
   macro avg     0.8081    0.8199    0.8104      4769
weighted avg     0.8123    0.8096    0.8069      4769

########## TEST ###########
0.8
0.8000000000000002
              precision    recall  f1-score   support

           0     0.4718    0.5755    0.5185       669
           1     0.7503    0.9065    0.8210       620
           2     0.9377    0.9435    0.9406       797
           3     0.7155    0.5196    0.6020      1099
           4     0.9558    0.9786    0.9671       795
           5    

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8056196267561334
0.8056196267561334
              precision    recall  f1-score   support

           0     0.4942    0.5522    0.5216       699
           1     0.7838    0.9154    0.8445       697
           2     0.9344    0.9491    0.9417       766
           3     0.6620    0.5391    0.5943       959
           4     0.9740    0.9563    0.9651       824
           5     0.9752    0.9539    0.9644       824

    accuracy                         0.8056      4769
   macro avg     0.8039    0.8110    0.8053      4769
weighted avg     0.8070    0.8056    0.8040      4769

########## TEST ###########
0.8064989517819706
0.8064989517819706
              precision    recall  f1-score   support

           0     0.5233    0.5882    0.5538       726
           1     0.7583    0.9045    0.8250       628
           2     0.9352    0.9518    0.9434       788
           3     0.6754    0.5401    0.6002       998
           4     0.9730    0.9588    0.9659       826
 

# Transformer Models

In [19]:
data

Unnamed: 0,tweet_text,cyberbullying_type,label,cleaned_text
0,"In other words #katandandre, your food was cra...",not_cyberbullying,0,"[in, other, words, katandandre, your, food, wa..."
1,Why is #aussietv so white? #MKR #theblock #ImA...,not_cyberbullying,0,"[why, is, aussietv, so, white, mkr, theblock, ..."
2,@XochitlSuckkks a classy whore? Or more red ve...,not_cyberbullying,0,"[a, classy, whore, or, more, red, velvet, cupc..."
3,"@Jason_Gio meh. :P thanks for the heads up, b...",not_cyberbullying,0,"[meh, p, thanks, for, the, heads, up, but, not..."
4,@RudhoeEnglish This is an ISIS account pretend...,not_cyberbullying,0,"[this, is, an, isis, account, pretending, to, ..."
...,...,...,...,...
47687,"Black ppl aren't expected to do anything, depe...",ethnicity,5,"[black, ppl, are, not, expected, to, do, anyth..."
47688,Turner did not withhold his disappointment. Tu...,ethnicity,5,"[turner, did, not, withhold, his, disappointme..."
47689,I swear to God. This dumb nigger bitch. I have...,ethnicity,5,"[i, swear, to, god, this, dumb, nigger, bitch,..."
47690,Yea fuck you RT @therealexel: IF YOURE A NIGGE...,ethnicity,5,"[yea, fuck, you, rt, if, youre, a, nigger, fuc..."


In [20]:
BATCH_SIZE = 50
del convnet
del mlp

In [21]:
model_type = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=False)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [22]:
tokenized_text = list(data['cleaned_text'])
tokenized_text = [' '.join(x) for x in tokenized_text]
tokenized_text = tokenizer(tokenized_text, padding = 'max_length', max_length = PADDING_VALUE, truncation=True, add_special_tokens=True)
x_ids = torch.tensor([f for f in tokenized_text.input_ids], dtype=torch.long)
x_mask = torch.tensor([f for f in tokenized_text.attention_mask], dtype=torch.long)
label_tensor = torch.tensor(data.label)

In [23]:
tensor_dataset = TensorDataset(x_ids, x_mask, label_tensor)

In [24]:
train, val_test = random_split(tensor_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT], generator=torch.Generator().manual_seed(13))
val, test = random_split(val_test, [VAL_SPLIT, TEST_SPLIT], generator=torch.Generator().manual_seed(13))

print(len(train), len(val), len(test))

38153 4769 4770


In [25]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE)
val_loader = DataLoader(val, batch_size=BATCH_SIZE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE)

In [26]:
class transformer_classifier(nn.Module):
  def __init__(self, model_type, hidden_dim, dropout, device ):
      super(transformer_classifier, self).__init__()

      self.device = device
      self.hidden_dim = hidden_dim
      self.LABELS_NUM = len(labels2id.items())

      self.lm = AutoModel.from_pretrained(model_type).to(self.device)

      config = config = AutoConfig.from_pretrained(model_type, num_labels = self.LABELS_NUM)
      self.lm_sequence_classification = AutoModelForSequenceClassification.from_pretrained(model_type, config=config).to(self.device)

      self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
      self.a1 = nn.GELU().to(self.device)
      self.d1 = nn.Dropout(dropout).to(self.device)
      self.l2 = nn.LazyLinear(self.LABELS_NUM).to(self.device)

      self.loss = nn.CrossEntropyLoss()


  def forward(self, ids, masks, labels, mode):
    if mode:
      output = self.lm_sequence_classification(ids, masks, labels=labels)
      return output.logits.to(self.device), output.loss.to(self.device)
    else:
      x = self.lm(ids, masks).last_hidden_state[:, 0].to(self.device)
      x = self.l1(x).to(self.device)
      x = self.a1(x).to(self.device)
      x = self.d1(x).to(self.device)
      x = self.l2(x).to(self.device)

      loss = self.loss(x, labels).to(self.device)

      return x, loss

  def get_predictions(self, loader, mode):
    with torch.no_grad():
      with tqdm.notebook.tqdm(
      loader,
      unit="batch",
      total=len(loader)) as batch_iterator:

        preds_all = []
        labels_all = []
        accuracies = []

        for iteration, data in enumerate(batch_iterator, start=1):
          ids, masks, labels = data
          ids = ids.to(self.device)
          masks = masks.to(self.device)
          labels = labels.to(self.device)

          self.eval()

          logits, loss = self.forward(ids, masks, labels, mode)

          _, predictions = torch.max(logits, axis=1)

          preds_all += list(predictions.cpu())
          labels_all += list(labels.cpu())


          accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
          accuracies.append(accuracy)

          batch_iterator.set_postfix(accuracy = accuracy, mean_accuracy = np.mean(accuracies))
    return preds_all, labels_all

  def train_model(self, train_loader, val_loader, test_loader, mode, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        ids, masks, labels = data
        ids = ids.to(self.device)
        masks = masks.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()

        logits, loss = self.forward(ids, masks, labels, mode)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 400 == 0:
          pred, label = self.get_predictions(val_loader, mode)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader, mode)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          



In [27]:
hidden_dim = 128
dropout = 0.2
learning_rate = 1e-04
epochs = 5

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, False, optimizer)

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, True, optimizer)

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerN

  0%|          | 0/764 [00:00<?, ?batch/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 150.00 MiB (GPU 0; 8.00 GiB total capacity; 7.05 GiB already allocated; 0 bytes free; 7.15 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [28]:
model_type

'bert-base-uncased'

In [29]:
model_type = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=False)
tokenized_text = list(data['cleaned_text'])
tokenized_text = [' '.join(x) for x in tokenized_text]
tokenized_text = tokenizer(tokenized_text, padding = 'max_length', max_length = PADDING_VALUE, truncation=True, add_special_tokens=True)
x_ids = torch.tensor([f for f in tokenized_text.input_ids], dtype=torch.long)
x_mask = torch.tensor([f for f in tokenized_text.attention_mask], dtype=torch.long)
label_tensor = torch.tensor(data.label)
tensor_dataset = TensorDataset(x_ids, x_mask, label_tensor)
train, val_test = random_split(tensor_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT], generator=torch.Generator().manual_seed(13))
val, test = random_split(val_test, [VAL_SPLIT, TEST_SPLIT], generator=torch.Generator().manual_seed(13))
train_loader = DataLoader(train, batch_size=BATCH_SIZE)
val_loader = DataLoader(val, batch_size=BATCH_SIZE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [30]:
classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, False, optimizer)

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, True, optimizer)

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 8.00 GiB total capacity; 7.18 GiB already allocated; 0 bytes free; 7.26 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF