In [None]:
# determine if current environment is a python script
is_python_script = '__file__' in globals()

# evaluate below if run as a python script
if not is_python_script:
    from types import SimpleNamespace
    args = SimpleNamespace()

    args.data_splits_path =  '../../data/annotations/group_mention_categorization/splits/fold01/'
    # args.label_cols = 'economic,noneconomic'
    args.label_cols = 'noneconomic__*'
    
    args.id_col = 'mention_id'
    args.text_col = 'text'
    args.mention_col = 'mention'
    args.span_col = 'span'

    # args.model_name = 'sentence-transformers/all-mpnet-base-v2'
    # args.model_name = "sentence-transformers/all-MiniLM-L6-v2"
    args.model_name = "ibm-granite/granite-embedding-english-r2" # can't use because it uses `pooling_mode_cls_token`
    # args.model_name = "nomic-ai/modernbert-embed-base"
    # # args.model_name = "Alibaba-NLP/gte-modernbert-base" # can't use because it uses `pooling_mode_cls_token`
    # args.model_name = "google/embeddinggemma-300m"
    # args.model_name = "Qwen/Qwen3-Embedding-0.6B"


    args.use_span_embeddings = False # or True
    args.concat_strategy = None # 'prefix', 'suffix' or None
    args.concat_sep_token = ': '  # separator token for prefix/suffix concatenation
    
    args.class_weighting_strategy = 'inverse_proportional'  # or 'balanced' or None
    args.class_weighting_smooth_exponent = 0.5  # default: 0.5

    args.head_learning_rate = 0.001 # default: 0.01
    args.train_batch_sizes = [32, 16] # default
    # args.train_batch_sizes = [32, 8] # for gemma
    # args.train_batch_sizes = [16, 4] # for Qwen3 embedding

    args.body_early_stopping_patience = 2
    args.body_early_stopping_threshold = 0.01
    args.head_early_stopping_patience = 5
    args.head_early_stopping_threshold = 0.015

    strategy = 'span_embedding' if args.use_span_embeddings else 'mention_text' if args.concat_strategy is None else f'concat_{args.concat_strategy}'
    args.save_eval_results_to = f'../../results/classifiers/noneconomic_attributes_classification/model_selection/setfit/{args.model_name.replace("/", "--")}/fold01/{strategy}'
    args.overwrite_results = True
    args.do_eval = True
    args.save_eval_results = True
    args.save_eval_predictions = True
    args.do_test = False
    args.save_test_results = False
    args.save_test_predictions = False

    args.save_model = False
    # args.save_model_to = '../../models/'
    # args.save_model_as = 'social-group-mention-attribute-dimension-classifier-v3'

