In [1]:

from datasets import Dataset
from torch.utils.data import DataLoader
import pandas as pd
import faiss 
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
import pickle
from transformers import pipeline

import networkx as nx
from retriever import Retriever
from reranker import BGEReranker
from utils import k_hop_neighbors, get_pref_label, precision_at_k, recall_at_k, f1_at_k, strip_uri, get_label_mapping

from statistics import mean, stdev, median

In [6]:
def apply_pref_label(label_list, graph):
    return [get_pref_label(graph, label) for label in label_list if label in graph.nodes]

data_path = "data/title/train.tsv.gz"
eval_path = "data/title/test.tsv.gz"
data_df = pd.read_csv(data_path, sep="\t", compression="gzip", header=0, names=["title", "label-idn"])
data_df["label-idn"] = data_df["label-idn"].apply(strip_uri)
eval_df = pd.read_csv(eval_path, sep="\t", compression="gzip", header=0, names=["title", "label-idn"])
eval_df["label-idn"] = eval_df["label-idn"].apply(strip_uri)

gnd = pickle.load(open("data/gnd.pickle", "rb"))
data_df["label_list"] = data_df["label-idn"].apply(lambda x: apply_pref_label(x, gnd))
eval_df["label_list"] = eval_df["label-idn"].apply(lambda x: apply_pref_label(x, gnd))

In [7]:
eval_df.to_feather("data/title_test.feather")

In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
# Read in text data
data_path = "data/title/validate.tsv.gz"
data = pd.read_csv(data_path, sep="\t", compression="gzip", header=0, names=["title", "label-idn"])
data["label-idn"] = data["label-idn"].apply(strip_uri)


In [2]:
retriever_model_str = 'BAAI/bge-m3'

In [4]:
data_path = "data/title_test_predictions.feather"
data = pd.read_feather(data_path)
gnd = pickle.load(open("data/gnd.pickle", "rb"))

In [5]:
data.head()

Unnamed: 0,title,label-idn,label_list,predictions
0,Die Landesministerkonferenzen und der Bund koo...,"[040118827, 040320553, 040340139, 041303059, 0...","[Deutschland, Konferenz, Länder, Minister, Koo...","[Föderalismus, Deutschland (Bundesrepublik), B..."
1,Die Geburt der Philosophie im Garten der Lüst...,"[041359380, 041373073, 118530941]","[Plato : Symposium, Foucault, Michel : L' usag...","[Philosophie, Michel Fouca, Geburt, Platon, Lu..."
2,Das Geldwäscherisiko verschiedener Glücksspi...,"[040118827, 040213293, 040278336, 040763080, 0...","[Deutschland, Glücksspiel, Italien, Prävention...","[Glücksspiel, Geldwäsche, Wett, Risikobereitsc..."
3,"Entwicklung von großvolumigen CdTe- und (Cd,Zn...","[04124298X, 041471172, 041690257, 041716140, 0...","[Teilchendetektor, Cadmiumtellurid, Masse (Phy...","[Kondensationswirkung, Mischk, Keramische Schu..."
4,Integrierte bioinformatische Methoden zur repr...,"[042002303, 042900913, 956645356, 960068600]","[Genanalyse, Computational chemistry, Open Sou...","[Bioinformatik, Biomarker, Biologische Prozess..."


In [6]:
retriever = Retriever(retriever_model_str, device=DEVICE)

In [7]:
label_strings, label_mapping = get_label_mapping(gnd)

In [8]:
index = retriever.fit(labels=label_strings, batch_size=512)

Batches:   0%|          | 0/677 [00:00<?, ?it/s]

In [24]:
mapped_labels = []
for data_point in tqdm(data.itertuples(), total=len(data)):
    data_point = data_point._asdict()
    predictions = data_point["predictions"]
    sim, idns = retriever.retrieve(
        mapping=label_mapping,
        index=index,
        labels=label_strings,
        texts=predictions,
        top_k=2,
        batch_size=512)
    idns = [i[0] for i in idns]
    mapped_labels.append(idns)

100%|██████████| 8414/8414 [04:07<00:00, 34.06it/s]


In [25]:
idns = mapped_labels

In [26]:
recall_dict = {}
precision_dict = {}
f1_dict = {}

for k in range(1, 6):
    recall_dict[k] = []
    precision_dict[k] = []
    f1_dict[k] = []
    for preds_i, golds_i in zip(idns, data["label-idn"]):
        recall_dict[k].append(recall_at_k(
            y_pred=preds_i, y_true=golds_i, k=k))
        precision_dict[k].append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=k))
        f1_dict[k].append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=k))

# All predictions
rec_all = []
prec_all = []
f1_all = []
for preds_i, golds_i in zip(idns, data["label-idn"]):
    rec_all.append(recall_at_k(y_pred=preds_i, y_true=golds_i, k=10))
    prec_all.append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=10))
    f1_all.append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=10))

print(f"Recall@10: {mean(rec_all)}")
print(f"Precision@10: {mean(prec_all)}")
print(f"F1@10: {mean(f1_all)}")
print("=====================================")
for k in range(1, 6):
    print(f"Recall@{k}: {mean(recall_dict[k])}")
    print(f"Precision@{k}: {mean(precision_dict[k])}")
    print(f"F1@{k}: {mean(f1_dict[k])}")
    print("-----------------")

