In [None]:
import json
import itertools
from tokenizers import Encoding
from typing import List
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback
import torch
from collections import defaultdict
from datasets import Dataset
import pandas as pd
import numpy as np
import evaluate
from sklearn.metrics import f1_score
from collections import Counter
from seqeval.metrics import classification_report
import re
import unicodedata

In [None]:
def align_tokens_and_annotations_bio(tokenized: Encoding, annotations):
    tokens = tokenized.tokens
    aligned_labels = ["O"] * len(
        tokens
    )  # Make a list to store our labels the same length as our tokens
    for anno in annotations:
        annotation_token_ix_set = (
            set()
        )  # A set that stores the token indices of the annotation
        for char_ix in range(anno["start"], anno["end"]):
            print('char_ix = ', char_ix)
            token_ix = tokenized.char_to_token(char_ix)
            if token_ix is not None:
                annotation_token_ix_set.add(token_ix)
        if len(annotation_token_ix_set) == 1:
            # If there is only one token
            token_ix = annotation_token_ix_set.pop()
            prefix = (
                "B"  # This annotation spans one token so is prefixed with U for unique
            )
            aligned_labels[token_ix] = f"{prefix}-{anno['tag']}"

        else:

            last_token_in_anno_ix = len(annotation_token_ix_set) - 1
            for num, token_ix in enumerate(sorted(annotation_token_ix_set)):
                if num == 0:
                    prefix = "B"
                elif num == last_token_in_anno_ix:
                    prefix = "I"  # Its the last token
                else:
                    prefix = "I"  # We're inside of a multi token annotation
                aligned_labels[token_ix] = f"{prefix}-{anno['tag']}"
    return aligned_labels

class LabelSet:
    def __init__(self, labels: List[str]):
        self.labels_to_id = {}
        self.ids_to_label = {}
        self.labels_to_id["O"] = 0
        self.ids_to_label[0] = "O"
        num = 0  # in case there are no labels
        # Writing BILU will give us incremental ids for the labels
        for _num, (label, s) in enumerate(itertools.product(labels, "BI")):
            num = _num + 1  # skip 0
            l = f"{s}-{label}"
            self.labels_to_id[l] = num
            self.ids_to_label[num] = l


    def get_aligned_label_ids_from_annotations(self, tokenized_text, annotations):
        raw_labels = align_tokens_and_annotations_bio(tokenized_text, annotations)
        return list(map(self.labels_to_id.get, raw_labels))


def tokenize_token_classification(examples, tokenizer):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True, padding='longest', return_tensors='pt')

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        word_ids = [tokenized_inputs.token_to_word(i, j) for j in range(len(tokenized_inputs['input_ids'][i]))]  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = torch.tensor(labels)
    return tokenized_inputs

def dict_of_lists(lst_of_dicts):
    result = defaultdict(list)
    for d in lst_of_dicts:
        for key, value in d.items():
            result[key].append(value)
    return dict(result)

def list_of_dicts(dict_of_lists):
    # First, we need to check if all lists are of the same length to ensure correct transformation
    if not all(len(lst) == len(next(iter(dict_of_lists.values()))) for lst in dict_of_lists.values()):
        raise ValueError("All lists in the dictionary must have the same length")

    # Get the length of the items in any of the lists
    length = len(next(iter(dict_of_lists.values())))
    
    # Create a list of dictionaries, one for each index in the lists
    result = []
    for i in range(length):
        # Create a dictionary for the current index 'i' across all lists
        new_dict = {key: dict_of_lists[key][i] for key in dict_of_lists}
        result.append(new_dict)
    
    return result

def sub_shift_spans(text, ents = [], mappings = []):
    for mapping in mappings:
        adjustment = 0
        pattern = re.compile(mapping['pattern'])
        for match in re.finditer(pattern, text):
            match_index = match.start() + adjustment
            match_contents = match.group()
            if all(mapping['check'](char) for char in match_contents):
                subbed_text = mapping['target'].replace('placeholder', match_contents)
            else:
                subbed_text = mapping['target']
            len_diff = len(subbed_text) - len(match_contents)
            text = text[:match_index] + subbed_text + text[match_index + len(match_contents):]
            if ents:
                if isinstance(ents, list):
                    for ent in ents:
                        if ent['start'] <= match_index and ent['end'] > match_index:
                            ent['end'] += len_diff
                        if ent['start'] > match_index:
                            ent['start'] += len_diff
                            ent['end'] += len_diff
                elif isinstance(ents, dict):
                    if ents['value']['start'] <= match_index and ents['value']['end'] > match_index:
                        ents['value']['end'] += len_diff
                    if ents['value']['start'] > match_index:
                        ents['value']['start'] += len_diff
                        ents['value']['end'] += len_diff

            adjustment += len_diff

    return text, ents

