In [1]:
import os
import pandas as pd
import numpy as np
import random
from tqdm.auto import tqdm
import pickle
import re
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
from datasets import ClassLabel, load_dataset, Dataset, DatasetDict
import string
from typing import Dict ,List
import transformers

import torch
import torch.nn as nn
from torch.optim import SGD
from torch.optim.lr_scheduler import _LRScheduler
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data
from torch.utils.data import DataLoader, TensorDataset
from transformers import AdamW
from transformers import BertTokenizerFast
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from arabert.preprocess import ArabertPreprocessor
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

import warnings

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## SETTINGS

In [2]:
DATA_PATH = "/Users/gufran/Developer/Projects/AI/MawqifStanceDetection/data"
MODEL_PATH = "/Users/gufran/Developer/Projects/AI/MawqifStanceDetection/models"

In [3]:
model_name = "aubmindlab/bert-base-arabertv02-twitter"
# model_name = "aubmindlab/bert-base-arabertv02"

tokenizer = AutoTokenizer.from_pretrained(model_name)
bert = AutoModel.from_pretrained(model_name)

Some weights of BertModel were not initialized from the model checkpoint at aubmindlab/bert-base-arabertv02-twitter and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
batch_size = 32
num_epochs = 20
learning_rate = 1e-5
device = "mps"

## DATA LOADING AND PREPROCESSING

In [5]:
df = pd.read_csv(os.path.join(DATA_PATH, "cleaned.csv"))
df = df.dropna(subset=["stance"])
df = df.sample(frac=1).reset_index(drop=True)
df = df[["text", "sarcasm", "sentiment", "stance"]]
df.head()

Unnamed: 0,text,sarcasm,sentiment,stance
0,اكيد واكبر دليل التحول الإلكتروني الي حاصل وا...,No,Neutral,Favor
1,#لا_للتطعيم_الاجباري56 شفت حق الصحة في سناب ي...,No,Neutral,Against
2,"شكرا"" لشركة التحول التقني على مبادراتها وبر...",No,Negative,Favor
3,نصيحتي لطلبة الشهادة الثانوية انهم يقرو البرمج...,No,Positive,Favor
4,ولذلك نتمنى الاهتمام بضرورة التحول الرقمي السر...,No,Positive,Favor


In [6]:
mapping_sarcasm = {"No": 0, "Yes": 1}
df['sarcasm'] = df['sarcasm'].map(lambda x: mapping_sarcasm[x])

mapping_sentiment = {"Negative": 0, "Neutral": 1, "Positive": 2}
df['sentiment'] = df['sentiment'].map(lambda x: mapping_sentiment[x])

mapping_stance = {"Favor": 1, "Against": 0}
df['stance'] = df['stance'].map(lambda x: mapping_stance[x])

In [7]:
arabic_punctuations = '''`÷×؛<>()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ'''
english_punctuations = string.punctuation
punctuations_list = arabic_punctuations + english_punctuations

def remove_hash_URL_MEN(text):
    text = re.sub(r'#',' ',text)
    text = re.sub(r'_',' ',text)
    text = re.sub(r'URL','',text)
    text = re.sub(r'MENTION','',text)
    return text

def normalize_arabic(text):
    text = re.sub("[إآ]", "ا", text)
    text = re.sub("گ", "ك", text)
    return text

def remove_punctuations(text):
    translator = str.maketrans('', '', punctuations_list)
    return text.translate(translator)

def remove_repeating_char(text):
    return re.sub(r'(.)\1+', r'\1', text)

def process_tweet(tweet):     
    tweet=remove_hash_URL_MEN(tweet)
    tweet = re.sub('@[^\s]+', ' ', str(tweet))
    tweet = re.sub('((www\.[^\s]+)|(https?://[^\s]+))',' ',str(tweet))    
    tweet= normalize_arabic(str(tweet))
    
    return tweet

arabert_prep = ArabertPreprocessor(model_name=model_name)
df.text = df.text.apply(lambda x: process_tweet(x))
df.text = df.text.apply(lambda x: arabert_prep.preprocess(x))

