In [2]:
from sentence_transformers import SentenceTransformer
from foe_foundry_nl.data.monsters2 import get_monsters
from foe_foundry_nl.embeddings.distance import cosine_similarity
from pathlib import Path
import torch
import json
from foe_foundry_nl.data.utils import name_to_key
from tqdm.notebook import tqdm
import numpy as np
import numpy.typing as npt
import pandas as pd

  from tqdm.autonotebook import tqdm, trange


In [3]:
## LOAD MODEL

print("Loading Fine-Tuned Model...")
fine_tuned_st_dir = Path.cwd().parent.parent / "models" / "minilm-finetuned-st"
model = SentenceTransformer(str(fine_tuned_st_dir))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

## ENCODE MONSTERS

rng = np.random.default_rng(20210518)

# generate mean embeddings for each monster
monsters = get_monsters()

keys = []
embeddings = []
is_srd = []
creature_types = []
crs = []
acs = []
hps = []

print("encoding monsters...")
with tqdm(total=len(monsters)) as pbar:
    for key, monster in monsters.items():
        paragraphs = [paragraph for _, paragraph in monster.iter_paragraphs(rng)]
        monster_embeddings = [model.encode(paragraph) for paragraph in paragraphs]

        embedding = np.mean(monster_embeddings, axis=0)
        keys.append(key)
        embeddings.append(embedding)
        is_srd.append(monster.srd)
        creature_types.append(monster.creature_type)
        crs.append(monster.cr_numeric)
        acs.append(monster.ac)
        hps.append(monster.hp)
        pbar.update(1)


keys = np.array(keys)
embeddings = np.array(embeddings, dtype=np.float32)
is_srd = np.array(is_srd, dtype=bool)
creature_types = np.array(creature_types)
crs = np.array(crs)
acs = np.array(acs)
hps = np.array(hps)
index_lookup = {k: i for i, k in enumerate(keys)}

Loading Fine-Tuned Model...
encoding monsters...


  0%|          | 0/2203 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [4]:
test_queries_dir = Path.cwd().parent.parent / "data" / "5e_test_queries"


def get_similar_monsters(
    key: str,
    n: int | None = None,
    threshold: float | None = None,
    skip_self: bool = True,
) -> npt.NDArray[np.str_]:
    """Get a set of keys of similar monsters."""
    index = index_lookup[key]
    embedding = embeddings[index]
    return get_similar(
        embedding, skip_key=key if skip_self else None, n=n, threshold=threshold
    )


def get_similar(
    embedding,
    skip_key: str | None = None,
    n: int | None = None,
    threshold: float | None = None,
) -> npt.NDArray[np.str_]:
    all_similarities = cosine_similarity(embedding[np.newaxis, :], embeddings).flatten()
    sorted_indexes = np.argsort(all_similarities)[::-1]

    sorted_keys = keys[sorted_indexes]
    sorted_similarities = all_similarities[sorted_indexes]

    condition = np.ones_like(sorted_similarities, dtype=bool)
    if skip_key is not None:
        condition = condition & (sorted_keys != skip_key)

    if threshold is not None:
        condition = condition & (sorted_similarities >= threshold)

    indexes = np.where(condition)[0]
    if n is not None:
        indexes = indexes[:n]

    return sorted_keys[indexes]


rows = []
n = 6

test_files =[p for p in test_queries_dir.glob("*.json")]
n_test_query = 5

print("Testing Against Test Queries...")
with tqdm(total= len(test_files) * n_test_query) as pbar:

    for test_query_path in test_queries_dir.glob("*.json"):
        with test_query_path.open(encoding="utf-8") as f:
            test_data = json.load(f)
            expected_key = name_to_key(test_data["name"])

            index = index_lookup.get(expected_key)
            if index is None:
                print(f"SKIPPING {expected_key}")
                pbar.update(n_test_query)
                continue


            expected_embedding = embeddings[index]

            expected_similar_keys = get_similar_monsters(expected_key, n = n, threshold=0.5, skip_self=False)
            queries = [q["query"] for q in test_data["queries"]]
            for query in queries:
                encoded_query = model.encode(query)
                similarity = cosine_similarity(
                    encoded_query[np.newaxis, :], expected_embedding[np.newaxis, :]
                ).flatten()[0]
                similar_keys_all = get_similar(encoded_query, threshold=0.5)
                similar_keys = similar_keys_all[:n]
                relevant_keys = np.intersect1d(similar_keys, expected_similar_keys)
                precision_n = len(relevant_keys) / n
                recall_n = (
                    len(relevant_keys) / len(expected_similar_keys)
                    if len(expected_similar_keys) > 0
                    else 0
                )

                # get index of expected_key
                rank = (
                    np.where(similar_keys_all == expected_key)[0][0] + 1
                    if expected_key in similar_keys
                    else None
                )

                rows.append(
                    dict(
                        expected_key=expected_key,
                        query=query,
                        cosine_similarity=similarity,
                        relevant_keys=",".join(expected_similar_keys),
                        similar_keys=",".join(similar_keys),
                        retrieved_relevant_keys=",".join(relevant_keys),
                        precision_n=precision_n,
                        recall_n=recall_n,
                        hit_n = 1 if expected_key in similar_keys else 0,
                        rank=rank,
                    )
                )
                pbar.update(1)

