%run ../ipynb_util_tars.py

In [1]:
import os

from qdrant_client import QdrantClient, models

client = QdrantClient(
    os.getenv("QDRANT_API_URL"),
    port=os.getenv("QDRANT_API_PORT"),
    api_key=os.getenv("QDRANT_API_KEY"),
)



In [2]:
# scroll over points that HAVE labels ("labels" field is not null)
scroll_result = client.scroll(
    collection_name="publications",
    limit=500,
    with_payload=["id", "labels", "xai"],
    scroll_filter=models.Filter(
        must_not=[
            models.IsNullCondition(
                is_null=models.PayloadField(key="labels")
            )
        ]
    )
)

for point in scroll_result[0]:
    if len(point.payload["xai"]) != 1:
        print(point.payload["id"])

    shap_values = point.payload["xai"][0]["xai_values"]["token_scores"]
    # check if each value is in the range [-1, 1]
    # the shap_values are of size [len(tokens), 17]
    for token_scores in shap_values:
        for score in token_scores:
            if score < -0.1 or score > 0.1:
                print(point.payload["id"])
                break

print(scroll_result[0][0].payload["xai"])

oai:www.zora.uzh.ch:125412
oai:www.zora.uzh.ch:125412
oai:www.zora.uzh.ch:125975
oai:www.zora.uzh.ch:127527
oai:www.zora.uzh.ch:143565
oai:www.zora.uzh.ch:148179
oai:www.zora.uzh.ch:148179
oai:www.zora.uzh.ch:168757
oai:www.zora.uzh.ch:168757
[{'model_family': 'scibert', 'model_path': '/srv/scratch2/dbielik//.cache/huggingface/checkpoints/allenai/scibert_scivocab_uncased-ft-zo_up-lower/checkpoint-240/', 'predicted_label': 10, 'probs': [0.013929353095591068, 0.013398760929703712, 0.0121486634016037, 0.004457520321011543, 0.005304357502609491, 0.015620237216353416, 0.042211271822452545, 0.08767326176166534, 0.02332335151731968, 0.42830902338027954, 0.01496716309338808, 0.11171144992113112, 0.03795639052987099, 0.01296242605894804, 0.03673660382628441, 0.05917535722255707, 0.08011487126350403], 'xai_method': 'shap', 'xai_values': {'input_tokens': ['', 'Rein', 'sur', 'ance ', 'or ', 'Sec', 'uri', 'tization', ': ', 'The ', 'Case ', 'of ', 'Natural ', 'Catast', 'roph', 'e ', 'Risk ', 'We ', 

In [6]:
xai_out_dict = []

for point in scroll_result[0]:
    xai_out_dict.append(
        {
            "id": point.payload["id"],
            "xai": point.payload["xai"],
        }
    )

In [9]:
# save xai_out_dict on disk as json
# import json

# with open("xai_out_dict.json", "w") as f:
#    json.dump(xai_out_dict, f)


In [3]:
points_with_labels_count = client.count(
    collection_name="publications",
    count_filter=models.Filter(
        must_not=[
            models.IsNullCondition(
                is_null=models.PayloadField(key="labels")
            )
        ]
    ),
    exact=True
)
print(points_with_labels_count)

sdgs = [f"sdg{i}" for i in range(1, 18)]
sdg_counts = {}

for sdg in sdgs:
    result = client.count(
        collection_name="publications",
        count_filter=models.Filter(
            must_not=[
                models.IsEmptyCondition(
                    is_empty=models.PayloadField(key=f"labels.{sdg}")
                )
            ]
        ),
        exact=True
    )
    sdg_counts[sdg] = result.count

print(sdg_counts)
print(sum(sdg_counts.values()))

count=384
{'sdg1': 6, 'sdg2': 13, 'sdg3': 43, 'sdg4': 2, 'sdg5': 22, 'sdg6': 3, 'sdg7': 17, 'sdg8': 26, 'sdg9': 13, 'sdg10': 31, 'sdg11': 3, 'sdg12': 30, 'sdg13': 41, 'sdg14': 14, 'sdg15': 79, 'sdg16': 34, 'sdg17': 7}
384


In [4]:
assert sum(sdg_counts.values()) == points_with_labels_count.count

In [5]:
client.count(
    collection_name="publications",
    count_filter=models.Filter(
        must_not=[
            models.IsEmptyCondition(
                is_empty=models.PayloadField(key=f"xai")
            )
        ]
    ),
    exact=True
)

CountResult(count=384)