In [8]:
df.head()

Unnamed: 0,text,sarcasm,sentiment,stance
0,اكيد واكبر دليل التحول الالكتروني الي حاصل وال...,0,1,1
1,لا للتطعيم الاجباري 56 شفت حق الصحة في سناب يق...,0,1,0
2,"شكرا "" لشركة التحول التقني على مبادراتها وبرام...",0,0,1
3,نصيحتي لطلبة الشهادة الثانوية انهم يقرو البرمج...,0,2,1
4,ولذلك نتمنى الاهتمام بضرورة التحول الرقمي السر...,0,2,1


In [9]:
X = df[["text"]]
y = df[["sentiment", "sarcasm", "stance"]]

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
print("X_train shape:", X_train.shape, "y_train shape:", y_train.shape, "X_test shape:", X_test.shape, "y_test shape:", y_test.shape)

X_train shape: (2693, 1) y_train shape: (2693, 3) X_test shape: (476, 1) y_test shape: (476, 3)


In [11]:
x_train_extra = len(X_train) % batch_size
x_test_extra = len(X_test) % batch_size

if x_train_extra > 0:
    indices_to_drop = random.sample(X_train.index.tolist(), x_train_extra)
    X_train = X_train.drop(indices_to_drop)
    y_train = y_train.drop(indices_to_drop)
    
if x_test_extra > 0:
    indices_to_drop = random.sample(X_test.index.tolist(), x_test_extra)
    X_test = X_test.drop(indices_to_drop)
    y_test = y_test.drop(indices_to_drop)

print("Train dataframe length divisible by batch_size:", len(X_train) % batch_size == 0)
print("Test dataframe length divisible by batch_size:", len(X_test) % batch_size == 0)
print("X_train shape:", X_train.shape, "y_train shape:", y_train.shape, "X_test shape:", X_test.shape, "y_test shape:", y_test.shape)

Train dataframe length divisible by batch_size: True
Test dataframe length divisible by batch_size: True
X_train shape: (2688, 1) y_train shape: (2688, 3) X_test shape: (448, 1) y_test shape: (448, 3)


In [12]:
def encode_text(text):
    return tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

encoded_tweets_train = [encode_text(text) for text in X_train["text"]]
encoded_tweets_test = [encode_text(text) for text in X_test["text"]]

In [13]:
len(encoded_tweets_train)

2688

In [14]:
def data_generator(encoded_tweets, labels):
    len_encoded_tweets = len(encoded_tweets)
    num_batches = len_encoded_tweets//batch_size
    
    for b in range(num_batches):
        main_tweets, context_tweets, sentiments, sarcasms, stances, similarities = [], [], [], [], [], []
        for i in range(batch_size*b, (batch_size*b) + batch_size):
            context_tweet_index = random.randint(0, len_encoded_tweets-1)
            main_tweets.append(encoded_tweets[i])
            context_tweets.append(encoded_tweets[context_tweet_index])
            
            mt_sent = labels.sentiment.iloc[i]
            mt_sarc = labels.sarcasm.iloc[i]
            ct_sent = labels.sentiment.iloc[context_tweet_index]
            ct_sarc = labels.sarcasm.iloc[context_tweet_index]
            similarity = (int(mt_sent == ct_sent) + int(mt_sarc == ct_sarc))/2
            
            sentiments.append(labels.sentiment.iloc[i])
            sarcasms.append(labels.sarcasm.iloc[i])
            stances.append(labels.stance.iloc[i])
            similarities.append(similarity)
            
        yield main_tweets, context_tweets, sentiments, sarcasms, stances, similarities

In [15]:
num_sarcasm_labels = len(df.sarcasm.unique())
num_sentiment_labels = len(df.sentiment.unique())
num_stance_labels = len(df.stance.unique())

num_sarcasm_labels, num_sentiment_labels, num_stance_labels

(2, 3, 2)

---

## MODEL ARCHITECTURE

