Use this notebook to prepare vector binaries to be used for topk searcer

In [40]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from main import load_data
import pickle

In [41]:
class SentenceEmbedder:
    def __init__(self) -> None:
        self.device = 'cpu'
        self.model = None
        self.tokenizer = None

    def setModel(self, device, tokenizer, model):
        self.device = device
        self.tokenizer = tokenizer
        self.model = model.to(device)

    def embed(self, sentences):
        if isinstance(sentences, str):
            sentences = [sentences]
        encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.device)
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        sentence_embeddings = SentenceEmbedder.mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

        return sentence_embeddings

    @staticmethod
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0].detach().cpu() #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.detach().cpu().unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# embed raw_text and annotation sentences

In [42]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
# This is a good sentence embedding model
model_name = "sentence-transformers/all-MiniLM-L6-v2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

Embedder = SentenceEmbedder()
Embedder.setModel(device=device, model=model, tokenizer=tokenizer)

In [149]:
# load data
annotations, id_to_clinical_trial_record = load_data()

In [93]:
# embed reports
# each report's section is embedded separately
#   e.g.: "invervention": [(text, embedding), (text, embedding)]
embedded_r = {}
for id, report in tqdm(id_to_clinical_trial_record.items()):
    embedded_r[id] = {}
    for section_id in ['intervention', 'eligibility', 'adverse_events', 'results']:
        section_sentences = report[section_id]
        section_embeddings = Embedder.embed(section_sentences)
        embedded_r[id][section_id] = list(zip(section_sentences, section_embeddings))

100%|██████████| 999/999 [04:01<00:00,  4.13it/s]


In [94]:
# save to disk
with open('raw_text_db.pickle', 'wb') as f:
    pickle.dump(embedded_r, f, protocol=pickle.HIGHEST_PROTOCOL)

In [95]:
# embed annotations
# same except each statement[str] -> staetment (str, embeddings)
embedded_a = []
for i, sample in enumerate(tqdm(annotations['validation'])):
    x = sample
    x['statement'] = [sample['statement'], Embedder.embed([sample['statement']])[0]]
    embedded_a.append(x)

100%|██████████| 200/200 [00:01<00:00, 130.34it/s]


In [96]:
# save to disk
with open('annotations_db_val.pickle', 'wb') as f: 
    pickle.dump(embedded_a, f, protocol=pickle.HIGHEST_PROTOCOL)

# if you are using a different embedding model
# make a new folder under vectordb/{folderName}
# and put both annotation and text embedding there. Keep the filename!!!!

# Inference

load from topksearcher

In [16]:
from pathlib import Path


In [143]:
class TopKSearcher:
    def __init__(self, topk=20) -> None:
        self.topk = topk
        self.raw_text_db = None
        self.annotations_db = None

    def setTopK(self, k):
        self.topk = k

    def load_vector_db(self, db_path):
        db_path = Path(db_path)

        if not db_path.exists():
            raise ValueError(f'invalid path: {str(db_path)}')

        with open(str(db_path.joinpath('raw_text_db.pickle')), 'rb') as f:
            self.raw_text_db = pickle.load(f)

        with open(str(db_path.joinpath('annotations_db_val.pickle')), 'rb') as f:
            self.annotations_db = pickle.load(f)
    
    def search(self, query_text):
        raw_text_db = self.raw_text_db

        # search for hypothesis vector
        s = [x for x in self.annotations_db if x['statement'][0] == query_text]
        if not s:
            print('invalid hypothesis.')
            return ""
        
        sample = s[0]

        sample_type = sample['type'].lower()
        primary_id = sample['primary_id']
        secondary_id = sample['secondary_id']

        # (text, embeddings)
        query = sample['statement']

        # prepare 
        if sample_type.lower() == 'single':
            # [(text, embedding), ]
            db1 = [x for sec in raw_text_db[primary_id].values() for x in sec]
            primary_text = "\n".join(TopKSearcher.search_topk_sentences(db1, query, self.topk))
            premise = f"Primary trial evidence are {primary_text}."
        else:
            db1 = [x for sec in raw_text_db[primary_id].values() for x in sec]
            primary_text = "\n".join(TopKSearcher.search_topk_sentences(db1, query, self.topk // 2))
            db2 = [x for sec in raw_text_db[secondary_id].values() for x in sec]
            secondary_text = "\n".join(TopKSearcher.search_topk_sentences(db2, query, self.topk // 2))
            premise = (
                f"Primary trial evidence are {primary_text}\n and Secondary "
                + f"trial evidence are {secondary_text}."
            )

        return premise

    @staticmethod
    def find_topk_tensors(query_tensor, tensor_list, topk):
        tensor_stack = torch.stack(tensor_list)
        similarity_scores = torch.nn.functional.cosine_similarity(query_tensor.unsqueeze(0), tensor_stack, dim=1)
        topk_indices = torch.topk(similarity_scores, k=topk).indices
        return topk_indices

    @staticmethod
    def search_topk_sentences(fulldoc, hypothesis, topk):
        topk_indicies = TopKSearcher.find_topk_tensors(hypothesis[1], [_[1] for _ in fulldoc], topk)
        found = []
        for index in topk_indicies:
            found.append(fulldoc[index.item()][0])
        return found

In [99]:
raw_text_db, annotations_db = load_vector_data('./vectordb/allMiniLML6V2/')

In [102]:
query = annotations_db[0]['statement']
# search entire doc without using section id
db = [x for sec in raw_text_db['NCT00003199'].values() for x in sec]

In [144]:
topksearcher = TopKSearcher(15)
topksearcher.load_vector_db(db_path='./vectordb/allMiniLML6V2/')

In [147]:
topksearcher.setTopK(20)

In [148]:
len(topksearcher.search(query[0]))

1998