In [1]:

from datasets import Dataset
from torch.utils.data import DataLoader
import pandas as pd
import faiss 
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
import pickle
from transformers import pipeline

import networkx as nx
from retriever import Retriever
from reranker import BGEReranker
from utils import k_hop_neighbors, get_pref_label, precision_at_k, recall_at_k, f1_at_k, strip_uri, get_label_mapping

from statistics import mean, stdev, median

In [10]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
# Read in text data
data_path = "data/title/validate.tsv.gz"
data = pd.read_csv(data_path, sep="\t", compression="gzip", header=0, names=["title", "label-idn"])
data["label-idn"] = data["label-idn"].apply(strip_uri)
gnd = pickle.load(open("data/gnd.pickle", "rb"))

In [12]:
retriever_model_str = 'BAAI/bge-m3'

In [13]:
data.head()

Unnamed: 0,title,label-idn
0,Theoretica chimica acta a journal for structur...,"[040674886, 04185098X]"
1,Adressbuch deutscher Chemiker,"[004713532, 040098362, 040118827, 040118894, 0..."
2,Österreichischer Amtskalender zusammengestell...,"[040052982, 040432718, 040678709, 041133935, 0..."
3,Die Angestellten-Versicherung Zeitschrift der ...,"[040674886, 041424298]"
4,Die Arbeiten des Statistischen Bundesamtes im ...,"[004017331, 040064328, 041289463, 042571316]"


In [14]:
retriever = Retriever(retriever_model_str, device=DEVICE)

In [15]:
label_strings, label_mapping = get_label_mapping(gnd)

In [16]:
sim, idns = retriever.retrieve(
    mapping=label_mapping,
    labels=label_strings,
    texts=data["title"].tolist(),
    top_k=10,
    batch_size=512)

Batches:   0%|          | 0/677 [00:00<?, ?it/s]

Batches:   0%|          | 0/144 [00:00<?, ?it/s]

In [17]:
recall_dict = {}
precision_dict = {}
f1_dict = {}

for k in range(1, 6):
    recall_dict[k] = []
    precision_dict[k] = []
    f1_dict[k] = []
    for preds_i, golds_i in zip(idns, data["label-idn"]):
        recall_dict[k].append(recall_at_k(
            y_pred=preds_i, y_true=golds_i, k=k))
        precision_dict[k].append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=k))
        f1_dict[k].append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=k))

# All predictions
rec_all = []
prec_all = []
f1_all = []
for preds_i, golds_i in zip(idns, data["label-idn"]):
    rec_all.append(recall_at_k(y_pred=preds_i, y_true=golds_i, k=10))
    prec_all.append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=10))
    f1_all.append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=10))

print(f"Recall@10: {mean(rec_all)}")
print(f"Precision@10: {mean(prec_all)}")
print(f"F1@10: {mean(f1_all)}")
print("=====================================")
for k in range(1, 6):
    print(f"Recall@{k}: {mean(recall_dict[k])}")
    print(f"Precision@{k}: {mean(precision_dict[k])}")
    print(f"F1@{k}: {mean(f1_dict[k])}")
    print("-----------------")

Recall@10: 0.2081938444038028
Precision@10: 0.0498029160244957
F1@10: 0.07494055280636612
Recall@1: 0.11850893374622679
Precision@1: 0.2382056492859968
F1@1: 0.14476086593513401
-----------------
Recall@2: 0.14821605207537591
Precision@2: 0.15806271225739577
F1@2: 0.13787387268395257
-----------------
Recall@3: 0.16440495290042154
Precision@3: 0.12057356665166373
F1@3: 0.1254536650707299
-----------------
Recall@4: 0.1754055007272013
Precision@4: 0.0986988365907882
F1@4: 0.1144016151782439
-----------------
Recall@5: 0.18305099985785436
Precision@5: 0.0836699900435085
F1@5: 0.10458279044492849
-----------------


In [18]:
idn_plus_neighbors = retriever.get_neighbors(idns, graph=gnd, k=2, relation="broader")

rec_all = []
prec_all = []
f1_all = []
for preds_i, golds_i in zip(idn_plus_neighbors, data["label-idn"]):
    rec_all.append(recall_at_k(y_pred=preds_i, y_true=golds_i, k=len(preds_i)))
    prec_all.append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=len(preds_i)))
    f1_all.append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=len(preds_i)))

