In [1]:
from types import SimpleNamespace

args = SimpleNamespace()

args.group_type = 'social group'

In [2]:
import os

import pandas as pd

from utils.io import read_jsonlines
from utils.parsing import unnest_sequence_annotations
from utils.embedding import E5SentenceEmbedder
from reconstruction_loss_ranker import ReconstructionLossRanker
from utils.qualtrix import make_mention_categorization_block

In [17]:
base_path = os.path.join('..', '..')
data_path = os.path.join(base_path, 'data')
SEED = 42

### Read the mention annotations 

In [4]:
annotations_path = os.path.join(data_path, 'annotations', 'group_mention_detection')
jobs = [nm for nm in os.listdir(annotations_path) if nm.startswith('group-mention-annotation-batch-')]
fps = [os.path.join(data_path, job, 'review_annotations.jsonl') for job in jobs]

parse_entry = lambda x: {k: x[k] for k in ['id', 'text', 'label']}
data = [parse_entry(line) for fp in fps for line in read_jsonlines(fp)]

df = pd.DataFrame(unnest_sequence_annotations(data, keep_text=True))

### Read the mention metadata

In [5]:
fp = os.path.join(data_path, 'manifestos', 'all_manifesto_sentences_translated.tsv')
sentence_df = pd.read_csv(fp, sep='\t')

sentence_df.rename(columns={'sentence_id': 'text_id', 'text_mt_m2m_100_1.2b': 'text_en'}, inplace=True)
sentence_df['manifesto_id'] = sentence_df.text_id.str.split('-', expand=True)[0]

# for each line, within manifesto, get the two texts occuring before and after the line in separate columns 
sentence_df['prev_texts'] = sentence_df.groupby('manifesto_id')['text_en'].shift([2, 1], fill_value='').values.tolist()
sentence_df['prev_texts'] = sentence_df['prev_texts'].apply(lambda x: [t for t in x if t != ''])
sentence_df['next_texts'] = sentence_df.groupby('manifesto_id')['text_en'].shift([-1, -2], fill_value='').values.tolist()
sentence_df['next_texts'] = sentence_df['next_texts'].apply(lambda x: [t for t in x if t != ''])

In [6]:
# merge with annotations
df = pd.merge(df, sentence_df[['text_id', 'prev_texts', 'next_texts']], on='text_id', how='left')

In [8]:
# subset to relevant group category
examples = df[df.type == args.group_type].reset_index(drop=True)

### Embed the mentions

In [7]:
embedder = E5SentenceEmbedder(device='mps')

In [9]:
# we leverage the fact that the E5 embedding model is designed for retrieving relevant answers to a query
#  asking what kind of group a mention refers to shoudl encode its meaning
prompt_template = "Query: In the context of the sentence '''{text}''', what kind of group does '''{mention}''' refer to?"
prompts = examples.apply(lambda r: prompt_template.format(text=r['text'], mention=r['mention']), axis=1).to_list()

In [10]:
embeddings = embedder.encode(prompts, batch_size=32)

  0%|          | 0/145 [00:00<?, ?it/s]

In [22]:
from typing import Literal
from transformers import set_seed
from utils import clean_memory

def rank(
        embeddings,
        epochs: int=25_000,
        seed: int=42,
        device: Literal['cuda', 'mps', 'cpu']='cpu', 
        verbose: bool=False,
    ):
    set_seed(seed)
    ranker = ReconstructionLossRanker(hdim=embeddings.shape[0], num_epochs=epochs, log_n_steps=2_500, device=device, seed=seed)
    idxs, _ = ranker.fit(data=embeddings, verbose=verbose)
    
    ranker.cpu();
    del ranker
    clean_memory(device=device)

    return idxs

In [23]:
idxs = rank(embeddings, epochs=20_000, verbose=True)

[2024-09-12 20:21:32] epoch [1/20000], loss: 47.1449
[2024-09-12 20:25:41] epoch [2501/20000], loss: 40.4800
[2024-09-12 20:29:48] epoch [5001/20000], loss: 34.6800
[2024-09-12 20:33:56] epoch [7501/20000], loss: 28.8878
[2024-09-12 20:38:23] epoch [10001/20000], loss: 23.1037
[2024-09-12 20:42:30] epoch [12501/20000], loss: 17.3286
[2024-09-12 20:46:32] epoch [15001/20000], loss: 11.5663
[2024-09-12 20:48:37] Final loss after 16300 epochs: 0.0116


In [27]:
examples.iloc[idxs, [5, 6]].head(50)

Unnamed: 0,mention,text
2883,The workforce in the companies,The workforce in the companies is getting older.
2777,women from poorer countries,The extreme expression of this is when women f...
3027,those who are not taxed,Even those who are not taxed.
1360,citizens,There is a clear link between the freedom of c...
1230,a minority group of the population of this cou...,"The law, which will respect current migration ..."
1976,people with mental illness,This is especially true of people with mental ...
1894,the Italians,First the Italians in access to social service...
2649,any worker,the revision of the status of temporary worker...
3739,nursing staff,A Netherlands to be proud again!16 billion tax...
1617,the students,The individual higher education institution is...


In [25]:
# return
examples['informativeness_rank'] = idxs+1
examples.sort_values('informativeness_rank', inplace=True)
examples.reset_index(drop=True, inplace=True)

In [29]:
examples.prev_texts = examples.prev_texts.apply(lambda s: '\n'.join(s))
examples.next_texts = examples.next_texts.apply(lambda s: '\n'.join(s))

In [30]:
group_type = args.group_type.replace(' ', '_')
fp = os.path.join(data_path, 'intermediate', f'{group_type}_mentions_ranked.tsv')
examples.to_csv(fp, sep='\t', index=False)