In [16]:
class SubTaskHead(nn.Module):
    def __init__(self, num_labels, dropout, hidden_size):
        super(SubTaskHead, self).__init__()
        
        self.num_labels = num_labels
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_labels)
        
    def forward(self, inputs):
        logits = self.dropout(inputs)
        logits = self.classifier(logits)
        
        return logits

In [17]:
class SimilarityHead(nn.Module):
    def __init__(self):
        super(SimilarityHead, self).__init__()
        
    def forward(self, inputs):
        embedding1, embedding2 = inputs

        cosine_sim = F.cosine_similarity(embedding1, embedding2, dim=1)        
        scaled_similarity = 0.5 * (cosine_sim + 1)
        
        return scaled_similarity

In [18]:
class StanceHead(torch.nn.Module):
    def __init__(self, last_hidden_state_size, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(last_hidden_state_size, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.classifier = nn.Linear(hidden_channels, num_classes)

    def forward(self, mt_last_hidden_state, sentiments, sarcasms, similarities):
        node_features = torch.tensor(mt_last_hidden_state, dtype=torch.float)
        
        edges = []
        edge_features = []
        for i in range(len(sentiments)):
            for j in range(i+1, len(sentiments)):
                if sentiments[i] == sentiments[j] or sarcasms[i] == sarcasms[j]:
                    edges.append((i, j))
                    if sentiments[i] == sentiments[j] and sarcasms[i] == sarcasms[j]:
                        edge_features.append(1.0)
                    else:
                        edge_features.append(0.5)
        
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)
        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr).to(device)
        
        x = self.conv1(data.x, data.edge_index, data.edge_attr)
        x = x.relu()
        x = F.dropout(x, p=0.3, training=self.training)
        
#         x = self.conv2(x, data.edge_index, data.edge_attr)
#         x = x.relu()
#         x = F.dropout(x, p=0.3, training=self.training)

#         x = self.conv3(x, data.edge_index, data.edge_attr)
#         x = x.relu()
#         x = F.dropout(x, p=0.3, training=self.training)

        x = self.classifier(x)
        
        return x

In [19]:
class MultiTaskModel(nn.Module):
    def __init__(self, bert, sentiment_head, sarcasm_head, similarity_head, stance_head, subtask_hidden_layer_size):
        super(MultiTaskModel, self).__init__()

        self.bert = bert
        self.hidden_layer = nn.Linear(bert.config.hidden_size, subtask_hidden_layer_size)
        self.dropout = nn.Dropout(0.2)
        
        self.sentiment_head = sentiment_head
        self.sarcasm_head = sarcasm_head
        self.similarity_head = similarity_head
        self.stance_head = stance_head
        
    def forward(self, mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask):
        mt_outputs = self.bert(input_ids=mt_input_ids, attention_mask=mt_attention_mask)
        ct_outputs = self.bert(input_ids=ct_input_ids, attention_mask=ct_attention_mask)
                
        mt_last_hidden_state = mt_outputs.last_hidden_state[:, 0, :]
        ct_last_hidden_state = ct_outputs.last_hidden_state[:, 0, :]
        
        hidden_output = self.dropout(F.relu(self.hidden_layer(mt_last_hidden_state)))
        sarcasm_logits = self.sarcasm_head(hidden_output)
        sentiment_logits = self.sentiment_head(hidden_output)
        similarity_logits = self.similarity_head((mt_last_hidden_state, ct_last_hidden_state))
        
        sentiments = sentiment_logits.argmax(axis=1)
        sarcasms = sarcasm_logits.argmax(axis=1)
        stance_logits = self.stance_head(mt_last_hidden_state, sentiments, sarcasms, similarity_logits)
        
        return sarcasm_logits, sentiment_logits, similarity_logits, stance_logits

## TRAINING & VALIDATION

