In [1]:
import os
import pandas as pd
import numpy as np

import random
from itertools import product

from sklearn.metrics import pairwise_distances
from sklearn.model_selection import StratifiedGroupKFold

from datasets import Dataset, DatasetDict

from sentence_transformers import (
    SentenceTransformer, 
    SentenceTransformerTrainingArguments,
    SentenceTransformerTrainer, 
    losses,
    SentenceTransformerModelCardData
)
from sentence_transformers.evaluation import TripletEvaluator
from transformers import EarlyStoppingCallback

  from .autonotebook import tqdm as notebook_tqdm


## load and prepare the data

In [2]:
data_path = '../../data/annotations/group_mention_categorization'
fp = os.path.join(data_path, 'consolidated_annotations.tsv')
df = pd.read_csv(fp, sep='\t')

In [3]:
attributes = ['universal_attributes', 'non-economic_attributes', 'economic_attributes']
df = df[df.q_id.isin(attributes)]
df.loc[:, 'attribute'] = df.q_id.str.removesuffix('_attributes')
attributes = [a.removesuffix('_attributes') for a in attributes]

texts_df = df[['mention_id', 'text', 'mention']].drop_duplicates()
len(texts_df)

449

In [4]:
df.loc[df['category'].isna(), 'category'] = ''
df['attribute_combination'] = df['attribute'] + ": " + df['category']

# pivoting the DataFrame
df_wide = df.pivot_table(index=['mention_id', 'mention'], columns='attribute_combination', values='label', aggfunc='first').reset_index()
df_wide.columns.name = None
df_wide.fillna('No', inplace=True)

cols = df_wide.columns[2:].to_list()

df_wide.loc[:, cols] = df_wide[cols].map(lambda x: 1 if x == 'Yes' else 0)

In [5]:
df_wide['labels'] = df_wide[cols].apply(lambda r: tuple([c[:-2] if c.endswith(': ') else c for c in cols if r[c]==1]), axis=1)
cnts = df_wide['labels'].value_counts()
cnts[cnts>1]