def span_to_words_annotation(samples, target_tag = '', mappings = {}, labels_model = []):
    samples_new = []
    # if not any([l for l in samples['annotations']]):
        
    for i in range(len(samples['data'])):
        text, annotation_list = samples['data'][i]['text'], samples['annotations'][i][0]['result']
        labels_text = []
        tokens = []
        if not annotation_list:
            annotation_list = [[]]
        for j, annotation in enumerate(annotation_list):
            if isinstance(annotation, dict):
                if annotation['value']['labels'][0] != target_tag:
                    continue
            text_subshifted, ents = sub_shift_spans(text, annotation, mappings=mappings)
            text_subshifted_matches = re.finditer(r'[^\s]+', text_subshifted)
            labels_words = []
            first = True
            for regex_match in text_subshifted_matches:
                if j == 0:
                    tokens.append(regex_match.group())
                if isinstance(annotation, dict):
                    if regex_match.start() < ents['value']['start']:
                        labels_words.append(labels_model.labels_to_id['O'])
                    elif regex_match.start() >= ents['value']['start'] and regex_match.end() <= ents['value']['end']:
                        if first:
                            labels_words.append(labels_model.labels_to_id['B-' + ents['value']['labels'][0]])
                            first = False
                        elif not first:
                            labels_words.append(labels_model.labels_to_id['I-' + ents['value']['labels'][0]])
                    else:
                        labels_words.append(labels_model.labels_to_id['O'])
                    labels_text.append({'labels': labels_words, 'tag': annotation['value']['labels'][0]})
        allowed_labels = [labels_model.labels_to_id['O'],
                          labels_model.labels_to_id['B-' + target_tag],
                          labels_model.labels_to_id['I-' + target_tag],
                          ]
        # if the training sample has no tags that we need, we just produce a 0s list
        if target_tag not in [labels['tag'] for labels in labels_text]:
            labels = [0] * len(tokens)
            tag = 'no_tag'
        # if the training sample has tags we need, we first exclude the label lists whose tags don't match
        # and then we merge the label lists that have tags that match the target tag
        else:
            labels = [max(values) for values in zip(*[labels['labels'] for labels in labels_text if labels['tag'] == target_tag])]
            labels = [(label if label in allowed_labels else 0) for label in labels]
            tag = target_tag
        samples_new.append({
            'id': i,
            'ner_tags': labels,
            'tokens': tokens,
            'tag': tag,
        })
    return samples_new

regex_tokenizer_mappings = [
    {'pattern': r'(?<!\s)([^\w\s])|([^\w\s])(?!\s)',
    'target': ' placeholder ',
    'check': lambda x: unicodedata.category(x).startswith('P'),
    },
    {'pattern': r'\s+',
     'target': ' ',
     'check': lambda x: False if re.match('\s+', x) is None else True,
     },
    ]

