In [1]:
import os
import pandas as pd
import numpy as np
import random
from tqdm.auto import tqdm
import pickle
import re
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
from datasets import ClassLabel, load_dataset, Dataset, DatasetDict
import string
from typing import Dict ,List
import transformers

import torch
import torch.nn as nn
from torch.optim import SGD
from torch.optim.lr_scheduler import _LRScheduler
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data
from torch.utils.data import DataLoader, TensorDataset
from transformers import AdamW
from transformers import BertTokenizerFast
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from arabert.preprocess import ArabertPreprocessor
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

import warnings

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

<h2 style='color:red'>NOTE: change all occurrences of MPS to CUDA if not training on macOS</h2>

## SETTINGS

In [2]:
DATA_PATH = "/Users/gufran/Developer/Projects/AI/MawqifStanceDetection/data"
MODEL_PATH = "/Users/gufran/Developer/Projects/AI/MawqifStanceDetection/models"

In [3]:
use_gnn = True
num_msg_passing = 1
concatenate_msg_bert = True

In [4]:
target = "all"
#target = 'Covid Vaccine'
#target = 'Digital Transformation'
#target = 'Women empowerment'

In [5]:
model_name = "aubmindlab/bert-base-arabertv02-twitter"
# model_name = "aubmindlab/bert-base-arabertv02"

tokenizer = AutoTokenizer.from_pretrained(model_name)
bert = AutoModel.from_pretrained(model_name)

Some weights of BertModel were not initialized from the model checkpoint at aubmindlab/bert-base-arabertv02-twitter and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
batch_size = 16
num_epochs = 20
learning_rate = 2e-5
weight_decay = 1e-5
dropout = 0.1
device = "mps"

In [7]:
seed = 42

random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

#FOR MAC --------------------------------------------
torch.mps.manual_seed(seed)
torch.backends.mps.deterministic=True
torch.backends.mps.benchmark = False

#FOR WINDOWS AND LINUX -------------------------------
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic=True #replace mps with cudnn here
# torch.backends.cudnn.benchmark = False #replace mps with cudnn here

## DATA LOADING AND PREPROCESSING

In [8]:
df = pd.read_csv(os.path.join(DATA_PATH, "cleaned.csv"))
df = df.dropna(subset=["stance"])
df = df.sample(frac=1).reset_index(drop=True)
df = df[["text", "sarcasm", "sentiment", "stance", "sentiment_confidence"]]
df.head()

Unnamed: 0,text,sarcasm,sentiment,stance
0,والوزارة كل يوم مغردين عن التحول الرقمي والت...,No,Neutral,Favor
1,تمكين المرأة يعني فتح المجال لها للعمل في مجال...,No,Neutral,Favor
2,MENTION تمكين المرأة السعودية ورحلة كفاحها هو ...,No,Positive,Favor
3,هذا الدراسه العالميه الحديثه التي تعاون و شارك...,No,Neutral,Favor
4,نحن نتعرض لحملة موافقة اجبارية ممنهجة😅 ماتلاحظ...,No,Neutral,Against


In [9]:
mapping_sarcasm = {"No": 0, "Yes": 1}
df['sarcasm'] = df['sarcasm'].map(lambda x: mapping_sarcasm[x])

mapping_sentiment = {"Negative": 0, "Neutral": 1, "Positive": 2}
df['sentiment'] = df['sentiment'].map(lambda x: mapping_sentiment[x])

mapping_stance = {"Favor": 1, "Against": 0}
df['stance'] = df['stance'].map(lambda x: mapping_stance[x])

In [10]:
arabic_punctuations = '''`÷×؛<>()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ'''
english_punctuations = string.punctuation
punctuations_list = arabic_punctuations + english_punctuations

def remove_hash_URL_MEN(text):
    text = re.sub(r'#',' ',text)
    text = re.sub(r'_',' ',text)
    text = re.sub(r'URL','',text)
    text = re.sub(r'MENTION','',text)
    return text

def normalize_arabic(text):
    text = re.sub("[إآ]", "ا", text)
    text = re.sub("گ", "ك", text)
    return text