labels
(universal,)                                                                                                             55
(economic: occupation/profession,)                                                                                       54
(non-economic: shared values/mentalities,)                                                                               41
(non-economic: nationality,)                                                                                             37
(non-economic: age,)                                                                                                     30
(economic: income/wealth/economic status,)                                                                               28
(economic: employment status,)                                                                                           18
(non-economic: age, non-economic: family)                                                                                13
(

## construct contastive triplets

In [6]:
# use Jaccard distance to compute similarity between rows' label vectors
sims = pairwise_distances(df_wide[cols].values.astype(bool), metric='jaccard')

# gather triplets
triplets = []
for i, row in enumerate(sims):

    mention = df_wide.mention.iloc[i]
    label = df_wide.labels.iloc[i]
    # labels = df_wide[cols].iloc[i].values

    # get positives
    positive_idxs = np.where(row==0.0)[0].tolist() # get hard positives (all labels in common)
    positive_idxs.remove(i) # remove self
    if len(positive_idxs) == 0:
        continue
    positives = df_wide.iloc[positive_idxs, 1] # get mentions
    positives = positives[positives!=mention] # note: remove duplicates
    positives = positives.to_list()

    # get negatives
    negative_idxs = np.where(row==1.0)[0].tolist() # get hard negatives (no labels in common)
    negatives = df_wide.iloc[negative_idxs, 1] # get mentions
    negatives = negatives[negatives!=mention] # note: remove duplicates (unlikely but as a sanity check)
    negatives = negatives.to_list()

    # balance proportions
    negatives = random.Random(42).sample(negatives, len(positives))

    triplets.extend([(mention, p, n, label) for p, n in product(positives, negatives, repeat=1)])

len(triplets)

471020

## split into train, dev, and test folds

In [7]:
sgk_folder = StratifiedGroupKFold(n_splits=10, shuffle=True, random_state=42)

anchors = list(map(lambda x: x[0], triplets))
labels = list(map(lambda x: x[-1], triplets))
label_strs = list(map(lambda labs: '; '.join(list(labs)), labels))
idxs = [idxs for _, idxs in sgk_folder.split(triplets, label_strs, groups=anchors)]
# NOTE: the warning is expected because some label combinations are unique to single mentions



In [8]:
features = ['anchors', 'positives', 'negatives', 'labels']
df_train = pd.DataFrame([triplets[i] for tmp in idxs[2:] for i in tmp], columns=features)
df_test = pd.DataFrame([triplets[i] for i in idxs[0]], columns=features)
df_dev = pd.DataFrame([triplets[i] for i in idxs[1]], columns=features)

In [9]:
# NOTE: to ensure that the model sees diverse examples, we shuffle the data in a way that prioritizes diversity

# within groups defined by 'labels' column, create row_number indicator (which comes first is random due to prior shuffling)
df_train['r_'] = df_train.groupby('labels').cumcount()
# now make all first ccourences of a label combination first ...
df_train.sort_values(['r_'], inplace=True)
# ... and only shuffle within
df_train = df_train.groupby('r_').apply(lambda x: x.sample(frac=1, random_state=42)).reset_index(drop=True)

# same for dev
df_dev['r_'] = df_dev.groupby('labels').cumcount()
df_dev = df_dev.sort_values(['r_']).reset_index(drop=True)

# same for test
df_test['r_'] = df_test.groupby('labels').cumcount()
df_test = df_test.sort_values(['r_']).reset_index(drop=True)

  df_train = df_train.groupby('r_').apply(lambda x: x.sample(frac=1, random_state=42)).reset_index(drop=True)


In [10]:
df_train.groupby('labels').agg({'r_': 'max'}).sort_values('r_', ascending=False)
# NOTE: we could consider balancing here, but due to the sorting of data and efficiency of contrastive learning, early stopping will avoid overfitting to overrepresented label combinations

Unnamed: 0_level_0,r_
labels,Unnamed: 1_level_1
"(economic: occupation/profession,)",120156
"(universal,)",103101
"(non-economic: shared values/mentalities,)",52799
"(non-economic: nationality,)",31973
"(non-economic: age,)",19140
"(economic: income/wealth/economic status,)",16766
"(economic: employment status,)",3498
"(non-economic: family,)",1004
"(non-economic: age, non-economic: family)",817
"(non-economic: health,)",545


## prepare training

In [11]:
dataset = DatasetDict({
    'train': Dataset.from_pandas(df_train),
    'dev': Dataset.from_pandas(df_dev),
    'test': Dataset.from_pandas(df_test),
})
dataset = dataset.remove_columns(['labels', 'r_'])

In [12]:
# # DEPRECATED (yields not efficiency gain)
# from sentence_transformers.similarity_functions import SimilarityFunction
# from contextlib import nullcontext
# from sklearn.metrics.pairwise import paired_cosine_distances
# import csv

# import logging
# logger = logging.getLogger(__name__)

# class CustomTripletEvaluator(TripletEvaluator):
#     """
#     Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
#     Checks if distance(sentence, positive_example) < distance(sentence, negative_example).

#     Example:
#         ::

#             from sentence_transformers import SentenceTransformer
#             from sentence_transformers.evaluation import TripletEvaluator
#             from datasets import load_dataset

#             # Load a model
#             model = SentenceTransformer('all-mpnet-base-v2')

#             # Load a dataset with (anchor, positive, negative) triplets
#             dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")

#             # Initialize the TripletEvaluator using anchors, positives, and negatives
#             triplet_evaluator = TripletEvaluator(
#                 anchors=dataset[:1000]["anchor"],
#                 positives=dataset[:1000]["positive"],
#                 negatives=dataset[:1000]["negative"],
#                 name="all-nli-dev",
#             )
#             results = triplet_evaluator(model)
#             '''
#             TripletEvaluator: Evaluating the model on the all-nli-dev dataset:
#             Accuracy Cosine Distance:        95.60
#             '''
#             print(triplet_evaluator.primary_metric)
#             # => "all-nli-dev_max_accuracy"
#             print(results[triplet_evaluator.primary_metric])
#             # => 0.956
#     """

#     def __init__(
#         self,
#         **kwargs,
#     ):
#         """
#         Initializes a TripletEvaluator object.

#         Args:
#             anchors (List[str]): Sentences to check similarity to. (e.g. a query)
#             positives (List[str]): List of positive sentences
#             negatives (List[str]): List of negative sentences
#             main_distance_function (Union[str, SimilarityFunction], optional):
#                 The distance function to use. If not specified, use cosine similarity,
#                 dot product, Euclidean, and Manhattan. Defaults to None.
#             name (str): Name for the output. Defaults to "".
#             batch_size (int): Batch size used to compute embeddings. Defaults to 16.
#             show_progress_bar (bool): If true, prints a progress bar. Defaults to False.
#             write_csv (bool): Write results to a CSV file. Defaults to True.
#             truncate_dim (int, optional): The dimension to truncate sentence embeddings to.
#                 `None` uses the model's current truncation dimension. Defaults to None.
#         """
#         if kwargs['main_distance_function'] is None:
#             kwargs['main_distance_function'] = 'cosine'
#         elif kwargs['main_distance_function'] != 'cosine':
#             raise NotImplementedError("Only cosine similarity is supported at the moment.")
#         super().__init__(**kwargs)
#         self.csv_headers = ["epoch", "steps", "accuracy_cosinus"]
        
#     def __call__(
#         self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1
#     ) -> dict[str, float]:
#         if epoch != -1:
#             if steps == -1:
#                 out_txt = f" after epoch {epoch}"
#             else:
#                 out_txt = f" in epoch {epoch} after {steps} steps"
#         else:
#             out_txt = ""
#         if self.truncate_dim is not None:
#             out_txt += f" (truncated to {self.truncate_dim})"

#         logger.info(f"TripletEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:")

#         num_triplets, num_correct_cos_triplets = 0, 0 

#         with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
#             embeddings_anchors = model.encode(
#                 self.anchors,
#                 batch_size=self.batch_size,
#                 show_progress_bar=self.show_progress_bar,
#                 convert_to_numpy=True,
#             )
#             embeddings_positives = model.encode(
#                 self.positives,
#                 batch_size=self.batch_size,
#                 show_progress_bar=self.show_progress_bar,
#                 convert_to_numpy=True,
#             )
#             embeddings_negatives = model.encode(
#                 self.negatives,
#                 batch_size=self.batch_size,
#                 show_progress_bar=self.show_progress_bar,
#                 convert_to_numpy=True,
#             )

#         # Cosine distance
#         pos_cos_distance = paired_cosine_distances(embeddings_anchors, embeddings_positives)
#         neg_cos_distances = paired_cosine_distances(embeddings_anchors, embeddings_negatives)

#         for idx in range(len(pos_cos_distance)):
#             num_triplets += 1

#             if pos_cos_distance[idx] < neg_cos_distances[idx]:
#                 num_correct_cos_triplets += 1

#         accuracy_cos = num_correct_cos_triplets / num_triplets

#         logger.info(f"Accuracy Cosine Distance: \t{accuracy_cos * 100:.2f}")

#         if output_path is not None and self.write_csv:
#             csv_path = os.path.join(output_path, self.csv_file)
#             if not os.path.isfile(csv_path):
#                 with open(csv_path, newline="", mode="w", encoding="utf-8") as f:
#                     writer = csv.writer(f)
#                     writer.writerow(self.csv_headers)
#                     writer.writerow([epoch, steps, accuracy_cos])

#             else:
#                 with open(csv_path, newline="", mode="a", encoding="utf-8") as f:
#                     writer = csv.writer(f)
#                     writer.writerow([epoch, steps, accuracy_cos])

#         self.primary_metric = "cosine_accuracy"
#         metrics = {"cosine_accuracy": accuracy_cos}
#         metrics = self.prefix_name_to_metrics(metrics, self.name)
#         self.store_metrics_in_model_card_data(model, metrics)
#         return metrics

In [13]:
# source checkpoint
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
model = SentenceTransformer(model_id, device='mps')

In [48]:
# trained model 
model_path = '../../models'
run_id = 'paraphrase-mpnet-base-v2-social-group-mention-attributes-embedding'
model_dir = os.path.join(model_path, run_id)

In [14]:
steps_ = 100
training_args = SentenceTransformerTrainingArguments(
    # required
    output_dir=model_dir,
    # hyper parameters:
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    use_mps_device=True,
    # evaluation logging
    logging_steps=steps_,
    eval_strategy="steps",
    eval_steps=steps_,
    eval_on_start=True,
    # for early stopping
    load_best_model_at_end=True,
    metric_for_best_model="cosine_accuracy",
    greater_is_better=True,
    save_strategy="steps",
    save_steps=steps_,
    save_total_limit=2,
    report_to="none",
    # reproducibility
    seed=42,
    data_seed=42,
    full_determinism=True,
)



In [15]:
# NOTE: evaluation is time consuming so we only take first 10K examples
dev_evaluator = TripletEvaluator(
    **dataset['dev'].select(range(10_000)).to_dict(),
    main_distance_function='cosine',
    batch_size=64,
)

## fine-tune

In [16]:
from sentence_transformers.sampler import BatchSampler, DefaultBatchSampler
from torch.utils.data import SequentialSampler

# NOTE: we need to subclass the trainer to avoid shuffling of the training data before each epoch
#       (we already randomly ordered with a priority for diversity above)
class NoShufflingTrainer(SentenceTransformerTrainer):
    def get_batch_sampler(
            self,
            dataset: Dataset,
            batch_size: int,
            drop_last: bool,
            **kwargs,
        ) -> BatchSampler | None:
            return DefaultBatchSampler(
                SequentialSampler(range(len(dataset))), # overwrite `SubsetRandomSampler(range(len(dataset)), generator=generator)`
                batch_size=batch_size,
                drop_last=drop_last,
            )

In [17]:
trainer = NoShufflingTrainer(
    model=model,
    loss=losses.TripletLoss,
    args=training_args,
    train_dataset=dataset['train'],
    evaluator=dev_evaluator,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01),
    ],
)