else: # like __name__ == '__main__'
    
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--data_splits_path', type=str, required=True, help='Path to data splits directory. Should contain files "train.pkl", "val.pkl", and "test.pkl"')
    parser.add_argument('--label_cols', type=str, required=True, help='Comma-separated list of label column names') # TODO: allow glob patterns
    parser.add_argument('--id_col', type=str, default='mention_id', help='Column name for unique mention IDs')
    parser.add_argument('--text_col', type=str , default='text', help='Column name for mention context text')
    parser.add_argument('--mention_col', type=str , default='mention', help='Column name for mention text')
    parser.add_argument('--span_col', type=str , default='span', help='Column name for mention span (start, end)')
    
    parser.add_argument('--model_name', type=str, required=True, help='Name of the model to use. Must be a sentence-transformers compatible model.')
    parser.add_argument('--use_span_embeddings', action='store_true', help='Whether to use custom SeFitForSpanClassification Trainer instead of mention and text concatenation or mention-only strategies')
    parser.add_argument('--concat_strategy', type=str, choices=[None, 'prefix', 'suffix'], default=None, help='If not None, concatenate the mention text as prefix or suffix to the context text using --concat_sep_token')
    parser.add_argument('--concat_sep_token', type=str, default=': ', help='Separator token to use when concatenating mention text to context text')
    
    parser.add_argument('--class_weighting_strategy', type=str, choices=[None, 'balanced', 'inverse_proportional'], default=None, help='Class weighting strategy to use during training')
    parser.add_argument('--class_weighting_smooth_exponent', type=float, default=None, help='Smoothing exponent to use when computing class weights (only relevant if --class_weighting_strategy is set to "inverse_proportional")')

    parser.add_argument('--head_learning_rate', type=float, default=0.01, help='Learning rate to use for classifier head training')
    parser.add_argument('--train_batch_sizes', type=int, nargs='+', default=[32, 8], help='Tuple of batch sizes to use for embedding model and classifier training, respectively')

    parser.add_argument('--body_early_stopping_patience', type=int, default=2, help='Early stopping patience for sentence transformer finetuning')
    parser.add_argument('--body_early_stopping_threshold', type=float, default=0.01, help='Early stopping threshold for sentence transformer finetuning')
    parser.add_argument('--head_early_stopping_patience', type=int, default=5, help='Early stopping patience for classifier head finetuning')
    parser.add_argument('--head_early_stopping_threshold', type=float, default=0.015, help='Early stopping threshold for classifier head finetuning')

    parser.add_argument('--save_eval_results_to', type=str, required=True, help='Directory to save evaluation results to')
    parser.add_argument('--overwrite_results', action='store_true', help='Whether to overwrite existing evaluation results')
    parser.add_argument('--do_eval', action='store_true', help='Whether to perform evaluation on the validation set')
    parser.add_argument('--save_eval_results', action='store_true', help='Whether to save evaluation results to disk')
    parser.add_argument('--save_eval_predictions', action='store_true', help='Whether to save evaluation predictions to disk')
    parser.add_argument('--do_test', action='store_true', help='Whether to perform testing on the test set')
    parser.add_argument('--save_test_results', action='store_true', help='Whether to save test results to disk')
    parser.add_argument('--save_test_predictions', action='store_true', help='Whether to save test predictions to disk')
    
    parser.add_argument('--save_model', action='store_true', help='Whether to save the trained model to disk')
    parser.add_argument('--save_model_to', type=str, help='Directory to save the trained model to')
    parser.add_argument('--save_model_as', type=str, help='Name to save the trained model as')
    
    args = parser.parse_args()

## Setup

In [2]:
import os
from pathlib import Path
import shutil
import warnings
import json

import numpy as np
np.set_printoptions(precision=4, suppress=True)
import pandas as pd
import regex

import torch
torch.set_float32_matmul_precision('high')
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, set_seed
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # to enable deterministic behavior with CuBLAS
SEED = 42
set_seed(SEED, deterministic=True) # for reproducibility

# default setfit body and head
from sentence_transformers import SentenceTransformer
from setfit.modeling import SetFitHead

# class weight head
from src.finetuning.setfit_extensions.class_weights_head import (
    compute_class_weights,
    SetFitHeadWithClassWeights
)
# early stopping model, training args, and trainer
from src.finetuning.setfit_extensions.early_stopping import (
    SetFitModelWithEarlyStopping, 
    EarlyStoppingTrainingArguments,
    EarlyStoppingCallback,
    SetFitEarlyStoppingTrainer
)
# span embedding model, head, and trainer
from src.finetuning.setfit_extensions.span_embedding import (
    SentenceTransformerForSpanEmbedding,
    SetFitModelForSpanClassification,
    SetFitTrainerForSpanClassification,
)

from sklearn.metrics import classification_report
# from utils.metrics import *