def remove_punctuations(text):
    translator = str.maketrans('', '', punctuations_list)
    return text.translate(translator)

def remove_repeating_char(text):
    return re.sub(r'(.)\1+', r'\1', text)

def process_tweet(tweet):     
    tweet=remove_hash_URL_MEN(tweet)
    tweet = re.sub('@[^\s]+', ' ', str(tweet))
    tweet = re.sub('((www\.[^\s]+)|(https?://[^\s]+))',' ',str(tweet))    
    tweet= normalize_arabic(str(tweet))
    
    return tweet

arabert_prep = ArabertPreprocessor(model_name=model_name)
df.text = df.text.apply(lambda x: process_tweet(x))
df.text = df.text.apply(lambda x: arabert_prep.preprocess(x))

In [11]:
df.head()

Unnamed: 0,text,sarcasm,sentiment,stance
0,والوزارة كل يوم مغردين عن التحول الرقمي والتعل...,0,1,1
1,تمكين المرأة يعني فتح المجال لها للعمل في مجال...,0,1,1
2,تمكين المرأة السعودية ورحلة كفاحها هو الانتصار,0,2,1
3,هذا الدراسه العالميه الحديثه التي تعاون و شارك...,0,1,1
4,نحن نتعرض لحملة موافقة اجبارية ممنهجة 😅 ماتلاح...,0,1,0


In [12]:
X = df[["text"]]
y = df[["sentiment", "sarcasm", "stance"]]

In [13]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
print("X_train shape:", X_train.shape, "y_train shape:", y_train.shape, "X_test shape:", X_test.shape, "y_test shape:", y_test.shape)

X_train shape: (2693, 1) y_train shape: (2693, 3) X_test shape: (476, 1) y_test shape: (476, 3)


In [14]:
def encode_text(text):
    return tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

encoded_tweets_train = [encode_text(text) for text in X_train["text"]]
encoded_tweets_test = [encode_text(text) for text in X_test["text"]]

In [15]:
def data_generator(encoded_tweets, labels):    
    main_tweets, context_tweets, sentiments, sarcasms, stances = [], [], [], [], []
    len_encoded_tweets = len(encoded_tweets)
    
    for i in range(len_encoded_tweets):
        context_tweet_index = random.randint(0, len_encoded_tweets-1)
        main_tweets.append(encoded_tweets[i])
        context_tweets.append(encoded_tweets[context_tweet_index])

        sentiments.append(labels.sentiment.iloc[i])
        sarcasms.append(labels.sarcasm.iloc[i])
        stances.append(labels.stance.iloc[i])
            
    return main_tweets, context_tweets, sentiments, sarcasms, stances

