In [1]:
import re
import pickle
from statistics import mean

from datasets import Dataset
import pandas as pd
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from retriever import Retriever
from reranker import BGEReranker
from utils import get_label_mapping, get_pref_label, recall_at_k, precision_at_k, f1_at_k

In [2]:
transformers.logging.set_verbosity_error()

In [3]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [38]:
INSTRUCT_PROMPT = "Du bist ein hilfreicher Assistent für Bibliothekare. Gib mir einige Schlagworte für diesen Buchtitel: {}. Schlagworte: "

In [13]:
model_name = "meta-llama/Llama-3.2-3B-Instruct"
data_path = "data/title_test.feather"
retriever_model_str = 'BAAI/bge-m3'

In [24]:
data_df = pd.read_feather(data_path)

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=DEVICE, torch_dtype="auto")

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:  22%|##2       | 1.10G/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [52]:
predictions = []
count = 0
for index, row in tqdm(data_df.iterrows(), total=len(data_df)):
    title = row["title"]
    text = INSTRUCT_PROMPT.format(title)
    tokenized_text = tokenizer(text, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(**tokenized_text, max_new_tokens=20)

    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    new_tokens = output_text[len(text):]
    predictions.append(new_tokens.strip())

100%|██████████| 8414/8414 [1:23:29<00:00,  1.68it/s]


In [53]:
data_df["predictions"] = predictions
data_df.to_feather("data/title_test_manual_prompt_predictions.feather")

In [4]:
df = pd.read_feather("data/title_test_manual_prompt_predictions.feather")

In [5]:
df.head()

Unnamed: 0,title,label-idn,label_list,predictions
0,Die Landesministerkonferenzen und der Bund koo...,"[040118827, 040320553, 040340139, 041303059, 0...","[Deutschland, Konferenz, Länder, Minister, Koo...","1. Kooperativer Föderalismus, 2. Politikverfle"
1,Die Geburt der Philosophie im Garten der Lüst...,"[041359380, 041373073, 118530941]","[Plato : Symposium, Foucault, Michel : L' usag...","1. Philosophie, 2. Geschichte der Philosophie,..."
2,Das Geldwäscherisiko verschiedener Glücksspi...,"[040118827, 040213293, 040278336, 040763080, 0...","[Deutschland, Glücksspiel, Italien, Prävention...",1. 1. 2. 3. 4. 5. 6.
3,"Entwicklung von großvolumigen CdTe- und (Cd,Zn...","[04124298X, 041471172, 041690257, 041716140, 0...","[Teilchendetektor, Cadmiumtellurid, Masse (Phy...",1. Cadmiumtellurid (CdTe)-Detektor 2. Cadmiumzink
4,Integrierte bioinformatische Methoden zur repr...,"[042002303, 042900913, 956645356, 960068600]","[Genanalyse, Computational chemistry, Open Sou...","1. Bioinformatik, 2. Hochdurchsatz-Analyse,"


In [6]:
def extract_keywords(text):
    # remove 
    pattern = r"\d[.)]"
    text = re.sub(pattern, "", text)
    text = text.split(",")
    text = [x.strip() for x in text]
    if len(text) == 1:
        text = text[0].split(" ")
    return text

In [7]:
df["pred_list"] = df["predictions"].apply(lambda x: extract_keywords(x))

In [12]:
df.sample(10)[["predictions", "pred_list"]]

Unnamed: 0,predictions,pred_list
4424,"1. Bayerisch-ligistische Armee, 2. Kriegskommi...","[Bayerisch-ligistische Armee, Kriegskommissariat]"
1465,"1. Quantenmechanik, 2. Quantenphysik, 3. Welt","[Quantenmechanik, Quantenphysik, Welt]"
3674,0 Kommentare\n\nEin Schlagwort für diesen Buch...,[0 Kommentare\n\nEin Schlagwort für diesen Buc...
2946,"1. Organisation, 2. Hochschulen, 3. Stand-by-M...","[Organisation, Hochschulen, Stand-by-Modus, ]"
1757,1. Schallmessung 2. Akustik 3. Körpersch,"[Schallmessung, , Akustik, , Körpersch]"
1689,"1. Isolierte Extremitätenperfusion, 2. Sarkom,","[Isolierte Extremitätenperfusion, Sarkom, ]"
4950,1. Geschlechterdifferenzierung 2. Geschlechter...,"[Geschlechterdifferenzierung, , Geschlechterdi..."
7294,1. Jenseitsnarrative 2. Literatur 3. Kunst 4.,"[Jenseitsnarrative, , Literatur, , Kunst]"
7160,1. Bildung 2. Handlungspraxis 3. Literaturunte...,"[Bildung, , Handlungspraxis, , Literaturunterr..."
5083,"1. Bildung, 2. Politik, 3. Kultur, 4.","[Bildung, Politik, Kultur, ]"


In [9]:
gnd_path = "data/gnd.pickle"
gnd = pickle.load(open(gnd_path, "rb"))

In [14]:
retriever = Retriever(retriever_model_str, device=DEVICE)

In [15]:
label_strings, label_mapping = get_label_mapping(gnd)

In [16]:
index = retriever.fit(labels=label_strings, batch_size=512)

Batches:   0%|          | 0/677 [00:00<?, ?it/s]

In [17]:
mapped_labels = []
for row in tqdm(df.iterrows(), total=len(df)):
    pred_list = row[1]["pred_list"]
    current_mapping = []
    for pred in pred_list:
        distance, idns = retriever.retrieve(
            mapping=label_mapping,
            labels=label_strings,
            index=index,
            texts=[pred],
            top_k=1)
        idn_sim = zip(idns[0], distance[0])
        current_mapping.extend(idn_sim)
    current_mapping = sorted(current_mapping, key=lambda x: x[1])
    current_mapping = [x[0] for x in current_mapping]
    current_mapping = list(set(current_mapping))
    mapped_labels.append(current_mapping)


 49%|████▉     | 4129/8414 [05:06<05:18, 13.45it/s]


KeyboardInterrupt: 

In [None]:
df["pred_idns"] = mapped_labels

In [None]:
rec_all = []
prec_all = []
f1_all = []
k = 10
for preds_i, golds_i in df[["pred_idns", "label-idn"]].itertuples(index=False):
    rec_all.append(recall_at_k(y_pred=preds_i, y_true=golds_i, k=k))
    prec_all.append(precision_at_k(y_pred=preds_i, y_true=golds_i, k=k))
    f1_all.append(f1_at_k(y_pred=preds_i, y_true=golds_i, k=k))
print(f"Recall: {mean(rec_all)}")
print(f"Precision: {mean(prec_all)}")
print(f"F1: {mean(f1_all)}")

Recall: 0.22855588164544527
Precision: 0.06264559068219634
F1: 0.09172787547280405


In [None]:
reranker_str = 'BAAI/bge-reranker-v2-m3'
reranker = BGEReranker(reranker_str, device=DEVICE)

In [None]:
df.head()

Unnamed: 0,title,label-idn,label_list,predictions,pred_list,pred_idns
0,Die Landesministerkonferenzen und der Bund koo...,"[040118827, 040320553, 040340139, 041303059, 0...","[Deutschland, Konferenz, Länder, Minister, Koo...","1. Kooperativer Föderalismus, 2. Politikverfle","[Kooperativer Föderalismus, Politikverfle]","[04196201X, 041652436]"
1,Die Geburt der Philosophie im Garten der Lüst...,"[041359380, 041373073, 118530941]","[Plato : Symposium, Foucault, Michel : L' usag...","1. Philosophie, 2. Geschichte der Philosophie,...","[Philosophie, Geschichte der Philosophie, Plat...","[040457915, 040463036, 943591538]"
2,Das Geldwäscherisiko verschiedener Glücksspi...,"[040118827, 040213293, 040278336, 040763080, 0...","[Deutschland, Glücksspiel, Italien, Prävention...",1. 1. 2. 3. 4. 5. 6.,[],[943591538]
3,"Entwicklung von großvolumigen CdTe- und (Cd,Zn...","[04124298X, 041471172, 041690257, 041716140, 0...","[Teilchendetektor, Cadmiumtellurid, Masse (Phy...",1. Cadmiumtellurid (CdTe)-Detektor 2. Cadmiumzink,"[Cadmiumtellurid, (CdTe)-Detektor, , Cadmiumzink]","[043030580, 041471172, 040092747, 943591538]"
4,Integrierte bioinformatische Methoden zur repr...,"[042002303, 042900913, 956645356, 960068600]","[Genanalyse, Computational chemistry, Open Sou...","1. Bioinformatik, 2. Hochdurchsatz-Analyse,","[Bioinformatik, Hochdurchsatz-Analyse, ]","[1021247227, 960068600, 943591538]"


In [None]:

pair_dict = {
    "pair": [],
    "label-idn": [],
    "title-idx": []
}

c = 0
for idx, (row) in tqdm(enumerate(df.itertuples()), total=len(df)):
    idn_i_list = row.pred_idns
    title_i = row.title
    for idn_i in idn_i_list:
        idn_i_str = get_pref_label(gnd, idn_i)
        pair_dict["pair"].append((title_i, idn_i_str))
        pair_dict["label-idn"].append(idn_i)
        pair_dict["title-idx"].append(idx)
        c += 1

  0%|          | 0/8414 [00:00<?, ?it/s]

100%|██████████| 8414/8414 [00:00<00:00, 65140.56it/s]


In [None]:
def tokenize(example):
    pair = example["pair"]
    return reranker.tokenizer(pair, padding=True, truncation=True, return_tensors='pt', max_length=64)

In [None]:
ds = Dataset.from_dict(pair_dict)
ds = ds.map(tokenize, batched=True, batch_size=2000)
ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label-idn', 'title-idx'])
dataloader = DataLoader(ds, batch_size=100, shuffle=False)

Map:   0%|          | 0/35401 [00:00<?, ? examples/s]

In [None]:
sim = {
    "title-idx": [],
    "label-idn": [],
    "score": []
}

for batch in tqdm(dataloader):
    scores = reranker.similarities(
        batch["input_ids"].to(DEVICE),
        batch["attention_mask"].to(DEVICE)
    )
    sim["title-idx"].extend(batch["title-idx"])
    sim["label-idn"].extend(batch["label-idn"])
    sim["score"].extend(scores.tolist())

100%|██████████| 355/355 [01:27<00:00,  4.07it/s]


In [None]:
df_sim = pd.DataFrame(sim)
df_sim["title-idx"] = df_sim["title-idx"].astype(int)

In [None]:
recall_dict = {}
precision_dict = {}
f1_dict = {}
ks = [1, 3, 5]

for idx in tqdm(set(df_sim["title-idx"])):
    df_i = df_sim[df_sim["title-idx"] == idx]
    df_i = df_i.sort_values(by="score", ascending=False)
    pred = df_i["label-idn"].tolist()
    gold = df["label-idn"].iloc[idx]
    for k in ks:
        if k not in recall_dict:
            recall_dict[k] = []
            precision_dict[k] = []
            f1_dict[k] = []
        recall_dict[k].append(recall_at_k(y_pred=pred, y_true=gold, k=k))
        precision_dict[k].append(precision_at_k(y_pred=pred, y_true=gold, k=k))
        f1_dict[k].append(f1_at_k(y_pred=pred, y_true=gold, k=k))

for k in ks:
    print(f"Recall@{k}: {mean(recall_dict[k])}")
    print(f"Precision@{k}: {mean(precision_dict[k])}")
    print(f"F1@{k}: {mean(f1_dict[k])}")
    print("-----------------")

  0%|          | 0/8414 [00:00<?, ?it/s]

100%|██████████| 8414/8414 [00:03<00:00, 2122.22it/s]

Recall@1: 0.14716058762535922
Precision@1: 0.36938435940099834
F1@1: 0.1923714085349455
-----------------
Recall@3: 0.2224457335071298
Precision@3: 0.20172727993027492
F1@3: 0.19122816328436515
-----------------
Recall@5: 0.2283835498175394
Precision@5: 0.12517233182790588
F1@5: 0.14747908923301956
-----------------