In [18]:
trainer.train()

  0%|          | 0/5488 [00:23<?, ?it/s]

{'eval_cosine_accuracy': 0.7036, 'eval_dot_accuracy': 0.3047, 'eval_manhattan_accuracy': 0.711, 'eval_euclidean_accuracy': 0.7073, 'eval_max_accuracy': 0.711, 'eval_runtime': 23.6101, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0}


  2%|▏         | 100/5488 [02:06<1:23:09,  1.08it/s]

{'loss': 4.3505, 'grad_norm': 5.524215221405029, 'learning_rate': 3.642987249544627e-06, 'epoch': 0.02}


  2%|▏         | 100/5488 [02:27<1:23:09,  1.08it/s]

{'eval_cosine_accuracy': 0.8194, 'eval_dot_accuracy': 0.1877, 'eval_manhattan_accuracy': 0.8224, 'eval_euclidean_accuracy': 0.8234, 'eval_max_accuracy': 0.8234, 'eval_runtime': 21.2983, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.02}


  4%|▎         | 200/5488 [04:07<1:23:10,  1.06it/s] 

{'loss': 3.1213, 'grad_norm': 10.639139175415039, 'learning_rate': 7.285974499089254e-06, 'epoch': 0.04}


  4%|▎         | 200/5488 [04:29<1:23:10,  1.06it/s]

