In [1]:
import torch
import torch.nn as nn
import pandas
import re, string
import emoji
from collections import Counter, defaultdict
import torchtext

from sklearn.metrics import f1_score, accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset, SequentialSampler, random_split)
import tqdm
import numpy as np

from torch import cuda
RANDOM_STATE = 42
device = 'cuda' if cuda.is_available() else 'cpu'
device

2022-11-28 18:23:45.973810: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


'cuda'

In [2]:
data = pandas.read_csv('cyberbullying_tweets.csv')

In [3]:
data

Unnamed: 0,tweet_text,cyberbullying_type
0,"In other words #katandandre, your food was cra...",not_cyberbullying
1,Why is #aussietv so white? #MKR #theblock #ImA...,not_cyberbullying
2,@XochitlSuckkks a classy whore? Or more red ve...,not_cyberbullying
3,"@Jason_Gio meh. :P thanks for the heads up, b...",not_cyberbullying
4,@RudhoeEnglish This is an ISIS account pretend...,not_cyberbullying
...,...,...
47687,"Black ppl aren't expected to do anything, depe...",ethnicity
47688,Turner did not withhold his disappointment. Tu...,ethnicity
47689,I swear to God. This dumb nigger bitch. I have...,ethnicity
47690,Yea fuck you RT @therealexel: IF YOURE A NIGGE...,ethnicity


In [4]:
data.loc[data['cyberbullying_type'] == 'other_cyberbullying']

Unnamed: 0,tweet_text,cyberbullying_type
23916,"@ikralla fyi, it looks like I was caught by it...",other_cyberbullying
23917,I need to just switch to an organization-based...,other_cyberbullying
23918,RMAed my monoprice. Shoddy power bricks on tho...,other_cyberbullying
23919,@murphy_slaw https://t.co/M8w8xnUnDL,other_cyberbullying
23920,@1Life0Continues i've got the code to interpre...,other_cyberbullying
...,...,...
31734,"@kufr666 @blockbot no, that's @oolon",other_cyberbullying
31735,@AriMelber why are you giving these idiots air...,other_cyberbullying
31736,I am right now watching Enforcers defend Chums...,other_cyberbullying
31737,✨✨✨ misandry is not a word iOS can autocomplet...,other_cyberbullying


In [5]:
labels = list(data.cyberbullying_type.unique())
labels2id = {label: i for i, label in enumerate(labels)}
print(labels2id)

{'not_cyberbullying': 0, 'gender': 1, 'religion': 2, 'other_cyberbullying': 3, 'age': 4, 'ethnicity': 5}


In [6]:
# map labels to ints
data['label'] = data.cyberbullying_type.map(labels2id)

In [7]:
data

Unnamed: 0,tweet_text,cyberbullying_type,label
0,"In other words #katandandre, your food was cra...",not_cyberbullying,0
1,Why is #aussietv so white? #MKR #theblock #ImA...,not_cyberbullying,0
2,@XochitlSuckkks a classy whore? Or more red ve...,not_cyberbullying,0
3,"@Jason_Gio meh. :P thanks for the heads up, b...",not_cyberbullying,0
4,@RudhoeEnglish This is an ISIS account pretend...,not_cyberbullying,0
...,...,...,...
47687,"Black ppl aren't expected to do anything, depe...",ethnicity,5
47688,Turner did not withhold his disappointment. Tu...,ethnicity,5
47689,I swear to God. This dumb nigger bitch. I have...,ethnicity,5
47690,Yea fuck you RT @therealexel: IF YOURE A NIGGE...,ethnicity,5


In [8]:
#Clean emojis from text
def strip_emoji(text):
    return emoji.replace_emoji(text)
    # return re.sub(emoji.get_emoji_regexp(), r"", text) #remove emoji
    
def strip_all_entities(text): 
    text = text.replace('\r', '').replace('\n', ' ').lower() #remove \n and \r and lowercase
    text = re.sub(r"(?:\@|https?\://)\S+", "", text) #remove links and mentions
    text = re.sub(r'[^\x00-\x7f]',r'', text) #remove non utf8/ascii characters such as '\x9a\x91\x97\x9a\x97'
    banned_list= string.punctuation
    table = str.maketrans('', '', banned_list)
    text = text.translate(table)
    text = [word for word in text.split()]
    text = ' '.join(text)
    return text