In [20]:
class LinearDecayLR(_LRScheduler):
    def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
        self.start_decay=start_decay
        self.n_epoch=n_epoch
        super(LinearDecayLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        last_epoch = self.last_epoch
        n_epoch=self.n_epoch
        b_lr=self.base_lrs[0]
        start_decay=self.start_decay
        if last_epoch>start_decay:
            lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
        else:
            lr=b_lr
        return [lr]

In [None]:
subtask_hidden_layer_size = 256
gnn_hidden_layer_size = 256
bert_emb_size = bert.config.hidden_size

model = MultiTaskModel(
    bert,
    SubTaskHead(num_sentiment_labels, 0.1, subtask_hidden_layer_size),
    SubTaskHead(num_sarcasm_labels, 0.1, subtask_hidden_layer_size),
    SimilarityHead(),
    StanceHead(bert_emb_size, gnn_hidden_layer_size, num_stance_labels),
    subtask_hidden_layer_size,
).to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate)
ce_loss = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

lr_scheduler=LinearDecayLR(optimizer, num_epochs, int(num_epochs*0.5))

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1} || Learning Rate: {lr_scheduler.get_lr()} -------------------------------------------")
    
    # TRAINING -------------------------------------------------------------------------------
    model.train()
    train_loader_gen = data_generator(encoded_tweets_train, y_train)
    valid_loader_gen = data_generator(encoded_tweets_test, y_test)
    
    epoch_sarcasm_loss = 0.0
    epoch_sentiment_loss = 0.0
    epoch_similarity_loss = 0.0
    epoch_stance_loss = 0.0
    
    correct_sarcasm = 0
    correct_sentiment = 0
    correct_similarity = 0
    correct_stance = 0
    total_samples = 0
    
    for main_tweets, context_tweets, sent_y, sarc_y, stance_y, sim_y in tqdm(train_loader_gen, desc="Training"):
        dataset = TensorDataset(
            torch.cat([item["input_ids"] for item in main_tweets]),
            torch.cat([item["attention_mask"] for item in main_tweets]),
            torch.cat([item["input_ids"] for item in context_tweets]),
            torch.cat([item["attention_mask"] for item in context_tweets]),
            torch.tensor(sarc_y),
            torch.tensor(sent_y),
            torch.tensor(stance_y),
            torch.tensor(sim_y),
        )
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        for batch in loader:
            mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask, sarc_y, sent_y, stance_y, sim_y = batch
            mt_input_ids = mt_input_ids.to(device)
            mt_attention_mask = mt_attention_mask.to(device)
            ct_input_ids = ct_input_ids.to(device)
            ct_attention_mask = ct_attention_mask.to(device)
            
            sarc_y = sarc_y.to(device)
            sent_y = sent_y.to(device)
            stance_y = stance_y.to(device)
            sim_y = sim_y.to(device)
                        
            optimizer.zero_grad()
            sarcasm_logits, sentiment_logits, similarity_logits, stance_logits = model(mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask)
            
            sarcasm_loss = ce_loss(sarcasm_logits, sarc_y)
            sentiment_loss = ce_loss(sentiment_logits, sent_y)
            similarity_loss = mse_loss(similarity_logits.squeeze(), sim_y.float())
            stance_loss = ce_loss(stance_logits, stance_y)

            total_loss = sarcasm_loss + sentiment_loss + stance_loss # + similarity_loss
            total_loss.backward()
            optimizer.step()
            
            correct_sarcasm += (sarcasm_logits.argmax(dim=1) == sarc_y).sum().item()
            correct_sentiment += (sentiment_logits.argmax(dim=1) == sent_y).sum().item()
            correct_stance += (stance_logits.argmax(dim=1) == stance_y).sum().item()
            correct_similarity += ((similarity_logits.squeeze() > 0.5).float() == sim_y.float()).sum().item()
            total_samples += mt_input_ids.size(0)
            
            epoch_sarcasm_loss += sarcasm_loss.item()
            epoch_sentiment_loss += sentiment_loss.item()
            epoch_stance_loss += stance_loss.item()
            epoch_similarity_loss += similarity_loss.item()
           
    avg_sarcasm_loss = epoch_sarcasm_loss / total_samples
    avg_sentiment_loss = epoch_sentiment_loss / total_samples
    avg_stance_loss = epoch_stance_loss / total_samples
    avg_similarity_loss = epoch_similarity_loss / total_samples

    sarcasm_acc = correct_sarcasm / total_samples
    sentiment_acc = correct_sentiment / total_samples
    stance_acc = correct_stance / total_samples
    similarity_acc = correct_similarity / total_samples

    print("Training Results:")
    print(f"Sarcasm -> Loss: {avg_sarcasm_loss}, Acc: {sarcasm_acc}")
    print(f"Sentiment -> Loss: {avg_sentiment_loss}, Acc: {sentiment_acc}")
    print(f"Stance -> Loss: {avg_stance_loss}, Acc: {stance_acc}")
    # print(f"Similarity -> Loss: {avg_similarity_loss}")
    
    
    # VALIDATION -------------------------------------------------------------------------------
    model.eval()
    valid_sarcasm_loss = 0.0
    valid_sentiment_loss = 0.0
    valid_similarity_loss = 0.0
    valid_stance_loss = 0.0
    
    valid_correct_sarcasm = 0
    valid_correct_sentiment = 0
    valid_correct_similarity = 0
    valid_correct_stance = 0
    valid_total_samples = 0
    
    with torch.no_grad():
        for main_tweets, context_tweets, sent_y, sarc_y, stance_y, sim_y in tqdm(valid_loader_gen, desc="Validation"):
            dataset = TensorDataset(
                torch.cat([item["input_ids"] for item in main_tweets]),
                torch.cat([item["attention_mask"] for item in main_tweets]),
                torch.cat([item["input_ids"] for item in context_tweets]),
                torch.cat([item["attention_mask"] for item in context_tweets]),
                torch.tensor(sarc_y),
                torch.tensor(sent_y),
                torch.tensor(stance_y),
                torch.tensor(sim_y),
            )
            loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

            for batch in loader:
                mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask, sarc_y, sent_y, stance_y, sim_y = batch
                mt_input_ids = mt_input_ids.to(device)
                mt_attention_mask = mt_attention_mask.to(device)
                ct_input_ids = ct_input_ids.to(device)
                ct_attention_mask = ct_attention_mask.to(device)

                sarc_y = sarc_y.to(device)
                sent_y = sent_y.to(device)
                stance_y = stance_y.to(device)
                sim_y = sim_y.to(device)

                sarcasm_logits, sentiment_logits, similarity_logits, stance_logits = model(mt_input_ids, mt_attention_mask, ct_input_ids, ct_attention_mask)

                sarcasm_loss = ce_loss(sarcasm_logits, sarc_y)
                sentiment_loss = ce_loss(sentiment_logits, sent_y)
                similarity_loss = mse_loss(similarity_logits.squeeze(), sim_y.float())
                stance_loss = ce_loss(stance_logits, stance_y)

                valid_sarcasm_loss += sarcasm_loss.item()
                valid_sentiment_loss += sentiment_loss.item()
                valid_similarity_loss += similarity_loss.item()
                valid_stance_loss += stance_loss.item()

                valid_correct_sarcasm += (sarcasm_logits.argmax(dim=1) == sarc_y).sum().item()
                valid_correct_sentiment += (sentiment_logits.argmax(dim=1) == sent_y).sum().item()
                valid_correct_similarity += ((similarity_logits.squeeze() > 0.5).float() == sim_y.float()).sum().item()
                valid_correct_stance += (stance_logits.argmax(dim=1) == stance_y).sum().item()
                valid_total_samples += mt_input_ids.size(0)
    
    avg_valid_sarcasm_loss = valid_sarcasm_loss / valid_total_samples
    avg_valid_sentiment_loss = valid_sentiment_loss / valid_total_samples
    avg_valid_similarity_loss = valid_similarity_loss / valid_total_samples
    avg_valid_stance_loss = valid_stance_loss / valid_total_samples
    
    valid_sarcasm_acc = valid_correct_sarcasm / valid_total_samples
    valid_sentiment_acc = valid_correct_sentiment / valid_total_samples
    valid_similarity_acc = valid_correct_similarity / valid_total_samples
    valid_stance_acc = valid_correct_stance / valid_total_samples

    print("Validation Results:")
    print(f"Sarcasm -> Loss: {avg_valid_sarcasm_loss}, Acc: {valid_sarcasm_acc}")
    print(f"Sentiment -> Loss: {avg_valid_sentiment_loss}, Acc: {valid_sentiment_acc}")
    print(f"Stance -> Loss: {avg_valid_stance_loss}, Acc: {valid_stance_acc}\n\n")
    # print(f"Similarity -> Loss: {avg_valid_similarity_loss}\n\n")
    
    lr_scheduler.step()