{'eval_cosine_accuracy': 0.9528, 'eval_dot_accuracy': 0.0499, 'eval_manhattan_accuracy': 0.9529, 'eval_euclidean_accuracy': 0.9523, 'eval_max_accuracy': 0.9529, 'eval_runtime': 22.0717, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.04}


  5%|▌         | 300/5488 [05:54<1:11:35,  1.21it/s] 

{'loss': 1.3908, 'grad_norm': 20.23415756225586, 'learning_rate': 1.0928961748633882e-05, 'epoch': 0.05}


  5%|▌         | 300/5488 [06:16<1:11:35,  1.21it/s]

{'eval_cosine_accuracy': 0.9581, 'eval_dot_accuracy': 0.0462, 'eval_manhattan_accuracy': 0.9592, 'eval_euclidean_accuracy': 0.9582, 'eval_max_accuracy': 0.9592, 'eval_runtime': 22.2712, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.05}


  7%|▋         | 400/5488 [07:41<1:10:24,  1.20it/s] 

{'loss': 0.8727, 'grad_norm': 13.966049194335938, 'learning_rate': 1.4571948998178507e-05, 'epoch': 0.07}


  7%|▋         | 400/5488 [08:03<1:10:24,  1.20it/s]

{'eval_cosine_accuracy': 0.9749, 'eval_dot_accuracy': 0.0314, 'eval_manhattan_accuracy': 0.9789, 'eval_euclidean_accuracy': 0.9749, 'eval_max_accuracy': 0.9789, 'eval_runtime': 21.9051, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.07}


  9%|▉         | 500/5488 [09:49<1:27:44,  1.06s/it] 

{'loss': 0.3826, 'grad_norm': 4.889810562133789, 'learning_rate': 1.8214936247723133e-05, 'epoch': 0.09}


  9%|▉         | 500/5488 [10:11<1:27:44,  1.06s/it]

{'eval_cosine_accuracy': 0.9889, 'eval_dot_accuracy': 0.0116, 'eval_manhattan_accuracy': 0.989, 'eval_euclidean_accuracy': 0.988, 'eval_max_accuracy': 0.989, 'eval_runtime': 21.6023, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.09}


 11%|█         | 600/5488 [11:53<1:06:43,  1.22it/s] 