def compute_metrics_wrapper(label_list):
    def compute_metrics(eval_preds):
        nonlocal label_list
        logits, labels = eval_preds
        predictions = np.argmax(logits, axis=2)

        # Extract the true predictions and labels from the sequences
        true_predictions = [
            [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        true_labels = [
            [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]

        # Compute sequence-level evaluation metrics
        results = classification_report(true_predictions, true_labels, output_dict=True)

        # Flatten the lists to calculate micro F1-score and supports
        flat_true_predictions = [item for sublist in true_predictions for item in sublist]
        flat_true_labels = [item for sublist in true_labels for item in sublist]

        # Calculate micro F1-score using sklearn
        micro_f1 = f1_score(flat_true_labels, flat_true_predictions, average='micro')

        # Prepare the results dictionary
        flat_results = {'micro_f1': float(micro_f1)}

        # Add detailed metrics for each label to the results dictionary
        for label, metrics in results.items():
            if isinstance(metrics, dict):
                for metric, value in metrics.items():
                    flat_results[f'{label}_{metric}'] = float(value)

        # Compute support for each label using Counter
        label_support = Counter(flat_true_labels)
        for label, count in label_support.items():
            flat_results[f'{label}_support'] = count

        return flat_results
    return compute_metrics

In [3]:
train_data_path = '/home/luca/Desktop/checkthat24_DIT/data/formatted/train_sentences.json'
with open(train_data_path, 'r', encoding='utf8') as f:
    dataset_raw = json.load(f)

df_raw = pd.DataFrame(dataset_raw)  

df_pos = df_raw[df_raw['annotations'].apply(lambda x: len(x[0]['result']) > 0)]
df_neg = df_raw[df_raw['annotations'].apply(lambda x: x[0]['result'] == [])].sample(len(df_pos))
df = pd.concat([df_pos,df_neg])

target_tags = ["Appeal_to_Authority", "Appeal_to_Popularity","Appeal_to_Values","Appeal_to_Fear-Prejudice","Flag_Waving","Causal_Oversimplification",
               "False_Dilemma-No_Choice","Consequential_Oversimplification","Straw_Man","Red_Herring","Whataboutism","Slogans","Appeal_to_Time",
               "Conversation_Killer","Loaded_Language","Repetition","Exaggeration-Minimisation","Obfuscation-Vagueness-Confusion","Name_Calling-Labeling",
               "Doubt","Guilt_by_Association","Appeal_to_Hypocrisy","Questioning_the_Reputation"]
target_tags = [(i, el.strip()) for i, el in enumerate(target_tags)]


In [None]:
shift = 0
for i, tt in enumerate(target_tags):
    if i < shift:
        continue
    print(f'Training model no. {i} of {len(target_tags)} for {tt[1]} persuasion technique...')
    labels_model = LabelSet(labels=[tt[1]])
    
    df_list = df.to_dict(orient='records')
    df_list_binary = span_to_words_annotation(dict_of_lists(df_list), target_tag=tt[1], mappings=regex_tokenizer_mappings, labels_model=labels_model)
    df_binary = pd.DataFrame(df_list_binary)
    df_binary_pos = df_binary[df_binary['tag'] == tt[1]]
    df_binary_neg = df_binary[df_binary['tag'] != tt[1]].sample(len(df_binary_pos))
    df_binary_subsampled = pd.concat([df_binary_pos, df_binary_neg])#.sample(1000)

    binary_dataset = Dataset.from_pandas(df_binary_subsampled[['id', 'ner_tags', 'tokens']])

    split_ratio = 0.2
    split_seed = 42
    datadict = binary_dataset.train_test_split(split_ratio, seed=split_seed)

    #model_name = 'bert-base-multilingual-cased'
    #model_name = 'xlm-roberta-base'
    model_name = 'microsoft/mdeberta-v3-base'
    model_name_simple = model_name.split('/')[-1]
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    batch_size = 16
    datadict = datadict.map(lambda x: tokenize_token_classification(x, tokenizer), batched=True, batch_size=None)

    columns = [
                'input_ids',
                'token_type_ids',
                'attention_mask',
                'labels'
                ]

    datadict.set_format('torch', columns = columns)

    train_data = datadict['train']
    val_data = datadict['test']

    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding='longest')

    model = AutoModelForTokenClassification.from_pretrained(model_name,
                                                                num_labels=len(labels_model.ids_to_label.values()),
                                                                label2id=labels_model.labels_to_id,
                                                                id2label=labels_model.ids_to_label,
                                                                )
    
    training_args = TrainingArguments(#output_dir='/home/lgiordano/LUCA/checkthat_GITHUB/models/M2/mdeberta-v3-base-NEW',
        output_dir='/home/luca/Desktop',
                                  save_total_limit=2,
                                  save_strategy='epoch',
                                  load_best_model_at_end=True,
                                  save_only_model=True,
                                  metric_for_best_model='eval_macro avg_f1-score',
                                  logging_strategy='epoch',
                                  evaluation_strategy='epoch',
                                  learning_rate=5e-5,
                                  optim='adamw_torch',
                                  num_train_epochs=10)
    
    early_stopping = EarlyStoppingCallback(early_stopping_patience=2)

    trainer = Trainer(model,
                      training_args,
                      train_dataset=train_data,
                      eval_dataset=val_data,
                      data_collator=data_collator,
                      tokenizer=tokenizer,
                      compute_metrics=compute_metrics_wrapper(label_list=[i for i in labels_model.ids_to_label.values()]),
                      callbacks=[early_stopping])
    
    trainer.train()

    trainer.evaluate()