#remove contractions
def decontract(text):
    text = re.sub(r"can\'t", "can not", text)
    text = re.sub(r"won\'t", "will not", text)
    text = re.sub(r"n\'t", " not", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"\'s", " is", text)
    text = re.sub(r"\'d", " would", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'t", " not", text)
    text = re.sub(r"\'ve", " have", text)
    text = re.sub(r"\'m", " am", text)
    return text

def clean_hashtags(tweet):
    # new_tweet = " ".join(word.strip() for word in re.split('#(?!(?:hashtag)\b)[\w-]+(?=(?:\s+#[\w-]+)*\s*$)', tweet)) #remove last hashtags
    new_tweet2 = " ".join(word.strip() for word in re.split('#|_', tweet)) #remove hashtags symbol from words in the middle of the sentence
    return new_tweet2

#Filter special characters such as "&" and "$" present in some words
def filter_chars(a):
    sent = []
    for word in a.split(' '):
        if ('$' in word) | ('&' in word):
            sent.append('')
        else:
            sent.append(word)
    return ' '.join(sent)

#Remove multiple sequential spaces
def remove_mult_spaces(text):
    return re.sub("\s\s+" , " ", text)


#Then we apply all the defined functions in the following order
def deep_clean(text):
    text = strip_emoji(text)
    text = decontract(text)
    text = strip_all_entities(text)
    text = clean_hashtags(text)
    text = filter_chars(text)
    text = remove_mult_spaces(text)
    return text

In [9]:
data.to_csv('cyberbullying_tweets_cleaned.csv', index=False)

In [10]:
PADDING_VALUE = 256
data['cleaned_text'] =  [deep_clean(x).split()[:256] for x in data.tweet_text]

In [21]:
MIN_FREQ = 5
word_counts = Counter([item for sublist in data['cleaned_text'] for item in sublist])
vocab = torchtext.vocab.vocab(word_counts,
                              min_freq = MIN_FREQ,
                              specials = ['<pad>', '<unk>', ])
vocab.set_default_index(1)

VOCAB_SIZE = len(vocab)

In [22]:
try: 
    with open('trainvaltest_idx.npy', 'rb') as f:
        train_idx = np.load(f)
        val_idx = np.load(f)
        test_idx = np.load(f)
except:
    # stratified 80/10/10 train val test split
    TRAIN_SIZE=0.8
    indicies = np.arange(0,len(labels),1)

    train_idx, valtest_idx = train_test_split(indicies, train_size=TRAIN_SIZE, random_state=RANDOM_STATE, shuffle=True, stratify=labels)
    val_idx, test_idx = train_test_split(valtest_idx, train_size=0.5, random_state=RANDOM_STATE, shuffle=True, stratify=labels[valtest_idx])
    with open('trainvaltest_idx.npy', 'wb') as f:
        np.save(f, train_idx)
        np.save(f, val_idx)
        np.save(f, test_idx)

# y train-val-test
labels = torch.tensor(data.label, dtype = torch.int64)
ytrain = labels[train_idx]
yval = labels[val_idx]
ytest = labels[test_idx]

tokenized_data = [torch.tensor(vocab(x), dtype = torch.int64) for x in data.cleaned_text]

# x train-val-test for cnn
# pad tokenized_data with 0s to same length as the longest sequence
tokenized_data = nn.utils.rnn.pad_sequence(tokenized_data, batch_first=True)

xtrain = tokenized_data[train_idx]
xval = tokenized_data[val_idx]
xtest = tokenized_data[test_idx]
train = torch.utils.data.TensorDataset(xtrain, ytrain)
val = torch.utils.data.TensorDataset(xval, yval)
test = torch.utils.data.TensorDataset(xtest, ytest)
print(len(train), len(val), len(test))

# x train-val-test for bag of words (frequency of token)
bow_tokenized_data = torch.zeros(tokenized_data.shape[0], VOCAB_SIZE)
for i in range(tokenized_data.shape[0]):
  tweet = tokenized_data[i,tokenized_data[i].nonzero().flatten()]
  bow_tokenized_data[i].put_(tweet, torch.ones_like(tweet).float(), accumulate=True)

bow_tokenized_data = nn.utils.rnn.pad_sequence(bow_tokenized_data, batch_first=True)
bow_xtrain = bow_tokenized_data[train_idx]
bow_xval = bow_tokenized_data[val_idx]
bow_xtest = bow_tokenized_data[test_idx]
bow_train = torch.utils.data.TensorDataset(bow_xtrain, ytrain)
bow_val = torch.utils.data.TensorDataset(bow_xval, yval)
bow_test = torch.utils.data.TensorDataset(bow_xtest, ytest)
print(len(bow_train), len(bow_val), len(bow_test))

38153 4769 4770
38153 4769 4770


In [23]:
BATCH_SIZE = 64
train_loader = torch.utils.data.DataLoader(train, BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val, BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test, BATCH_SIZE)

bow_train_loader = torch.utils.data.DataLoader(bow_train, BATCH_SIZE)
bow_val_loader = torch.utils.data.DataLoader(bow_val, BATCH_SIZE)
bow_test_loader = torch.utils.data.DataLoader(bow_test, BATCH_SIZE)

In [24]:
# augment data


# BOW

In [25]:
class MLP(nn.Module):
  def __init__(self, hidden_dim, dropout, device, outputs):
    super(MLP, self).__init__()

    self.hidden_dim = hidden_dim
    self.dropout = dropout
    self.device = device
    self.outputs = outputs
    
    self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
    self.a1 = nn.ReLU().to(self.device)
    self.d1 = nn.Dropout().to(self.device)
    self.l2 = nn.LazyLinear(self.outputs).to(self.device)

    self.loss = nn.CrossEntropyLoss()
  
  def forward(self, x, y):
    x = self.l1(x).to(self.device)
    x = self.a1(x).to(self.device)
    x = self.d1(x).to(self.device)
    x = self.l2(x).to(self.device)

    loss = self.loss(x, y).to(device)

    return x, loss

  def get_predictions(self, loader):
    with torch.no_grad():
      preds_all = []
      labels_all = []
      accuracies = []

      for tweets, labels in loader:
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        self.eval()

        logits, loss = self.forward(tweets, labels)

        _, predictions = torch.max(logits, axis=1)

        preds_all += list(predictions.cpu())
        labels_all += list(labels.cpu())


        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

    return preds_all, labels_all
  
  def train_model(self, train_loader, val_loader, test_loader, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        tweets, labels = data
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()

        logits, loss = self.forward(tweets, labels)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 500 == 0:
          pred, label = self.get_predictions(val_loader)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)
  


In [26]:
hidden_dim = 128
dropout = 0.
learning_rate = 1e-04
epochs = 15
outputs = len(labels2id.items())

mlp = MLP(hidden_dim, dropout, device, outputs)
optimizer = torch.optim.Adam(mlp.parameters(), lr=learning_rate)

for i in range(epochs):
  mlp.train_model(bow_train_loader, bow_val_loader, bow_test_loader, optimizer)



  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7322289788215559
0.732228978821556
              precision    recall  f1-score   support

           0     0.4295    0.5702    0.4899       598
           1     0.8130    0.7440    0.7770       871
           2     0.9663    0.7118    0.8197      1086
           3     0.2465    0.6108    0.3512       316
           4     0.9712    0.7512    0.8472      1033
           5     0.9560    0.8798    0.9163       865

    accuracy                         0.7322      4769
   macro avg     0.7304    0.7113    0.7002      4769
weighted avg     0.8225    0.7322    0.7630      4769

########## TEST ###########
0.7247379454926625
0.7247379454926625
              precision    recall  f1-score   support

           0     0.3975    0.5593    0.4647       565
           1     0.8095    0.7267    0.7659       889
           2     0.9738    0.7075    0.8196      1101
           3     0.2289    0.5424    0.3219       330
           4     0.9762    0.7640    0.8571      1021
  

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7976514992660935
0.7976514992660935
              precision    recall  f1-score   support

           0     0.5277    0.5594    0.5431       749
           1     0.7842    0.8840    0.8311       707
           2     0.9463    0.8990    0.9220       842
           3     0.5747    0.5711    0.5729       788
           4     0.9787    0.8947    0.9348       874
           5     0.9686    0.9530    0.9607       809

    accuracy                         0.7977      4769
   macro avg     0.7967    0.7936    0.7941      4769
weighted avg     0.8048    0.7977    0.8003      4769

########## TEST ###########
0.8010482180293501
0.80104821802935
              precision    recall  f1-score   support

           0     0.5082    0.5504    0.5284       734
           1     0.7832    0.8865    0.8317       705
           2     0.9550    0.9041    0.9289       845
           3     0.6010    0.5642    0.5820       833
           4     0.9800    0.9288    0.9537       843
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8072971272803523
0.8072971272803523
              precision    recall  f1-score   support

           0     0.5113    0.5647    0.5367       719
           1     0.7792    0.9146    0.8415       679
           2     0.9450    0.9322    0.9385       811
           3     0.6577    0.5616    0.6059       917
           4     0.9775    0.9221    0.9490       847
           5     0.9686    0.9686    0.9686       796

    accuracy                         0.8073      4769
   macro avg     0.8065    0.8106    0.8067      4769
weighted avg     0.8105    0.8073    0.8070      4769

########## TEST ###########
0.8127882599580712
0.8127882599580712
              precision    recall  f1-score   support

           0     0.5208    0.5687    0.5437       728
           1     0.7769    0.9078    0.8373       683
           2     0.9487    0.9440    0.9464       804
           3     0.6675    0.5631    0.6109       927
           4     0.9837    0.9516    0.9674       826
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8150555672048647
0.8150555672048647
              precision    recall  f1-score   support

           0     0.5063    0.5826    0.5418       690
           1     0.7992    0.9100    0.8510       700
           2     0.9437    0.9449    0.9443       799
           3     0.6935    0.5668    0.6238       958
           4     0.9762    0.9386    0.9571       831
           5     0.9673    0.9735    0.9704       791

    accuracy                         0.8151      4769
   macro avg     0.8144    0.8194    0.8147      4769
weighted avg     0.8185    0.8151    0.8145      4769

########## TEST ###########
0.8176100628930818
0.8176100628930818
              precision    recall  f1-score   support

           0     0.5145    0.5860    0.5479       698
           1     0.7882    0.9037    0.8420       696
           2     0.9463    0.9498    0.9480       797
           3     0.6944    0.5656    0.6234       960
           4     0.9825    0.9573    0.9697       820
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8184105682533026
0.8184105682533026
              precision    recall  f1-score   support

           0     0.4798    0.6038    0.5347       631
           1     0.8043    0.9092    0.8535       705
           2     0.9450    0.9450    0.9450       800
           3     0.7344    0.5637    0.6378      1020
           4     0.9762    0.9466    0.9612       824
           5     0.9673    0.9759    0.9716       789

    accuracy                         0.8184      4769
   macro avg     0.8178    0.8240    0.8173      4769
weighted avg     0.8267    0.8184    0.8187      4769

########## TEST ###########
0.8276729559748428
0.8276729559748428
              precision    recall  f1-score   support

           0     0.5094    0.6212    0.5598       652
           1     0.7970    0.9086    0.8491       700
           2     0.9450    0.9558    0.9503       791
           3     0.7519    0.5776    0.6533      1018
           4     0.9812    0.9620    0.9715       815
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8253302579157056
0.8253302579157056
              precision    recall  f1-score   support

           0     0.5000    0.6203    0.5537       640
           1     0.8105    0.9060    0.8556       713
           2     0.9450    0.9533    0.9492       793
           3     0.7510    0.5742    0.6508      1024
           4     0.9750    0.9570    0.9659       814
           5     0.9673    0.9809    0.9741       785

    accuracy                         0.8253      4769
   macro avg     0.8248    0.8320    0.8249      4769
weighted avg     0.8323    0.8253    0.8250      4769

########## TEST ###########
0.8324947589098532
0.8324947589098532
              precision    recall  f1-score   support

           0     0.5182    0.6280    0.5679       656
           1     0.8008    0.9116    0.8526       701
           2     0.9450    0.9618    0.9533       786
           3     0.7673    0.5888    0.6663      1019
           4     0.9837    0.9632    0.9734       816
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8305724470538897
0.8305724470538898
              precision    recall  f1-score   support

           0     0.5126    0.6252    0.5633       651
           1     0.8168    0.9182    0.8645       709
           2     0.9500    0.9560    0.9530       795
           3     0.7573    0.5802    0.6571      1022
           4     0.9750    0.9653    0.9701       807
           5     0.9686    0.9822    0.9753       785

    accuracy                         0.8306      4769
   macro avg     0.8301    0.8378    0.8306      4769
weighted avg     0.8365    0.8306    0.8298      4769

########## TEST ###########
0.8373165618448637
0.8373165618448637
              precision    recall  f1-score   support

           0     0.5258    0.6411    0.5777       652
           1     0.8120    0.9114    0.8588       711
           2     0.9463    0.9668    0.9564       783
           3     0.7813    0.5949    0.6755      1027
           4     0.9800    0.9691    0.9745       808
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8328790102746907
0.8328790102746906
              precision    recall  f1-score   support

           0     0.5189    0.6280    0.5683       656
           1     0.8243    0.9189    0.8690       715
           2     0.9500    0.9548    0.9524       796
           3     0.7535    0.5859    0.6592      1007
           4     0.9775    0.9654    0.9714       809
           5     0.9698    0.9822    0.9760       786

    accuracy                         0.8329      4769
   macro avg     0.8323    0.8392    0.8327      4769
weighted avg     0.8383    0.8329    0.8323      4769

########## TEST ###########
0.839832285115304
0.839832285115304
              precision    recall  f1-score   support

           0     0.5384    0.6417    0.5855       667
           1     0.8183    0.9146    0.8638       714
           2     0.9450    0.9668    0.9558       782
           3     0.7775    0.6020    0.6786      1010
           4     0.9825    0.9715    0.9770       808
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8332983854057454
0.8332983854057454
              precision    recall  f1-score   support

           0     0.5202    0.6239    0.5673       662
           1     0.8269    0.9217    0.8717       715
           2     0.9513    0.9572    0.9542       795
           3     0.7522    0.5849    0.6581      1007
           4     0.9762    0.9689    0.9726       805
           5     0.9698    0.9834    0.9766       785

    accuracy                         0.8333      4769
   macro avg     0.8328    0.8400    0.8334      4769
weighted avg     0.8380    0.8333    0.8324      4769

########## TEST ###########
0.8368972746331237
0.8368972746331237
              precision    recall  f1-score   support

           0     0.5396    0.6290    0.5809       682
           1     0.8170    0.9093    0.8607       717
           2     0.9425    0.9667    0.9544       780
           3     0.7634    0.5994    0.6715       996
           4     0.9825    0.9727    0.9776       807
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8351855734954917
0.8351855734954916
              precision    recall  f1-score   support

           0     0.5189    0.6329    0.5702       651
           1     0.8306    0.9194    0.8728       720
           2     0.9500    0.9560    0.9530       795
           3     0.7637    0.5886    0.6648      1016
           4     0.9750    0.9713    0.9731       802
           5     0.9698    0.9834    0.9766       785

    accuracy                         0.8352      4769
   macro avg     0.8347    0.8419    0.8351      4769
weighted avg     0.8409    0.8352    0.8345      4769

########## TEST ###########
0.8402515723270441
0.8402515723270441
              precision    recall  f1-score   support

           0     0.5472    0.6416    0.5906       678
           1     0.8195    0.9046    0.8600       723
           2     0.9437    0.9679    0.9557       780
           3     0.7737    0.6056    0.6794       999
           4     0.9800    0.9763    0.9781       802
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.834766198364437
0.834766198364437
              precision    recall  f1-score   support

           0     0.5252    0.6243    0.5705       668
           1     0.8256    0.9190    0.8698       716
           2     0.9525    0.9573    0.9549       796
           3     0.7573    0.5906    0.6637      1004
           4     0.9750    0.9725    0.9738       801
           5     0.9698    0.9847    0.9772       784

    accuracy                         0.8348      4769
   macro avg     0.8342    0.8414    0.8350      4769
weighted avg     0.8391    0.8348    0.8338      4769

########## TEST ###########
0.8379454926624738
0.8379454926624736
              precision    recall  f1-score   support

           0     0.5522    0.6280    0.5877       699
           1     0.8158    0.9042    0.8577       720
           2     0.9413    0.9666    0.9538       779
           3     0.7583    0.6045    0.6727       981
           4     0.9812    0.9776    0.9794       802
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.834766198364437
0.834766198364437
              precision    recall  f1-score   support

           0     0.5290    0.6213    0.5714       676
           1     0.8319    0.9145    0.8712       725
           2     0.9513    0.9584    0.9548       794
           3     0.7484    0.5925    0.6614       989
           4     0.9737    0.9713    0.9725       801
           5     0.9711    0.9860    0.9785       784

    accuracy                         0.8348      4769
   macro avg     0.8342    0.8407    0.8350      4769
weighted avg     0.8382    0.8348    0.8338      4769

########## TEST ###########
0.8385744234800838
0.8385744234800839
              precision    recall  f1-score   support

           0     0.5560    0.6287    0.5901       703
           1     0.8170    0.9043    0.8585       721
           2     0.9425    0.9679    0.9550       779
           3     0.7558    0.6055    0.6724       976
           4     0.9812    0.9776    0.9794       802
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8362340113231286
0.8362340113231286
              precision    recall  f1-score   support

           0     0.5428    0.6219    0.5797       693
           1     0.8294    0.9193    0.8720       719
           2     0.9513    0.9572    0.9542       795
           3     0.7420    0.5965    0.6614       974
           4     0.9750    0.9713    0.9731       802
           5     0.9736    0.9860    0.9798       786

    accuracy                         0.8362      4769
   macro avg     0.8357    0.8421    0.8367      4769
weighted avg     0.8385    0.8362    0.8350      4769

########## TEST ###########
0.8371069182389937
0.8371069182389937
              precision    recall  f1-score   support

           0     0.5610    0.6203    0.5892       719
           1     0.8170    0.9030    0.8579       722
           2     0.9437    0.9679    0.9557       780
           3     0.7404    0.6044    0.6655       958
           4     0.9812    0.9776    0.9794       802
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8364436988886559
0.8364436988886559
              precision    recall  f1-score   support

           0     0.5428    0.6166    0.5774       699
           1     0.8331    0.9184    0.8737       723
           2     0.9500    0.9572    0.9536       794
           3     0.7420    0.6002    0.6636       968
           4     0.9725    0.9749    0.9737       797
           5     0.9749    0.9848    0.9798       788

    accuracy                         0.8364      4769
   macro avg     0.8359    0.8420    0.8370      4769
weighted avg     0.8382    0.8364    0.8352      4769

########## TEST ###########
0.8337526205450734
0.8337526205450734
              precision    recall  f1-score   support

           0     0.5623    0.6041    0.5824       740
           1     0.8183    0.9007    0.8575       725
           2     0.9425    0.9679    0.9550       779
           3     0.7187    0.6017    0.6550       934
           4     0.9812    0.9776    0.9794       802
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8364436988886559
0.8364436988886559
              precision    recall  f1-score   support

           0     0.5441    0.6207    0.5799       696
           1     0.8319    0.9157    0.8718       724
           2     0.9500    0.9584    0.9542       793
           3     0.7420    0.5984    0.6625       971
           4     0.9725    0.9737    0.9731       798
           5     0.9749    0.9860    0.9804       787

    accuracy                         0.8364      4769
   macro avg     0.8359    0.8421    0.8370      4769
weighted avg     0.8383    0.8364    0.8351      4769

########## TEST ###########
0.8339622641509434
0.8339622641509434
              precision    recall  f1-score   support

           0     0.5509    0.6075    0.5778       721
           1     0.8233    0.8975    0.8588       732
           2     0.9437    0.9679    0.9557       780
           3     0.7251    0.5994    0.6563       946
           4     0.9812    0.9788    0.9800       801
 

# CNN

In [27]:
from torch.nn.modules.conv import Conv1d
class Conv1dBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=2, padding=0, max_pool=2, dropout=0.1):
    super(Conv1dBlock, self).__init__()

    self.max_pool = max_pool
    self.l1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding = padding)
    self.l2 = nn.ReLU()
    self.l3 = nn.MaxPool1d(kernel_size=self.max_pool, stride=1, padding=0)
    self.l4 = nn.Dropout(dropout)
  
  def __call__(self, x):
    x = self.l1(x)
    x = self.l2(x)
    if self.max_pool > 0:
      x = self.l3(x)
    x = self.l4(x)
    return x

