# Load SRL

In [12]:
# load pkl from data\srls\mfc\FRISS_srl.pkl

import pickle

with open('../../data/srls/mfc/FRISS_srl.pkl', 'rb') as f:
    data = pickle.load(f)

print(data.keys())

RangeIndex(start=0, stop=67480, step=1)


In [13]:
# create statistics

num_preds = []

for k, v in data.items():
    num_preds.append(len(v))

print('Number of predicates per sentence')
print('Mean:', sum(num_preds) / len(num_preds))
print('Max:', max(num_preds))
print('Min:', min(num_preds))
print('Number of sentences:', len(num_preds))

Number of predicates per sentence
Mean: 3.1542679312388855
Max: 17
Min: 1
Number of sentences: 67480


In [14]:
data[0]

[{'predicate': 'need',
  'ARG0': 'IMM-10005 PRIMARY Immigrants without HOPE',
  'ARG1': 'help entering college Anxiety'},
 {'predicate': 'entering',
  'ARG0': 'IMM-10005 PRIMARY Immigrants without HOPE',
  'ARG1': 'college Anxiety Jose Alvarado'},
 {'predicate': 'gripped',
  'ARG0': 'IMM-10005 PRIMARY Immigrants without HOPE need help entering college Anxiety',
  'ARG1': 'Jose Alvarado'}]

# Load MRC preprocessed Data

In [60]:
# data\mfc\data_prepared.json
import pandas as pd

with open('../../data/mfc/data_prepared.json', 'r') as f:
    data_prepared = pd.read_json(f)

In [62]:
data_prepared.shape

(67480, 18)

In [5]:
# preprocess text

import re

def preprocess_text(text):
    text = text.replace('\n', ' ')
    text = text.replace('\t', ' ')
    text = text.replace('  ', ' ')
    text = text.strip()

    # some texts start with "IMM-XXXXX PRIMARY" remove
    text = re.sub(r'^IMM-\d+ PRIMARY', '', text)

    # remove leading and trailing whitespaces
    text = text.strip()

    return text

data_prepared['text'] = data_prepared['text'].apply(preprocess_text)

In [6]:
data_prepared.head()