Epoch 1 || Learning Rate: [1e-05] -------------------------------------------


Training: 0it [00:00, ?it/s]

Training Results:
Sarcasm -> Loss: 0.0071515783880992485, Acc: 0.9401041666666666
Sentiment -> Loss: 0.030405390298082716, Acc: 0.5230654761904762
Stance -> Loss: 0.021082696061403977, Acc: 0.6142113095238095


Validation: 0it [00:00, ?it/s]

Validation Results:
Sarcasm -> Loss: 0.004013221727551094, Acc: 0.9754464285714286
Sentiment -> Loss: 0.02544799367231982, Acc: 0.6741071428571429
Stance -> Loss: 0.018935120398444787, Acc: 0.6852678571428571


Epoch 2 || Learning Rate: [1e-05] -------------------------------------------


Training: 0it [00:00, ?it/s]

Training Results:
Sarcasm -> Loss: 0.004783256479727459, Acc: 0.9572172619047619
Sentiment -> Loss: 0.02344830163444082, Acc: 0.6904761904761905
Stance -> Loss: 0.01917207956181041, Acc: 0.6770833333333334


Validation: 0it [00:00, ?it/s]

Validation Results:
Sarcasm -> Loss: 0.003545096759418292, Acc: 0.9754464285714286
Sentiment -> Loss: 0.02175393620772021, Acc: 0.7075892857142857
Stance -> Loss: 0.016989849640854766, Acc: 0.7053571428571429