class ConvNet(nn.Module):
  def __init__(self, hidden_dim, dropout, device, outputs, embedding_dim, kernel_size=2, max_pool=2):
    super(ConvNet, self).__init__()

    self.conv_dim = conv_dim
    self.hidden_dim = hidden_dim
    self.dropout = dropout
    self.device = device
    self.outputs = outputs
    self.embedding_dim = embedding_dim
    self.max_pool = max_pool
    self.kernel_size = kernel_size

    self.e = nn.Embedding(num_embeddings = VOCAB_SIZE, embedding_dim=self.embedding_dim).to(self.device)

    self.c1 = Conv1dBlock(in_channels=self.embedding_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)
    self.c2 = Conv1dBlock(in_channels=self.conv_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)
    self.c3 = Conv1dBlock(in_channels=self.conv_dim, out_channels=self.conv_dim, kernel_size=self.kernel_size, padding=0, max_pool=self.max_pool, dropout=0.1).to(self.device)

    self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
    self.a1 = nn.ReLU().to(self.device)
    self.d1 = nn.Dropout().to(self.device)
    self.l2 = nn.LazyLinear(self.outputs).to(self.device)

    self.loss = nn.CrossEntropyLoss()
  
  def forward(self, x, y):
    embeddings = self.e(x)
    C, L, I = embeddings.shape
    x = embeddings.reshape(C, I, L)
    x = self.c1(x).to(self.device)
    x = self.c2(x).to(self.device)
    x = self.c3(x).to(self.device)
    x = x.flatten(1)
    x = self.l1(x).to(self.device)
    x = self.a1(x).to(self.device)
    x = self.d1(x).to(self.device)
    x = self.l2(x).to(self.device)
    # x = x[:,0,:]
    loss = self.loss(x, y).to(device)

    return x, loss

  def get_predictions(self, loader):
    with torch.no_grad():
      preds_all = []
      labels_all = []
      accuracies = []

      for tweets, labels in loader:
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        self.eval()

        logits, loss = self.forward(tweets, labels)

        _, predictions = torch.max(logits, axis=1)

        preds_all += list(predictions.cpu())
        labels_all += list(labels.cpu())


        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

    return preds_all, labels_all
  
  def train_model(self, train_loader, val_loader, test_loader, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        tweets, labels = data
        tweets = tweets.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()


        logits, loss = self.forward(tweets, labels)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 500 == 0:
          pred, label = self.get_predictions(val_loader)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)
  