Unnamed: 0,article_id,text,document_frame,Capacity and Resources,Crime and Punishment,Cultural Identity,Economic,External Regulation and Reputation,Fairness and Equality,Health and Safety,"Legality, Constitutionality, Jurisdiction",Morality,Other,Policy Prescription and Evaluation,Political,Public Sentiment,Quality of Life,Security and Defense
0,Immigration1.0-10005,Immigrants without HOPE need help entering col...,Quality of Life,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0
1,Immigration1.0-10005,It mounted as students went around the room te...,Quality of Life,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0
2,Immigration1.0-10005,Georgia Tech.,Quality of Life,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0
3,Immigration1.0-10005,University of Georgia.,Quality of Life,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0
4,Immigration1.0-10005,"""All I could say was, 'I'm planning to see if ...",Quality of Life,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0


In [7]:
# stats for text length

text_lengths = data_prepared['text'].apply(len)

In [8]:
text_lengths.describe()

count    67480.000000
mean       139.985803
std         78.341919
min          1.000000
25%         81.000000
50%        130.000000
75%        188.000000
max       1102.000000
Name: text, dtype: float64

In [9]:
df = data_prepared.copy()

In [56]:
# save to json
df.to_json('../../data/mfc/data_prepared_cleaned.json', orient='records')

In [59]:
# load
with open('../../data/mfc/data_prepared_cleaned.json', 'r') as f:
    df = pd.read_json(f)

df.shape

(67480, 18)

# Load frameaxis

In [69]:
# load data\frameaxis\mfc\frameaxis_contextualized_mft.pkl

with open('../../data/frameaxis/mfc/frameaxis_contextualized_mft.pkl', 'rb') as f:
    data_fa = pickle.load(f)

data_fa.shape

(6097, 11)

# Test SRLs Processor

In [38]:
from collections import defaultdict
import os
import pandas as pd
import pickle
from allennlp.predictors.predictor import Predictor
from tqdm import tqdm

import logging
 
logger = logging.getLogger() 


class SRLProcessor:
    def __init__(
        self,
        df,
        dataframe_path="../notebooks/classifier/X_srl_filtered.pkl",
        force_recalculate=False,
        save_type="pickle",
        device=0,
    ):
        """
        Initializes the SRLProcessor with a DataFrame, a path to a pickle file, and a flag indicating whether to force recalculation.

        Args:
            df (pd.DataFrame): DataFrame with text data.
            dataframe_path (str): Path to save/load the SRL components DataFrame.
            force_recalculate (bool): If True, forces recalculation of SRL components.
        """
        self.df = df
        self.dataframe_path = dataframe_path
        self.force_recalculate = force_recalculate
        self.save_type = save_type
        self.device = device

        # allowed save types
        if self.save_type not in ["csv", "pickle", "json"]:
            raise ValueError(
                "Invalid save_type. Must be one of 'csv', 'pickle', or 'json'."
            )

    def get_srl_embeddings(self):
        """
        Main method to process the SRL components, either by loading them from a pickle file or by recalculating.
        """
        if self.force_recalculate or not os.path.exists(self.dataframe_path):
            return self._recalculate_srl()
        else:
            return self._load_srl()

    def _recalculate_srl(self):
        """
        Recalculates the SRL components for the sentences in the DataFrame and returns a DataFrame
        with columns for article_id, text, and srls, where srls is a list of SRL components for each text entry.
        """
        logger.info("Recalculating SRL components...")
        predictor = Predictor.from_path(
            "https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz",
            cuda_device=self.device,
        )

        # Directly process each text entry to get SRLs and associate with article_id and text
        srl_data = []
        for _, row in tqdm(
            self.df.iterrows(), desc="Processing SRL Batches", total=len(self.df)
        ):
            article_id, text = row["article_id"], row["text"]
            srls = self._extract_srl_batch(
                [text], predictor
            )  # Process a single text entry as a batch of size 1
            srl_data.append(
                {
                    "article_id": article_id,
                    "text": text,
                    "srls": srls[
                        0
                    ],  # Extract the first (and only) element as each text is processed individually
                }
            )

        # Convert the processed data into a DataFrame
        result_df = pd.DataFrame(srl_data)

        # Save the DataFrame if a path is specified
        if self.dataframe_path:
            if self.save_type == "csv":
                result_df.to_csv(self.dataframe_path, index=False)
            elif self.save_type == "json":
                result_df.to_json(self.dataframe_path)
            elif self.save_type == "pickle":
                with open(self.dataframe_path, "wb") as f:
                    pickle.dump(result_df, f)

        return result_df

    def _load_srl(self):
        """
        Loads the SRL components from a pickle file.
        """
        logger.info("Loading SRL components from pickle...")
        with open(self.dataframe_path, "rb") as f:
            srl_series = pickle.load(f)
        return srl_series

    def _extract_srl_batch(self, batched_sentences, predictor):
        """
        Extracts SRL components for a batch of sentences.
        """
        batched_sentences = [{"sentence": sentence} for sentence in batched_sentences]
        batched_srl = predictor.predict_batch_json(batched_sentences)

        results = []
        for srl in batched_srl:
            sentence_results = []

            for verb_entry in srl["verbs"]:
                arg_components = {"ARG0": [], "ARG1": []}
                for i, tag in enumerate(verb_entry["tags"]):
                    if "ARG0" in tag:
                        arg_components["ARG0"].append(srl["words"][i])
                    elif "ARG1" in tag:
                        arg_components["ARG1"].append(srl["words"][i])
                        
                (
                    sentence_results.append(
                        {
                            "predicate": verb_entry["verb"],
                            "ARG0": " ".join(arg_components["ARG0"]),
                            "ARG1": " ".join(arg_components["ARG1"]),
                        }
                    )
                    if arg_components["ARG0"] or arg_components["ARG1"]
                    else {"predicate": "", "ARG0": "", "ARG1": ""}
                )
            results.append(
                sentence_results
                if sentence_results
                else [{"predicate": "", "ARG0": "", "ARG1": ""}]
            )
        return results

    def _batch_process_srl(self, texts, article_ids, predictor, batch_size=32):
        """
        Extracts SRL components for all sentences in a DataFrame in an optimized, batched manner.
        Now also includes article IDs to ensure SRL components are associated with the correct articles.
        """
        results_by_article = defaultdict(list)
        for i in tqdm(range(0, len(texts), batch_size), desc="Processing SRL Batches"):
            batched_sentences = texts[i : i + batch_size].tolist()
            batch_article_ids = article_ids[i : i + batch_size].tolist()
            batch_results = self._extract_srl_batch(batched_sentences, predictor)
            for article_id, srls in zip(batch_article_ids, batch_results):
                results_by_article[article_id].extend(srls)
        return results_by_article


In [35]:
df_first_50 = df.head(50)
df_first = df.head(1)

In [64]:
df.head(2)

Unnamed: 0,article_id,text,document_frame,Capacity and Resources,Crime and Punishment,Cultural Identity,Economic,External Regulation and Reputation,Fairness and Equality,Health and Safety,"Legality, Constitutionality, Jurisdiction",Morality,Other,Policy Prescription and Evaluation,Political,Public Sentiment,Quality of Life,Security and Defense
0,Immigration1.0-10005,Immigrants without HOPE need help entering col...,Quality of Life,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0
1,Immigration1.0-10005,It mounted as students went around the room te...,Quality of Life,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0


In [65]:
srl_processor = SRLProcessor(df, dataframe_path="srl.pkl", force_recalculate=True, device=-1)

In [66]:
df_srls = srl_processor.get_srl_embeddings()

error loading _jsonnet (this is expected on Windows), treating C:\Users\elias\AppData\Local\Temp\tmp5e3o661f\config.json as plain json
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Processing SRL Batch

KeyboardInterrupt: 

In [63]:
# stats for srls

num_srls = []

for i, row in df_srls.iterrows():
    num_srls.append(len(row["srls"]))

print('Number of SRLs per sentence')
print('Mean:', sum(num_srls) / len(num_srls))
print('Max:', max(num_srls))
print('Min:', min(num_srls))

Number of SRLs per sentence
Mean: 2.9
Max: 9
Min: 1


In [67]:
import torch
from torch.utils.data import Dataset
import pandas as pd


class ArticleDataset(Dataset):
    def __init__(
        self,
        X,
        X_srl,
        X_frameaxis,
        tokenizer,
        labels=None,
        max_sentences_per_article=32,
        max_sentence_length=32,
        max_args_per_sentence=10,
        max_arg_length=16,
        frameaxis_dim=20,
    ):
        self.X = X
        self.X_srl = X_srl
        self.X_frameaxis = X_frameaxis
        self.labels = labels

        self.tokenizer = tokenizer
        self.max_sentences_per_article = max_sentences_per_article
        self.max_sentence_length = max_sentence_length
        self.max_args_per_sentence = max_args_per_sentence
        self.max_arg_length = max_arg_length

        self.frameaxis_dim = frameaxis_dim

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        sentences = self.X.iloc[idx]
        srl_data = self.X_srl.iloc[idx]
        frameaxis_data = self.X_frameaxis.iloc[idx]

        # labels
        labels = self.labels.iloc[idx]

        # Tokenize sentences and get attention masks
        sentence_ids, sentence_attention_masks = [], []
        for sentence in sentences:
            encoded = self.tokenizer(
                sentence,
                add_special_tokens=True,
                max_length=self.max_sentence_length,
                truncation=True,
                padding="max_length",
                return_attention_mask=True,
            )
            sentence_ids.append(encoded["input_ids"])
            sentence_attention_masks.append(encoded["attention_mask"])

        # Padding for sentences if necessary
        while len(sentence_ids) < self.max_sentences_per_article:
            sentence_ids.append([0] * self.max_sentence_length)
            sentence_attention_masks.append([0] * self.max_sentence_length)

        sentence_ids = sentence_ids[: self.max_sentences_per_article]
        sentence_attention_masks = sentence_attention_masks[
            : self.max_sentences_per_article
        ]

        # frameaxis
        while len(frameaxis_data) < self.max_sentences_per_article:
            frameaxis_data.append([0] * self.frameaxis_dim)

        frameaxis_data = frameaxis_data[: self.max_sentences_per_article]

        # replace nan values in frameaxis with 0
        frameaxis_data = pd.DataFrame(frameaxis_data).fillna(0).values.tolist()

        # Process SRL data
        predicates, arg0s, arg1s = [], [], []
        predicate_attention_masks, arg0_attention_masks, arg1_attention_masks = (
            [],
            [],
            [],
        )
        for srl_items in srl_data:
            sentence_predicates, sentence_arg0s, sentence_arg1s = [], [], []
            sentence_predicate_masks, sentence_arg0_masks, sentence_arg1_masks = (
                [],
                [],
                [],
            )

            if not isinstance(srl_items, list):
                srl_items = [srl_items]

            for item in srl_items:
                encoded_predicate = self.tokenizer(
                    item["predicate"],
                    add_special_tokens=True,
                    max_length=self.max_arg_length,
                    truncation=True,
                    padding="max_length",
                    return_attention_mask=True,
                )
                encoded_arg0 = self.tokenizer(
                    item["ARG0"],
                    add_special_tokens=True,
                    max_length=self.max_arg_length,
                    truncation=True,
                    padding="max_length",
                    return_attention_mask=True,
                )
                encoded_arg1 = self.tokenizer(
                    item["ARG1"],
                    add_special_tokens=True,
                    max_length=self.max_arg_length,
                    truncation=True,
                    padding="max_length",
                    return_attention_mask=True,
                )

                sentence_predicates.append(encoded_predicate["input_ids"])
                sentence_arg0s.append(encoded_arg0["input_ids"])
                sentence_arg1s.append(encoded_arg1["input_ids"])

                sentence_predicate_masks.append(encoded_predicate["attention_mask"])
                sentence_arg0_masks.append(encoded_arg0["attention_mask"])
                sentence_arg1_masks.append(encoded_arg1["attention_mask"])

            # Padding for SRL elements
            for _ in range(self.max_args_per_sentence):
                sentence_predicates.append([0] * self.max_arg_length)
                sentence_arg0s.append([0] * self.max_arg_length)
                sentence_arg1s.append([0] * self.max_arg_length)

                sentence_predicate_masks.append([0] * self.max_arg_length)
                sentence_arg0_masks.append([0] * self.max_arg_length)
                sentence_arg1_masks.append([0] * self.max_arg_length)

            sentence_predicates = sentence_predicates[: self.max_args_per_sentence]
            sentence_arg0s = sentence_arg0s[: self.max_args_per_sentence]
            sentence_arg1s = sentence_arg1s[: self.max_args_per_sentence]

            sentence_predicate_masks = sentence_predicate_masks[
                : self.max_args_per_sentence
            ]
            sentence_arg0_masks = sentence_arg0_masks[: self.max_args_per_sentence]
            sentence_arg1_masks = sentence_arg1_masks[: self.max_args_per_sentence]

            predicates.append(sentence_predicates)
            arg0s.append(sentence_arg0s)
            arg1s.append(sentence_arg1s)

            predicate_attention_masks.append(sentence_predicate_masks)
            arg0_attention_masks.append(sentence_arg0_masks)
            arg1_attention_masks.append(sentence_arg1_masks)

        # Padding for SRL data
        srl_padding = [[0] * self.max_arg_length] * self.max_args_per_sentence
        mask_padding = [[0] * self.max_arg_length] * self.max_args_per_sentence

        predicates = (predicates + [srl_padding] * self.max_sentences_per_article)[
            : self.max_sentences_per_article
        ]
        arg0s = (arg0s + [srl_padding] * self.max_sentences_per_article)[
            : self.max_sentences_per_article
        ]
        arg1s = (arg1s + [srl_padding] * self.max_sentences_per_article)[
            : self.max_sentences_per_article
        ]

        predicate_attention_masks = (
            predicate_attention_masks + [mask_padding] * self.max_sentences_per_article
        )[: self.max_sentences_per_article]
        arg0_attention_masks = (
            arg0_attention_masks + [mask_padding] * self.max_sentences_per_article
        )[: self.max_sentences_per_article]
        arg1_attention_masks = (
            arg1_attention_masks + [mask_padding] * self.max_sentences_per_article
        )[: self.max_sentences_per_article]

        data = {
            "sentence_ids": torch.tensor(sentence_ids, dtype=torch.long),
            "sentence_attention_masks": torch.tensor(
                sentence_attention_masks, dtype=torch.long
            ),
            "predicate_ids": torch.tensor(predicates, dtype=torch.long),
            "predicate_attention_masks": torch.tensor(
                predicate_attention_masks, dtype=torch.long
            ),
            "arg0_ids": torch.tensor(arg0s, dtype=torch.long),
            "arg0_attention_masks": torch.tensor(
                arg0_attention_masks, dtype=torch.long
            ),
            "arg1_ids": torch.tensor(arg1s, dtype=torch.long),
            "arg1_attention_masks": torch.tensor(
                arg1_attention_masks, dtype=torch.long
            ),
            "frameaxis": torch.tensor(frameaxis_data, dtype=torch.float),
            "labels": torch.tensor(labels[0], dtype=torch.long),
        }

        return data


def custom_collate_fn(batch):
    # Extract individual lists from the batch
    sentence_ids = [item["sentence_ids"] for item in batch]
    sentence_attention_masks = [item["sentence_attention_masks"] for item in batch]
    predicate_ids = [item["predicate_ids"] for item in batch]
    predicate_attention_masks = [item["predicate_attention_masks"] for item in batch]
    arg0_ids = [item["arg0_ids"] for item in batch]
    arg0_attention_masks = [item["arg0_attention_masks"] for item in batch]
    arg1_ids = [item["arg1_ids"] for item in batch]
    arg1_attention_masks = [item["arg1_attention_masks"] for item in batch]
    frameaxis = [item["frameaxis"] for item in batch]
    labels = [item["labels"] for item in batch]

    # Pad each list
    sentence_ids = torch.nn.utils.rnn.pad_sequence(
        sentence_ids, batch_first=True, padding_value=0
    )
    sentence_attention_masks = torch.nn.utils.rnn.pad_sequence(
        sentence_attention_masks, batch_first=True, padding_value=0
    )
    predicate_ids = torch.nn.utils.rnn.pad_sequence(
        predicate_ids, batch_first=True, padding_value=0
    )
    predicate_attention_masks = torch.nn.utils.rnn.pad_sequence(
        predicate_attention_masks, batch_first=True, padding_value=0
    )
    arg0_ids = torch.nn.utils.rnn.pad_sequence(
        arg0_ids, batch_first=True, padding_value=0
    )
    arg0_attention_masks = torch.nn.utils.rnn.pad_sequence(
        arg0_attention_masks, batch_first=True, padding_value=0
    )
    arg1_ids = torch.nn.utils.rnn.pad_sequence(
        arg1_ids, batch_first=True, padding_value=0
    )
    arg1_attention_masks = torch.nn.utils.rnn.pad_sequence(
        arg1_attention_masks, batch_first=True, padding_value=0
    )
    frameaxis = torch.nn.utils.rnn.pad_sequence(
        frameaxis, batch_first=True, padding_value=0
    )
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)

    # Create the output dictionary
    output_dict = {
        "sentence_ids": sentence_ids,
        "sentence_attention_masks": sentence_attention_masks,
        "predicate_ids": predicate_ids,
        "predicate_attention_masks": predicate_attention_masks,
        "arg0_ids": arg0_ids,
        "arg0_attention_masks": arg0_attention_masks,
        "arg1_ids": arg1_ids,
        "arg1_attention_masks": arg1_attention_masks,
        "frameaxis": frameaxis,
        "labels": labels,
    }

    return output_dict


In [159]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertModel, RobertaTokenizerFast, RobertaModel
from nltk.corpus import stopwords
import nltk
import pickle
from sklearn.metrics.pairwise import cosine_similarity
import string
import json

from logging import getLogger

logger = getLogger(__name__)

# make logger print to console
import sys
logger.addHandler(logging.StreamHandler(sys.stdout))


class FrameAxisProcessor:
    def __init__(
        self,
        df,
        path_antonym_pairs="frameaxis/axes/custom.tsv",
        dataframe_path=None,
        bert_model_name="bert-base-uncased",
        name_tokenizer="bert-base-uncased",
        path_name_bert_model="bert-base-uncased",
        force_recalculate=False,
        save_type="pickle",
        dim_names=["positive", "negative"],
    ):
        """
        FrameAxisProcessor constructor

        Args:
        df (pd.DataFrame): DataFrame with text data
        path_antonym_pairs (str): Path to the antonym pairs file
        dataframe_path (str): Path to save the FrameAxis Embeddings DataFrame for saving and loading
        name_tokenizer (str): Name or path of the model
        path_name_bert_model (str): Name or path of the model
        force_recalculate (bool): If True, recalculate the FrameAxis Embeddings
        save_type (str): Type of file to save the FrameAxis Embeddings DataFrame
        """
        self.df = df
        self.force_recalculate = force_recalculate
        self.dataframe_path = dataframe_path

        if bert_model_name == "bert-base-uncased":
            self.tokenizer = BertTokenizerFast.from_pretrained(name_tokenizer)
            self.model = BertModel.from_pretrained(path_name_bert_model)
        elif bert_model_name == "roberta-base":
            self.tokenizer = RobertaTokenizerFast.from_pretrained(name_tokenizer)
            self.model = RobertaModel.from_pretrained(path_name_bert_model)

        if torch.cuda.is_available():
            logger.info("Using CUDA")
            self.model.cuda()

        self.antonym_pairs = {}
        with open(path_antonym_pairs) as f:
            self.antonym_pairs = json.load(f)

        self.dim_names = dim_names

        # Load the stopwords and non-word characters
        nltk.download("stopwords")
        self.stopwords = set(stopwords.words("english"))
        self.non_word_characters = set(string.punctuation)

        # allowed save types
        self.save_type = save_type
        if save_type not in ["csv", "pickle", "json"]:
            raise ValueError(
                "Invalid save_type. Must be one of 'csv', 'pickle', or 'json'."
            )

    def _load_antonym_pairs(self, axis_path):
        axes_df = pd.read_csv(axis_path, sep="\t", header=None)
        return [tuple(x) for x in axes_df.values]

    def precompute_antonym_embeddings(self):
        frame_axis_words = []
        for _, pairs in self.antonym_pairs.items():
            for dim, words in pairs.items():
                frame_axis_words.extend(words)

        antonym_embeddings = {}

        for _, row in tqdm(
            self.df.iterrows(),
            desc="Generating antonym embeddings",
            total=self.df.shape[0],
        ):
            # Access the article text from the 'text' column
            article_text = row["text"]
            embeddings = self.get_embeddings_for_words(article_text, frame_axis_words)

            # Add the embeddings to the microframes based on the word
            for word, embedding in embeddings.items():
                antonym_embeddings.setdefault(word, []).append(embedding)

        antonym_avg_embeddings = {}

        for key, value in tqdm(
            self.antonym_pairs.items(), desc="Generating average embeddings"
        ):
            antonym_avg_embeddings[key] = {}
            for dim, words in tqdm(
                value.items(), desc="Processing dimension", leave=False
            ):
                antonym_avg_embeddings[key][dim] = {}

                for word in tqdm(words, desc="Processing word", leave=False):
                    # Ensure the word is in antonym_embeddings to handle cases where it might not be found
                    if word in antonym_embeddings:
                        word_embed = antonym_embeddings[word]

                        # Convert each tensor in word_embed to the appropriate device (GPU if available)
                        word_embed = [
                            embed.to(self.model.device) for embed in word_embed
                        ]

                        # Get the average of the torch word embeddings, ensuring computation happens on the same device
                        avg_word_embed = torch.mean(torch.stack(word_embed), dim=0)

                        antonym_avg_embeddings[key][dim][word] = avg_word_embed

        microframes = {}

        for key, value in tqdm(
            antonym_avg_embeddings.items(), desc="Generating microframes"
        ):
            microframes[key] = {}

            pos_embeddings = antonym_avg_embeddings[key][self.dim_names[0]]
            neg_embeddings = antonym_avg_embeddings[key][self.dim_names[1]]

            
            pos_embeddings_list = [embed for embed in pos_embeddings.values()]
            neg_embeddings_list = [embed for embed in neg_embeddings.values()]

            # only stack if not empty
            if pos_embeddings_list:
                pos_embedding_avg = torch.mean(torch.stack(pos_embeddings_list), dim=0)
            else:
                pos_embedding_avg = torch.zeros(768)

            if neg_embeddings_list:
                neg_embedding_avg = torch.mean(torch.stack(neg_embeddings_list), dim=0)
            else:
                neg_embedding_avg = torch.zeros(768)
                
            microframes[key] = {
                self.dim_names[0]: pos_embedding_avg,
                self.dim_names[1]: neg_embedding_avg,
            }

        return microframes

    def calculate_word_contributions(self, df, antonym_pairs_embeddings):
        """
        Calculates the bias scores for each word in each document and aggregates them into a list of dictionaries.
        :param df: A DataFrame containing the articles.
        :param antonym_pairs_embeddings: A dictionary containing the embeddings for antonym pairs for each dimension.
        :return: A DataFrame with each row containing a list of dictionaries, each representing a word and its corresponding bias score.
        """

        def calculate_word_contribution(article_id, text):
            words, embeddings = self.get_embeddings_for_text(text)

            if embeddings.numel() == 0:
                print(f"No embeddings found for article {article_id}, words: {words}")
                return []

            # List to collect word contribution dictionaries
            word_contributions = []

            for word, embedding in zip(words, embeddings):
                word_dict = {"word": word}
                for dimension in antonym_pairs_embeddings:
                    pos_embedding = antonym_pairs_embeddings[dimension][self.dim_names[0]].reshape(1, -1)
                    neg_embedding = antonym_pairs_embeddings[dimension][self.dim_names[1]].reshape(1, -1)
                    diff_vector = neg_embedding - pos_embedding

                    diff_norm = F.normalize(diff_vector, p=2, dim=1).to(self.model.device)
                    embedding = embedding.unsqueeze(0).to(self.model.device) if embedding.dim() == 1 else embedding.to(self.model.device)
                    embedding_norm = F.normalize(embedding, p=2, dim=1)

                    cos_sim = torch.matmul(embedding_norm, diff_norm.T).squeeze().cpu().item()

                    word_dict[dimension] = cos_sim
                
                word_contributions.append(word_dict)

            return word_contributions

        tqdm.pandas(desc="Calculating Word Contributions")
        df['word_contributions'] = df.progress_apply(lambda row: calculate_word_contribution(row['article_id'], row['text']), axis=1)

        return df

    def calculate_microframe_bias(self, df):
        # Initialize a DataFrame to collect microframe bias results
        bias_results = []

        # Iterate over each row in the DataFrame
        for idx, row in df.iterrows():
            # Each 'word_contributions' entry is a list of dictionaries with words and their contributions
            word_contributions = row['word_contributions']

            # Initialize a dictionary to hold bias calculations for this article
            bias_dict = {'article_id': row['article_id']}

            if word_contributions:
                dimensions = [k for k in word_contributions[0].keys() if k not in ['word']]
                for dimension in dimensions:
                    # Calculate weighted contributions for each dimension
                    weighted_contributions = sum(d[dimension] for d in word_contributions if dimension in d)

                    # Calculate microframe bias for the dimension
                    microframe_bias = weighted_contributions / len(word_contributions)
                    bias_dict[dimension + '_bias'] = microframe_bias

            # Append the results for this article to the results list
            bias_results.append(bias_dict)

        # Convert the results list to a DataFrame
        bias_df = pd.DataFrame(bias_results)
        bias_df = bias_df.set_index('article_id')

        return bias_df

    def calculate_baseline_bias(self, df):
        """
        Calculate the baseline microframe bias for the entire corpus.

        :param df: A DataFrame with columns for article_id, word, and microframe cosine similarities.
        :return: A dictionary of baseline biases for each microframe dimension.
        """

        baseline_bias = {}        
        for idx, row in df.iterrows():
            word_contributions = row['word_contributions']

            if word_contributions:
                dimensions = [k for k in word_contributions[0].keys() if k not in ['word']]
                for dimension in dimensions:
                    baseline_bias.setdefault(dimension, []).extend(
                        [d[dimension] for d in word_contributions if dimension in d]
                    )

        for dimension in baseline_bias:
            baseline_bias[dimension] = sum(baseline_bias[dimension]) / len(baseline_bias[dimension])

        return baseline_bias

    def calculate_microframe_intensity(self, df):
        """
        Calculate the microframe intensity for each document in the DataFrame.

        :param df: A DataFrame containing the word contributions and article IDs.
        :return: A DataFrame with the microframe intensity for each article and dimension.
        """
        # First, calculate the baseline bias for the corpus
        baseline_bias = self.calculate_baseline_bias(df)

        # Initialize DataFrame to store intensity results
        intensity_df = pd.DataFrame()

        for idx, row in df.iterrows():
            word_contributions = row['word_contributions']

            # Initialize a DataFrame to store the intensity results for this article
            intensity_dict = {'article_id': row['article_id']}
            total_contributions = len(word_contributions)

            for dimension in [k for k in word_contributions[0].keys() if k not in ['word']]:
                # Calculate the second moment for the dimension
                deviations_squared = sum((d[dimension] - baseline_bias[dimension]) ** 2 for d in word_contributions if dimension in d)
                microframe_intensity = deviations_squared / total_contributions if total_contributions else 0

                # Store the results
                intensity_dict[dimension + '_intensity'] = microframe_intensity

            # Append to the intensity DataFrame
            intensity_df = pd.concat([intensity_df, pd.DataFrame([intensity_dict])], ignore_index=True)

        intensity_df = intensity_df.set_index('article_id')

        return intensity_df

    def calculate_all_metrics(self, df, antonym_pairs_embeddings):
        """
        Executes the calculation of word contributions, microframe bias, and microframe intensity for each article.

        :param df: A DataFrame containing articles with 'article_id' and 'text' columns.
        :param antonym_pairs_embeddings: A dictionary containing the embeddings for antonym pairs for each dimension.
        :return: A DataFrame with the structure article_id | dim1_bias | dim1_intensity | ...
        """

        logger.info("Calculating all metrics...")
        logger.info("Step 1: Calculating word contributions...")
        # Step 1: Calculate word contributions for each article and dimension
        word_contributions_df = self.calculate_word_contributions(
            df, antonym_pairs_embeddings
        )

        logger.info("Step 2: Calculating microframe bias...")
        # Step 2: Calculate microframe bias for each article and dimension
        microframe_bias_df = self.calculate_microframe_bias(word_contributions_df)

        logger.info("Step 3: Calculating microframe intensity...")
        # Step 3: Calculate microframe intensity for each article and dimension
        microframe_intensity_df = self.calculate_microframe_intensity(
            word_contributions_df
        )

        if 0 == 1:            
            logger.info("Step 4: Merging bias and intensity dataframes...")
            # Merge the bias and intensity dataframes
            final_df = microframe_bias_df.merge(
                microframe_intensity_df, left_index=True, right_index=True
            )

            print("Size: ", final_df.shape)

            # Reformat the final DataFrame to match the desired structure
            final_df.reset_index(inplace=True)
            final_columns = ["article_id"]
            for dimension in [
                col.replace("_bias", "")
                for col in microframe_bias_df.columns
                if "_bias" in col
            ]:
                final_columns.append(dimension + "_bias")
                final_columns.append(dimension + "_intensity")
            final_df = final_df[final_columns]

        return microframe_intensity_df

    def get_embeddings_for_text(
        self, text, remove_stopwords=True, remove_non_words=True
    ):
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            add_special_tokens=False,
        ).to(self.model.device)

        # Obtain the embeddings
        with torch.no_grad():
            outputs = self.model(**inputs)

        embeddings = outputs.last_hidden_state.squeeze(0)

        # Initialize lists for filtered tokens' embeddings and words
        filtered_embeddings = []
        filtered_words = []

        # Obtain a list mapping words to tokens
        word_ids = inputs.word_ids()

        for w_idx in set(word_ids):
            if w_idx is None:  # Skip special tokens
                continue

            # Obtain the start and end token positions for the current word
            word_tokens_range = inputs.word_to_tokens(w_idx)

            if word_tokens_range is None:
                continue

            start, end = word_tokens_range

            # Reconstruct the word from tokens to check against stopwords and non-word characters
            word = self.tokenizer.decode(inputs.input_ids[0][start:end])

            # Normalize the word for checks
            normalized_word = word.lower().strip(string.punctuation).strip()

            if remove_stopwords and normalized_word in self.stopwords:
                continue

            if remove_non_words and all(
                char in self.non_word_characters for char in normalized_word
            ):
                continue

            # If the word passes the filters, append its embeddings and the word itself
            word_embeddings = embeddings[start:end]
            filtered_embeddings.append(word_embeddings.mean(dim=0))
            filtered_words.append(normalized_word)

        # Stack the filtered embeddings
        filtered_embeddings_tensor = (
            torch.stack(filtered_embeddings)
            if filtered_embeddings
            else torch.tensor([])
        )

        # if no embeddings were found, logger.info debug info
        if filtered_embeddings_tensor.numel() == 0:
            logger.info(
                f"No embeddings found for input text: {text}, after filtering: {filtered_words}"
            )

        return filtered_words, filtered_embeddings_tensor

    def get_embeddings_for_words(self, sentence, words):
        """
        Get the contextualized embeddings for a list of words using the sentence as context.

        :param sentence: The sentence to get embeddings from.
        :param words: A list of words to get embeddings for.
        :return: A dictionary containing the average embeddings for each word.
        """
        sentence_words, word_embeddings = self.get_embeddings_for_text(
            sentence, remove_stopwords=False, remove_non_words=False
        )

        # Initialize dictionary to hold word embeddings
        embeddings = {}

        # Iterate over each word to get its embedding
        for word in words:
            if word in sentence_words:
                word_idx = sentence_words.index(word)

                embedding = word_embeddings[word_idx]

                embeddings[word] = embedding

        return embeddings

    def get_frameaxis_data(self):
        """
        Calculate the FrameAxis Values for the DataFrame

        Returns:
        pd.DataFrame: DataFrame with FrameAxis Embeddings
        """
        # check if self.dataframe_path is None
        if not self.force_recalculate and (
            not self.dataframe_path or not os.path.exists(self.dataframe_path)
        ):
            self.force_recalculate = True

        if self.force_recalculate:
            logger.info("Calculating FrameAxis Embeddings")

            antonym_pairs_embeddings = self.precompute_antonym_embeddings()

            frameaxis_df = self.calculate_all_metrics(self.df, antonym_pairs_embeddings)

            if self.dataframe_path:
                if self.save_type == "csv":
                    frameaxis_df.to_csv(self.dataframe_path, index=False)
                if self.save_type == "json":
                    frameaxis_df.to_json(self.dataframe_path)
                elif self.save_type == "pickle":
                    with open(self.dataframe_path, "wb") as f:
                        pickle.dump(frameaxis_df, f)

            return frameaxis_df
        else:
            # load from pickle
            logger.info("Loading FrameAxis Embeddings")
            with open(self.dataframe_path, "rb") as f:
                frameaxis_df = pickle.load(f)

            return frameaxis_df


In [160]:
frameaxis_preprocessor = FrameAxisProcessor(
    df_first_50,
    path_antonym_pairs="../../data/axis/mft.json",
    dataframe_path="frameaxis.pkl",
    force_recalculate=True,
    save_type="pickle",
    dim_names=["virtue", "vice"],
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\elias\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-d

In [161]:
fx_df = frameaxis_preprocessor.get_frameaxis_data()

Generating antonym embeddings: 100%|██████████| 50/50 [00:06<00:00,  7.73it/s]
Generating average embeddings:   0%|          | 0/5 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Generating average embeddings:  60%|██████    | 3/5 [00:00<00:00, 16.44it/s]
[A
[A
[A
[A
[A
[A
[A
Generating average embeddings: 100%|██████████| 5/5 [00:00<00:00, 15.74it/s]
Generating microframes: 100%|██████████| 5/5 [00:00<00:00, 1248.97it/s]
Calculating Word Contributions:  62%|██████▏   | 31/50 [00:03<00:02,  7.91it/s]

In [None]:
fx_df.head()

[{'word': 'immigrants',
  'care': 0.3375058174133301,
  'loyalty': -0.2993716597557068,
  'authority': 0.006949363276362419,
  'fairness': 0.0,
  'sanctity': -0.37106406688690186},
 {'word': 'without',
  'care': 0.42150604724884033,
  'loyalty': -0.3720255494117737,
  'authority': 0.009534255601465702,
  'fairness': 0.0,
  'sanctity': -0.45589500665664673},
 {'word': 'hope',
  'care': 0.4013954699039459,
  'loyalty': -0.36111950874328613,
  'authority': 0.020152224227786064,
  'fairness': 0.0,
  'sanctity': -0.40954044461250305},
 {'word': 'need',
  'care': 0.419064998626709,
  'loyalty': -0.38503146171569824,
  'authority': 0.020468231290578842,
  'fairness': 0.0,
  'sanctity': -0.40951722860336304},
 {'word': 'help',
  'care': 0.4097133278846741,
  'loyalty': -0.3968883752822876,
  'authority': 0.042825471609830856,
  'fairness': 0.0,
  'sanctity': -0.42373716831207275},
 {'word': 'entering',
  'care': 0.3933647871017456,
  'loyalty': -0.3986949920654297,
  'authority': 0.10797394067