In [1]:
import datetime
import json
import math
import os
import pickle
import random
import re
import string
import subprocess
from pathlib import Path

import contractions
import evaluate
import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import sklearn
import tensorboard
import textattack
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import transformers
from electra_classifier import ElectraClassifier, SarcasmDataModule, SarcasmDataset
from finetuning_scheduler import FinetuningScheduler
from nltk.corpus import stopwords, wordnet
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelSummary
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from sklearn.utils import compute_class_weight
from textattack.augmentation import Augmenter, CharSwapAugmenter, DeletionAugmenter, EasyDataAugmenter, WordNetAugmenter
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
from textattack.transformations import (
    BackTranslation,
    CompositeTransformation,
    WordSwapEmbedding,
    WordSwapExtend,
    WordSwapRandomCharacterDeletion,
    WordSwapRandomCharacterInsertion,
)

# from textattack.transformations.sentence_transformations import BackTranslation
from torch.nn import CrossEntropyLoss
from torch.nn.functional import cross_entropy
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, TensorDataset, WeightedRandomSampler, random_split
from torchmetrics import Accuracy, F1Score, Precision, Recall
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
from transformers import (  # cosine_schedule_with_warmup,; get_linear_schedule_with_warmup,
    AdamW,
    AutoTokenizer,
    DataCollatorWithPadding,
    ElectraConfig,
    ElectraForSequenceClassification,
    ElectraModel,
    ElectraTokenizer,
    TrainingArguments,
)
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.electra.modeling_electra import ElectraClassificationHead

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path = "/workspaces/sarcasm_detection/sarcasm_detection/project_data/Sarcasm_Headlines_Dataset_v2.json"
sub_data_path_train = "/workspaces/sarcasm_detection/sarcasm_detection/project_data/train.csv"
sub_data_path_test = "/workspaces/sarcasm_detection/sarcasm_detection/project_data/test.csv"
version_number = 6
sub_version_number = 9
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f"/workspaces/sarcasm_detection/sarcasm_detection/checkpoints/sarcasm_detection_finetune_ckpt_v{version_number}_{current_time}.ckpt"
sub_checkpoint_path = f"/workspaces/sarcasm_detection/sarcasm_detection/checkpoints/subcat_finetune_ckpt_v{sub_version_number}_{current_time}.ckpt"
checkpoint_directory = os.path.dirname(checkpoint_path)
logdir = "/workspaces/sarcasm_detection/sarcasm_detection/tb_logs"
save_directory = (
    f"/workspaces/sarcasm_detection/sarcasm_detection/saved_models/sarcasm_model_v{version_number}_{current_time}"
)

In [3]:
df = pd.read_csv(sub_data_path_train)
df.head(10)

Unnamed: 0.1,Unnamed: 0,tweet,sarcastic,rephrase,sarcasm,irony,satire,understatement,overstatement,rhetorical_question
0,0,The only thing I got from college is a caffein...,1,"College is really difficult, expensive, tiring...",0.0,1.0,0.0,0.0,0.0,0.0
1,1,I love it when professors draw a big question ...,1,I do not like when professors don’t write out ...,1.0,0.0,0.0,0.0,0.0,0.0
2,2,Remember the hundred emails from companies whe...,1,"I, at the bare minimum, wish companies actuall...",0.0,1.0,0.0,0.0,0.0,0.0
3,3,Today my pop-pop told me I was not “forced” to...,1,"Today my pop-pop told me I was not ""forced"" to...",1.0,0.0,0.0,0.0,0.0,0.0
4,4,@VolphanCarol @littlewhitty @mysticalmanatee I...,1,I would say Ted Cruz is an asshole and doesn’t...,1.0,0.0,0.0,0.0,0.0,0.0
5,5,"@jimrossignol I choose to interpret it as ""XD""...",1,It's a terrible name and the product sounds aw...,0.0,1.0,0.0,1.0,0.0,0.0
6,6,Why would Alexa's recipe for Yorkshire pudding...,1,Great recipe from Alexa,0.0,1.0,0.0,0.0,0.0,1.0
7,7,someone hit me w a horse tranquilizer istg ive...,1,Simply “I’m miserable.”,1.0,0.0,0.0,0.0,0.0,0.0
8,8,Loving season 4 of trump does America. Funnies...,1,this last year of trumps presidency is not goi...,1.0,0.0,0.0,0.0,0.0,0.0
9,9,Holly Arnold ??? Who #ImACeleb #MBE nope not ...,1,"Holly Arnold seem like a nice lady, just feel ...",1.0,0.0,0.0,0.0,0.0,1.0


In [4]:
x_col_types = {
        "tweet": "str",
        "sarcasm": "int32",
        "irony": "int32",
        "satire": "int32",
        "understatement": "int32",
        "overstatement": "int32",
        "rhetorical_question": "int32",
    }

df = (
    df.drop(
        columns=[
            "rephrase",
            "Unnamed: 0",
            "sarcastic",
        ]
    )
    .fillna(0)
    .astype(x_col_types)
)

df.head(20)

Unnamed: 0,tweet,sarcasm,irony,satire,understatement,overstatement,rhetorical_question
0,The only thing I got from college is a caffein...,0,1,0,0,0,0
1,I love it when professors draw a big question ...,1,0,0,0,0,0
2,Remember the hundred emails from companies whe...,0,1,0,0,0,0
3,Today my pop-pop told me I was not “forced” to...,1,0,0,0,0,0
4,@VolphanCarol @littlewhitty @mysticalmanatee I...,1,0,0,0,0,0
5,"@jimrossignol I choose to interpret it as ""XD""...",0,1,0,1,0,0
6,Why would Alexa's recipe for Yorkshire pudding...,0,1,0,0,0,1
7,someone hit me w a horse tranquilizer istg ive...,1,0,0,0,0,0
8,Loving season 4 of trump does America. Funnies...,1,0,0,0,0,0
9,Holly Arnold ??? Who #ImACeleb #MBE nope not ...,1,0,0,0,0,1


