In [1]:
from types import SimpleNamespace

args = SimpleNamespace()

args.input_file = '../../data/labeled/manifesto_sentences_predicted_group_mentions_spans.tsv'
args.sentence_text_col = 'sentence_text'
args.mention_text_col = 'text'
args.group_mention_types = 'social group'
args.group_mention_type_col = 'label'

args.model_path = '../../models/social-group-mention-attribute-dimension-classifier'
args.batch_size = 128

args.output_file = '../../data/labeled/manifesto_sentences_predicted_social_group_mentions_with_attribute_dimension_classifications.tsv'

args.test = False
args.verbose = True

args.group_mention_types = [t.strip() for t in args.group_mention_types.split(',')]

In [2]:
import pandas as pd
import torch
from setfit import SetFitModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# read the input file
sep = None
if args.input_file.endswith('.tsv') or args.input_file.endswith('.tab'):
    sep = '\t'
elif args.input_file.endswith('.csv'):
    sep = ','
else:
    raise ValueError('input file must be a tab-separated (.tsv, .tab) or comma-separated (.csv) file')

df = pd.read_csv(args.input_file, sep=sep)

if args.group_mention_type_col:
    df.rename(columns={args.group_mention_type_col: 'group_type'}, inplace=True)
if args.group_mention_types:
    df = df[df['group_type'].isin(args.group_mention_types)]

print(len(df))
if args.test:
    n_ = args.batch_size*100
    if n_ < len(df):
        df = df.sample(n=n_, random_state=42).reset_index(drop=True)

if args.verbose: print(f'processing {len(df)} texts')

209351
processing 209351 texts


In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device = torch.device(device)
if args.verbose: print('using device:', str(device))

classifier = SetFitModel.from_pretrained(args.model_path)
classifier.to(device);

using device: mps


In [48]:
sep = classifier.model_body.tokenizer.sep_token
inputs = df.apply(lambda r: r[args.sentence_text_col] + sep + r[args.mention_text_col], axis=1).tolist()
preds = classifier.predict(inputs, batch_size=args.batch_size, as_numpy=True, use_labels=False)

In [52]:
label_cols = list(classifier.label2id.keys())
df[label_cols] = pd.DataFrame(preds, columns=label_cols)

In [84]:
if args.verbose: print(f'Writing span-level predictions in TSV format to {args.output_file}')
df.to_csv(args.output_file, sep='\t', index=False, encoding='utf-8')

Writing span-level predictions in TSV format to ../../data/labeled/manifesto_sentences_predicted_social_group_mentions_with_attribute_dimension_classifications.tsv