In [28]:

conv_dim = 32
hidden_dim = 128
dropout = 0.1
learning_rate = 5e-04
epochs = 15
outputs = len(labels2id.items())
embedding_dim = 128

convnet = ConvNet(hidden_dim, dropout, device, outputs, embedding_dim, kernel_size=8, max_pool=2)
optimizer = torch.optim.Adam(convnet.parameters(), lr=learning_rate)

for i in range(epochs):
  convnet.train_model(train_loader, val_loader, test_loader, optimizer)



  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.6569511427972321
0.6569511427972321
              precision    recall  f1-score   support

           0     0.3690    0.4666    0.4121       628
           1     0.6562    0.6138    0.6343       852
           2     0.7362    0.7037    0.7196       837
           3     0.5696    0.4231    0.4856      1054
           4     0.8623    0.9174    0.8890       751
           5     0.7450    0.9165    0.8219       647

    accuracy                         0.6570      4769
   macro avg     0.6564    0.6735    0.6604      4769
weighted avg     0.6578    0.6570    0.6527      4769

########## TEST ###########
0.6654088050314465
0.6654088050314465
              precision    recall  f1-score   support

           0     0.3899    0.4791    0.4300       647
           1     0.6767    0.6345    0.6549       851
           2     0.7087    0.7177    0.7132       790
           3     0.6010    0.4397    0.5078      1069
           4     0.8561    0.9231    0.8883       741
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7406164814426505
0.7406164814426505
              precision    recall  f1-score   support

           0     0.3577    0.5081    0.4198       559
           1     0.7089    0.8060    0.7543       701
           2     0.8875    0.8373    0.8617       848
           3     0.6335    0.4653    0.5365      1066
           4     0.9562    0.9161    0.9357       834
           5     0.8957    0.9369    0.9159       761

    accuracy                         0.7406      4769
   macro avg     0.7399    0.7449    0.7373      4769
weighted avg     0.7557    0.7406    0.7430      4769

########## TEST ###########
0.7433962264150943
0.7433962264150943
              precision    recall  f1-score   support

           0     0.3623    0.5217    0.4276       552
           1     0.7143    0.8062    0.7575       707
           2     0.8725    0.8430    0.8575       828
           3     0.6560    0.4772    0.5525      1075
           4     0.9687    0.9171    0.9422       844
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7676661773956804
0.7676661773956804
              precision    recall  f1-score   support

           0     0.4118    0.5158    0.4580       634
           1     0.7591    0.8067    0.7822       750
           2     0.8800    0.9412    0.9096       748
           3     0.6564    0.4957    0.5648      1037
           4     0.9612    0.9423    0.9517       815
           5     0.9334    0.9465    0.9399       785

    accuracy                         0.7677      4769
   macro avg     0.7670    0.7747    0.7677      4769