In [3]:
def model_init(
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        num_classes: int = 2,
        class_weights: np._typing.NDArray = None,
        multilabel: bool = False,
        use_span_embedding: bool = False,
        body_kwargs: dict = {},
        head_kwargs: dict = {},
        model_kwargs: dict = {},
    ) -> SetFitModelWithEarlyStopping | SetFitModelForSpanClassification:
    """
    Initialize a SetFit model with optional span embeddings and class weights.
    """
    
    body_class = SentenceTransformerForSpanEmbedding if use_span_embedding else SentenceTransformer
    body_kwargs={"device_map": "auto", **body_kwargs}
    body = body_class(model_name, model_kwargs=body_kwargs, trust_remote_code=True)
    
    
    head_class = SetFitHead
    head_kwargs = {
        "in_features": body.get_sentence_embedding_dimension(),
        "out_features": num_classes,
        "device": body.device,
        "multitarget": multilabel,
        **head_kwargs
    }
    if class_weights is not None:
        head_class = SetFitHeadWithClassWeights
        head_kwargs["class_weights"] = class_weights
    head = head_class(**head_kwargs)
    

    model_class = SetFitModelForSpanClassification if use_span_embedding else SetFitModelWithEarlyStopping
    if multilabel and "multi_target_strategy" not in model_kwargs:
        model_kwargs["multi_target_strategy"] = "one-vs-rest"
    return model_class(
        model_body=body,
        model_head=head.to(body.device),
        normalize_embeddings=True,
        **model_kwargs
    )

In [4]:
args.data_splits_path = Path(args.data_splits_path)

if isinstance(args.label_cols, str):
    args.label_cols = [col.strip() for col in args.label_cols.split(',')]

if args.save_eval_results_to is not None:
    args.save_eval_results_to = Path(args.save_eval_results_to)
    if not (args.do_eval or args.do_test):
        raise ValueError("'save_eval_results_to' is specified but neither 'do_eval' nor 'do_test' is set.")
    elif not any([args.save_eval_results, args.save_eval_predictions, args.save_test_results, args.save_test_predictions]):
        warnings.warn("'save_eval_results_to' is specified but none of 'save_eval_results', 'save_eval_predictions', 'save_test_results', or 'save_test_predictions' is set.")
    elif args.save_eval_results_to.exists() and not args.overwrite_results:
        raise ValueError(f"The directory '{args.save_eval_results_to}' already exists. To avoid overwriting, please specify a different path or set 'overwrite_results'")
    else:
        args.save_eval_results_to.mkdir(parents=True, exist_ok=True)

if args.save_model:
    if args.save_model_to is None or args.save_model_as is None:
        raise ValueError("Both 'save_model_to' and 'save_model_as' must be specified if 'save_model' is True.")
    args.save_model_to = Path(args.save_model_to)

## Prepare the datasets

### Load the splits

In [5]:
df = pd.concat({split: pd.read_pickle(args.data_splits_path / f"{split}.pkl") for split in ['train', 'val', 'test']})
df.reset_index(level=0, names='split', inplace=True)
df['split'] = pd.Categorical(df['split'], categories=['train', 'val', 'test'], ordered=True)

### prepare the label column

In [6]:
# consider that entries in args.label_cols may be glob patterns
import fnmatch
expanded_label_cols = []
for lab in args.label_cols:
    matched = fnmatch.filter(df.columns, lab)
    if matched:
        expanded_label_cols.extend(matched)
    else:
        expanded_label_cols.append(lab)
args.label_cols = expanded_label_cols

In [7]:
df['labels'] = df[args.label_cols].apply(list, axis=1)

### format inputs

In [8]:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)

In [9]:
if args.use_span_embeddings:
    if "span" not in df.columns:
    # using span embedding strategy
        df['span'] = df.apply(lambda x: regex.search(regex.escape(x[args.mention_col]), x[args.text_col]).span(), axis=1)
    max_length_ = max(tokenizer(df[args.text_col].to_list(), truncation=False, padding=False, return_length=True).length)
    cols = [args.text_col, 'span', 'labels']
    cols_mapping = {args.text_col: 'text', 'span': 'span', 'labels': 'label'}
elif args.concat_strategy is None:
    # default: just the mention text
    max_length_ = max(tokenizer(df[args.mention_col].to_list(), truncation=False, padding=False, return_length=True).length)
    cols = [args.mention_col, 'labels']
    cols_mapping = {args.mention_col: 'text', 'labels': 'label'}
