In [1]:
import os, ast, torch, pandas as pd, numpy as np, matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score,
                             confusion_matrix, ConfusionMatrixDisplay,
                             roc_curve, auc)
from transformers import (BertTokenizerFast, BertForSequenceClassification,
                          Trainer, TrainingArguments)

In [2]:
tqdm.pandas()                       # krajšie progress bary
os.environ["CUDA_VISIBLE_DEVICES"] = ""   # Force CPU (zmaž, ak chceš GPU)
device = torch.device("cpu")

CFG = dict(
    name        = "bert_xml_tags",
    epochs      = 7,
    batch_size  = 16,
    lr          = 3e-5,
    data_dir    = "/home/matus/NLPD_18/part1/outputs"   # ← uprav len túto cestu
)


In [3]:
df_train  = pd.read_csv(f"{CFG['data_dir']}/output_train.csv")
df_valid  = pd.read_csv(f"{CFG['data_dir']}/output_valid.csv")
df_test   = pd.read_csv(f"{CFG['data_dir']}/output_test.csv")


In [4]:
df_train_full = pd.concat([df_train, df_valid], ignore_index=True)\
                  .sample(frac=1, random_state=42).reset_index(drop=True)


In [5]:
def merge_adjacent_entities(entities):
    if not entities: return []
    entities = sorted(entities, key=lambda x: x["start"])
    merged = [entities[0]]
    for cur in entities[1:]:
        last = merged[-1]
        if cur["entity"] == last["entity"] and cur["start"] <= last["end"] + 1:
            last["end"] = cur["end"]
        else:
            merged.append(cur)
    return merged

def insert_xml(text, entities):
    if not entities: return text
    if isinstance(entities, str):
        try: entities = ast.literal_eval(entities)
        except Exception: return text
    merged = merge_adjacent_entities(entities)
    merged.sort(key=lambda x: x["start"])
    offset = 0
    for ent in merged:
        tag = ent["entity"].lower()
        open_tag, close_tag = f"<{tag}>", f"</{tag}>"
        s, e = ent["start"] + offset, ent["end"] + offset
        text = text[:s] + open_tag + text[s:e] + close_tag + text[e:]
        offset += len(open_tag) + len(close_tag)
    return text

for df in (df_train_full, df_valid, df_test):
    df["xml_stmt"] = df.progress_apply(
        lambda r: insert_xml(r["statement"], r["A_raw_entities"]), axis=1
    )

  0%|          | 0/20666 [00:00<?, ?it/s]

  0%|          | 0/2297 [00:00<?, ?it/s]

  0%|          | 0/2296 [00:00<?, ?it/s]

In [6]:
df

Unnamed: 0,statement,label,label_binary,A_raw_entities,B_raw_entities,xml_stmt
0,Three doctors from the same hospital 'die sudd...,1,0,"[{'entity': 'MISC', 'score': 0.99962676, 'inde...","[{'word': 'Three', 'entity': 'CARDINAL'}, {'wo...",Three doctors from the same hospital 'die sudd...
1,Say Joe Biden is a pedophile.,0,0,"[{'entity': 'PER', 'score': 0.9993856, 'index'...","[{'word': 'Joe Biden', 'entity': 'PERSON'}]",Say <per>Joe Biden</per> is a pedophile.
2,A photo shows President Joe Biden and Ukrainia...,1,0,"[{'entity': 'PER', 'score': 0.9996147, 'index'...","[{'word': 'Joe Biden', 'entity': 'PERSON'}, {'...",A photo shows President <per>Joe Biden</per> a...
3,"It will cost $50,000 per enrollee in Obamacare...",1,0,"[{'entity': 'MISC', 'score': 0.99520916, 'inde...","[{'word': '50,000', 'entity': 'MONEY'}, {'word...","It will cost $50,000 per enrollee in <misc>Oba..."
4,The Federal Register - which houses all Washin...,3,1,"[{'entity': 'ORG', 'score': 0.6887246, 'index'...","[{'word': 'The Federal Register - which', 'ent...",The <org>Federal Register</org> - which houses...
...,...,...,...,...,...,...
2291,"Says ""Rosie O'Donnell apparently committed the...",1,0,"[{'entity': 'PER', 'score': 0.99957937, 'index...","[{'word': ""Rosie O'Donnell"", 'entity': 'PERSON...","Says ""<per>Rosie O'Donnell</per> apparently co..."
2292,"An image shows ""Ukrainian soldiers praying.",1,0,"[{'entity': 'MISC', 'score': 0.999556, 'index'...","[{'word': 'Ukrainian', 'entity': 'NORP'}]","An image shows ""<misc>Ukrainian</misc> soldier..."
2293,Since 1938 the minimum wage has been increased...,4,1,"[{'entity': 'MISC', 'score': 0.9998487, 'index...","[{'word': '1938', 'entity': 'DATE'}, {'word': ...",Since 1938 the minimum wage has been increased...
2294,Says Wisconsin Supreme Court Justice David Pro...,2,0,"[{'entity': 'ORG', 'score': 0.7034567, 'index'...","[{'word': 'Wisconsin Supreme Court', 'entity':...",Says <org>Wisconsin Supreme Court</org> Justic...


In [7]:
tok = BertTokenizerFast.from_pretrained("bert-base-uncased")
special = ['<per>', '</per>', '<org>', '</org>', '<loc>', '</loc>', '<misc>', '</misc>']
tok.add_special_tokens({'additional_special_tokens': special})

def encode(texts): return tok(texts, truncation=True, padding=True)

enc_train = encode(df_train_full["xml_stmt"].tolist())
enc_valid = encode(df_valid["xml_stmt"].tolist())
enc_test  = encode(df_test ["xml_stmt"].tolist())


In [8]:
class ClsDataset(torch.utils.data.Dataset):
    def __init__(self, enc, labels):
        self.enc, self.labels = enc, labels
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.enc.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item
    def __len__(self): return len(self.labels)

ds_train = ClsDataset(enc_train, df_train_full["label_binary"].tolist())
ds_valid = ClsDataset(enc_valid, df_valid["label_binary"].tolist())
ds_test  = ClsDataset(enc_test , df_test ["label_binary"].tolist())


In [9]:
def metrics(pred):
    lab, pr = pred.label_ids, pred.predictions.argmax(-1)
    return dict(
        accuracy  = accuracy_score (lab, pr),
        f1        = f1_score       (lab, pr, zero_division=0),
        precision = precision_score(lab, pr, zero_division=0),
        recall    = recall_score   (lab, pr, zero_division=0)
    )