Recall@10: 0.22009574986458752
Precision@10: 0.06360827192773949
F1@10: 0.09210844396307787
Recall@1: 0.049492124031582076
Precision@1: 0.1344188257665795
F1@1: 0.06662975655725834
-----------------
Recall@2: 0.10462712979649039
Precision@2: 0.14374851438079392
F1@2: 0.10946985692647684
-----------------
Recall@3: 0.1579886436227059
Precision@3: 0.1482053719990492
F1@3: 0.13819444523016303
-----------------
Recall@4: 0.1967831303076133
Precision@4: 0.14089612550511052
F1@4: 0.14911642443417217
-----------------
Recall@5: 0.21481585920616109
Precision@5: 0.1239838364630378
F1@5: 0.14356928216719617
-----------------


In [44]:
idn_plus_neighbors = retriever.get_neighbors(idns, graph=gnd, k=1)

rec_all = []
prec_all = []
f1_all = []
for preds_i, golds_i in zip(idn_plus_neighbors, data["label-idn"]):
    rec_all.append(recall_at_k(y_pred=preds_i, y_true=golds_i, k=len(preds_i)))
    prec_all.append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=len(preds_i)))
    f1_all.append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=len(preds_i)))

print(f"Recall@10: {mean(rec_all)}")
print(f"Precision@10: {mean(prec_all)}")
print(f"F1@10: {mean(f1_all)}")

Recall@10: 0.2607239949259203
Precision@10: 0.07572474756391588
F1@10: 0.10781349618628097


In [45]:
reranker_str = 'BAAI/bge-reranker-v2-m3'
reranker = BGEReranker(reranker_str, device=DEVICE)


In [46]:

pair_dict = {
    "pair": [],
    "label-idn": [],
    "title-idx": []
}

c = 0
for idx, (row, idn_i_list) in tqdm(enumerate(zip(data.itertuples(), idn_plus_neighbors)), total=len(data)):
    title_i = row.title
    for idn_i in idn_i_list:
        idn_i_str = get_pref_label(gnd, idn_i)
        pair_dict["pair"].append((title_i, idn_i_str))
        pair_dict["label-idn"].append(idn_i)
        pair_dict["title-idx"].append(idx)
        c += 1

  0%|          | 0/8414 [00:00<?, ?it/s]

100%|██████████| 8414/8414 [00:00<00:00, 44957.67it/s]


In [47]:
ds = Dataset.from_dict(pair_dict)

In [48]:
def tokenize(example):
    pair = example["pair"]
    return reranker.tokenizer(pair, padding=True, truncation=True, return_tensors='pt', max_length=64)

In [49]:
ds = ds.map(tokenize, batched=True, batch_size=2000)
ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label-idn', 'title-idx'])

Map:   0%|          | 0/88358 [00:00<?, ? examples/s]

In [50]:
dataloader = DataLoader(ds, batch_size=1000, shuffle=False)

In [51]:
sim = {
    "title-idx": [],
    "label-idn": [],
    "score": []
}

for batch in tqdm(dataloader):
    scores = reranker.similarities(
        batch["input_ids"].to(DEVICE),
        batch["attention_mask"].to(DEVICE)
    )
    sim["title-idx"].extend(batch["title-idx"])
    sim["label-idn"].extend(batch["label-idn"])
    sim["score"].extend(scores.tolist())

100%|██████████| 89/89 [03:19<00:00,  2.25s/it]


In [52]:
df = pd.DataFrame(sim)
df["title-idx"] = df["title-idx"].astype(int)

In [53]:
recall_dict = {}
precision_dict = {}
f1_dict = {}

for idx in tqdm(set(df["title-idx"])):
    df_i = df[df["title-idx"] == idx]
    df_i = df_i.sort_values(by="score", ascending=False)
    pred = df_i["label-idn"].tolist()
    gold = data["label-idn"].iloc[idx]
    for k in range(1, 6):
        if k not in recall_dict:
            recall_dict[k] = []
            precision_dict[k] = []
            f1_dict[k] = []
        recall_dict[k].append(recall_at_k(y_pred=pred, y_true=gold, k=k))
        precision_dict[k].append(precision_at_k(y_pred=pred, y_true=gold, k=k))
        f1_dict[k].append(f1_at_k(y_pred=pred, y_true=gold, k=k))

for k in range(1, 6):
    print(f"Recall@{k}: {mean(recall_dict[k])}")
    print(f"Precision@{k}: {mean(precision_dict[k])}")
    print(f"F1@{k}: {mean(f1_dict[k])}")
    print("-----------------")

  0%|          | 0/8414 [00:00<?, ?it/s]

100%|██████████| 8414/8414 [00:04<00:00, 2026.36it/s]

Recall@1: 0.14138843916298158
Precision@1: 0.3555978131685286
F1@1: 0.18475168098415068
-----------------
Recall@2: 0.19283142048770763
Precision@2: 0.2598050867601616
F1@2: 0.20016874813642108
-----------------
Recall@3: 0.2166139686581807
Precision@3: 0.20069725061405594
F1@3: 0.1885176168253043
-----------------
Recall@4: 0.23182684487949523
Precision@4: 0.16478488233895888
F1@4: 0.17515438516221526
-----------------
Recall@5: 0.24194759034906407
Precision@5: 0.13931542666983598
F1@5: 0.16160873484293614
-----------------





In [82]:
df.to_feather("data/title/validate_reranked_3hop.feather")