else:
    # using concat strategy
    sep_tok = tokenizer.sep_token if args.concat_sep_token is None else args.concat_sep_token
    if args.concat_strategy == 'prefix':
        df['input'] = df[args.mention_col] + sep_tok + df[args.text_col]
    elif args.concat_strategy == 'suffix':
        df['input'] = df[args.text_col] + sep_tok + df[args.mention_col]
    else:
        raise ValueError(f"Unknown concat strategy: {args.concat_strategy}")
    max_length_ = max(tokenizer(df['input'].to_list(), truncation=False, padding=False, return_length=True).length)
    cols = ['input', 'labels']
    cols_mapping = {"input": "text", "labels": "label"}

### split the data

In [10]:
datasets = DatasetDict({
    s: Dataset.from_pandas(d, preserve_index=False)
    for s, d in df.groupby('split', observed=True)
})

In [11]:
datasets = datasets.remove_columns(set(df.columns)-set(cols))

In [12]:
datasets = datasets.rename_columns(column_mapping=cols_mapping)

In [13]:
datasets.num_rows

{'train': 388, 'val': 81, 'test': 131}

### Prepare fine-tuning

In [14]:
id2label = {i: l for i, l in enumerate(args.label_cols)}
label2id = {l: i for i, l in id2label.items()}

In [15]:
if args.class_weighting_strategy in ['inverse_proportional']:
    class_weighting_args = {
        "multitarget": len(args.label_cols) > 1,
        "smooth_weights": args.class_weighting_strategy is not None and args.class_weighting_strategy != 'balanced',
        "smooth_exponent": args.class_weighting_smooth_exponent if args.class_weighting_smooth_exponent is not None else 0.5
    }
    class_weights = compute_class_weights(datasets['train']['label'], **class_weighting_args)
    print(f"Class weights: {dict(zip(label2id.keys(), class_weights))}")
else:
    class_weights = None

Class weights: {'noneconomic__age': 2.9495762407505253, 'noneconomic__crime': 4.533823502911814, 'noneconomic__ethnicity': 3.65655170486763, 'noneconomic__family': 3.4544657088084305, 'noneconomic__gender_sexuality': 4.671566055592568, 'noneconomic__health': 4.2895221179054435, 'noneconomic__nationality': 2.3614129639774317, 'noneconomic__place_location': 6.489307444643928, 'noneconomic__religion': 4.671566055592568, 'noneconomic__shared_values_mentalities': 2.385299807658105}


In [None]:
from sentence_transformers.losses import ContrastiveLoss

if args.save_model:
    if args.save_model_to is None or args.save_model_as is None:
        raise ValueError("Both 'save_model_to' and 'save_model_as' must be specified if 'save_model' is True.")
    model_dir = args.save_model_to / args.save_model_as 
else:
    from tempfile import TemporaryDirectory
    with TemporaryDirectory() as tmpdirname:
        model_dir = tmpdirname

training_args = EarlyStoppingTrainingArguments(
    output_dir=model_dir,
    loss=ContrastiveLoss,
    
    num_epochs=(1, 25),
    batch_size=tuple(args.train_batch_sizes),

    head_learning_rate = args.head_learning_rate,
    # l2_weight=0.03,# TODO !!! /default 0.01
    # warmup_proportion=0.15, # TODO !!! /default 0.1
    
    # sentence transformer (embedding) finetuning args
    logging_first_step=False,
    eval_strategy="steps",
    eval_steps=50,
    max_steps=750,
    eval_max_steps=250,
    
    # early stopping config
    metric_for_best_model=("embedding_loss", "f1"),
    greater_is_better=(False, True),
    load_best_model_at_end=True,
    save_total_limit=2, # NOTE: currently no effect on (early stopping in) classification head training
    
    # misc
    end_to_end=True,
)