print(f"Recall@10: {mean(rec_all)}")
print(f"Precision@10: {mean(prec_all)}")
print(f"F1@10: {mean(f1_all)}")

Recall@10: 0.24150659657546164
Precision@10: 0.027874730570504783
F1@10: 0.04775497096106295


In [19]:
reranker_str = 'BAAI/bge-reranker-v2-m3'
reranker = BGEReranker(reranker_str, device=DEVICE)


In [20]:

pair_dict = {
    "pair": [],
    "label-idn": [],
    "title-idx": []
}

c = 0
for idx, (row, idn_i_list) in tqdm(enumerate(zip(data.itertuples(), idn_plus_neighbors)), total=len(data)):
    title_i = row.title
    for idn_i in idn_i_list:
        idn_i_str = get_pref_label(gnd, idn_i)
        pair_dict["pair"].append((title_i, idn_i_str))
        pair_dict["label-idn"].append(idn_i)
        pair_dict["title-idx"].append(idx)
        c += 1

100%|██████████| 73319/73319 [00:03<00:00, 18664.76it/s]


In [21]:
ds = Dataset.from_dict(pair_dict)

In [22]:
def tokenize(example):
    pair = example["pair"]
    return reranker.tokenizer(pair, padding=True, truncation=True, return_tensors='pt', max_length=64)

In [23]:
ds = ds.map(tokenize, batched=True, batch_size=2000)
ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label-idn', 'title-idx'])

Map:   0%|          | 0/1785612 [00:00<?, ? examples/s]

In [26]:
dataloader = DataLoader(ds, batch_size=1000, shuffle=False)

In [27]:
sim = {
    "title-idx": [],
    "label-idn": [],
    "score": []
}

for batch in tqdm(dataloader):
    scores = reranker.similarities(
        batch["input_ids"].to(DEVICE),
        batch["attention_mask"].to(DEVICE)
    )
    sim["title-idx"].extend(batch["title-idx"])
    sim["label-idn"].extend(batch["label-idn"])
    sim["score"].extend(scores.tolist())

  0%|          | 0/1786 [00:00<?, ?it/s]

  2%|▏         | 31/1786 [02:30<2:21:53,  4.85s/it]


KeyboardInterrupt: 

In [1]:
df = pd.DataFrame(sim)
df["title-idx"] = df["title-idx"].astype(int)

NameError: name 'pd' is not defined

In [20]:
recall_dict = {}
precision_dict = {}
f1_dict = {}

for idx in tqdm(set(df["title-idx"])):
    df_i = df[df["title-idx"] == idx]
    df_i = df_i.sort_values(by="score", ascending=False)
    pred = df_i["label-idn"].tolist()
    gold = data["label-idn"].iloc[idx]
    for k in range(1, 6):
        if k not in recall_dict:
            recall_dict[k] = []
            precision_dict[k] = []
            f1_dict[k] = []
        recall_dict[k].append(recall_at_k(y_pred=pred, y_true=gold, k=k))
        precision_dict[k].append(precision_at_k(y_pred=pred, y_true=gold, k=k))
        f1_dict[k].append(f1_at_k(y_pred=pred, y_true=gold, k=k))

for k in range(1, 6):
    print(f"Recall@{k}: {mean(recall_dict[k])}")
    print(f"Precision@{k}: {mean(precision_dict[k])}")
    print(f"F1@{k}: {mean(f1_dict[k])}")
    print("-----------------")

100%|██████████| 73319/73319 [03:09<00:00, 387.56it/s]


Recall@1: 0.13912855387171288
Precision@1: 0.3079420068467928
F1@1: 0.17500305148211978
-----------------
Recall@2: 0.1798534548279318
Precision@2: 0.21038202921480106
F1@2: 0.17472219635348804
-----------------
Recall@3: 0.20199366039881259
Precision@3: 0.16202257714007737
F1@3: 0.16226111691061812
-----------------
Recall@4: 0.21664021564085587
Precision@4: 0.13277595166327963
F1@4: 0.14925031066155456
-----------------
Recall@5: 0.22731729348026544
Precision@5: 0.11271293934723606
F1@5: 0.13741702805661196
-----------------


In [82]:
df.to_feather("data/title/validate_reranked_3hop.feather")