weighted avg     0.7728    0.7677    0.7667      4769

########## TEST ###########
0.7643605870020964
0.7643605870020964
              precision    recall  f1-score   support

           0     0.4151    0.5156    0.4599       640
           1     0.7531    0.8111    0.7810       741
           2     0.8538    0.9343    0.8922       731
           3     0.6637    0.4981    0.5691      1042
           4     0.9675    0.9392    0.9531       823
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7729083665338645
0.7729083665338645
              precision    recall  f1-score   support

           0     0.5894    0.4835    0.5312       968
           1     0.7704    0.8309    0.7995       739
           2     0.9100    0.9357    0.9227       778
           3     0.4521    0.5145    0.4813       688
           4     0.9537    0.9646    0.9591       790
           5     0.9548    0.9429    0.9488       806

    accuracy                         0.7729      4769
   macro avg     0.7717    0.7787    0.7738      4769
weighted avg     0.7720    0.7729    0.7709      4769

########## TEST ###########
0.7735849056603774
0.7735849056603775
              precision    recall  f1-score   support

           0     0.5799    0.4873    0.5296       946
           1     0.7594    0.8256    0.7911       734
           2     0.8888    0.9355    0.9115       760
           3     0.4910    0.5393    0.5141       712
           4     0.9549    0.9634    0.9591       792
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7880058712518347
0.7880058712518347
              precision    recall  f1-score   support

           0     0.3602    0.5641    0.4397       507
           1     0.7440    0.9109    0.8191       651
           2     0.9200    0.9281    0.9240       793
           3     0.7816    0.5008    0.6105      1222
           4     0.9612    0.9660    0.9636       795
           5     0.9585    0.9526    0.9555       801

    accuracy                         0.7880      4769
   macro avg     0.7876    0.8038    0.7854      4769
weighted avg     0.8143    0.7880    0.7898      4769

########## TEST ###########
0.789308176100629
0.789308176100629
              precision    recall  f1-score   support

           0     0.3761    0.5663    0.4520       528
           1     0.7293    0.9151    0.8117       636
           2     0.9150    0.9161    0.9156       799
           3     0.7992    0.5161    0.6272      1211
           4     0.9574    0.9684    0.9629       790
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.7957643111763473
0.7957643111763473
              precision    recall  f1-score   support

           0     0.4723    0.5564    0.5109       674
           1     0.7742    0.9020    0.8332       684
           2     0.9363    0.9224    0.9293       812
           3     0.6897    0.5148    0.5895      1049
           4     0.9499    0.9781    0.9638       776
           5     0.9485    0.9755    0.9618       774

    accuracy                         0.7958      4769
   macro avg     0.7951    0.8082    0.7981      4769
weighted avg     0.7974    0.7958    0.7925      4769

########## TEST ###########
0.7962264150943397
0.7962264150943396
              precision    recall  f1-score   support

           0     0.4528    0.5455    0.4948       660
           1     0.7744    0.8944    0.8301       691
           2     0.9313    0.9141    0.9226       815
           3     0.7238    0.5360    0.6159      1056
           4     0.9499    0.9819    0.9656       773
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8001677500524219
0.8001677500524219
              precision    recall  f1-score   support

           0     0.4584    0.5871    0.5149       620
           1     0.7729    0.8902    0.8274       692
           2     0.9350    0.9315    0.9333       803
           3     0.7011    0.5229    0.5990      1050
           4     0.9625    0.9747    0.9685       789
           5     0.9673    0.9448    0.9559       815

    accuracy                         0.8002      4769
   macro avg     0.7995    0.8085    0.7998      4769
weighted avg     0.8081    0.8002    0.7996      4769

########## TEST ###########
0.8046121593291404
0.8046121593291404
              precision    recall  f1-score   support

           0     0.4352    0.5757    0.4957       601
           1     0.7769    0.8857    0.8278       700
           2     0.9325    0.9279    0.9302       804
           3     0.7430    0.5518    0.6332      1053
           4     0.9637    0.9722    0.9679       792
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8056196267561334
0.8056196267561334
              precision    recall  f1-score   support

           0     0.4408    0.6140    0.5132       570
           1     0.7955    0.8818    0.8364       719
           2     0.9525    0.9104    0.9310       837
           3     0.7203    0.5361    0.6147      1052
           4     0.9549    0.9770    0.9658       781
           5     0.9661    0.9494    0.9577       810

    accuracy                         0.8056      4769
   macro avg     0.8050    0.8114    0.8031      4769
weighted avg     0.8192    0.8056    0.8073      4769

########## TEST ###########
0.8054507337526206
0.8054507337526206
              precision    recall  f1-score   support

           0     0.4239    0.5975    0.4960       564
           1     0.7957    0.8723    0.8322       728
           2     0.9325    0.9031    0.9176       826
           3     0.7481    0.5556    0.6376      1053
           4     0.9612    0.9808    0.9709       783
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8028936884042777
0.8028936884042777
              precision    recall  f1-score   support

           0     0.4471    0.5839    0.5064       608
           1     0.7942    0.8941    0.8412       708
           2     0.9113    0.9554    0.9328       763
           3     0.7280    0.5249    0.6100      1086
           4     0.9687    0.9687    0.9687       799
           5     0.9648    0.9540    0.9594       805

    accuracy                         0.8029      4769
   macro avg     0.8023    0.8135    0.8031      4769
weighted avg     0.8116    0.8029    0.8018      4769

########## TEST ###########
0.8058700209643606
0.8058700209643607
              precision    recall  f1-score   support

           0     0.4239    0.5902    0.4934       571
           1     0.7982    0.8774    0.8360       726
           2     0.9050    0.9476    0.9258       764
           3     0.7698    0.5463    0.6391      1102
           4     0.9662    0.9785    0.9723       789
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8043615013629691
0.8043615013629691
              precision    recall  f1-score   support

           0     0.3866    0.6116    0.4738       502
           1     0.7967    0.8982    0.8444       707
           2     0.9250    0.9499    0.9373       779
           3     0.7854    0.5190    0.6250      1185
           4     0.9650    0.9686    0.9668       796
           5     0.9648    0.9600    0.9624       800

    accuracy                         0.8044      4769
   macro avg     0.8039    0.8179    0.8016      4769
weighted avg     0.8280    0.8044    0.8063      4769

########## TEST ###########
0.809853249475891
0.809853249475891
              precision    recall  f1-score   support

           0     0.3874    0.6247    0.4783       493
           1     0.7920    0.8889    0.8376       711
           2     0.9062    0.9428    0.9242       769
           3     0.8274    0.5465    0.6582      1184
           4     0.9700    0.9736    0.9718       796
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8026840008387502
0.8026840008387502
              precision    recall  f1-score   support

           0     0.3501    0.6450    0.4539       431
           1     0.8105    0.8636    0.8362       748
           2     0.9213    0.9437    0.9323       781
           3     0.8008    0.5216    0.6317      1202
           4     0.9625    0.9710    0.9667       792
           5     0.9686    0.9460    0.9572       815

    accuracy                         0.8027      4769
   macro avg     0.8023    0.8152    0.7963      4769
weighted avg     0.8368    0.8027    0.8082      4769