training_callbacks = [
    # for sentence transformer finetuning
    EarlyStoppingCallback(
        early_stopping_patience=args.body_early_stopping_patience,
        early_stopping_threshold=args.body_early_stopping_threshold,
    ), 
    # for classifier finetuning
    EarlyStoppingCallback(
        early_stopping_patience=args.head_early_stopping_patience,
        early_stopping_threshold=args.head_early_stopping_threshold,
    ), 
]

In [17]:
trainer_class = SetFitTrainerForSpanClassification if args.use_span_embeddings else SetFitEarlyStoppingTrainer

trainer = trainer_class(
    model_init=lambda: model_init(
        model_name=args.model_name,
        num_classes=len(id2label),
        multilabel=True,
        class_weights=class_weights,
        use_span_embedding=args.use_span_embeddings,
    ),
    metric="f1",
    metric_kwargs={
        "average": "macro" if args.label_cols and len(args.label_cols) > 1 else "binary",
        # "zero_division": 0.0
    },
    args=training_args,
    train_dataset=datasets['train'],
    eval_dataset=datasets['val'],
    callbacks=training_callbacks,
    # column_mapping=cols_mapping,
)
# fix max_length issue
trainer._args.max_length = min(trainer.st_trainer.model.tokenizer.model_max_length, int(max_length_*1.1))

# set seeds for reproducibility
trainer._args.seed = SEED
trainer.st_trainer.args.seed = SEED
trainer.st_trainer.args.data_seed = SEED
trainer.st_trainer.args.full_determinism = True

# don't report to wandb or other experiment trackers
trainer._args.report_to = 'none'
trainer.st_trainer.args.report_to = 'none'

Map:   0%|          | 0/388 [00:00<?, ? examples/s]

### Fine-tune

In [18]:
trainer.train()

# clean up
if os.path.exists(model_dir):
    shutil.rmtree(model_dir)

***** Running training *****
  Num unique pairs = 24000
  Batch size = 32
  Num epochs = 1


Step,Training Loss,Validation Loss
50,0.0256,0.022608
100,0.0137,0.017154
150,0.0062,0.015488


Epoch:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1: calculating validation f1 for early stopping


Epoch:   4%|▍         | 1/25 [00:01<00:45,  1.89s/it]

{'training loss': 0.7563324356079102, 'validation loss': 0.6972164809703827, 'f1': 0.0838095238095238}




Epoch 2: calculating validation f1 for early stopping


Epoch:   8%|▊         | 2/25 [00:03<00:42,  1.84s/it]

{'training loss': 0.6237829291820526, 'validation loss': 0.5429519265890121, 'f1': 0.05}




Epoch 3: calculating validation f1 for early stopping


Epoch:  12%|█▏        | 3/25 [00:05<00:35,  1.63s/it]

{'training loss': 0.5250413858890534, 'validation loss': 0.5189195772012075, 'f1': 0.34894605394605394}




Epoch 4: calculating validation f1 for early stopping


Epoch:  16%|█▌        | 4/25 [00:06<00:31,  1.49s/it]

{'training loss': 0.4396515822410583, 'validation loss': 0.40876491367816925, 'f1': 0.34325396825396826}




Epoch 5: calculating validation f1 for early stopping


Epoch:  20%|██        | 5/25 [00:07<00:29,  1.45s/it]

{'training loss': 0.37032562017440795, 'validation loss': 0.3612157727281253, 'f1': 0.5383155080213904}




Epoch 6: calculating validation f1 for early stopping


Epoch:  24%|██▍       | 6/25 [00:09<00:26,  1.40s/it]

{'training loss': 0.33077362179756165, 'validation loss': 0.3514048134287198, 'f1': 0.5108080808080808}




Epoch 7: calculating validation f1 for early stopping


Epoch:  28%|██▊       | 7/25 [00:10<00:25,  1.39s/it]

{'training loss': 0.3076485604047775, 'validation loss': 0.5460466047128042, 'f1': 0.5566666666666666}




Epoch 8: calculating validation f1 for early stopping


Epoch:  32%|███▏      | 8/25 [00:11<00:23,  1.39s/it]

