In [None]:
import wandb
from transformers import AutoModelForTokenClassification, AutoTokenizer
import conllu
from pathlib import Path
import re
import pandas as pd
import datasets
import ast
import torch
from tqdm import tqdm
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    f1_score,
    confusion_matrix,
)
from sklearn.preprocessing import LabelEncoder
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import json
import shutil
run = wandb.init()


In [None]:
porttinari_test_path = Path("../data/UD_Portuguese-Porttinari/pt_porttinari-ud-test.conllu")
dante_test_path = Path("../data/UD_Portuguese-DANTE/pt_dante-ud-test.conllu")
petrogold_test_path = Path("../data/UD_Portuguese-PetroGold/pt_petrogold-ud-test.conllu")


In [None]:
porttinari_models = [
    "model-2s49nst7:v0", # 42
    "model-tc1zzw4k:v0", # 43
    "model-bm6gp43m:v0", # 44
    "model-1o0gqxr3:v0", # 45
    "model-31361v0k:v0", # 46
    "model-3rjlz0lt:v0", # 47
    "model-2bw1espl:v0", # 48
    "model-v70g08fj:v0", # 49
    "model-9ntjk1ug:v0", # 50
    "model-kymqot4e:v0", # 51
]
dante_models = [
    "model-o0ojeh3x:v0", # 42
    "model-hfcre0cn:v0", # 43
    "model-sm3hnlkk:v0", # 44
    "model-tv4yte30:v0", # 45
    "model-66omv4vd:v0", # 46
    "model-2l1q1co8:v0", # 47
    "model-19fciu72:v0", # 48
    "model-2akpbr05:v0", # 49
    "model-1b7kowrx:v0", # 50
    "model-2vl5c8as:v0", # 51
]
petrogold_models = [
    "model-1q90myqb:v0", # 42
    "model-2v9qux4u:v0", # 43
    "model-97qzwo1n:v0", # 44
    "model-1yl1sxe4:v0", # 45
    "model-2w6nmfoo:v0", # 46
    "model-19jxdvv7:v0", # 47
    "model-2k0a4cta:v0", # 48
    "model-2gzdimee:v0", # 49
    "model-23r6lkqz:v0", # 50
    "model-1hbxjamt:v0", # 51
]
porttinari_dante_models = [
    "model-16140tgi:v0", # 42
    "model-1vj1vfst:v0", # 43
    "model-1nw1r5gm:v0", # 44
    "model-3ipj5wvo:v0", # 45
    "model-2u66l64n:v0", # 46
    "model-2t876ez4:v0", # 47
    "model-3ftgmuns:v0", # 48
    "model-26vtwt5o:v0", # 49
    "model-13sziv1x:v0", # 50
    "model-3s1i1gig:v0", # 51
]
porttinari_petrogold_models = [
    "model-3t2o4ugz:v0", # 42
    "model-1jmh5yzi:v0", # 43
    "model-25j07376:v0", # 44
    "model-3ub7r5rd:v0", # 45
    "model-vzsfwq56:v0", # 46
    "model-29r6wb4o:v0", # 47
    "model-39kwlzfy:v0", # 48
    "model-37a43ty6:v0", # 49
    "model-cbw7mob2:v0", # 50
    "model-1m7fe4k1:v0", # 51
]
dante_petrogold_models = [
    "model-3lesa1b2:v0", # 42
    "model-vlmx70y6:v0", # 43
    "model-2kyu5jev:v0", # 44
    "model-216u2b9d:v0", # 45
    "model-bnm5t8fw:v0", # 46
    "model-25rt674h:v0", # 47
    "model-1vwxg7v2:v0", # 48
    "model-3h6mbbp2:v0", # 49
    "model-suxb490l:v0", # 50
    "model-3evjf8e7:v0", # 51
]
porttinari_dante_petrogold_models = [
    "model-2pf394eb:v0", # 42
    "model-2c5f8s18:v0", # 43
    "model-2oxpwyg9:v0", # 44
    "model-a4v60z6o:v0", # 45
    "model-35wuwdty:v0", # 46
    "model-3nr9mytc:v0", # 47
    "model-3dnmzhw0:v0", # 48
    "model-azc3woza:v0", # 49
    "model-n7rgstlo:v0", # 50
    "model-1p59bqxh:v0", # 51
]