In [5]:


aggregated_cols = ["satire", "irony", "overstatement", "understatement", "rhetorical_question"]

print("Original dataset class distribution:")
aggregated_cols = ["satire", "irony", "overstatement", "understatement", "rhetorical_question", "sarcasm"]
for label in aggregated_cols:
    class_counts = df[label].value_counts()
    print(f"{label}:")
    print(class_counts)

all_cols = ['sarcasm', "satire", "irony", "overstatement", "understatement", "rhetorical_question"]

not_sarcastic = df[(df[all_cols]==0).all(axis=1)]
print('not sarcastic')
print(not_sarcastic.shape[0])


Original dataset class distribution:
satire:
satire
0    3443
1      25
Name: count, dtype: int64
irony:
irony
0    3313
1     155
Name: count, dtype: int64
overstatement:
overstatement
0    3428
1      40
Name: count, dtype: int64
understatement:
understatement
0    3458
1      10
Name: count, dtype: int64
rhetorical_question:
rhetorical_question
0    3367
1     101
Name: count, dtype: int64
sarcasm:
sarcasm
0    2755
1     713
Name: count, dtype: int64
not sarcastic
2601


In [6]:
aggregated_cols = ["satire", "overstatement", "understatement", "rhetorical_question"]

overlap_counts = {}

for col in aggregated_cols:
    overlap_counts[col] = len(df[(df['irony'] == 1) & (df[col] == 1)])

print("Overlap between sarcasm and each of the aggregated columns:")
for col, count in overlap_counts.items():
    print(f"{col}: {count}")

Overlap between sarcasm and each of the aggregated columns:
satire: 4
overstatement: 9
understatement: 4
rhetorical_question: 15


In [7]:
df_new = df.copy()
aggregated_cols = ["satire", "irony", "overstatement", "understatement", "rhetorical_question"]
for col in aggregated_cols:
    df_new.loc[(df_new['sarcasm']==1) & df_new[col] == 1, 'sarcasm'] = 0

In [9]:
print(df_new.head(40))

                                                tweet  sarcasm  irony  satire   
0   The only thing I got from college is a caffein...        0      1       0  \
1   I love it when professors draw a big question ...        1      0       0   
2   Remember the hundred emails from companies whe...        0      1       0   
3   Today my pop-pop told me I was not “forced” to...        1      0       0   
4   @VolphanCarol @littlewhitty @mysticalmanatee I...        1      0       0   
5   @jimrossignol I choose to interpret it as "XD"...        0      1       0   
6   Why would Alexa's recipe for Yorkshire pudding...        0      1       0   
7   someone hit me w a horse tranquilizer istg ive...        1      0       0   
8   Loving season 4 of trump does America. Funnies...        1      0       0   
9   Holly Arnold ??? Who #ImACeleb  #MBE nope not ...        0      0       0   
10  ANY PENSIONER AND 4 YEAR OLD WHO DARE TAKE ME ...        0      1       0   
11                          

In [8]:
aggregated_cols = ["satire", "irony", "overstatement", "understatement", "rhetorical_question"]

df['other'] = df[aggregated_cols].any(axis=1).astype(int)
other_count = df['other'].value_counts()

print(other_count)

other
0    3170
1     298
Name: count, dtype: int64


In [6]:
total_ones_in_aggregated_cols = df[aggregated_cols].sum().sum()
print("Total ones in aggregated columns:", total_ones_in_aggregated_cols)

df["other"] = df[aggregated_cols].any(axis=1).astype("int32")
df = df.drop(columns=aggregated_cols)

other_count = df['other'].value_counts()
print("Other counts:\n", other_count)


Total ones in aggregated columns: 1044
Other counts:
 other
0    2601
1     867
Name: count, dtype: int64


In [18]:
train_df, val_df = train_test_split(df, test_size=0.2, random_state=None, shuffle=False)

train_df['not_sarcastic'] = ((train_df['sarcasm'] == 0) & (train_df['other']==0)).astype(int)
labels_list = ["sarcasm", "not_sarcastic", "other"]

train_df.head(20)

Unnamed: 0,tweet,sarcasm,other,not_sarcastic
0,The only thing I got from college is a caffein...,0,1,0
1,I love it when professors draw a big question ...,1,0,0
2,Remember the hundred emails from companies whe...,0,1,0
3,Today my pop-pop told me I was not “forced” to...,1,0,0
4,@VolphanCarol @littlewhitty @mysticalmanatee I...,1,0,0
5,"@jimrossignol I choose to interpret it as ""XD""...",0,1,0
6,Why would Alexa's recipe for Yorkshire pudding...,0,1,0
7,someone hit me w a horse tranquilizer istg ive...,1,0,0
8,Loving season 4 of trump does America. Funnies...,1,0,0
9,Holly Arnold ??? Who #ImACeleb #MBE nope not ...,1,1,0


In [2]:
model = ElectraForSequenceClassification.from_pretrained('google/electra-small-discriminator', num_labels=2)

# Create a sample input
input_ids = torch.tensor([101, 2054, 2003, 2026, 2171, 102]).unsqueeze(0)  # Batch size 1
attention_mask = torch.tensor([1, 1, 1, 1, 1, 1]).unsqueeze(0)  # Batch size 1

# Run the model
outputs = model(input_ids=input_ids, attention_mask=attention_mask)

# Check the outputs
print(outputs)

Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier

SequenceClassifierOutput(loss=None, logits=tensor([[0.0182, 0.0150]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