In [16]:
main_tweets, context_tweets, sentiments, sarcasms, stances = data_generator(encoded_tweets_train, y_train)
train_dataset = TensorDataset(
    torch.cat([item["input_ids"] for item in main_tweets]),
    torch.cat([item["attention_mask"] for item in main_tweets]),
    torch.cat([item["input_ids"] for item in context_tweets]),
    torch.cat([item["attention_mask"] for item in context_tweets]),
    torch.tensor(sarcasms),
    torch.tensor(sentiments),
    torch.tensor(stances),
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

main_tweets, context_tweets, sentiments, sarcasms, stances = data_generator(encoded_tweets_test, y_test)
val_dataset = TensorDataset(
    torch.cat([item["input_ids"] for item in main_tweets]),
    torch.cat([item["attention_mask"] for item in main_tweets]),
    torch.cat([item["input_ids"] for item in context_tweets]),
    torch.cat([item["attention_mask"] for item in context_tweets]),
    torch.tensor(sarcasms),
    torch.tensor(sentiments),
    torch.tensor(stances),
)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [17]:
num_sarcasm_labels = len(df.sarcasm.unique())
num_sentiment_labels = len(df.sentiment.unique())
num_stance_labels = len(df.stance.unique())

num_sarcasm_labels, num_sentiment_labels, num_stance_labels

(2, 3, 2)

---

## MODEL ARCHITECTURE

In [18]:
class SubTaskHead(nn.Module):
    def __init__(self, num_labels, hidden_size):
        super(SubTaskHead, self).__init__()
        
        self.num_labels = num_labels
        self.classifier = nn.Linear(hidden_size, num_labels)
        
    def forward(self, inputs):
        logits = F.dropout(inputs, p=dropout, training=self.training)
        logits = self.classifier(logits)
        
        return logits
    
    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()

In [19]:
class StanceHead(torch.nn.Module):
    def __init__(self, last_hidden_state_size, hidden_channels, num_classes, concatenate_msg_bert = True):
        super().__init__()
        self.conv1 = GCNConv(last_hidden_state_size, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.concatenate_msg_bert = concatenate_msg_bert
        
        if use_gnn:
            self.classifier = nn.Linear(hidden_channels if not concatenate_msg_bert else hidden_channels + last_hidden_state_size, num_classes)
        else:
            self.classifier = nn.Linear(last_hidden_state_size, num_classes)

    def forward(self, mt_last_hidden_state, sentiments, sarcasms):
        if use_gnn:
            node_features = torch.tensor(mt_last_hidden_state, dtype=torch.float)

            edges = []
            edge_features = []
            for i in range(len(sentiments)):
                for j in range(len(sentiments)):
                    if i == j: continue
                    if sentiments[i] == sentiments[j] or sarcasms[i] == sarcasms[j]:
                        edges.append((i, j))
                        if sentiments[i] == sentiments[j] and sarcasms[i] == sarcasms[j]:
                            edge_features.append(1.0)
                        else:
                            edge_features.append(0.5)

            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_features, dtype=torch.float)
            data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr).to(device)

            x = self.conv1(data.x, data.edge_index, data.edge_attr)
            x = x.relu()
            x = F.dropout(x, p=dropout, training=self.training)
            
            if num_msg_passing == 2:
                x = self.conv2(x, data.edge_index, data.edge_attr)
                x = x.relu()
                x = F.dropout(x, p=dropout, training=self.training)

            if num_msg_passing == 3:
                x = self.conv3(x, data.edge_index, data.edge_attr)
                x = x.relu()
                x = F.dropout(x, p=dropout, training=self.training)
            
            if self.concatenate_msg_bert:
                x = torch.cat((x, data.x), dim=-1)
                
            x = self.classifier(x)
            return x
        
        else:
            logits = F.dropout(mt_last_hidden_state, p=dropout, training=self.training)
            logits = self.classifier(logits)
        
            return logits
    
    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()

In [20]:
class MultiTaskModel(nn.Module):
    def __init__(self, bert, sentiment_head, sarcasm_head, stance_head, subtask_hidden_layer_size):
        super(MultiTaskModel, self).__init__()

        self.bert = bert
        self.hidden_layer = nn.Linear(bert.config.hidden_size, subtask_hidden_layer_size)
        self.dropout = nn.Dropout(dropout)
        
        self.sentiment_head = sentiment_head
        self.sarcasm_head = sarcasm_head
        self.stance_head = stance_head
        
    def forward(self, mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask):
        mt_outputs = self.bert(input_ids=mt_input_ids, attention_mask=mt_attention_mask)
        ct_outputs = self.bert(input_ids=ct_input_ids, attention_mask=ct_attention_mask)
                
        mt_last_hidden_state = mt_outputs.last_hidden_state[:, 0, :]
        ct_last_hidden_state = ct_outputs.last_hidden_state[:, 0, :]
        
        hidden_output = self.dropout(F.relu(self.hidden_layer(mt_last_hidden_state)))
        sarcasm_logits = self.sarcasm_head(hidden_output)
        sentiment_logits = self.sentiment_head(hidden_output)
        
        sentiments = sentiment_logits.argmax(axis=1)
        sarcasms = sarcasm_logits.argmax(axis=1)
        stance_logits = self.stance_head(mt_last_hidden_state, sentiments, sarcasms)
        
        return sarcasm_logits, sentiment_logits, stance_logits

## TRAINING & VALIDATION

