In [1]:
import os
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
import re
import os
import string

import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp, global_add_pool as gaddp
from torch.optim.lr_scheduler import _LRScheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import AdamW
from arabert.preprocess import ArabertPreprocessor
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

import warnings

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

<h2 style='color:red'>NOTE: change all occurrences of MPS to CUDA if not training on macOS</h2>

## SETTINGS

In [2]:
DATA_PATH = "/Users/gufran/Developer/Projects/AI/MawqifStanceDetection/data"
LOGS_PATH = "/Users/gufran/Developer/Projects/AI/MawqifStanceDetection/logs"
MODEL_PATH = "/Users/gufran/Developer/Projects/AI/MawqifStanceDetection/models"

if not os.path.exists(LOGS_PATH):
    os.makedirs(LOGS_PATH)

if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)

In [3]:
target = "all"
#target = 'Covid Vaccine'
#target = 'Digital Transformation'
#target = 'Women empowerment'

In [4]:
ensemble_settings = {
    "ENSEMBLE_AVERAGE" : 0,
    "ENSEMBLE_ATTENTION" : 1,
    "ENSEMBLE_GRAPH" : 2
}

ENSEMBLE_METHOD = ensemble_settings["ENSEMBLE_ATTENTION"]

In [5]:
weighting_settings = {
    "EQUAL_WEIGHTING" : 0,
    "STATIC_WEIGHTING" : 1,
    "RELATIVE_WEIGHTING" : 2,
    "HEIRARCHICAL_WEIGHTING" : 3,
    "PRIORITIZE_HIGH_CONFIDENCE_WEIGHTING" : 4,
    "PRIORITIZE_LOW_CONFIDENCE_WEIGHTING" : 5,
    "META_WEIGHTING": 6
}

WEIGHTING_METHOD = weighting_settings["EQUAL_WEIGHTING"]

In [6]:
bert_models = [
    "aubmindlab/bert-base-arabertv02-twitter", 
    # "aubmindlab/bert-base-arabertv02",
    # "UBC-NLP/MARBERT",
    # "CAMeL-Lab/bert-base-arabic-camelbert-da"
]

In [7]:
first_task = "sarcasm"
first_task = "sentiment"

In [8]:
pool_bert_output = False
use_bi = True
use_gru = False

task_head_layer_size = 256
BERT_hidden_state_size = 768
batch_size = 32
num_epochs = 20
learning_rate = 2e-5
weight_decay = 1e-5
dropout = 0.1
device = "mps"

In [9]:
seed = 42

random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

#FOR MAC --------------------------------------------
torch.mps.manual_seed(seed)
torch.backends.mps.deterministic=True
torch.backends.mps.benchmark = False

#FOR WINDOWS AND LINUX -------------------------------
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic=True #replace mps with cudnn here
# torch.backends.cudnn.benchmark = False #replace mps with cudnn here

## DATA LOADING AND PREPROCESSING

In [10]:
df = pd.read_csv(os.path.join(DATA_PATH, "Mawqif_AllTargets_Train.csv"))
df = df.dropna(subset=["stance"])
df = df.sample(frac=1).reset_index(drop=True)
df = df[["text", "sarcasm", "sentiment", "stance" , "sarcasm:confidence", "sentiment:confidence", "stance:confidence"]]

if target != "all":
    df = df[df.target == target]

df.head()

Unnamed: 0,text,sarcasm,sentiment,stance,sarcasm:confidence,sentiment:confidence,stance:confidence
0,والوزارة كل يوم مغردين عن التحول الرقمي والت...,No,Neutral,Favor,1.0,1.0,0.3977
1,تمكين المرأة يعني فتح المجال لها للعمل في مجال...,No,Neutral,Favor,1.0,0.6667,1.0
2,MENTION تمكين المرأة السعودية ورحلة كفاحها هو ...,No,Positive,Favor,1.0,1.0,1.0
3,هذا الدراسه العالميه الحديثه التي تعاون و شارك...,No,Neutral,Favor,1.0,0.7029,0.7533
4,نحن نتعرض لحملة موافقة اجبارية ممنهجة😅 ماتلاحظ...,No,Neutral,Against,0.6187,0.6823,1.0


