In [1]:
import os 
import pickle
from statistics import mean
from datasets import Dataset

import pandas as pd
from safetensors import safe_open
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import logging

from yaml import safe_load

from prompt_str import SUFFIX_PROMPT, PREFIX_PROMPT
from retriever import Retriever
from reranker import BGEReranker
from utils import init_prompt_model, get_label_mapping, recall_at_k, precision_at_k, f1_at_k, get_pref_label, inference_tokenize, load_model


logging.set_verbosity_error()


In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
config_path = "config.yaml"
checkpoint_path = "hyperparams/best_model/model.safetensors"
test_data = "data/title_test.feather"

In [6]:
test_df = pd.read_feather(test_data)
config = safe_load(open(config_path, "r"))

In [7]:
model, tokenizer = load_model(checkpoint_path, config=config, device=DEVICE, data_parallel=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
test_ds = Dataset.from_pandas(test_df)
test_ds = test_ds.map(
    lambda x: inference_tokenize(x, tokenizer, max_length=75, suffix=SUFFIX_PROMPT, prefix=PREFIX_PROMPT),
)

Map:   0%|          | 0/8414 [00:00<?, ? examples/s]

In [11]:
model.to(DEVICE)
model.eval()
#model.pad_token_id = tokenizer.pad_token_id
predictions = []
for title_batch in tqdm(test_ds):
    input_ids = torch.tensor(title_batch["input_ids"]).to(DEVICE).unsqueeze(0)
    attention_mask = torch.tensor(title_batch["attention_mask"]).to(DEVICE).unsqueeze(0)
    seq_lengths = torch.tensor(title_batch["seq_lengths"]).to(DEVICE).unsqueeze(0)
    with torch.no_grad():
        out = model.module.generate(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            seq_lengths=seq_lengths)
    len_input = len(input_ids[0])
    out = out[0][len_input:]  # Remove the input part of the output
    out_str = tokenizer.decode(out, skip_special_tokens=True)
    out_str = [x.strip() for x in out_str.split(";") if x != ""]
    #out_str = process_output(out_str)
    predictions.append(list(out_str))

  0%|          | 0/8414 [00:00<?, ?it/s]

100%|██████████| 8414/8414 [1:12:58<00:00,  1.92it/s]


In [12]:
n_pred = len(predictions[:10])
for pred, title in zip(predictions, test_ds["title"][:n_pred]):
    print(f"Title: {title}")
    print(f"Predictions: {pred}")
    print("---"*30)

Title: Die Landesministerkonferenzen und der Bund kooperativer Föderalismus im Schatten der Politikverflechtung
Predictions: ['Deutschland', 'Länder', 'Kooperatives Föderalismusmodell']
------------------------------------------------------------------------------------------
Title: Die Geburt der Philosophie im Garten der Lüste Michel Foucaults Archäologie des platonischen Eros
Predictions: ['Philosophie', 'Plato', 'Foucault, Michel (-)']
------------------------------------------------------------------------------------------
Title: Das Geldwäscherisiko verschiedener Glücksspielarten
Predictions: ['Glücksspiel', 'Glücksspielrisiko']
------------------------------------------------------------------------------------------
Title: Entwicklung von großvolumigen CdTe- und (Cd,Zn)Te-Detektorsystemen
Predictions: ['ZnTe', 'Schichtwechsel', 'Halbleiter', 'Dosismessung', 'Cadmium']
------------------------------------------------------------------------------------------
Title: Integri

In [13]:
test_df["predictions"] = predictions    
test_df.to_feather("data/title_best_predictions.feather")

In [14]:
def filter_predictions(pred_list):
    pred_list = [x for x in pred_list if len(x) > 0]
    return list(set(pred_list))

In [15]:
test_df["predictions"] = test_df["predictions"].apply(filter_predictions)

In [16]:
test_df.sample(10)[["title", "predictions", "label_list"]]

Unnamed: 0,title,predictions,label_list
6054,Alles Schicksal? Wie wir uns aus Familienmuste...,"[Sozialpsychologie, Familie, Soz, Identitätsfi...","[Prägung, Familienbeziehung, Identitätsentwick..."
8090,Vom Domänenamt Schöneck zur Domäne Pogutken...,"[Schönken (Ostpreußen, Pogutken (Ostpreußen)]","[Domänenamt, Skarszewy, Generalpächter]"
65,In der Erinnerung ankern Die Trauer von Kinder...,"[Kind, Trauerarbeit, Memory-Book, Jugend]","[Jugend, Kind, Begleitung (Psychologie), Traue..."
7546,Lösungen zum Lehrbuch Corporate Finance Theor...,[Corporate Finance],[Finanzmanagement]
6922,Damit aus meiner Trauer Liebe wird Neue Wege i...,[Trauerarbeit],[Trauerarbeit]
1457,„Beruf und Berufung“ Die evangelische Geistlic...,"[Evangelische Kirche, Geistliche, Geistliches ...","[Evangelische Kirche, Pommern, Konfessionalisi..."
8232,75 Coachingkarten Achtsamkeits- und Weisheitsg...,[Coaching],"[Erzählung, Coaching]"
5580,Bondifaktoren Ein natürlicher Zugang zur spez...,"[Bondifaktor, Relativitätstheorie]",[Spezielle Relativitätstheorie]
6991,Roloff/Matek Maschinenelemente Aufgabensammlun...,[Maschinenelement],[Maschinenelement]
1672,Der Unmittelbarkeitsgrundsatz im Zivilprozess ...,"[Unmittelbarkeit, Zivilprozess, Deutschland]","[Deutschland, Zivilprozess, Unmittelbarkeitsgr..."


In [17]:
gnd_path = "data/gnd.pickle"
gnd = pickle.load(open(gnd_path, "rb"))

In [18]:
retriever_model_str = 'BAAI/bge-m3'
retriever = Retriever(retriever_model_str, device="cuda:1")

In [19]:
label_strings, label_mapping = get_label_mapping(gnd)

In [20]:
index = retriever.fit(labels=label_strings, batch_size=512)

Batches:   0%|          | 0/677 [00:00<?, ?it/s]

In [21]:
mapped_labels = []
for row in tqdm(test_df.iterrows(), total=len(test_df)):
    pred_list = row[1]["predictions"]
    current_mapping = []
    for pred in pred_list:
        distance, idns = retriever.retrieve(
            mapping=label_mapping,
            labels=label_strings,
            index=index,
            texts=[pred],
            top_k=1)
        idn_sim = zip(idns[0], distance[0])
        current_mapping.extend(idn_sim)
    current_mapping = sorted(current_mapping, key=lambda x: x[1])
    current_mapping = [x[0] for x in current_mapping]
    current_mapping = list(set(current_mapping))
    mapped_labels.append(current_mapping)

100%|██████████| 8414/8414 [06:03<00:00, 23.16it/s]


In [22]:
test_df["pred_idns"] = mapped_labels

In [23]:
rec_all = []
prec_all = []
f1_all = []
k = 5
for preds_i, golds_i in test_df[["pred_idns", "label-idn"]].itertuples(index=False):
    rec_all.append(recall_at_k(y_pred=preds_i, y_true=golds_i, k=k))
    prec_all.append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=k))
    f1_all.append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=k))
