In [4]:
from sklearn.metrics import f1_score
import random
from datasets import DatasetDict
from transformers import RobertaConfig, RobertaForTokenClassification, TrainingArguments, Trainer
from data_handling import load_data_file, create_vocab, create_gloss_vocab, prepare_dataset
from uspanteko_morphology import morphology
from tokenizer import WordLevelTokenizer
from taxonomic_loss_model import TaxonomicLossModel
from eval import eval_accuracy

MODEL_INPUT_LENGTH = 64
BATCH_SIZE = 64

random.seed(1)

train_data = load_data_file('../data/usp-train-track2-uncovered')
dev_data = load_data_file('../data/usp-dev-track2-uncovered')
test_data = load_data_file('../data/usp-test-track2-uncovered')

train_vocab = create_vocab([line.morphemes() for line in train_data], threshold=1)
tokenizer = WordLevelTokenizer(vocab=train_vocab, model_max_length=MODEL_INPUT_LENGTH)

glosses = create_gloss_vocab(morphology)

dataset = DatasetDict()

dataset['train'] = prepare_dataset(data=train_data, tokenizer=tokenizer, labels=glosses, device='mps')
dataset['dev'] = prepare_dataset(data=dev_data, tokenizer=tokenizer, labels=glosses, device='mps')
dataset['test'] = prepare_dataset(data=test_data, tokenizer=tokenizer, labels=glosses, device='mps')

config = RobertaConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=MODEL_INPUT_LENGTH,
    pad_token_id=tokenizer.PAD_ID,
    position_embedding_type='absolute',
    num_labels=len(glosses)
)

flat_model = RobertaForTokenClassification.from_pretrained("../models/full-flat-1").to('mps')
tax_model = TaxonomicLossModel.from_pretrained("../models/full-tax_loss-1 (alt)", loss_sum='linear').to('mps')
tax_model.use_morphology_tree(morphology, max_depth=5)
harmonic_model = TaxonomicLossModel.from_pretrained("../models/full-harmonic_loss-1 (alt)", loss_sum='harmonic').to(
    'mps')
harmonic_model.use_morphology_tree(morphology, max_depth=5)

hierarchy_matrix = tax_model.hierarchy_matrix

Map:   0%|          | 0/9774 [00:00<?, ? examples/s]

Map:   0%|          | 0/232 [00:00<?, ? examples/s]

Map:   0%|          | 0/633 [00:00<?, ? examples/s]

In [5]:
def create_trainer(model: RobertaForTokenClassification, dataset: DatasetDict, tokenizer: WordLevelTokenizer,
                   labels, batch_size, max_epochs):
    def compute_metrics(eval_preds):
        preds, gold_labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]

        print("PREDS", preds)
        print("LABELS", gold_labels)
        if len(gold_labels.shape) > 2:
            gold_labels = gold_labels.take(axis=1, indices=0)

        print(gold_labels.shape)

        # Decode predicted output
        decoded_preds = [[labels[index] for index in pred_seq if len(labels) > index >= 0] for pred_seq in preds]

        # Decode (gold) labels
        decoded_labels = [[labels[index] for index in label_seq if len(labels) > index >= 0] for label_seq in
                          gold_labels]

        # Trim preds to the same length as the labels
        decoded_preds = [pred_seq[:len(label_seq)] for pred_seq, label_seq in zip(decoded_preds, decoded_labels)]

        print('Preds:\t', decoded_preds[0])
        print('Labels:\t', decoded_labels[0])

        accuracy = eval_accuracy(decoded_preds, decoded_labels)

        # Calculate f1 between decoded_preds and decoded_labels
        flat_true_labels = [label for sublist in decoded_labels for label in sublist]
        flat_predicted_labels = [label for sublist in decoded_preds for label in sublist]

        # Compute F1 score
        f1 = f1_score(flat_true_labels, flat_predicted_labels, average='macro')

        # Compute accuracy at the second level of hierarchy_matrix
        def compute_list_of_lists_accuracy(true_labels, predicted_labels):
            correct = 0
            total = 0

            for t_list, p_list in zip(true_labels, predicted_labels):
                # Count matches in the overlapping parts
                correct += sum(t == p for t, p in zip(t_list, p_list))
                # Total is the length of the true list (since missing predictions are errors)
                total += len(t_list)

            return correct / total if total > 0 else 0

        pred_indices = [[index for index in pred_seq if len(labels) > index >= 0] for pred_seq in preds]
        true_indices = [[index for index in label_seq if len(labels) > index >= 0] for label_seq in gold_labels]
        pred_indices = [pred_seq[:len(label_seq)] for pred_seq, label_seq in zip(pred_indices, true_indices)]
        pred_categories = [[hierarchy_matrix[2][index] for index in pred_seq] for pred_seq in pred_indices]
        true_categories = [[hierarchy_matrix[2][index] for index in gold_labels] for gold_labels in true_indices]
        flat_pred_categories = [label for sublist in pred_categories for label in sublist]
        flat_true_categories = [label for sublist in true_categories for label in sublist]

        print("PRED CATEGORIES", pred_categories[0])
        print("TRUE CATEGORIES", true_categories[0])

        # Compute accuracy between two lists
        category_accuracy = compute_list_of_lists_accuracy(true_categories, pred_categories)

        category_f1 = f1_score(flat_true_categories, flat_pred_categories, average='macro')

        return {
            "accuracy": accuracy,
            "f1": f1,
            "category_accuracy": category_accuracy,
            "category_f1": category_f1
        }

    def preprocess_logits_for_metrics(logits, labels):
        return logits.argmax(dim=2)

    args = TrainingArguments(
        output_dir=f"../finetune-training-checkpoints",
        evaluation_strategy="epoch",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=3,
        save_strategy="epoch",
        save_total_limit=3,
        num_train_epochs=max_epochs,
        load_best_model_at_end=True,
        logging_strategy='epoch',
    )

    return Trainer(
        model,
        args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )


flat_trainer = create_trainer(flat_model, dataset=dataset, tokenizer=tokenizer, labels=glosses, batch_size=BATCH_SIZE,
                              max_epochs=100)
tax_trainer = create_trainer(tax_model, dataset=dataset, tokenizer=tokenizer, labels=glosses, batch_size=BATCH_SIZE,
                             max_epochs=100)
harmonic_trainer = create_trainer(harmonic_model, dataset=dataset, tokenizer=tokenizer, labels=glosses,
                                  batch_size=BATCH_SIZE,
                                  max_epochs=100)

flat_trainer.evaluate()

PREDS [[61  1 25 ... 39 39 39]
 [25 38  1 ... 39 39 39]
 [61  1 25 ... 43 43 43]
 ...
 [54  1 26 ... 43 43 43]
 [54  1 26 ... 39 39 39]
 [55  1 60 ... 38 38 38]]
LABELS [[  61    1   25 ... -100 -100 -100]
 [  25   38    1 ... -100 -100 -100]
 [  61    1   25 ... -100 -100 -100]
 ...
 [  54    1   26 ... -100 -100 -100]
 [  54    1   26 ... -100 -100 -100]
 [  60    1   60 ... -100 -100 -100]]
(633, 64)
Preds:	 ['CONJ', '[SEP]', 'COM', 'VI', '[SEP]', 'S', '[SEP]', 'DIM', '[SEP]', 'S', '[SEP]', 'E3', 'SREL', '[SEP]', 'E3', 'S']
Labels:	 ['CONJ', '[SEP]', 'COM', 'VI', '[SEP]', 'VT', '[SEP]', 'DIM', '[SEP]', 'S', '[SEP]', 'E3S', 'SREL', '[SEP]', 'E3S', 'S']
PRED CATEGORIES [47.0, 1.0, 15.0, 24.0, 1.0, 29.0, 1.0, 34.0, 1.0, 29.0, 1.0, 4.0, 44.0, 1.0, 4.0, 29.0]
TRUE CATEGORIES [47.0, 1.0, 15.0, 24.0, 1.0, 25.0, 1.0, 34.0, 1.0, 29.0, 1.0, 3.0, 44.0, 1.0, 3.0, 29.0]


{'eval_loss': 0.11494144797325134,
 'eval_accuracy': {'average_accuracy': 0.7371333708419027,
  'accuracy': 0.7410688768219491},
 'eval_f1': 0.491565613887541,
 'eval_category_accuracy': 0.8283058243217606,
 'eval_category_f1': 0.5964630122531189,
 'eval_runtime': 2.0843,
 'eval_samples_per_second': 303.697,
 'eval_steps_per_second': 4.798}

In [3]:
tax_trainer.evaluate()


PREDS [[61  1 25 ... 39 39 39]
 [25 38  1 ... 39 39 39]
 [61  1 25 ... 39 39 39]
 ...
 [54  1 26 ... 39 39 39]
 [54  1 26 ... 39 39 39]
 [60  1 60 ... 43 43 43]]