In [11]:
mapping_sarcasm = {"No": 0, "Yes": 1}
mapping_stance = {"Favor": 1, "Against": 0}
mapping_sentiment = {"Negative": 0, "Neutral": 1, "Positive": 2}

df['sarcasm'] = df['sarcasm'].map(lambda x: mapping_sarcasm[x])
df['sentiment'] = df['sentiment'].map(lambda x: mapping_sentiment[x])
df['stance'] = df['stance'].map(lambda x: mapping_stance[x])

In [12]:
arabic_punctuations = '''`÷×؛<>()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ'''
english_punctuations = string.punctuation
punctuations_list = arabic_punctuations + english_punctuations

def remove_hash_URL_MEN(text):
    text = re.sub(r'#',' ',text)
    text = re.sub(r'_',' ',text)
    text = re.sub(r'URL','',text)
    text = re.sub(r'MENTION','',text)
    return text

def normalize_arabic(text):
    text = re.sub("[إآ]", "ا", text)
    text = re.sub("گ", "ك", text)
    return text

def remove_punctuations(text):
    translator = str.maketrans('', '', punctuations_list)
    return text.translate(translator)

def remove_repeating_char(text):
    return re.sub(r'(.)\1+', r'\1', text)

def process_tweet(tweet):     
    tweet=remove_hash_URL_MEN(tweet)
    tweet = re.sub('@[^\s]+', ' ', str(tweet))
    tweet = re.sub('((www\.[^\s]+)|(https?://[^\s]+))',' ',str(tweet))    
    tweet= normalize_arabic(str(tweet))
    
    return tweet

arabert_prep = ArabertPreprocessor(model_name="aubmindlab/bert-base-arabertv02-twitter")
df.text = df.text.apply(lambda x: process_tweet(x))
df.text = df.text.apply(lambda x: arabert_prep.preprocess(x))

In [13]:
df.head()

Unnamed: 0,text,sarcasm,sentiment,stance,sarcasm:confidence,sentiment:confidence,stance:confidence
0,والوزارة كل يوم مغردين عن التحول الرقمي والتعل...,0,1,1,1.0,1.0,0.3977
1,تمكين المرأة يعني فتح المجال لها للعمل في مجال...,0,1,1,1.0,0.6667,1.0
2,تمكين المرأة السعودية ورحلة كفاحها هو الانتصار,0,2,1,1.0,1.0,1.0
3,هذا الدراسه العالميه الحديثه التي تعاون و شارك...,0,1,1,1.0,0.7029,0.7533
4,نحن نتعرض لحملة موافقة اجبارية ممنهجة 😅 ماتلاح...,0,1,0,0.6187,0.6823,1.0


In [14]:
X = df[["text", "sarcasm:confidence", "sentiment:confidence", "stance:confidence"]]
y = df[["sarcasm", "sentiment", "stance"]]

In [15]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
print("X_train shape:", X_train.shape, "y_train shape:", y_train.shape, "X_test shape:", X_test.shape, "y_test shape:", y_test.shape)

X_train shape: (2693, 4) y_train shape: (2693, 3) X_test shape: (476, 4) y_test shape: (476, 3)


In [16]:
def encode_text(text, tokenizer):
    return tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

encoded_tweets_train = []
encoded_tweets_test = []
for bert_model in bert_models:
    tokenizer = AutoTokenizer.from_pretrained(bert_model)
    encoded_tweets_train.append([encode_text(text, tokenizer) for text in X_train["text"]])
    encoded_tweets_test.append([encode_text(text, tokenizer) for text in X_test["text"]])

In [17]:
def data_generator(encoded_tweets, confidences, labels):    
    tweets, sentiments, sarcasms, stances = [], [], [], []
    sarcasm_confidence, sentiment_confidence, stance_confidence = [], [], []
    
    for i in range(len(encoded_tweets)):
        model_specific_tweets = []
        for j in range(len(encoded_tweets[i])):
            model_specific_tweets.append(encoded_tweets[i][j])            
        tweets.append(model_specific_tweets)
        
    for i in range(len(labels)):
        sentiments.append(labels.sentiment.iloc[i])
        sarcasms.append(labels.sarcasm.iloc[i])
        stances.append(labels.stance.iloc[i])
        sarcasm_confidence.append(confidences["sarcasm:confidence"].iloc[i])
        sentiment_confidence.append(confidences["sentiment:confidence"].iloc[i])
        stance_confidence.append(confidences["stance:confidence"].iloc[i])
            
    return tweets, sentiments, sarcasms, stances, sentiment_confidence, sarcasm_confidence, stance_confidence