{'training loss': 0.28406064689159394, 'validation loss': 0.550013134876887, 'f1': 0.5746256684491978}




Epoch 9: calculating validation f1 for early stopping


Epoch:  36%|███▌      | 9/25 [00:13<00:21,  1.36s/it]

{'training loss': 0.2653011679649353, 'validation loss': 0.3497401873270671, 'f1': 0.5753612854891627}




Epoch 10: calculating validation f1 for early stopping


Epoch:  40%|████      | 10/25 [00:14<00:20,  1.34s/it]

{'training loss': 0.2524099487066269, 'validation loss': 0.36011242618163425, 'f1': 0.5859673460952233}




Epoch 11: calculating validation f1 for early stopping


Epoch:  44%|████▍     | 11/25 [00:15<00:18,  1.32s/it]

{'training loss': 0.2414713567495346, 'validation loss': 0.5789640173316002, 'f1': 0.5844521945800717}




Epoch 12: calculating validation f1 for early stopping


Epoch:  48%|████▊     | 12/25 [00:17<00:17,  1.34s/it]

{'training loss': 0.23357467532157897, 'validation loss': 0.3124457647403081, 'f1': 0.5950582551861323}




Epoch 13: calculating validation f1 for early stopping


Epoch:  52%|█████▏    | 13/25 [00:18<00:15,  1.32s/it]

{'training loss': 0.22657135963439942, 'validation loss': 0.36317818860212964, 'f1': 0.5927855279134051}




Epoch 14: calculating validation f1 for early stopping


Epoch:  56%|█████▌    | 14/25 [00:19<00:14,  1.31s/it]

{'training loss': 0.21680176317691802, 'validation loss': 0.3572480579217275, 'f1': 0.6033915885194657}




Epoch 15: calculating validation f1 for early stopping


Epoch:  60%|██████    | 15/25 [00:20<00:13,  1.31s/it]

{'training loss': 0.2121681135892868, 'validation loss': 0.28647921855250996, 'f1': 0.6033915885194657}




Epoch 16: calculating validation f1 for early stopping


Epoch:  64%|██████▍   | 16/25 [00:22<00:11,  1.30s/it]

{'training loss': 0.20419933915138244, 'validation loss': 0.391651709874471, 'f1': 0.6033915885194657}




Epoch 17: calculating validation f1 for early stopping


Epoch:  64%|██████▍   | 16/25 [00:23<00:13,  1.47s/it]

{'training loss': 0.20189030587673187, 'validation loss': 0.2833680734038353, 'f1': 0.6033915885194657}
Early stopping triggered after 17 epochs
Loading best model from epoch 12





## Evaluate

In [19]:
def get_predictions_df(split: str = 'val') -> pd.DataFrame:
    preds_df = df.loc[df['split'] == split, [args.id_col, args.text_col, args.mention_col, args.span_col, *args.label_cols]].copy()
    
    inputs = trainer.model._normalize_inputs(texts=datasets[split]['text'], spans=datasets[split]['span']) if args.use_span_embeddings else datasets[split]['text']
    
    probs = trainer.model.predict_proba(inputs, as_numpy=True)
    prob_cols = [f"prob_{col}" for col in args.label_cols]
    preds_df[prob_cols] = probs
    
    preds = np.where(probs > 0.5, 1, 0)
    pred_cols = [f"pred_{col}" for col in args.label_cols]
    preds_df[pred_cols] = preds
    
    for lab in args.label_cols:
        preds_df[f"error_{lab}"] = preds_df[f"pred_{lab}"] != preds_df[lab]
    
    return preds_df

### Validation set

In [20]:
inputs = trainer.model._normalize_inputs(texts=datasets['val']['text'], spans=datasets['val']['span']) if args.use_span_embeddings else datasets['val']['text']
preds = trainer.model.predict(inputs, as_numpy=True)