In [None]:
for model_name in porttinari_dante_petrogold_models:
    artifact = run.use_artifact(f"huber-ai/pos_porttinari_dante_petrogold/{model_name}", type="model")
    artifact_dir = artifact.download()

    model = AutoModelForTokenClassification.from_pretrained(artifact_dir)
    tokenizer = AutoTokenizer.from_pretrained(artifact_dir)

    for test_file in [porttinari_test_path, dante_test_path, petrogold_test_path]:
        out_path = f"../tmp/{model_name.replace(':', '-')}"
        Path(out_path).mkdir(exist_ok=True)
        
        parent_dir = test_file.parent
        data = {
            "id": [],
            "tokens": [],
            "tags": [],
        }

        with open(str(test_file), "r", encoding="utf-8") as in_f:
            set_name = re.findall(r"-(train|dev|test)\.conllu", str(test_file))[0]
            dataset_name = test_file.stem
            gold_sents = conllu.parse(in_f.read())
            
            for sent in gold_sents:
                token_list = []
                tag_list = []
                for token in sent:
                    if isinstance(token["id"], int):
                        token_list.append(token["form"])
                        tag_list.append(token["upos"])

                data["id"].append(sent.metadata["sent_id"])
                data["tokens"].append(token_list)
                data["tags"].append(tag_list)

            df = pd.DataFrame(data)
            processed_filename = parent_dir.joinpath(f"{set_name}.csv")
            df.to_csv(processed_filename, index=False)
        
        dataset = datasets.load_dataset("csv", data_files={set_name: str(processed_filename)})

        def str_to_list(example):
            example["tokens"] = ast.literal_eval(example["tokens"])
            example["tags"] = ast.literal_eval(example["tags"])
            return example
        
        for set_name, dataset_sub in dataset.items():
            dataset[set_name] = dataset_sub.map(str_to_list)
        
        def tokenize_and_align_labels(examples, label_all_tokens = False):
            for i, sample in enumerate(examples["tokens"]):
                examples["tokens"][i] = [tk.replace("", "*").replace("", "*") for tk in sample]
            # print(examples["tokens"], examples["tokens"][0])
            # tokens = [tk.replace("", "*") for tk in examples["tokens"]]
            tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

            labels = []
            for i, label in enumerate(examples["tags"]):
                word_ids = tokenized_inputs.word_ids(batch_index=i)
                previous_word_idx = None
                label_ids = []
                for word_idx in word_ids:
                    # Special tokens have a word id that is None. We set the label to -100 so they are automatically
                    # ignored in the loss function.
                    if word_idx is None:
                        label_ids.append(-100)
                    # We set the label for the first token of each word.
                    elif word_idx != previous_word_idx:
                        label_ids.append(model.config.label2id.get(label[word_idx], model.config.label2id["X"]))
                    # For the other tokens in a word, we set the label to either the current label or -100, depending on
                    # the label_all_tokens flag.
                    else:
                        label_ids.append(label[word_idx] if label_all_tokens else -100)
                    previous_word_idx = word_idx

                labels.append(label_ids)

            tokenized_inputs["labels"] = labels
            return tokenized_inputs

        tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)

        outputs = {}
        for sample in tqdm(tokenized_dataset["test"], desc="Running model..."):

            output = model(torch.tensor([sample["input_ids"]]))
            outputs[sample["id"]] = output

        
        for sample in tqdm(tokenized_dataset[set_name], desc="Processing model outputs..."):
            output = outputs[sample["id"]]

            i_token = 0
            labels = []
            scores = []
            # print(sample["labels"], sample.keys())
            assert len(sample["labels"]) == len(output.logits[0]), "Sentence {} contains {} true labels and {} predictions".format(sample["id"], len(sample["labels"]), len(output.logits[0]))
            for original_label, pred in zip(
                sample["labels"],
                output.logits[0],
            ):
                if original_label == -100:
                    continue
                label = model.config.__dict__["id2label"][int(pred.argmax(axis=-1))]
                labels.append(label)
                scores.append(
                    "{:.2f}".format(100 * float(torch.softmax(pred, dim=-1).detach().max()))
                )
                i_token += 1
            assert i_token == len(sample["tokens"]), "Sentence {} produced {} tokens, but it has {} tokens".format(sample["id"], i_token, len(sample["tokens"]))
            output["predicted_labels"] = labels
            output["predicted_scores"] = scores
        
        pred_sents = conllu.parse(open(str(test_file), "r", encoding="utf-8").read())
        for sent in tqdm(pred_sents, desc=f"Creating output file for {dataset_name}"):
            output = outputs[sent.metadata["sent_id"]]
            i = 0
            for token in sent:
                if isinstance(token["id"], int):
                    # print(output["predicted_labels"], sent.metadata["sent_id"], token["form"], i)
                    token["upos"] = output["predicted_labels"][i]
                    i += 1
        pred_path = test_file.parent.joinpath(f"{dataset_name}_pred.conllu")
        with open(pred_path, "w", encoding="utf-8") as out_f:
            out_f.writelines([sentence.serialize() + "\n" for sentence in pred_sents])
        
        def get_tags(sents, min_tokens = 0, max_tokens = 60):
            true_tags = []
            for sent in sents:
                tags = []
                n_tokens = 0
                for token in sent:
                    if isinstance(token["id"], int):
                        tags.append(token["upos"])
                        n_tokens += 1
                if n_tokens >= min_tokens and n_tokens <= max_tokens:
                    true_tags += tags
            return true_tags

        true_tags = get_tags(gold_sents)
        pred_tags = get_tags(pred_sents)

        for i in range(len(pred_tags)):
            if pred_tags[i] == "_":
                pred_tags[i] = "X"

        cls_report = classification_report(true_tags, pred_tags)
        cls_report_path = test_file.parent.joinpath(f"{dataset_name}_cls_report.txt")
        with open(cls_report_path, "w") as out_f:
            out_f.writelines(cls_report)

        acc = accuracy_score(true_tags, pred_tags)
        f1_macro = f1_score(true_tags, pred_tags, average="macro")
        f1_weighted = f1_score(true_tags, pred_tags, average="weighted")

        json_path = test_file.parent.joinpath(f"{dataset_name}_results.json")
        with open(json_path, "w") as out_f:
            json.dump({
                f"{dataset_name}_acc": acc,
                f"{dataset_name}_f1_macro": f1_macro,
                f"{dataset_name}_f1_weighted": f1_macro,
            }, out_f)

        lbl = LabelEncoder()
        lbl.fit(true_tags)
        true_list = list(map(lbl.transform, [true_tags]))
        pred_list = list(map(lbl.transform, [pred_tags]))

        cm = confusion_matrix(true_list[0], pred_list[0])

        df_cm = pd.DataFrame(
            cm, index=lbl.classes_, columns=lbl.classes_,
        )

        plt.figure(figsize=(17,17))
        matplotlib.rcParams['figure.dpi'] = 300
        heatmap = sns.heatmap(df_cm, annot=True, cmap="gray", fmt="d", annot_kws={"size": 14})
        heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=45, ha='right', fontsize=14)
        heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=14)
        cf_matrix_path = test_file.parent.joinpath(f"{dataset_name}_cf_matrix.png")
        plt.savefig(cf_matrix_path)

        for artifact_file in [cf_matrix_path, json_path, cls_report_path, pred_path, test_file]:
            shutil.copy(str(artifact_file), out_path)