In [18]:
tweets, sentiments, sarcasms, stances, *confidences = data_generator(encoded_tweets_train, X_train[["sarcasm:confidence", "sentiment:confidence", "stance:confidence"]], y_train)
train_dataset = TensorDataset(
    *[torch.cat([item["input_ids"] for item in enc_tweets]) for enc_tweets in tweets],
    *[torch.cat([item["attention_mask"] for item in enc_tweets]) for enc_tweets in tweets],
    torch.tensor(confidences[0]),
    torch.tensor(confidences[1]),
    torch.tensor(confidences[2]),
    torch.tensor(sarcasms),
    torch.tensor(sentiments),
    torch.tensor(stances),
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

tweets, sentiments, sarcasms, stances, *confidences = data_generator(encoded_tweets_test, X_test[["sarcasm:confidence", "sentiment:confidence", "stance:confidence"]], y_test)
val_dataset = TensorDataset(
    *[torch.cat([item["input_ids"] for item in enc_tweets]) for enc_tweets in tweets],
    *[torch.cat([item["attention_mask"] for item in enc_tweets]) for enc_tweets in tweets],
    torch.tensor(confidences[0]),
    torch.tensor(confidences[1]),
    torch.tensor(confidences[2]),
    torch.tensor(sarcasms),
    torch.tensor(sentiments),
    torch.tensor(stances),
)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [19]:
num_sarcasm_labels = len(df.sarcasm.unique())
num_sentiment_labels = len(df.sentiment.unique())
num_stance_labels = len(df.stance.unique())

num_sarcasm_labels, num_sentiment_labels, num_stance_labels

(2, 3, 2)

---

## MODEL ARCHITECTURE

In [20]:
class TaskHead(nn.Module):
    def __init__(self, num_labels, hidden_size, is_first_head=False, pool_bert_output=False, use_bi=False, use_gru=False):
        super(TaskHead, self).__init__()
        multiplier_first_head = 1 if is_first_head else 2
        multiplier_bi = 2 if use_bi else 1

        self.num_labels = num_labels
        self.pool_bert_output = pool_bert_output
        
        if not pool_bert_output:
            if use_gru:
                self.sequential_model = nn.GRU(hidden_size * multiplier_first_head, hidden_size, batch_first=True, bidirectional=use_bi)
            else:
                self.sequential_model = nn.LSTM(hidden_size * multiplier_first_head, hidden_size, batch_first=True, bidirectional=use_bi)

        self.classifier = nn.Linear(hidden_size * multiplier_bi, num_labels)
        if pool_bert_output:
            self.classifier = nn.Linear(hidden_size * multiplier_bi * multiplier_first_head, num_labels)
        
    def forward(self, inputs):
        print(inputs.shape)
        if not self.pool_bert_output:
            inputs, _ = self.sequential_model(inputs)
        
        print(inputs.shape)
        
        logits = F.dropout(inputs if self.pool_bert_output else inputs[:, -1, :], p=dropout, training=self.training)
        res = self.classifier(logits)
        
        return logits if self.pool_bert_output else inputs, res
    
    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()

In [21]:
class SequentialMultiTaskModel(nn.Module):
    def __init__(self, bert_models, sentiment_head, sarcasm_head, stance_head, subtask_hidden_layer_size, combination_method=None, weighting_method=weighting_settings["EQUAL_WEIGHTING"], pool_bert_output=False):
        super(SequentialMultiTaskModel, self).__init__()

        self.pool_bert_output = pool_bert_output

        self.combination_method = combination_method
        self.bert_models = nn.ModuleList([AutoModel.from_pretrained(bert).to(device) for bert in bert_models])
                        
        self.sentiment_head = sentiment_head
        self.sarcasm_head = sarcasm_head
        self.stance_head = stance_head

        self.weighting_method = weighting_method
        self.combination_method = combination_method
        self.subtask_hidden_layer_size = subtask_hidden_layer_size

        self.intialize_ensemble_specific_components()
        self.initialize_loss_specific_components()

    def intialize_ensemble_specific_components(self):
        if self.combination_method == ensemble_settings["ENSEMBLE_GRAPH"]:
            graph_hidden_size = self.get_hidden_in()
            self.conv1 = GCNConv(graph_hidden_size, graph_hidden_size)

        if self.combination_method == ensemble_settings["ENSEMBLE_ATTENTION"]:
            self.attention_weights = nn.Linear(self.get_hidden_in(), 1)

    def initialize_loss_specific_components(self):
        if self.weighting_method == weighting_settings["EQUAL_WEIGHTING"]:
            self.loss_weight = [1,1,1]
        if self.weighting_method == weighting_settings["STATIC_WEIGHTING"]:
            self.loss_weight = [0.1, 0.3, 0.6]
        if self.weighting_method == weighting_settings["RELATIVE_WEIGHTING"]:
            self.loss_weight = [1/3, 1/2, 1]
        if self.weighting_method == weighting_settings["HEIRARCHICAL_WEIGHTING"]:
            self.loss_weight = [1,1,1]
        if self.weighting_method in [weighting_settings["PRIORITIZE_HIGH_CONFIDENCE_WEIGHTING"], weighting_settings["PRIORITIZE_LOW_CONFIDENCE_WEIGHTING"]]:
            self.loss_weight = [1,1,1]
        if self.weighting_method == weighting_settings["META_WEIGHTING"]:
            self.meta_weight_layer = nn.Linear(self.subtask_hidden_layer_size, 3)
            self.loss_weight = [1,1,1]

    
    def get_hidden_in(self):
        if len(self.bert_models) == 1:
            return self.bert_models[0].config.hidden_size
        if self.combination_method == ensemble_settings["ENSEMBLE_AVERAGE"]:
            return self.bert_models[0].config.hidden_size
        if self.combination_method == ensemble_settings["ENSEMBLE_ATTENTION"]:
            return self.bert_models[0].config.hidden_size
        if self.combination_method == ensemble_settings["ENSEMBLE_GRAPH"]:
            return self.bert_models[0].config.hidden_size

    def average(self, bert_outputs):
        return torch.mean(torch.stack(bert_outputs), dim=0)

    def attention_aggregate(self, bert_outputs):
        bert_outputs = torch.stack(bert_outputs, dim=1)
        bert_outputs_transposed = bert_outputs.transpose(0, 1)
        attention_weights = F.softmax(self.attention_weights(bert_outputs_transposed), dim=0)
        aggregated_output = torch.mean(bert_outputs_transposed * attention_weights, dim=0)
        
        return aggregated_output

    def graph_aggregate(self, bert_outputs):
        graph_batch = []
        batch_size = bert_outputs[0].shape[0]

        for i in range(batch_size):
            node_features = torch.stack([bert_outputs[j][i] for j in range(len(bert_outputs))])
            edges = []
            for i in range(len(node_features)):
                for j in range(len(node_features)):
                    if i != j:
                        edges.append([i, j])
                        
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            graph = Data(x=node_features, edge_index=edge_index).to(device)

            graph_batch.append(graph)

        x = torch.stack([graph.x for graph in graph_batch], dim=0)
        edge_index = torch.cat([graph.edge_index for graph in graph_batch], dim=1)
        graph_batch = Data(x=x, edge_index=edge_index).to(device)

        x = F.tanh(self.conv1(graph_batch.x, graph_batch.edge_index))
        x = F.dropout(x, p=dropout, training=self.training)
        x = gap(x, graph_batch.batch)
        
        return x
            
    def update_hw_weights(self, sarcasm_loss, sentiment_loss):
        if WEIGHTING_METHOD == weighting_settings["HEIRARCHICAL_WEIGHTING"]:
            self.loss_weight[2] = max(min((sarcasm_loss.item() / sentiment_loss.item()) * self.loss_weight[2], 1), 0.01)

    def update_conf_weights(self, sent_conf, sarc_conf, stance_conf, epoch):
        mean_sarc_conf = sum(sarc_conf).item()/len(sarc_conf)
        mean_sent_conf = sum(sent_conf).item()/len(sent_conf)
        mean_stance_conf = sum(stance_conf).item()/len(stance_conf)

        if self.weighting_method == weighting_settings["PRIORITIZE_LOW_CONFIDENCE_WEIGHTING"]:
            self.loss_weight[0] = 1 + ((1 - mean_sarc_conf)/(epoch * 0.5))
            self.loss_weight[1] = 1 + ((1 - mean_sent_conf)/(epoch * 0.5))
            self.loss_weight[2] = 1 + ((1 - mean_stance_conf)/(epoch * 0.5))

        if self.weighting_method == weighting_settings["PRIORITIZE_HIGH_CONFIDENCE_WEIGHTING"]:
            self.loss_weight[0] = 1 + ((mean_sarc_conf)/(epoch * 0.5))
            self.loss_weight[1] = 1 + ((mean_sent_conf)/(epoch * 0.5))
            self.loss_weight[2] = 1 + ((mean_stance_conf)/(epoch * 0.5))

    def forward(self, input_ids, attention_mask):
        outputs = []
        
        for i, bert in enumerate(self.bert_models):
            if self.pool_bert_output:
                outputs.append(bert(input_ids=input_ids[i], attention_mask=attention_mask[i]).last_hidden_state[:, 0, :])
            else:
                outputs.append(bert(input_ids=input_ids[i], attention_mask=attention_mask[i]).last_hidden_state)
                print("lol")

        combined_emb = outputs[0]
        if len(self.bert_models) > 1:
            if self.combination_method == ensemble_settings["ENSEMBLE_AVERAGE"]:
                combined_emb = self.average(outputs)
            if self.combination_method == ensemble_settings["ENSEMBLE_ATTENTION"]:
                combined_emb = self.attention_aggregate(outputs)
            if self.combination_method == ensemble_settings["ENSEMBLE_GRAPH"]:
                combined_emb = self.graph_aggregate(outputs)
        
        sarcasm_pred, sentiment_pred, stance_pred = None, None, None
        if first_task == "sarcasm":
            sarcasm_logits, sarcasm_pred = self.sarcasm_head(combined_emb)
            input_to_sentiment_head = torch.cat([combined_emb, sarcasm_logits], dim=-1) 

            sentiment_logits, sentiment_pred = self.sentiment_head(input_to_sentiment_head)
            input_to_stance_head = torch.cat([combined_emb, sentiment_logits], dim=-1)

            _, stance_pred = self.stance_head(input_to_stance_head)
        else:
            sentiment_logits, sentiment_pred = self.sentiment_head(combined_emb)
            input_to_sarcasm_head = torch.cat([combined_emb, sentiment_logits], dim=-1)
            sarcasm_logits, sarcasm_pred = self.sarcasm_head(input_to_sarcasm_head)
            input_to_stance_head = torch.cat([combined_emb, sentiment_logits], dim=-1)

            _, stance_pred = self.stance_head(input_to_stance_head)
        
        meta_weights = None
        if self.weighting_method == weighting_settings["META_WEIGHTING"]:
            meta_weights = self.meta_weight_layer(combined_emb)
            meta_weights = F.softmax(meta_weights, dim=-1)
            self.loss_weight = meta_weights.mean(dim=0).tolist()
        
        return sarcasm_pred, sentiment_pred, stance_pred

## TRAINING & VALIDATION

In [22]:
class LinearDecayLR(_LRScheduler):
    def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
        self.start_decay=start_decay
        self.n_epoch=n_epoch
        super(LinearDecayLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        last_epoch = self.last_epoch
        n_epoch=self.n_epoch
        b_lr=self.base_lrs[0]
        start_decay=self.start_decay
        if last_epoch>start_decay:
            lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
        else:
            lr=b_lr
        return [lr]

In [23]:
model = SequentialMultiTaskModel(
    bert_models,
    TaskHead(
        num_sentiment_labels if first_task == "sentiment" else num_sarcasm_labels, 
        BERT_hidden_state_size, 
        is_first_head=True,
        pool_bert_output=pool_bert_output, 
        use_bi=use_bi, 
        use_gru=use_gru
    ),
    TaskHead(
        num_sentiment_labels if first_task == "sarcasm" else num_sarcasm_labels, 
        BERT_hidden_state_size,
        pool_bert_output=pool_bert_output, 
        use_bi=use_bi, 
        use_gru=use_gru
    ),
    TaskHead(
        num_stance_labels,
        BERT_hidden_state_size,
        pool_bert_output=pool_bert_output, 
        use_bi=use_bi, 
        use_gru=use_gru,
        
    ),
    task_head_layer_size,
    ENSEMBLE_METHOD,
    WEIGHTING_METHOD,
    pool_bert_output=pool_bert_output
).to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
ce_loss = nn.CrossEntropyLoss()
lr_scheduler=LinearDecayLR(optimizer, num_epochs, int(num_epochs*0.5))

Some weights of BertModel were not initialized from the model checkpoint at aubmindlab/bert-base-arabertv02-twitter and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
print("BERT Models:", bert_models)
print("Task Head Layer Size:", task_head_layer_size)
print("Batch Size:", batch_size)
print("Num Epochs:", num_epochs)
print("Learning Rate:", learning_rate)
print("Weight Decay:", weight_decay)
print("Dropout:", dropout)
print("Device:", device)
print("Seed:", seed)
print("Ensemble Method:", list(ensemble_settings.keys())[list(ensemble_settings.values()).index(ENSEMBLE_METHOD)])
print("Weighting Method:", list(weighting_settings.keys())[list(weighting_settings.values()).index(WEIGHTING_METHOD)])
print("Pool BERT Output:", pool_bert_output)
print("Use Bi:", use_bi)
print("Use GRU:", use_gru)

BERT Models: ['aubmindlab/bert-base-arabertv02-twitter']
Task Head Layer Size: 256
Batch Size: 32
Num Epochs: 20
Learning Rate: 2e-05
Weight Decay: 1e-05
Dropout: 0.1
Device: mps
Seed: 42
Ensemble Method: ENSEMBLE_ATTENTION
Weighting Method: EQUAL_WEIGHTING
Pool BERT Output: False
Use Bi: True
Use GRU: False


In [25]:
for epoch in range(num_epochs):
    print(f"Epoch [{epoch+1}/{num_epochs}] || Learning Rate: {lr_scheduler.get_lr()} ----------------------------------------------")
    
    # TRAINING -------------------------------------------------------------------------------
    model.train()
    
    train_sarcasm_loss, train_sentiment_loss, train_stance_loss = 0.0, 0.0, 0.0    
    correct_sarcasm, correct_sentiment, correct_stance, total_samples = 0,0,0,0

    for batch in tqdm(train_loader, desc="Training"):
        *tweet_data, sent_conf, sarc_conf, stance_conf, sarc_y, sent_y, stance_y = batch
        
        input_ids, attention_mask = [], []
        for i in range(len(tweet_data)//2):
            input_ids.append(tweet_data[i])
        for i in range(len(tweet_data)//2,len(tweet_data)):
            attention_mask.append(tweet_data[i])
        
        for i in range(len(input_ids)):
            input_ids[i] = input_ids[i].to(device)
            attention_mask[i] = attention_mask[i].to(device)
        sarc_y = sarc_y.to(device)
        sent_y = sent_y.to(device)
        stance_y = stance_y.to(device)
        
        optimizer.zero_grad()

        sarcasm_logits, sentiment_logits, stance_logits = None, None, None
        if first_task == "sarcasm":
            sarcasm_logits, sentiment_logits, stance_logits = model(input_ids, attention_mask)
        else:
            sentiment_logits, sarcasm_logits, stance_logits = model(input_ids, attention_mask)

        if WEIGHTING_METHOD == [weighting_settings["PRIORITIZE_HIGH_CONFIDENCE_WEIGHTING"], weighting_settings["PRIORITIZE_LOW_CONFIDENCE_WEIGHTING"]]: 
            model.update_conf_weights(sent_conf, sarc_conf, stance_conf, epoch+1)

        sarcasm_loss = ce_loss(sarcasm_logits, sarc_y) * model.loss_weight[0]
        sentiment_loss = ce_loss(sentiment_logits, sent_y) * model.loss_weight[1]
        stance_loss = ce_loss(stance_logits, stance_y) * model.loss_weight[2]
        
        total_loss = sarcasm_loss + sentiment_loss + stance_loss
        total_loss.backward()
        optimizer.step()
        
        if WEIGHTING_METHOD ==weighting_settings["HEIRARCHICAL_WEIGHTING"]: 
            model.update_hw_weights(sarcasm_loss, sentiment_loss)

        correct_sarcasm += (sarcasm_logits.argmax(dim=1) == sarc_y).sum().item()
        correct_sentiment += (sentiment_logits.argmax(dim=1) == sent_y).sum().item()
        correct_stance += (stance_logits.argmax(dim=1) == stance_y).sum().item()
        total_samples += input_ids[0].size(0)
        
        train_sarcasm_loss += sarcasm_loss.item()
        train_sentiment_loss += sentiment_loss.item()
        train_stance_loss += stance_loss.item()
           
    avg_sarcasm_loss = train_sarcasm_loss / total_samples
    avg_sentiment_loss = train_sentiment_loss / total_samples
    avg_stance_loss = train_stance_loss / total_samples

    sarcasm_acc = correct_sarcasm / total_samples
    sentiment_acc = correct_sentiment / total_samples
    stance_acc = correct_stance / total_samples

    print(f"Sarcasm -> Loss: {avg_sarcasm_loss:.4f}, Acc: {sarcasm_acc:.4f}")
    print(f"Sentiment -> Loss: {avg_sentiment_loss:.4f}, Acc: {sentiment_acc:.4f}")
    print(f"Stance -> Loss: {avg_stance_loss:.4f}, Acc: {stance_acc:.4f}\n")
    
    
    # VALIDATION -------------------------------------------------------------------------------
    model.eval()
    valid_sarcasm_loss, valid_sentiment_loss, valid_stance_loss = 0.0, 0.0, 0.0
    valid_correct_sarcasm, valid_correct_sentiment, valid_correct_stance, valid_total_samples = 0,0,0,0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            *tweet_data, sent_conf, sarc_conf, stance_conf, sarc_y, sent_y, stance_y = batch

            input_ids, attention_mask = [], []
            for i in range(len(tweet_data)//2):
                input_ids.append(tweet_data[i])
            for i in range(len(tweet_data)//2,len(tweet_data)):
                attention_mask.append(tweet_data[i])

            for i in range(len(input_ids)):
                input_ids[i] = input_ids[i].to(device)
                attention_mask[i] = attention_mask[i].to(device)
            sarc_y = sarc_y.to(device)
            sent_y = sent_y.to(device)
            stance_y = stance_y.to(device)

            sarcasm_logits, sentiment_logits, stance_logits = model(input_ids, attention_mask)

            sarcasm_loss = ce_loss(sarcasm_logits, sarc_y)
            sentiment_loss = ce_loss(sentiment_logits, sent_y)
            stance_loss = ce_loss(stance_logits, stance_y)

            valid_sarcasm_loss += sarcasm_loss.item()
            valid_sentiment_loss += sentiment_loss.item()
            valid_stance_loss += stance_loss.item()

            valid_correct_sarcasm += (sarcasm_logits.argmax(dim=1) == sarc_y).sum().item()
            valid_correct_sentiment += (sentiment_logits.argmax(dim=1) == sent_y).sum().item()
            valid_correct_stance += (stance_logits.argmax(dim=1) == stance_y).sum().item()
            valid_total_samples += input_ids[0].size(0)
    
    avg_valid_sarcasm_loss = valid_sarcasm_loss / valid_total_samples
    avg_valid_sentiment_loss = valid_sentiment_loss / valid_total_samples
    avg_valid_stance_loss = valid_stance_loss / valid_total_samples
    
    valid_sarcasm_acc = valid_correct_sarcasm / valid_total_samples
    valid_sentiment_acc = valid_correct_sentiment / valid_total_samples
    valid_stance_acc = valid_correct_stance / valid_total_samples

    print(f"Sarcasm -> Loss: {avg_valid_sarcasm_loss:.4f}, Acc: {valid_sarcasm_acc:.4f}")
    print(f"Sentiment -> Loss: {avg_valid_sentiment_loss:.4f}, Acc: {valid_sentiment_acc:.4f}")
    print(f"Stance -> Loss: {avg_valid_stance_loss:.4f}, Acc: {valid_stance_acc:.4f}\n\n")
    
    lr_scheduler.step()

Epoch [1/20] || Learning Rate: [2e-05] ----------------------------------------------


Training:   0%|          | 0/85 [00:00<?, ?it/s]

: 

## TESTING