In [21]:
class LinearDecayLR(_LRScheduler):
    def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
        self.start_decay=start_decay
        self.n_epoch=n_epoch
        super(LinearDecayLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        last_epoch = self.last_epoch
        n_epoch=self.n_epoch
        b_lr=self.base_lrs[0]
        start_decay=self.start_decay
        if last_epoch>start_decay:
            lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
        else:
            lr=b_lr
        return [lr]

In [22]:
subtask_hidden_layer_size = 256
gnn_hidden_layer_size = 256
bert_emb_size = bert.config.hidden_size

model = MultiTaskModel(
    bert,
    SubTaskHead(num_sentiment_labels, subtask_hidden_layer_size),
    SubTaskHead(num_sarcasm_labels, subtask_hidden_layer_size),
    StanceHead(bert_emb_size, gnn_hidden_layer_size, num_stance_labels, concatenate_msg_bert = concatenate_msg_bert),
    subtask_hidden_layer_size,
).to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
ce_loss = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

lr_scheduler=LinearDecayLR(optimizer, num_epochs, int(num_epochs*0.75))

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1} || Learning Rate: {lr_scheduler.get_lr()}")
    
    # TRAINING -------------------------------------------------------------------------------
    model.train()
    train_loader_gen = data_generator(encoded_tweets_train, y_train)
    valid_loader_gen = data_generator(encoded_tweets_test, y_test)
    
    epoch_sarcasm_loss, epoch_sentiment_loss, epoch_stance_loss = 0.0, 0.0, 0.0    
    correct_sarcasm, correct_sentiment, correct_stance, total_samples = 0,0,0,0

    for batch in tqdm(train_loader, desc="Training"):
        mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask, sarc_y, sent_y, stance_y = batch
        mt_input_ids = mt_input_ids.to(device)
        mt_attention_mask = mt_attention_mask.to(device)
        ct_input_ids = ct_input_ids.to(device)
        ct_attention_mask = ct_attention_mask.to(device)

        sarc_y = sarc_y.to(device)
        sent_y = sent_y.to(device)
        stance_y = stance_y.to(device)

        optimizer.zero_grad()
        sarcasm_logits, sentiment_logits, stance_logits = model(mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask)

        sarcasm_loss = ce_loss(sarcasm_logits, sarc_y)
        sentiment_loss = ce_loss(sentiment_logits, sent_y)
        stance_loss = ce_loss(stance_logits, stance_y)

        total_loss = sarcasm_loss + sentiment_loss + stance_loss
        total_loss.backward()
        optimizer.step()

        correct_sarcasm += (sarcasm_logits.argmax(dim=1) == sarc_y).sum().item()
        correct_sentiment += (sentiment_logits.argmax(dim=1) == sent_y).sum().item()
        correct_stance += (stance_logits.argmax(dim=1) == stance_y).sum().item()
        total_samples += mt_input_ids.size(0)

        epoch_sarcasm_loss += sarcasm_loss.item()
        epoch_sentiment_loss += sentiment_loss.item()
        epoch_stance_loss += stance_loss.item()
           
    avg_sarcasm_loss = epoch_sarcasm_loss / total_samples
    avg_sentiment_loss = epoch_sentiment_loss / total_samples
    avg_stance_loss = epoch_stance_loss / total_samples

    sarcasm_acc = correct_sarcasm / total_samples
    sentiment_acc = correct_sentiment / total_samples
    stance_acc = correct_stance / total_samples

    print(f"Sarcasm -> Loss: {avg_sarcasm_loss:.4f}, Acc: {sarcasm_acc:.4f}")
    print(f"Sentiment -> Loss: {avg_sentiment_loss:.4f}, Acc: {sentiment_acc:.4f}")
    print(f"Stance -> Loss: {avg_stance_loss:.4f}, Acc: {stance_acc:.4f}")
    
    
    # VALIDATION -------------------------------------------------------------------------------
    model.eval()
    valid_sarcasm_loss, valid_sentiment_loss, valid_stance_loss = 0.0, 0.0, 0.0
    valid_correct_sarcasm, valid_correct_sentiment, valid_correct_stance, valid_total_samples = 0,0,0,0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask, sarc_y, sent_y, stance_y = batch
            mt_input_ids = mt_input_ids.to(device)
            mt_attention_mask = mt_attention_mask.to(device)
            ct_input_ids = ct_input_ids.to(device)
            ct_attention_mask = ct_attention_mask.to(device)

            sarc_y = sarc_y.to(device)
            sent_y = sent_y.to(device)
            stance_y = stance_y.to(device)

            sarcasm_logits, sentiment_logits, stance_logits = model(mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask)

            sarcasm_loss = ce_loss(sarcasm_logits, sarc_y)
            sentiment_loss = ce_loss(sentiment_logits, sent_y)
            stance_loss = ce_loss(stance_logits, stance_y)

            valid_sarcasm_loss += sarcasm_loss.item()
            valid_sentiment_loss += sentiment_loss.item()
            valid_stance_loss += stance_loss.item()

            valid_correct_sarcasm += (sarcasm_logits.argmax(dim=1) == sarc_y).sum().item()
            valid_correct_sentiment += (sentiment_logits.argmax(dim=1) == sent_y).sum().item()
            valid_correct_stance += (stance_logits.argmax(dim=1) == stance_y).sum().item()
            valid_total_samples += mt_input_ids.size(0)
    
    avg_valid_sarcasm_loss = valid_sarcasm_loss / valid_total_samples
    avg_valid_sentiment_loss = valid_sentiment_loss / valid_total_samples
    avg_valid_stance_loss = valid_stance_loss / valid_total_samples
    
    valid_sarcasm_acc = valid_correct_sarcasm / valid_total_samples
    valid_sentiment_acc = valid_correct_sentiment / valid_total_samples
    valid_stance_acc = valid_correct_stance / valid_total_samples

    print(f"Sarcasm -> Loss: {avg_valid_sarcasm_loss:.4f}, Acc: {valid_sarcasm_acc:.4f}")
    print(f"Sentiment -> Loss: {avg_valid_sentiment_loss:.4f}, Acc: {valid_sentiment_acc:.4f}")
    print(f"Stance -> Loss: {avg_valid_stance_loss:.4f}, Acc: {valid_stance_acc:.4f}\n\n")
    
    lr_scheduler.step()