########## TEST ###########
0.8073375262054507
0.8073375262054507
              precision    recall  f1-score   support

           0     0.3434    0.6516    0.4498       419
           1     0.8020    0.8568    0.8285       747
           2     0.9200    0.9436    0.9316       780
           3     0.8261    0.5433    0.6555      1189
           4     0.9762    0.9738    0.9750       801
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8028936884042777
0.8028936884042777
              precision    recall  f1-score   support

           0     0.4295    0.5839    0.4949       584
           1     0.8130    0.8663    0.8388       748
           2     0.9225    0.9523    0.9371       775
           3     0.7190    0.5286    0.6093      1065
           4     0.9637    0.9661    0.9649       797
           5     0.9661    0.9613    0.9637       800

    accuracy                         0.8029      4769
   macro avg     0.8023    0.8097    0.8015      4769
weighted avg     0.8137    0.8029    0.8034      4769

########## TEST ###########
0.8060796645702306
0.8060796645702307
              precision    recall  f1-score   support

           0     0.4252    0.5778    0.4899       585
           1     0.8095    0.8625    0.8352       749
           2     0.9137    0.9457    0.9294       773
           3     0.7442    0.5491    0.6319      1060
           4     0.9725    0.9786    0.9755       794
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8047808764940239
0.8047808764940239
              precision    recall  f1-score   support

           0     0.4055    0.5930    0.4817       543
           1     0.8168    0.8611    0.8384       756
           2     0.9350    0.9315    0.9333       803
           3     0.7356    0.5343    0.6190      1078
           4     0.9700    0.9651    0.9675       803
           5     0.9623    0.9746    0.9684       786

    accuracy                         0.8048      4769
   macro avg     0.8042    0.8099    0.8014      4769
weighted avg     0.8213    0.8048    0.8073      4769

########## TEST ###########
0.810062893081761
0.810062893081761
              precision    recall  f1-score   support

           0     0.3987    0.6167    0.4843       514
           1     0.8158    0.8444    0.8298       771
           2     0.9263    0.9297    0.9280       797
           3     0.7801    0.5561    0.6493      1097
           4     0.9712    0.9724    0.9718       798
   

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8091843153700986
0.8091843153700986
              precision    recall  f1-score   support

           0     0.4395    0.6144    0.5125       568
           1     0.8018    0.8863    0.8419       721
           2     0.9062    0.9527    0.9289       761
           3     0.7663    0.5425    0.6353      1106
           4     0.9725    0.9616    0.9670       808
           5     0.9661    0.9553    0.9606       805

    accuracy                         0.8092      4769
   macro avg     0.8087    0.8188    0.8077      4769
weighted avg     0.8237    0.8092    0.8099      4769

########## TEST ###########
0.8079664570230608
0.8079664570230608
              precision    recall  f1-score   support

           0     0.4314    0.6049    0.5037       567
           1     0.8033    0.8733    0.8368       734
           2     0.8925    0.9571    0.9237       746
           3     0.7775    0.5517    0.6454      1102
           4     0.9700    0.9675    0.9687       801
 

  0%|          | 0/597 [00:00<?, ?batch/s]

########## VAL ###########
0.8106521283287901
0.8106521283287901
              precision    recall  f1-score   support

           0     0.4773    0.5931    0.5290       639
           1     0.7955    0.8917    0.8408       711
           2     0.9300    0.9502    0.9400       783
           3     0.7267    0.5424    0.6212      1049
           4     0.9637    0.9747    0.9692       790
           5     0.9673    0.9661    0.9667       797

    accuracy                         0.8107      4769
   macro avg     0.8101    0.8197    0.8111      4769
weighted avg     0.8164    0.8107    0.8093      4769

########## TEST ###########
0.8142557651991614
0.8142557651991614
              precision    recall  f1-score   support

           0     0.4616    0.5872    0.5169       625
           1     0.8008    0.8912    0.8436       717
           2     0.9200    0.9497    0.9346       775
           3     0.7737    0.5654    0.6533      1070
           4     0.9612    0.9821    0.9715       782
 

# Transformer Models

In [29]:
data

Unnamed: 0,tweet_text,cyberbullying_type,label,cleaned_text
0,"In other words #katandandre, your food was cra...",not_cyberbullying,0,"[in, other, words, katandandre, your, food, wa..."
1,Why is #aussietv so white? #MKR #theblock #ImA...,not_cyberbullying,0,"[why, is, aussietv, so, white, mkr, theblock, ..."
2,@XochitlSuckkks a classy whore? Or more red ve...,not_cyberbullying,0,"[a, classy, whore, or, more, red, velvet, cupc..."
3,"@Jason_Gio meh. :P thanks for the heads up, b...",not_cyberbullying,0,"[meh, p, thanks, for, the, heads, up, but, not..."
4,@RudhoeEnglish This is an ISIS account pretend...,not_cyberbullying,0,"[this, is, an, isis, account, pretending, to, ..."
...,...,...,...,...
47687,"Black ppl aren't expected to do anything, depe...",ethnicity,5,"[black, ppl, are, not, expected, to, do, anyth..."
47688,Turner did not withhold his disappointment. Tu...,ethnicity,5,"[turner, did, not, withhold, his, disappointme..."
47689,I swear to God. This dumb nigger bitch. I have...,ethnicity,5,"[i, swear, to, god, this, dumb, nigger, bitch,..."
47690,Yea fuck you RT @therealexel: IF YOURE A NIGGE...,ethnicity,5,"[yea, fuck, you, rt, if, youre, a, nigger, fuc..."


In [30]:
BATCH_SIZE = 50
# del convnet
# del mlp

In [31]:
model_type = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=False)

In [32]:
tokenized_text = list(data['cleaned_text'])
tokenized_text = [' '.join(x) for x in tokenized_text]
tokenized_text = tokenizer(tokenized_text, padding = 'max_length', max_length = PADDING_VALUE, truncation=True, add_special_tokens=True)
x_ids = torch.tensor([f for f in tokenized_text.input_ids], dtype=torch.long)
x_mask = torch.tensor([f for f in tokenized_text.attention_mask], dtype=torch.long)
label_tensor = torch.tensor(data.label)

In [33]:
tensor_dataset = TensorDataset(x_ids, x_mask, label_tensor)

In [35]:
TRAIN_SPLIT = int(0.8 * len(labels))
VAL_TEST_SPLIT = len(labels) - TRAIN_SPLIT
VAL_SPLIT = VAL_TEST_SPLIT//2
TEST_SPLIT = VAL_TEST_SPLIT - VAL_SPLIT

train, val_test = random_split(tensor_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT], generator=torch.Generator().manual_seed(13))
val, test = random_split(val_test, [VAL_SPLIT, TEST_SPLIT], generator=torch.Generator().manual_seed(13))

print(len(train), len(val), len(test))

38153 4769 4770


In [36]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE)
val_loader = DataLoader(val, batch_size=BATCH_SIZE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE)