Epoch 3 || Learning Rate: [1e-05] -------------------------------------------


Training: 0it [00:00, ?it/s]

Training Results:
Sarcasm -> Loss: 0.0041541634085920775, Acc: 0.9572172619047619
Sentiment -> Loss: 0.02019091790896796, Acc: 0.7388392857142857
Stance -> Loss: 0.01734702381128002, Acc: 0.7176339285714286


Validation: 0it [00:00, ?it/s]

Validation Results:
Sarcasm -> Loss: 0.003172670873547239, Acc: 0.9754464285714286
Sentiment -> Loss: 0.021113508913133825, Acc: 0.7142857142857143
Stance -> Loss: 0.015895097383431027, Acc: 0.7455357142857143


Epoch 4 || Learning Rate: [1e-05] -------------------------------------------


Training: 0it [00:00, ?it/s]

## TESTING

In [None]:
# new_tweet = "أنا أؤيد قرار الحكومة الجديدة"
# new_tweet = process_tweet(new_tweet)

# main_tweet_embed1 = torch.randn(1, hidden_size_bert)
# context_tweet_embed1 = torch.randn(1, hidden_size_bert)
# main_tweet_embed2 = torch.randn(1, hidden_size_bert)
# context_tweet_embed2 = torch.randn(1, hidden_size_bert)
# inputs = [
#     [main_tweet_embed1, context_tweet_embed1],
#     [main_tweet_embed2, context_tweet_embed2]
# ]

# # Forward pass through the model
# similarities = model(inputs)
# # sentiment_logits, sarcasm_logits, similarities = model(inputs)

# # Print the outputs
# # print("Sentiment Logits:", sentiment_logits)
# # print("Sarcasm Logits:", sarcasm_logits)
# print("Similarities:", similarities)