In [1]:
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from gnd_dataset import GNDDataset, Dataset
import yaml
import pickle
import faiss
from utils import get_label_mapping, get_pref_label


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = SentenceTransformer("retriever/partial/checkpoint-17218") # "sentence-transformers/all-MiniLM-L6-v2"

In [3]:
model.push_to_hub("gnd_retriever_100k")

model.safetensors:   0%|          | 0.00/2.27G [00:00<?, ?B/s]
model.safetensors:   0%|          | 1.15M/2.27G [00:00<03:20, 11.3MB/s]
model.safetensors:   0%|          | 8.60M/2.27G [00:00<00:47, 47.5MB/s]
[A
model.safetensors:   1%|          | 16.0M/2.27G [00:00<01:11, 31.6MB/s]
model.safetensors:   1%|          | 27.8M/2.27G [00:00<00:41, 53.7MB/s]
model.safetensors:   2%|▏         | 34.8M/2.27G [00:00<00:55, 40.4MB/s]
model.safetensors:   2%|▏         | 42.6M/2.27G [00:00<00:46, 47.8MB/s]
[A
[A
tokenizer.json: 100%|██████████| 17.1M/17.1M [00:02<00:00, 7.60MB/s]/s]
model.safetensors: 100%|██████████| 2.27G/2.27G [01:06<00:00, 33.9MB/s]
Upload 2 LFS files: 100%|██████████| 2/2 [01:07<00:00, 33.64s/it]


'https://huggingface.co/KatjaK/gnd_retriever_100k/commit/6a2e4d9d43cdd161e12889710fba17e9c3dfeb32'

In [3]:
gnd_path = "data/gnd.pickle"
config_path = "configs/config_pt_baseline.yaml"
# Load config 
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
gnd_graph = pickle.load(open(gnd_path, "rb"))

In [4]:
ds = GNDDataset(
    data_dir="datasets/no_context/",
    gnd_graph=gnd_graph,
    config=config,
    load_from_disk=True
    
)

In [6]:

strings, mapping = get_label_mapping(gnd_graph)
index_fs = faiss.IndexHNSWFlat(model.get_sentence_embedding_dimension(), 200)
label_embeddings = model.encode(strings, show_progress_bar=True, batch_size=1024)
index_fs.add(label_embeddings)

Batches: 100%|██████████| 339/339 [03:26<00:00,  1.64it/s]


In [105]:
example = ds["validate"][1300]
enc_example = model.encode([example["title"]], show_progress_bar=False, batch_size=1024)
ee = model.encode([example["title"]], show_progress_bar=False, batch_size=1024)
dist, labels =index_fs.search(enc_example, 3)
d, ls = index_fs.search(ee, 3)
example["title"], [get_pref_label(gnd_graph, mapping[i]) for i in labels[0]],example["label-names"]

('Das Ende des Ersten Weltkriegs und die Dolchstoßlegende',
 ['Dolchstoßlegende', 'Kriegsende', 'Zwölfte Isonzoschlacht'],
 ['Erster Weltkrieg', 'Kriegsende', 'Dolchstoßlegende'])

In [108]:
texts = ds["test"]["title"]
text_embeddings = model.encode(texts, show_progress_bar=True, batch_size=1024)

Batches:   0%|          | 0/9 [00:00<?, ?it/s]

Batches: 100%|██████████| 9/9 [00:11<00:00,  1.30s/it]


In [109]:
similarity, indices = index_fs.search(text_embeddings, 10)

In [110]:
label_idn = [list(map(lambda idx: mapping[idx], top_indices)) for top_indices in indices]

In [116]:
import pandas as pd
import os

In [113]:
test_ds = ds["test"]

In [114]:
pred_df = pd.DataFrame(
    {
        "predictions": label_idn,
        "label-ids": test_ds["label-ids"],
        "label-names": test_ds["label-names"],
        "title": test_ds["title"],
    }
)

In [118]:
os.mkdir("results/retrieval-ft")

In [119]:
pred_df.to_csv(os.path.join("results/retrieval-ft", "predictions.csv"), index=False)