In [6]:

# determine if current environment is a python script
is_python_script = '__file__' in globals()

# evaluate below if run as a python script
if is_python_script:
    from argparse import ArgumentParser
    parser = ArgumentParser()

    parser.add_argument('--input_file', type=str, required=True, help='Path to input TSV file containing sentences with predicted group mentions.')
    parser.add_argument('--sentence_text_col', type=str, default='sentence_text', help='Name of the column containing the sentence text.')
    parser.add_argument('--mention_text_col', type=str, default='text', help='Name of the column containing the mention text.')
    parser.add_argument('--group_mention_types', type=str, required=True, help='Comma-separated list of group mention types to classify (e.g., "social group").')
    parser.add_argument('--group_mention_type_col', type=str, default='label', help='Name of the column containing the group mention type labels.')
    
    parser.add_argument('--model_path', type=str, required=True, help='Path to the pre-trained SetFit model for classification.')
    parser.add_argument('--use_span_embeddings', action='store_true', help='Whether to use custom SeFitForSpanClassification Trainer instead of mention and text concatenation or mention-only strategies')
    parser.add_argument('--concat_strategy', type=str, choices=[None, 'prefix', 'suffix'], default=None, help='If not None, concatenate the mention text as prefix or suffix to the context text using --concat_sep_token')
    parser.add_argument('--concat_sep_token', type=str, default=': ', help='Separator token to use when concatenating mention text to context text')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for inference.')

    parser.add_argument('--output_file', type=str, required=True, help='Path to output TSV file to save predictions.')

    parser.add_argument('--test', action='store_true', help='If set, run in test mode with a smaller subset of data.')
    parser.add_argument('--verbose', action='store_true', help='If set, print verbose output during processing.')
    
    args = parser.parse_args()
else:

    from types import SimpleNamespace

    args = SimpleNamespace()

    args.input_file = './../../../data/labeled/manifesto_sentences_predicted_group_mentions_spans.tsv'
    args.sentence_text_col = 'sentence_text'
    args.mention_text_col = 'text'
    args.group_mention_types = 'social group'
    args.group_mention_type_col = 'label'

    args.model_path = './../../../models/all-mpnet-base-v2_economic-attributes-classifier'
    # args.model_path = './../../../models/all-mpnet-base-v2_noneconomic-attributes-classifier'
    args.use_span_embeddings = False # or True
    args.concat_strategy = None # 'prefix', 'suffix' or None
    args.concat_sep_token = ': '  # separator token for prefix/suffix concatenation

    args.batch_size = 64

    args.output_file = './../../../data/labeled/manifesto_sentences_predicted_social_group_mentions_with_economic_attributes_classifications.tsv'
    # args.output_file = './../../../data/labeled/manifesto_sentences_predicted_social_group_mentions_with_noneconomic_attributes_classifications.tsv'

    args.test = False
    args.verbose = True

    args.group_mention_types = [t.strip() for t in args.group_mention_types.split(',')]

In [7]:
import pandas as pd
import regex

import torch
from setfit import SetFitModel
from src.finetuning.setfit_extensions import SetFitModelForSpanClassification

In [8]:
if args.verbose: print(f'Loading model')
device = 'cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device = torch.device(device)
if args.verbose: print('using device:', str(device))

classifier = SetFitModelForSpanClassification.from_pretrained(args.model_path) if args.use_span_embeddings else SetFitModel.from_pretrained(args.model_path)
classifier.to(device);

Loading model
using device: cuda:0


In [11]:
# import torch
# from setfit import SetFitModel

# # model_name = "hauke-licht/{model_name}"
# device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# classifier = SetFitModel.from_pretrained(model_name, device=device)


In [16]:

# Example texts
texts = ["working class people", "highly-educated professionals", "people without a stable job"]

# Get predictions
predictions = classifier.predict(texts, as_numpy=True)
print(predictions)

[
    [
        classifier.id2label[l] 
        for l, p in enumerate(pred) if p==1
    ]
    for pred in predictions
]

[[1 0 0 0 0 0]
 [0 0 0 0 0 1]
 [0 0 0 1 1 0]]


[['economic__class_membership'],
 ['economic__occupation_profession'],
 ['economic__employment_status', 'economic__income_wealth_economic_status']]

In [10]:
# read the input file
sep = None
if args.input_file.endswith('.tsv') or args.input_file.endswith('.tab'):
    sep = '\t'
elif args.input_file.endswith('.csv'):
    sep = ','
else:
    raise ValueError('input file must be a tab-separated (.tsv, .tab) or comma-separated (.csv) file')

df = pd.read_csv(args.input_file, sep=sep)

if args.group_mention_type_col:
    df.rename(columns={args.group_mention_type_col: 'group_type'}, inplace=True)
if args.group_mention_types:
    df = df[df['group_type'].isin(args.group_mention_types)]

if args.test:
    n_ = args.batch_size*1000
    if n_ < len(df):
        df = df.sample(n=n_, random_state=42).reset_index(drop=True)

if args.verbose: print(f'processing {len(df)} texts')

processing 209351 texts


In [11]:
if args.use_span_embeddings:
    if "span" not in df.columns:
    # using span embedding strategy
        df['span'] = df.apply(lambda x: regex.search(regex.escape(x[args.mention_text_col]), x[args.sentence_text_col]).span(), axis=1)
    df['input'] = classifier._normalize_inputs(texts=df[args.sentence_text_col], spans=df['span'])
elif args.concat_strategy is None:
    # default: just the mention text
    df['input'] = df[args.mention_text_col]
else:
    # using concat strategy
    sep_tok = classifier.model_body.tokenizer.sep_token if args.concat_sep_token is None else args.concat_sep_token
    if args.concat_strategy == 'prefix':
        df['input'] = df[args.mention_text_col] + sep_tok + df[args.sentence_text_col]
    elif args.concat_strategy == 'suffix':
        df['input'] = df[args.sentence_text_col] + sep_tok + df[args.mention_text_col]
    else:
        raise ValueError(f"Unknown concat strategy: {args.concat_strategy}")

In [12]:
preds = classifier.predict(df['input'].to_list(), batch_size=args.batch_size, as_numpy=True, use_labels=False, show_progress_bar=True)
label_cols = list(classifier.label2id.keys())
df[classifier.labels] = preds
del df["input"]

Batches:   0%|          | 0/3272 [00:00<?, ?it/s]

In [13]:
if args.verbose: print(f'Writing mention-level predicted labels to {args.output_file}')
df.to_csv(args.output_file, sep='\t', index=False, encoding='utf-8')

Writing mention-level predicted labels to ./../../../data/labeled/manifesto_sentences_predicted_social_group_mentions_with_noneconomic_attributes_classifications.tsv