In [37]:
class transformer_classifier(nn.Module):
  def __init__(self, model_type, hidden_dim, dropout, device ):
      super(transformer_classifier, self).__init__()

      self.device = device
      self.hidden_dim = hidden_dim
      self.LABELS_NUM = len(labels2id.items())

      self.lm = AutoModel.from_pretrained(model_type).to(self.device)

      config = config = AutoConfig.from_pretrained(model_type, num_labels = self.LABELS_NUM)
      self.lm_sequence_classification = AutoModelForSequenceClassification.from_pretrained(model_type, config=config).to(self.device)

      self.l1 = nn.LazyLinear(hidden_dim).to(self.device)
      self.a1 = nn.GELU().to(self.device)
      self.d1 = nn.Dropout(dropout).to(self.device)
      self.l2 = nn.LazyLinear(self.LABELS_NUM).to(self.device)

      self.loss = nn.CrossEntropyLoss()


  def forward(self, ids, masks, labels, mode):
    if mode:
      output = self.lm_sequence_classification(ids, masks, labels=labels)
      return output.logits.to(self.device), output.loss.to(self.device)
    else:
      x = self.lm(ids, masks).last_hidden_state[:, 0].to(self.device)
      x = self.l1(x).to(self.device)
      x = self.a1(x).to(self.device)
      x = self.d1(x).to(self.device)
      x = self.l2(x).to(self.device)

      loss = self.loss(x, labels).to(self.device)

      return x, loss

  def get_predictions(self, loader, mode):
    with torch.no_grad():
      with tqdm.notebook.tqdm(
      loader,
      unit="batch",
      total=len(loader)) as batch_iterator:

        preds_all = []
        labels_all = []
        accuracies = []

        for iteration, data in enumerate(batch_iterator, start=1):
          ids, masks, labels = data
          ids = ids.to(self.device)
          masks = masks.to(self.device)
          labels = labels.to(self.device)

          self.eval()

          logits, loss = self.forward(ids, masks, labels, mode)

          _, predictions = torch.max(logits, axis=1)

          preds_all += list(predictions.cpu())
          labels_all += list(labels.cpu())


          accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
          accuracies.append(accuracy)

          batch_iterator.set_postfix(accuracy = accuracy, mean_accuracy = np.mean(accuracies))
    return preds_all, labels_all

  def train_model(self, train_loader, val_loader, test_loader, mode, optimizer):
    with tqdm.notebook.tqdm(
      train_loader,
      unit="batch",
      total=len(train_loader)) as batch_iterator:

      total_loss = 0.0
      accuracies = []

      for iteration, data in enumerate(batch_iterator, start=1):
        ids, masks, labels = data
        ids = ids.to(self.device)
        masks = masks.to(self.device)
        labels = labels.to(self.device)

        optimizer.zero_grad()
        self.zero_grad()
        self.train()

        logits, loss = self.forward(ids, masks, labels, mode)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predictions = torch.max(logits, axis=1)
        accuracy = accuracy_score(list(labels.cpu()), list(predictions.cpu()))
        accuracies.append(accuracy)

        batch_iterator.set_postfix(mean_loss=total_loss / iteration, current_loss=loss.item(), accuracy = accuracy, mean_accuracy = np.mean(accuracies))

        if iteration % 400 == 0:
          pred, label = self.get_predictions(val_loader, mode)
          print("########## VAL ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          pred, label = self.get_predictions(test_loader, mode)
          print("########## TEST ###########")
          print(accuracy_score(pred, label))
          print(f1_score(pred, label, average='micro'))
          print(classification_report(pred, label, zero_division=0, digits = 4),)

          



In [38]:
hidden_dim = 128
dropout = 0.2
learning_rate = 1e-04
epochs = 5

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, False, optimizer)

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, True, optimizer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight',

  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8500733906479345
0.8500733906479344
              precision    recall  f1-score   support

           0     0.5908    0.6562    0.6218       704
           1     0.8726    0.8726    0.8726       793
           2     0.9579    0.9695    0.9637       821
           3     0.7174    0.6293    0.6704       847
           4     0.9688    0.9789    0.9738       760
           5     0.9707    0.9799    0.9752       844

    accuracy                         0.8501      4769
   macro avg     0.8463    0.8478    0.8463      4769
weighted avg     0.8508    0.8501    0.8497      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8427672955974843
0.8427672955974843
              precision    recall  f1-score   support

           0     0.5864    0.6478    0.6156       744
           1     0.8766    0.8643    0.8704       781
           2     0.9508    0.9741    0.9623       773
           3     0.7107    0.6115    0.6574       888
           4     0.9640    0.9926    0.9781       810
           5     0.9683    0.9858    0.9770       774

    accuracy                         0.8428      4770
   macro avg     0.8428    0.8460    0.8435      4770
weighted avg     0.8422    0.8428    0.8415      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8544768295240093
0.8544768295240093
              precision    recall  f1-score   support

           0     0.4783    0.7421    0.5816       504
           1     0.8979    0.8466    0.8715       841
           2     0.9735    0.9654    0.9694       838
           3     0.7954    0.6124    0.6920       965
           4     0.9766    0.9817    0.9791       764
           5     0.9847    0.9790    0.9819       857

    accuracy                         0.8545      4769
   macro avg     0.8511    0.8545    0.8459      4769
weighted avg     0.8743    0.8545    0.8588      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8433962264150944
0.8433962264150944
              precision    recall  f1-score   support

           0     0.4672    0.7218    0.5672       532
           1     0.8961    0.8415    0.8679       820
           2     0.9634    0.9561    0.9597       798
           3     0.7866    0.5915    0.6753      1016
           4     0.9676    0.9914    0.9794       814
           5     0.9873    0.9848    0.9861       790

    accuracy                         0.8434      4770
   macro avg     0.8447    0.8479    0.8393      4770
weighted avg     0.8635    0.8434    0.8473      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8496540155168798
0.8496540155168798
              precision    recall  f1-score   support

           0     0.4373    0.7500    0.5525       456
           1     0.8852    0.8656    0.8753       811
           2     0.9735    0.9642    0.9689       839
           3     0.8264    0.5837    0.6841      1052
           4     0.9792    0.9792    0.9792       768
           5     0.9777    0.9881    0.9829       843

    accuracy                         0.8497      4769
   macro avg     0.8466    0.8551    0.8405      4769
weighted avg     0.8764    0.8497    0.8545      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8364779874213837
0.8364779874213837
              precision    recall  f1-score   support

           0     0.4075    0.7545    0.5292       444
           1     0.8662    0.8595    0.8629       776
           2     0.9684    0.9599    0.9642       799
           3     0.8377    0.5575    0.6695      1148
           4     0.9736    0.9842    0.9789       825
           5     0.9759    0.9884    0.9821       778

    accuracy                         0.8365      4770
   macro avg     0.8382    0.8507    0.8311      4770
weighted avg     0.8702    0.8365    0.8418      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8368630740197106
0.8368630740197106
              precision    recall  f1-score   support

           0     0.4834    0.6608    0.5583       572
           1     0.8764    0.8580    0.8671       810
           2     0.9747    0.9474    0.9609       855
           3     0.6945    0.5778    0.6308       893
           4     0.9818    0.9593    0.9704       786
           5     0.9836    0.9824    0.9830       853

    accuracy                         0.8369      4769
   macro avg     0.8324    0.8310    0.8284      4769
weighted avg     0.8494    0.8369    0.8404      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8356394129979036
0.8356394129979036
              precision    recall  f1-score   support

           0     0.4842    0.6803    0.5657       585
           1     0.8688    0.8734    0.8711       766
           2     0.9722    0.9471    0.9595       813
           3     0.7304    0.5723    0.6417       975
           4     0.9760    0.9679    0.9719       841
           5     0.9860    0.9835    0.9848       790

    accuracy                         0.8356      4770
   macro avg     0.8363    0.8374    0.8325      4770
weighted avg     0.8493    0.8356    0.8384      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8326693227091634
0.8326693227091634
              precision    recall  f1-score   support

           0     0.4194    0.6521    0.5105       503
           1     0.8562    0.8899    0.8728       763
           2     0.9567    0.9625    0.9596       826
           3     0.7725    0.5616    0.6504      1022
           4     0.9779    0.9665    0.9722       777
           5     0.9906    0.9613    0.9757       878

    accuracy                         0.8327      4769
   macro avg     0.8289    0.8323    0.8235      4769