In [21]:
if args.do_eval:
    print(classification_report(y_pred=preds, y_true=datasets['val']['label'], target_names=args.label_cols, zero_division=0))

                                        precision    recall  f1-score   support

                      noneconomic__age       1.00      0.71      0.83         7
                    noneconomic__crime       0.00      0.00      0.00         3
                noneconomic__ethnicity       0.50      1.00      0.67         3
                   noneconomic__family       1.00      1.00      1.00         5
         noneconomic__gender_sexuality       1.00      0.70      0.82        10
                   noneconomic__health       1.00      0.67      0.80         3
              noneconomic__nationality       0.40      0.50      0.44         4
           noneconomic__place_location       0.00      0.00      0.00         1
                 noneconomic__religion       0.50      0.75      0.60         4
noneconomic__shared_values_mentalities       0.90      0.69      0.78        13

                             micro avg       0.78      0.68      0.73        53
                             macro avg

In [22]:
if args.save_eval_results:
    res = classification_report(y_pred=preds, y_true=datasets['val']['label'], target_names=args.label_cols, zero_division=0, output_dict=True)
    fp = args.save_eval_results_to / 'eval_results.json'
    with open(fp, 'w') as f:
        json.dump(res, f)

In [24]:
if args.save_eval_predictions:
    preds_df = get_predictions_df(split='val')
    fp = args.save_eval_results_to / 'eval_predictions.pkl'
    preds_df.to_pickle(fp)

### Test set

In [25]:
inputs = trainer.model._normalize_inputs(texts=datasets['test']['text'], spans=datasets['test']['span']) if args.use_span_embeddings else datasets['test']['text']
preds = trainer.model.predict(inputs, as_numpy=True)

In [26]:
if args.do_test:
    print(classification_report(y_pred=preds, y_true=datasets['test']['label'], target_names=args.label_cols, zero_division=0))
print(classification_report(y_pred=preds, y_true=datasets['test']['label'], target_names=args.label_cols, zero_division=0))

                                        precision    recall  f1-score   support

                      noneconomic__age       0.92      0.61      0.73        18
                    noneconomic__crime       1.00      0.33      0.50         9
                noneconomic__ethnicity       1.00      0.75      0.86         4
                   noneconomic__family       1.00      0.27      0.42        15
         noneconomic__gender_sexuality       1.00      0.95      0.98        21
                   noneconomic__health       1.00      0.38      0.55         8
              noneconomic__nationality       0.80      0.73      0.76        11
           noneconomic__place_location       0.00      0.00      0.00         7
                 noneconomic__religion       0.67      0.50      0.57         4
noneconomic__shared_values_mentalities       0.89      0.50      0.64        16

                             micro avg       0.93      0.55      0.69       113
                             macro avg

In [27]:
if args.save_test_results:
    res = classification_report(y_pred=preds, y_true=datasets['test']['label'], target_names=args.label_cols, zero_division=0, output_dict=True)
    fp = args.save_eval_results_to / 'test_results.json'
    with open(fp, 'w') as f:
        json.dump(res, f)

In [28]:
if args.save_test_predictions:
    preds_df = get_predictions_df(split='test')
    fp = args.save_eval_results_to / 'test_predictions.pkl'
    preds_df.to_pickle(fp)

In [29]:
# highlight = lambda text, mention: text.replace(mention, '\u001B[30m\u001B[43m'+mention+'\033[0m')

# for labs, subdf in df_test.groupby(args.label_cols):
#     print("\033[1mtrue\033[0m:", [id2label[i] for i, l in enumerate(labs) if l==1])
#     subdf = subdf[subdf[error_cols].any(axis=1)]
#     for preds, subsubdf in subdf.groupby(pred_cols):
#         print(" ↳ \033[1m\033[3mpred\033[0m:", [id2label[i] for i, l in enumerate(preds) if l==1])
#         for i, row in subsubdf.sample(n=min(4, len(subsubdf)), random_state=42).iterrows():
#             print(f"    - {str(i).rjust(3)}: {highlight(row['text'], row['mention'])}")
#     print()

## Save the model

In [None]:
if args.save_model:
    trainer.model.save_pretrained(model_dir)

: 