In [9]:
from retriever import Retriever
from utils import get_label_mapping, get_title_mapping
import torch
from sentence_transformers import SentenceTransformer,  losses, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from datasets import Dataset
import os
import pickle
from tqdm import tqdm
import yaml
from utils import get_pref_label, map_labels
import pandas as pd

from gnd_dataset import GNDDataset
from transformers import  pipeline, set_seed
from prompt_str import SYSTEM_PROMPT, USER_PROMPT, CONTEXT_PROMPT, FS_PROMPT

In [3]:
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"

In [4]:
retriever_model = "BAAI/bge-m3"
retriever = Retriever(
    retriever_model=retriever_model,
    device=DEVICE,
)

In [5]:
gnd_path = "data/gnd.pickle"
config_path = "configs/config_pt_baseline.yaml"
# Load config 
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
gnd_graph = pickle.load(open(gnd_path, "rb"))

In [6]:
gnd_ds = GNDDataset(
    data_dir=config["dataset_path"],
    gnd_graph=gnd_graph,
    config=config,
    load_from_disk=True,
)
train_ds = gnd_ds["train"]


In [7]:
strings, mapping = get_label_mapping(graph=gnd_graph)

In [8]:
index = retriever.fit(labels=strings, batch_size=1024)

Batches: 100%|██████████| 339/339 [03:22<00:00,  1.67it/s]


In [93]:
example = gnd_ds["validate"][400]

text_embeddings = retriever.retriever.encode([example["title"]], show_progress_bar=False, batch_size=1024)
similarity, indices = index.search(text_embeddings, 3)

In [None]:
pipe = pipeline(
        "text-generation",
        model="meta-llama/Llama-3.2-3B-Instruct",
        torch_dtype=torch.bfloat16,
        device=DEVICE,
    )

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 47.47it/s]
Device set to use cuda:1


In [None]:
fs_examples = [
    FS_PROMPT.format(
        train_ds[int(i)]["title"],
        "; ".join(train_ds[int(i)]["label-names"]),
    ) for i in indices[0]
]
system_prompt = f"{SYSTEM_PROMPT} {'\n'.join(fs_examples)}"


In [None]:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": USER_PROMPT.format(example["title"])},
]

In [None]:
outputs = pipe(messages, num_return_sequences=1, do_sample=True, temperature=0.7)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [None]:
outputs[0]["generated_text"][-1]["content"]

'Jugend; Stadt; Gewalttätigkeit; Prävention; Sozialarbeit'

In [None]:
example["title"]
# 'Synökologie; Winterweizen; Aphiden; Pilze; Pflanzenschutz'

'Jugendgewalt im städtischen Raum Strategien und Ansätze im Umgang mit Gewalt'

In [None]:
example["label-names"]

['Jugend',
 'Stadt',
 'Stadtentwicklung',
 'Öffentlicher Raum',
 'Sozialarbeit',
 'Prävention',
 'Segregation (Soziologie)',
 'Gewalttätigkeit',
 'Sozialraumanalyse']

In [12]:
pred_df = pd.read_csv("results/few-shot-baseline/predictions-test-few-shot-seed-42.csv")

In [15]:
processed_predictions = pred_df["raw_predictions"].str.split(";")

In [16]:
mapped_predictions = map_labels(
    prediction_list=processed_predictions,
    index=index,
    retriever=retriever,
    label_mapping=mapping
)

Mapping predictions to GND labels: 100%|██████████| 8414/8414 [08:20<00:00, 16.81it/s]


In [17]:
pred_df = pd.DataFrame(
    {
        "predictions": mapped_predictions,
        "raw_predictions": pred_df["raw_predictions"],
        "label-ids": pred_df["label-ids"],
        "label-names": pred_df["label-names"],
        "title": pred_df["title"],
    }
)

In [18]:
pred_df

Unnamed: 0,predictions,raw_predictions,label-ids,label-names,title
0,"[041665597, 04196201X, 041652436, 040118827]",Kooperativer Föderalismus; Politikverflechtung...,"['040118827', '040320553', '040340139', '04130...","['Deutschland', 'Länder', 'Konferenz', 'Minist...",Die Landesministerkonferenzen und der Bund koo...
1,"[11853453X, 040192865, 041528514, 040205177, 0...","Eros; Philosophie; Platon; Foucault, Michel (1...","['041359380', '041373073', '118530941']","['Plato : Symposium', ""Foucault, Michel : L' u...",Die Geburt der Philosophie im Garten der Lüst...
2,"[042395852, 040501299, 041681843, 040213293, 0...",Glücksspiele; Risiko; Geldwäsche; Deutschland...,"['040118827', '040213293', '040278336', '04076...","['Deutschland', 'Rechtsvergleich', 'Italien', ...",Das Geldwäscherisiko verschiedener Glücksspi...
3,"[947405127, 041369416, 94015739X, 040230287, 0...",CdTe; CdZnTe; Detektorsysteme; Photodetektoren...,"['04124298X', '041471172', '041690257', '04171...","['Masse (Physik)', 'Teilchendetektor', 'Cadmiu...","Entwicklung von großvolumigen CdTe- und (Cd,Zn..."
4,"[960068600, 97269093X, 1021247227, 041479289]",Bioinformatik; Hochdurchsatz-Analyse; Life Sci...,"['042002303', '042900913', '956645356', '96006...","['Bioinformatik', 'Genanalyse', 'Open Source',...",Integrierte bioinformatische Methoden zur repr...
...,...,...,...,...,...
8409,[998819808],"Humanistische Therapie, Psychotherapie, Grundl...",['998819808'],['Humanistische Therapie'],Humanistische Psychotherapie Grundlagen - Rich...
8410,"[04012259X, 040759962, 041738659]",Persönlichkeitspsychologie; Differentielle Psy...,"['04012259X', '041550463']","['Forschungsmethode', 'Differentielle Psycholo...",Differentielle Psychologie und Persönlichkeit...
8411,[1027903738],"Raspberry Pi, Linux, Smarthome, Entertainment,...","['040763706', '04160072X', '1027903738', '9409...","['Programmierung', 'Python (Programmiersprache...",Raspberry Pi – dein Einstieg Der vielseitige L...
8412,"[041143337, 040021742, 041235924]",Kunst; Betrachtung; Kunstwerke,['041253213'],['Kunstbetrachtung'],Neue Sicht auf Kunst Ein Beitrag zur Betrachtu...


