In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from my_processing import paths_to_dataset

# Load data
data_path = "training"
paths = [
    f"data/{data_path}.jsonl",
    "data/TAR_data.jsonl",
    "data/sysrev_conv.jsonl",
]
dataset = paths_to_dataset(
    paths,
    test_only_sources=['TAR', 'sysrev'],
    # train_sources=['pubmed-searchrefiner','pubmed-query','raw-jsonl']
    train_sources=['pubmed-searchrefiner','raw-jsonl']
)

data/training.jsonl
data/TAR_data.jsonl
data/sysrev_conv.jsonl


Finding similar: 100%|██████████| 2088/2088 [00:00<00:00, 3439.44it/s]
Finding similar: 100%|██████████| 343/343 [00:00<00:00, 13685.88it/s]
Finding similar: 100%|██████████| 79/79 [00:00<00:00, 26291.36it/s]
Finding similar: 100%|██████████| 50/50 [00:00<00:00, 12484.53it/s]
Finding similar: 100%|██████████| 40/40 [00:00<00:00, 19999.07it/s]
Finding similar: 100%|██████████| 3782/3782 [00:04<00:00, 781.83it/s]


In [2]:
N = 10000
df = pd.DataFrame({
    "nl": dataset["train"]["nl_query"],
    "bool": dataset["train"]["bool_query"],
    "quality": dataset["train"]["quality"],
    "source": dataset["train"]["source"]
})
df = df.sample(min(N, df.shape[0])).reset_index(drop=True)
df

Unnamed: 0,nl,bool,quality,source
0,Interventions for preventing falls in people w...,"(((""falls""[Title/Abstract]) OR ""recurrent fal...",0.549417,raw-jsonl
1,Associations of diet and physical activity dur...,(Pregnancy [Mesh] OR pregnan* [tiab] OR gestat...,0.417550,pubmed-searchrefiner
2,Prophylactic platelet transfusions prior to su...,((platelet* OR thrombocyte*) AND (prophyla* O...,0.549417,raw-jsonl
3,Media Multitasking Is Associated With Higher B...,"multitasking/ OR multitask*.ti,ab,id. OR task ...",0.113877,pubmed-searchrefiner
4,Low-value care.,('low value' OR 'low added value' OR harmful O...,0.626325,pubmed-searchrefiner
...,...,...,...,...
3777,Inositol in preterm infants at risk for or hav...,"((infant, newborn[MeSH] OR newborn*[TIAB] OR ...",0.137354,raw-jsonl
3778,A randomized controlled study for the treatmen...,"((""Acne Vulgaris""[Mesh] OR Acne[tiab] OR Vulga...",0.062633,pubmed-searchrefiner
3779,Oral retinoids for psoriasis,"((psoriasis (""psoriasis""[MeSH Terms] OR psori...",0.549417,raw-jsonl
3780,"An extension of a multicenter, randomized, spl...","((""Acne Vulgaris""[Mesh] OR Acne[tiab] OR White...",0.003631,pubmed-searchrefiner


In [3]:
import nltk
import numpy as np

nltk.download("words")
words = nltk.corpus.words.words()
words = np.random.choice(words, 10000).tolist()

[nltk_data] Downloading package words to
[nltk_data]     C:\Users\Simon\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!


In [45]:
import torch
from boolrank import DualSiglip2Model

model = DualSiglip2Model('BAAI/bge-small-en-v1.5')
model.load(r"models\clip\bge-small-en-v1.5\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-11288\model.safetensors")
# model = DualSiglip2Model('dmis-lab/biobert-v1.1')
# model.load(r"models\clip\biobert-v1.1\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-14110\model.safetensors")

embeddings = model.encode_bool(df["bool"].tolist(), batch_size=200).detach().cpu().numpy()
# embeddings = model.encode_text(words, batch_size=200).detach().cpu().numpy()
torch.cuda.empty_cache()

In [46]:
import umap

um = umap.UMAP(n_neighbors=15, n_components=3)
trans = um.fit_transform(embeddings)

x = trans[:,0]
y = trans[:,1]
z = trans[:,2]
df["x"] = x
df["y"] = y

def cutoff(n): return lambda x: x if len(x) < n else x[:n] + "..."
cut = 60
df["nl"] = df["nl"].map(cutoff(cut))
df["bool"] = df["bool"].map(cutoff(cut))

In [170]:
query = dataset["test"]["pubmed-searchrefiner"]["nl_query"][10]
# query = "cancer"
query_emb = model.encode_text(query).detach().cpu().numpy()
query

'Rapid eye movement sleep and slow wave sleep rebounded and related factors during positive airway pressure therapy.'

In [171]:
similarity = model.get_similarities(embeddings, query_emb).numpy()
df["sim"] = similarity
df["sim"].values

array([ 0.10858485, -0.05632198, -0.20263498, ..., -0.27698457,
       -0.06070346,  0.9894577 ], shape=(3782,), dtype=float32)

In [172]:
top_n = (-similarity).argsort()[:100]
mask = np.zeros_like(similarity)
# bool_mask = mask + 1
# bool_mask[top_n] = 0
mask[top_n] = 0.9
mask += 0.01

In [None]:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=df.x,
        y=df.y,
        mode='markers',
        marker=dict(
            opacity=mask,
            ),
        fillpattern=dict(
            fillmode="overlay"
        ),
        hovertext=df["bool"]
    )
)

fig.update_layout(width=1000, height=700)
fig.show()