{'loss': 0.0989, 'grad_norm': 0.4789236783981323, 'learning_rate': 1.979348046163191e-05, 'epoch': 0.11}


 11%|█         | 600/5488 [12:15<1:06:43,  1.22it/s]

{'eval_cosine_accuracy': 0.992, 'eval_dot_accuracy': 0.0092, 'eval_manhattan_accuracy': 0.9746, 'eval_euclidean_accuracy': 0.9907, 'eval_max_accuracy': 0.992, 'eval_runtime': 22.2271, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.11}


 13%|█▎        | 700/5488 [13:51<1:13:16,  1.09it/s] 

{'loss': 0.0747, 'grad_norm': 2.7142200469970703, 'learning_rate': 1.938854019032193e-05, 'epoch': 0.13}


 13%|█▎        | 700/5488 [14:12<1:13:16,  1.09it/s]

{'eval_cosine_accuracy': 0.9903, 'eval_dot_accuracy': 0.0105, 'eval_manhattan_accuracy': 0.9831, 'eval_euclidean_accuracy': 0.9904, 'eval_max_accuracy': 0.9904, 'eval_runtime': 21.796, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.13}


 15%|█▍        | 800/5488 [15:43<1:07:34,  1.16it/s] 

{'loss': 0.0594, 'grad_norm': 2.8423514366149902, 'learning_rate': 1.8983599919011947e-05, 'epoch': 0.15}


 15%|█▍        | 800/5488 [16:05<1:07:34,  1.16it/s]

{'eval_cosine_accuracy': 0.9914, 'eval_dot_accuracy': 0.0089, 'eval_manhattan_accuracy': 0.9868, 'eval_euclidean_accuracy': 0.9923, 'eval_max_accuracy': 0.9923, 'eval_runtime': 21.8326, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 0.15}


 15%|█▍        | 800/5488 [16:08<1:34:34,  1.21s/it]

{'train_runtime': 968.3625, 'train_samples_per_second': 362.678, 'train_steps_per_second': 5.667, 'train_loss': 1.2938791829347611, 'epoch': 0.15}





TrainOutput(global_step=800, training_loss=1.2938791829347611, metrics={'train_runtime': 968.3625, 'train_samples_per_second': 362.678, 'train_steps_per_second': 5.667, 'total_flos': 0.0, 'train_loss': 1.2938791829347611, 'epoch': 0.1457725947521866})

## Evaluate

In [19]:
# note: let's evaluate on all test examples (although time consuming)
test_evaluator = TripletEvaluator(
    **dataset['test'].to_dict(),
    main_distance_function='cosine',
    batch_size=64,
)
test_evaluator(trainer.model)

{'cosine_accuracy': 0.9854797000258598,
 'dot_accuracy': 0.014455650374967675,
 'manhattan_accuracy': 0.9164339281096457,
 'euclidean_accuracy': 0.9642487716576157,
 'max_accuracy': 0.9854797000258598}

## Save model to disk and clean up

In [27]:
import shutil
shutil.rmtree(model_dir)

In [44]:
trainer.model.model_card_data = SentenceTransformerModelCardData(
    language='en',
    model_id=os.path.basename(model_dir),
    model_name=model_id+' finetuned for social group mention attribute classification',
    train_datasets='Social group mention attributes multilabel classifications (Licht & Röth, 2025)',
    task_name='text embedding'
)

The provided 'paraphrase-mpnet-base-v2-social-group-mention-embedding' model ID should include the organization or user, such as "tomaarsen/mpnet-base-nli-matryoshka". Setting `model_id` to None.


In [49]:
model.save_pretrained(model_dir)

In [51]:
df['attribute_combination'].unique()

array(['economic: class membership', 'economic: ecology of group',
       'economic: education level', 'economic: employment status',
       'economic: income/wealth/economic status',
       'economic: occupation/profession', 'economic: other',
       'non-economic: age', 'non-economic: crime',
       'non-economic: ethnicity', 'non-economic: family',
       'non-economic: gender/sexuality', 'non-economic: health',
       'non-economic: nationality', 'non-economic: other',
       'non-economic: place/location', 'non-economic: religion',
       'non-economic: shared values/mentalities', 'universal: '],
      dtype=object)

In [54]:
trainer.model.to('cpu');
model.to('cpu');
del trainer, model