In [19]:
pred_df.to_csv(os.path.join("results/few-shot-baseline/predictions-test-few-shot-seed-42.csv"))

# FEW SHOT PROMPTING

In [164]:
def retrieve_negative_keywords(title, n=5):
    text_embeddings = retriever.retriever.encode([title], show_progress_bar=False)
    _, indices = index.search(text_embeddings, n)
    label_list = [mapping[i] for i in indices[0]]
    return label_list

In [160]:
model = SentenceTransformer("BAAI/bge-m3")

In [226]:
train_ds = gnd_ds["train"].select(range(10000))


In [227]:
retriever_dict = {
    "anchor": [],
    "positive": [],
}
for i in tqdm(train_ds):
    gold_labels_ids = i["label-ids"]
    gold_labels = i["label-names"]
    title = i["title"]
    for keyword in gold_labels:
        retriever_dict["anchor"].append(title)
        retriever_dict["positive"].append(keyword)

100%|██████████| 10000/10000 [00:05<00:00, 1819.55it/s]


In [228]:
eval_ds = gnd_ds["validate"].select(range(1000))
eval_dict = {
    "anchor": [],
    "positive": [],
}
for i in tqdm(eval_ds):
    gold_labels_ids = i["label-ids"]
    gold_labels = i["label-names"]
    title = i["title"]
    for keyword in gold_labels:
        eval_dict["anchor"].append(title)
        eval_dict["positive"].append(keyword)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:00<00:00, 2161.36it/s]


In [229]:
train_dataset = Dataset.from_dict(retriever_dict)
eval_dataset = Dataset.from_dict(eval_dict) 

In [232]:
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="retriever/testing",
    # Optional training parameters:
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-5,
    eval_strategy="steps",
    eval_steps=50,
    # Optional tracking/debugging parameters:
    save_total_limit=2,
    logging_steps=50,
    run_name="testing",  # Will be used in W&B if `wandb` is installed
)

Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.


In [233]:
loss = losses.MultipleNegativesRankingLoss(model)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    eval_dataset=eval_dataset,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

Step,Training Loss,Validation Loss
50,2.7496,1.639668
100,1.5498,1.528611
150,1.4474,1.491853
200,1.4392,1.473561
250,1.4087,1.450809
300,1.4218,1.425395
350,1.3754,1.417233
400,1.3243,1.402283
450,1.2585,1.401772


KeyboardInterrupt: 

In [234]:
import faiss

In [235]:

strings, mapping = get_label_mapping(gnd_graph)
index_fs = faiss.IndexHNSWFlat(model.get_sentence_embedding_dimension(), 200)
label_embeddings = model.encode(strings, show_progress_bar=True, batch_size=1024)
index_fs.add(label_embeddings)

Batches: 100%|██████████| 339/339 [03:26<00:00,  1.64it/s]


In [273]:
example = gnd_ds["validate"][2100]
enc_example = model.encode([example["title"]], show_progress_bar=False, batch_size=1024)
ee = retriever.retriever.encode([example["title"]], show_progress_bar=False, batch_size=1024)
dist, labels =index_fs.search(enc_example, 3)
d, ls = index.search(ee, 3)
example["title"], [get_pref_label(gnd_graph, mapping[i]) for i in labels[0]],example["label-names"]

('Die Bibel als Grundlage der politischen Theorie des Johannes Althusius',
 ['Althusius, Johannes : Politica methodice digesta et exemplis sacris et profanis illustrata',
  'Althusius, Johannes (1563-1638)',
  'Politische Theologie'],
 ['Staat',
  'Theologie',
  'Bibel',
  'Politische Theorie',
  'Kirche',
  'Calvinismus',
  'Althusius, Johannes : Politica methodice digesta et exemplis sacris et profanis illustrata'])

('Muscheln in meiner Hand und andere Geschichten',
 ['Alltagsgeschichte (Fach)', 'Kurzgeschichte', 'Erzählerkommentar'])