weighted avg     0.8542    0.8327    0.8371      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8255765199161426
0.8255765199161426
              precision    recall  f1-score   support

           0     0.4148    0.6752    0.5139       505
           1     0.8623    0.8830    0.8725       752
           2     0.9508    0.9592    0.9550       785
           3     0.7683    0.5420    0.6356      1083
           4     0.9772    0.9796    0.9784       832
           5     0.9873    0.9569    0.9719       813

    accuracy                         0.8256      4770
   macro avg     0.8268    0.8327    0.8212      4770
weighted avg     0.8495    0.8256    0.8297      4770



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight',

  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8460893269029146
0.8460893269029145
              precision    recall  f1-score   support

           0     0.4284    0.7719    0.5510       434
           1     0.8600    0.8857    0.8727       770
           2     0.9675    0.9629    0.9652       835
           3     0.8493    0.5742    0.6851      1099
           4     0.9727    0.9842    0.9784       759
           5     0.9812    0.9587    0.9698       872

    accuracy                         0.8461      4769
   macro avg     0.8432    0.8563    0.8370      4769
weighted avg     0.8772    0.8461    0.8510      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8406708595387841
0.8406708595387841
              precision    recall  f1-score   support

           0     0.4282    0.7702    0.5504       457
           1     0.8610    0.8816    0.8712       752
           2     0.9596    0.9632    0.9614       789
           3     0.8613    0.5638    0.6815      1167
           4     0.9628    0.9938    0.9781       808
           5     0.9822    0.9711    0.9767       797

    accuracy                         0.8407      4770
   macro avg     0.8425    0.8573    0.8366      4770
weighted avg     0.8734    0.8407    0.8447      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8481862025581883
0.8481862025581883
              precision    recall  f1-score   support

           0     0.4348    0.7522    0.5511       452
           1     0.8802    0.8747    0.8774       798
           2     0.9675    0.9629    0.9652       835
           3     0.8371    0.5797    0.6850      1073
           4     0.9661    0.9880    0.9770       751
           5     0.9847    0.9756    0.9801       860

    accuracy                         0.8482      4769
   macro avg     0.8451    0.8555    0.8393      4769
weighted avg     0.8760    0.8482    0.8528      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8387840670859539
0.8387840670859539
              precision    recall  f1-score   support

           0     0.4063    0.7540    0.5281       443
           1     0.8818    0.8661    0.8739       784
           2     0.9634    0.9646    0.9640       791
           3     0.8403    0.5587    0.6712      1149
           4     0.9640    0.9950    0.9793       808
           5     0.9886    0.9799    0.9842       795

    accuracy                         0.8388      4770
   macro avg     0.8407    0.8530    0.8334      4770
weighted avg     0.8729    0.8388    0.8441      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8414762004613127
0.8414762004613128
              precision    recall  f1-score   support

           0     0.4437    0.7229    0.5499       480
           1     0.8487    0.8774    0.8628       767
           2     0.9615    0.9744    0.9679       820
           3     0.8318    0.5649    0.6728      1094
           4     0.9727    0.9777    0.9752       764
           5     0.9730    0.9822    0.9776       844

    accuracy                         0.8415      4769
   macro avg     0.8386    0.8499    0.8344      4769
weighted avg     0.8653    0.8415    0.8441      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8314465408805032
0.8314465408805032
              precision    recall  f1-score   support

           0     0.4221    0.7259    0.5338       478
           1     0.8662    0.8753    0.8708       762
           2     0.9545    0.9631    0.9588       785
           3     0.8207    0.5452    0.6552      1150
           4     0.9700    0.9866    0.9782       820
           5     0.9645    0.9806    0.9725       775

    accuracy                         0.8314      4770
   macro avg     0.8330    0.8461    0.8282      4770
weighted avg     0.8591    0.8314    0.8345      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8456699517718599
0.8456699517718599
              precision    recall  f1-score   support

           0     0.5384    0.6672    0.5959       631
           1     0.8562    0.8818    0.8688       770
           2     0.9663    0.9571    0.9617       839
           3     0.7254    0.5923    0.6521       910
           4     0.9766    0.9778    0.9772       767
           5     0.9871    0.9871    0.9871       852

    accuracy                         0.8457      4769
   macro avg     0.8417    0.8439    0.8405      4769
weighted avg     0.8513    0.8457    0.8463      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8364779874213837
0.8364779874213837
              precision    recall  f1-score   support

           0     0.5182    0.6574    0.5796       648
           1     0.8597    0.8922    0.8757       742
           2     0.9583    0.9511    0.9547       798
           3     0.7277    0.5744    0.6420       968
           4     0.9712    0.9794    0.9753       827
           5     0.9860    0.9873    0.9867       787

    accuracy                         0.8365      4770
   macro avg     0.8369    0.8403    0.8357      4770
weighted avg     0.8432    0.8365    0.8368      4770



  0%|          | 0/764 [00:00<?, ?batch/s]

  0%|          | 0/96 [00:00<?, ?batch/s]

########## VAL ###########
0.8402180750681485
0.8402180750681485
              precision    recall  f1-score   support

           0     0.5780    0.6067    0.5920       745
           1     0.8739    0.8772    0.8756       790
           2     0.9687    0.9572    0.9629       841
           3     0.6245    0.6057    0.6150       766
           4     0.9805    0.9792    0.9798       769
           5     0.9859    0.9790    0.9825       858

    accuracy                         0.8402      4769
   macro avg     0.8352    0.8342    0.8346      4769
weighted avg     0.8417    0.8402    0.8409      4769



  0%|          | 0/96 [00:00<?, ?batch/s]

########## TEST ###########
0.8348008385744234
0.8348008385744236
              precision    recall  f1-score   support

           0     0.5669    0.6180    0.5914       754
           1     0.8766    0.8643    0.8704       781
           2     0.9684    0.9528    0.9606       805
           3     0.6309    0.5929    0.6113       813
           4     0.9724    0.9854    0.9789       823
           5     0.9911    0.9836    0.9874       794

    accuracy                         0.8348      4770
   macro avg     0.8344    0.8328    0.8333      4770
weighted avg     0.8369    0.8348    0.8355      4770



In [39]:
model_type

'bert-base-uncased'

In [None]:
model_type = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=False)
tokenized_text = list(data['cleaned_text'])
tokenized_text = [' '.join(x) for x in tokenized_text]
tokenized_text = tokenizer(tokenized_text, padding = 'max_length', max_length = PADDING_VALUE, truncation=True, add_special_tokens=True)
x_ids = torch.tensor([f for f in tokenized_text.input_ids], dtype=torch.long)
x_mask = torch.tensor([f for f in tokenized_text.attention_mask], dtype=torch.long)
label_tensor = torch.tensor(data.label)
tensor_dataset = TensorDataset(x_ids, x_mask, label_tensor)
train, val_test = random_split(tensor_dataset, [TRAIN_SPLIT, VAL_TEST_SPLIT], generator=torch.Generator().manual_seed(13))
val, test = random_split(val_test, [VAL_SPLIT, TEST_SPLIT], generator=torch.Generator().manual_seed(13))
train_loader = DataLoader(train, batch_size=BATCH_SIZE)
val_loader = DataLoader(val, batch_size=BATCH_SIZE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE)

In [None]:
classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, False, optimizer)

classifier = transformer_classifier(model_type, hidden_dim, dropout, device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

for i in range(epochs):
  classifier.train_model(train_loader, val_loader, test_loader, True, optimizer)