LABELS [[61  1 25 ... 66 66 66]
 [25 38  1 ... 66 66 66]
 [61  1 25 ... 66 66 66]
 ...
 [54  1 26 ... 66 66 66]
 [54  1 26 ... 66 66 66]
 [60  1 60 ... 66 66 66]]
(633, 64)
Preds:	 ['CONJ', '[SEP]', 'COM', 'VI', '[SEP]', 'S', '[SEP]', 'DIM', '[SEP]', 'S', '[SEP]', 'E3S', 'SREL', '[SEP]', 'E3S', 'S']
Labels:	 ['CONJ', '[SEP]', 'COM', 'VI', '[SEP]', 'VT', '[SEP]', 'DIM', '[SEP]', 'S', '[SEP]', 'E3S', 'SREL', '[SEP]', 'E3S', 'S']
PRED CATEGORIES [47.0, 1.0, 15.0, 24.0, 1.0, 29.0, 1.0, 34.0, 1.0, 29.0, 1.0, 3.0, 44.0, 1.0, 3.0, 29.0]
TRUE CATEGORIES [47.0, 1.0, 15.0, 24.0, 1.0, 25.0, 1.0, 34.0, 1.0, 29.0, 1.0, 3.0, 44.0, 1.0, 3.0, 29.0]


{'eval_loss': 0.20068112015724182,
 'eval_accuracy': {'average_accuracy': 0.8610567136041077,
  'accuracy': 0.8673906830523006},
 'eval_f1': 0.6621791917978597,
 'eval_category_accuracy': 0.9182318345664959,
 'eval_category_f1': 0.7377983066869656,
 'eval_runtime': 2.8114,
 'eval_samples_per_second': 225.158,
 'eval_steps_per_second': 3.557}

In [6]:
harmonic_trainer.evaluate()

PREDS [[61  1 25 ... 43 43 43]
 [25 38  1 ... 43 43 43]
 [61  1 25 ... 43 43 43]
 ...
 [54  1 26 ... 43 43 43]
 [54  1 26 ... 43 43 43]
 [60  1 60 ... 43 43 43]]
LABELS [[61  1 25 ... 66 66 66]
 [25 38  1 ... 66 66 66]
 [61  1 25 ... 66 66 66]
 ...
 [54  1 26 ... 66 66 66]
 [54  1 26 ... 66 66 66]
 [60  1 60 ... 66 66 66]]
(633, 64)
Preds:	 ['CONJ', '[SEP]', 'COM', 'VI', '[SEP]', 'S', '[SEP]', 'DIM', '[SEP]', 'S', '[SEP]', 'E3S', 'SREL', '[SEP]', 'E3S', 'S']
Labels:	 ['CONJ', '[SEP]', 'COM', 'VI', '[SEP]', 'VT', '[SEP]', 'DIM', '[SEP]', 'S', '[SEP]', 'E3S', 'SREL', '[SEP]', 'E3S', 'S']
PRED CATEGORIES [47.0, 1.0, 15.0, 24.0, 1.0, 29.0, 1.0, 34.0, 1.0, 29.0, 1.0, 3.0, 44.0, 1.0, 3.0, 29.0]
TRUE CATEGORIES [47.0, 1.0, 15.0, 24.0, 1.0, 25.0, 1.0, 34.0, 1.0, 29.0, 1.0, 3.0, 44.0, 1.0, 3.0, 29.0]


{'eval_loss': 0.1058390811085701,
 'eval_accuracy': {'average_accuracy': 0.826985117743412,
  'accuracy': 0.852529294084024},
 'eval_f1': 0.6349280903489274,
 'eval_category_accuracy': 0.907987099222159,
 'eval_category_f1': 0.6897642733988156,
 'eval_runtime': 2.5751,
 'eval_samples_per_second': 245.817,
 'eval_steps_per_second': 3.883}

In [14]:
hierarchy_matrix

Unnamed: 0,0,1,2,3,4
0,0.0,0.0,0.0,0.0,0.0
1,1.0,1.0,1.0,1.0,1.0
2,2.0,2.0,2.0,2.0,2.0
3,2.0,2.0,2.0,2.0,3.0
4,2.0,2.0,2.0,3.0,4.0
...,...,...,...,...,...
61,7.0,19.0,46.0,53.0,61.0
62,8.0,20.0,47.0,54.0,62.0
63,8.0,21.0,48.0,55.0,63.0
64,8.0,22.0,49.0,56.0,64.0


66

In [23]:
hierarchy_matrix[2][61]

46.0

In [40]:
import numpy as np

model_parameters = filter(lambda p: p.requires_grad, flat_model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params


87860802