Epoch 1 || Learning Rate: [2e-05]


Training:   0%|          | 0/169 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0108, Acc: 0.9528
Sentiment -> Loss: 0.0512, Acc: 0.6417
Stance -> Loss: 0.0298, Acc: 0.7635


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0079, Acc: 0.9622
Sentiment -> Loss: 0.0445, Acc: 0.7164
Stance -> Loss: 0.0236, Acc: 0.8445


Epoch 2 || Learning Rate: [2e-05]


Training:   0%|          | 0/169 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0076, Acc: 0.9614
Sentiment -> Loss: 0.0383, Acc: 0.7438
Stance -> Loss: 0.0220, Acc: 0.8455


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0072, Acc: 0.9622
Sentiment -> Loss: 0.0415, Acc: 0.7374
Stance -> Loss: 0.0208, Acc: 0.8466


Epoch 3 || Learning Rate: [2e-05]


Training:   0%|          | 0/169 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0062, Acc: 0.9688
Sentiment -> Loss: 0.0277, Acc: 0.8377
Stance -> Loss: 0.0190, Acc: 0.8801


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0069, Acc: 0.9622
Sentiment -> Loss: 0.0443, Acc: 0.7227
Stance -> Loss: 0.0197, Acc: 0.8655


Epoch 4 || Learning Rate: [2e-05]


Training:   0%|          | 0/169 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0050, Acc: 0.9759
Sentiment -> Loss: 0.0189, Acc: 0.8986
Stance -> Loss: 0.0194, Acc: 0.8752


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Sarcasm -> Loss: 0.0069, Acc: 0.9622
Sentiment -> Loss: 0.0478, Acc: 0.7038
Stance -> Loss: 0.0196, Acc: 0.8634


Epoch 5 || Learning Rate: [2e-05]


Training:   0%|          | 0/169 [00:00<?, ?it/s]

KeyboardInterrupt: 

## TESTING