In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd

paths = [
    "data/training.jsonl",
    "data/TAR_data.jsonl",
    "data/sysrev_conv.jsonl",
]
dataset = []
for path in paths:
    df = pd.read_json(path, lines=True)
    dataset.append(df)

dataset = pd.concat(dataset)
# dataset = dataset[dataset["nl_query"] != ""]
dataset

In [None]:
dataset[dataset["nl_query"].isna()]

In [None]:
dataset[dataset["nl_query"] == ""]

In [None]:
N = 10000
df = dataset.sample(min(N, dataset.shape[0])).reset_index(drop=True)
df

In [None]:
import nltk
import numpy as np

nltk.download("words")
words = nltk.corpus.words.words()
words = np.random.choice(words, 10000).tolist()

In [None]:
import torch
from utils.boolrank import DualSiglip2Model

model = DualSiglip2Model('BAAI/bge-small-en-v1.5')
model.load(r"models\clip\bge-small-en-v1.5\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-11288\model.safetensors")
# model = DualSiglip2Model('dmis-lab/biobert-v1.1')
# model.load(r"models\clip\biobert-v1.1\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-14110\model.safetensors")

embeddings = model.encode_bool(df["bool_query"].tolist(), batch_size=200).detach().cpu().numpy()
# embeddings = model.encode_text(words, batch_size=200).detach().cpu().numpy()
torch.cuda.empty_cache()

In [None]:
import umap

um = umap.UMAP(n_neighbors=15, n_components=3)
trans = um.fit_transform(embeddings)

x = trans[:,0]
y = trans[:,1]
z = trans[:,2]
df["x"] = x
df["y"] = y

def cutoff(n): return lambda x: x if len(x) < n else x[:n] + "..."
cut = 60
df["nl"] = df["nl_query"].map(cutoff(cut))
df["bool"] = df["bool_query"].map(cutoff(cut))

In [None]:
query = dataset.iloc[10]["nl_query"]
# query = "cancer"
query_emb = model.encode_text(query).detach().cpu().numpy()
query

In [None]:
similarity = model.get_similarities(embeddings, query_emb).numpy()
df["sim"] = similarity
df["sim"].values

In [None]:
top_n = (-similarity).argsort()[:100]
mask = np.zeros_like(similarity)
# bool_mask = mask + 1
# bool_mask[top_n] = 0
mask[top_n] = 0.9
mask += 0.01

In [None]:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=df.x,
        y=df.y,
        mode='markers',
        marker=dict(
            opacity=mask,
            ),
        fillpattern=dict(
            fillmode="overlay"
        ),
        hovertext=df["bool"]
    )
)

fig.update_layout(width=1000, height=700)
fig.show()