df = pd.DataFrame(rows)


Testing Against Test Queries...


  0%|          | 0/11130 [00:00<?, ?it/s]

SKIPPING angel_shrouded
SKIPPING clacker_soldier
SKIPPING clacker_swarm
SKIPPING centaur
SKIPPING crab_duffel
SKIPPING demon_vetala
SKIPPING sedge_dragonette
SKIPPING dragon_prismatic_ancient
SKIPPING prismatic_young_dragon
SKIPPING drake_riptide
SKIPPING ettin_kobold


KeyboardInterrupt: 

In [39]:
df.sort_values(by=["recall_n"], ascending=False, inplace=True)

precision_n = df["precision_n"].mean()
recall_n = df["recall_n"].mean()
hit_rate = df["hit_n"].mean()

rank = np.nan_to_num(df["rank"].to_numpy(), nan=0)
mrr = np.mean(np.divide(1.0, rank, out=np.zeros_like(rank), where = rank > 0))

print(f"Precision@{n}: {precision_n:.4f}")
print(f"Recall@{n}: {recall_n:.4f}")
print(f"Hit Rate@{n}: {hit_rate:.4f}")
print(f"MRR: {mrr:.4f}")

df

Precision@6: 0.1995
Recall@6: 0.2016
Hit Rate@6: 0.3593
MRR: 0.2481


Unnamed: 0,expected_key,query,cosine_similarity,relevant_keys,similar_keys,retrieved_relevant_keys,precision_n,recall_n,hit_n,rank
4455,giant_wolf_spider,agile predator spider that climbs walls,0.792245,"giant_wolf_spider,spider,giant_spider,j’ba_fof...","spider,red_banded_line_spider,giant_spider,j’b...","crypt_spider,giant_spider,giant_wolf_spider,j’...",1.0,1.0,1,5.0
7732,pseudodragon_familiar,small dragon familiar that provides magic resi...,0.836979,"pseudodragon_familiar,quasit_familiar,faerie_d...","pseudodragon_familiar,quasit_familiar,pseudodr...","faerie_dragon_familiar,imp_familiar,inkling,ps...",1.0,1.0,1,1.0
3720,field_commander,charismatic leader humanoid with high Strength...,0.867339,"field_commander,hobgoblin_commander,hobgoblin_...","field_commander,hobgoblin_captain,emerald_orde...","dhampir_commander,emerald_order_cult_leader,fi...",1.0,1.0,1,1.0
5249,holy_knight,Holy Knight stats and abilities for D&D,0.926380,"holy_knight,knight,black_knight_commander,deat...","holy_knight,knight,merfolk_knight,death_knight...","black_knight_commander,death_knight,dread_knig...",1.0,1.0,1,1.0
7058,night_scorpion,Night Scorpion abilities including multiple cl...,0.859043,"night_scorpion,giant_scorpion,stygian_fat_tail...","night_scorpion,giant_scorpion,scorpionfolk,sty...","giant_scorpion,insect_scorpion,night_scorpion,...",1.0,1.0,1,1.0
...,...,...,...,...,...,...,...,...,...,...
4,aalpamac,monster with perception skill and amphibious t...,0.219619,"aalpamac,gullkin,a_mi_kuk,orca,imperator,azi_d...","haleshi,merfolk,shellycoat,vila,naiad,sila",,0.0,0.0,0,0.0
3,aalpamac,monstrosity that slows enemies and has high hi...,-0.010996,"aalpamac,gullkin,a_mi_kuk,orca,imperator,azi_d...","clockwork_huntsman,diplodocus,tripwire_patch,c...",,0.0,0.0,0,0.0
2,aalpamac,strong melee attacker with cold resistance tha...,0.371749,"aalpamac,gullkin,a_mi_kuk,orca,imperator,azi_d...","frostjack,thuellai,frost_afflicted,lindwurm,we...",,0.0,0.0,0,0.0
10971,zouyu,monster that excels in mobility and has vulner...,0.453323,"zouyu,nichny,zilaq,hulking_whelp,huli_jing,wyr...","thrummren,ostinato,yali,ala,mad_piper,azza_gre...",,0.0,0.0,0,0.0