print(f"Recall: {mean(rec_all)}\nPrecision: {mean(prec_all)}\nF1: {mean(f1_all)}")

Recall: 0.33643151410562866
Precision: 0.19229855003565488
F1: 0.22324013955450137


In [24]:
reranker_str = 'BAAI/bge-reranker-v2-m3'
reranker = BGEReranker(reranker_str, device="cuda:1")

In [25]:

pair_dict = {
    "pair": [],
    "label-idn": [],
    "title-idx": []
}

c = 0
for idx, (row) in tqdm(enumerate(test_df.itertuples()), total=len(test_df)):
    idn_i_list = row.pred_idns
    title_i = row.title
    for idn_i in idn_i_list:
        idn_i_str = get_pref_label(gnd, idn_i)
        pair_dict["pair"].append((title_i, idn_i_str))
        pair_dict["label-idn"].append(idn_i)
        pair_dict["title-idx"].append(idx)
        c += 1

100%|██████████| 8414/8414 [00:00<00:00, 108120.56it/s]


In [26]:
def tokenize(example):
    pair = example["pair"]
    return reranker.tokenizer(pair, padding=True, truncation=True, return_tensors='pt', max_length=64)

In [27]:
ds = Dataset.from_dict(pair_dict)
ds = ds.map(tokenize, batched=True, batch_size=2000)
ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label-idn', 'title-idx'])
dataloader = DataLoader(ds, batch_size=100, shuffle=False)

Map:   0%|          | 0/22780 [00:00<?, ? examples/s]

In [28]:
sim = {
    "title-idx": [],
    "label-idn": [],
    "score": []
}
reranker.model.to(DEVICE)
for batch in tqdm(dataloader):
    scores = reranker.similarities(
        batch["input_ids"].to(DEVICE),
        batch["attention_mask"].to(DEVICE)
    )
    sim["title-idx"].extend(batch["title-idx"])
    sim["label-idn"].extend(batch["label-idn"])
    sim["score"].extend(scores.tolist())

100%|██████████| 228/228 [00:55<00:00,  4.09it/s]


In [29]:
df_sim = pd.DataFrame(sim)
df_sim["title-idx"] = df_sim["title-idx"].astype(int)

In [30]:
recall_dict = {}
precision_dict = {}
f1_dict = {}
ks = [1, 2,  3, 5]

for idx in tqdm(set(df_sim["title-idx"])):
    df_i = df_sim[df_sim["title-idx"] == idx]
    df_i = df_i.sort_values(by="score", ascending=False)
    pred = df_i["label-idn"].tolist()
    gold = test_df["label-idn"].iloc[idx]
    for k in ks:
        if k not in recall_dict:
            recall_dict[k] = []
            precision_dict[k] = []
            f1_dict[k] = []
        recall_dict[k].append(recall_at_k(y_pred=pred, y_true=gold, k=k))
        precision_dict[k].append(precision_at_k(y_pred=pred, y_true=gold, k=k))
        f1_dict[k].append(f1_at_k(y_pred=pred, y_true=gold, k=k))

for k in ks:
    print(f"Recall@{k}: {mean(recall_dict[k])}")
    print(f"Precision@{k}: {mean(precision_dict[k])}")
    print(f"F1@{k}: {mean(f1_dict[k])}")
    print("-----------------")

100%|██████████| 8414/8414 [00:03<00:00, 2431.69it/s]

Recall@1: 0.21862605311915984
Precision@1: 0.5177085809365344
F1@1: 0.2804121353384486
-----------------
Recall@2: 0.29315750178300687
Precision@2: 0.3876277632517233
F1@2: 0.3021702355431854
-----------------
Recall@3: 0.32199141118501706
Precision@3: 0.2975596228508042
F1@3: 0.2799625741604447
-----------------
Recall@5: 0.33643151410562866
Precision@5: 0.19229855003565488
F1@5: 0.22324013955450137
-----------------



