In [2]:
from typing import Union

from transformers import AutoModel, AutoTokenizer
import torch
from tqdm.auto import tqdm

import logging
import sys

log = logging.getLogger()
log.setLevel(logging.INFO)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
log.addHandler(handler)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)


cuda:0


In [4]:
class EmbedModel:
    def __init__(self, model_name:str, mode:str="token", device:str="cpu") -> None:
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = device
        self.mode = mode
        self.model.to(self.device)
    
    def __call__(self, text:Union[str, list]):
        tokenized = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        with torch.no_grad():
            emb = self.model(**tokenized)
        pooled = self._pool(emb, tokenized['attention_mask'])
        # return pooled
        return self._norm(pooled)
    
    def _norm(self, pooled):
        return pooled / torch.linalg.norm(pooled, dim=1, keepdim=True)

    def _pool(self, emb, mask=None):
        if self.mode=="token":
            masked = emb['last_hidden_state']*mask.unsqueeze(-1)
            return masked.sum(axis=1) / mask.sum(-1, keepdim=True)
        elif self.mode=="pooler":
            return emb['pooler_output']

In [5]:
emb_model = EmbedModel("menadsa/S-BioELECTRA", device=device)

# small smoke test
a = emb_model("text")
print(a.square().sum().sqrt())

print(emb_model("tumor in the chest") @ emb_model("lungs cancer").T)

tensor(1., device='cuda:0')
tensor([[0.6924]], device='cuda:0')


### Data loading

In [152]:
from datasets import load_dataset

sample_size = 100000
seed=1337

# dataset 1
data = load_dataset("medmcqa", split="train").to_pandas()
data = data.dropna()
ans_dict = {
    0: "opa",
    1: "opb",
    2: "opc",
    3: "opd"
}
data['Q'] = data['question']
data['A'] = data.apply(lambda row: row[ans_dict[row["cop"]]]+". "+row['exp'], axis=1)
data = data[['Q', "A"]]
print("Length:", len(data))
# data["text"] = data["question"] + "\n" + data["exp"]
# data = data[data["subject_name"] == "Medicine"]
data = data.sample(min(len(data), sample_size), random_state=seed)


# dataset 2
data1 = load_dataset("AnonymousSub/MedQuAD_47441_Question_Answer_Pairs", split="train").to_pandas()
data1 = data1.dropna()
data1["Q"] = data1["Questions"]
data1["A"] = data1["Answers"]
data1 = data1[["Q", "A"]]
print("Length:", len(data1))


data1 = data1.sample(min(len(data1), sample_size), random_state=seed)

Length: 87077
Length: 16407


### Retrieval

In [153]:
import numpy as np

class AnswerDB:
    def __init__(self, data: list, q: str="Q", a: str="A") -> None:
        self.data = np.array(data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]

    def encode(self, emb: EmbedModel, tqdm_flag: bool=False, batch_len: int=1):
        if batch_len < 1:
            raise ValueError("batch_len should be >=1")
        batches = range(0, len(self.data), batch_len)
        batched = (list(self.data[i:i+batch_len]) for i in batches)
        if tqdm_flag:
            proxy = tqdm(batched, total=len(self.data)/batch_len)
        else:
            proxy = batched
        self.embedded = torch.cat([emb(batch) for batch in proxy], dim=0)
        return self

In [154]:
db = AnswerDB(data1["A"].tolist()).encode(emb_model, tqdm_flag=True)

  0%|          | 0/16407.0 [00:00<?, ?it/s]

In [155]:
class Query:
    def __init__(self, text):
        self.text = text
    
    def embed(self, emb):
        self.embedded = emb(self.text)
        return self

In [156]:
class GoldenRetrieval:
    def __init__(self, data) -> None:
        self.data = data
        if not hasattr(self.data, "embedded"):
            log.warning("Data for retrieval task doesn't have embeddings")
        
    def find_it(self, q, top_k=None, sorted=True):
        if not top_k:
            top_k = self.data.embedded.shape[0]
        similarity = self.data.embedded @ q.embedded.T
        similarity = torch.topk(similarity.squeeze(-1).cpu(), k=top_k, sorted=sorted)
        return {
            "answers": self.data.data[similarity.indices],
            "similarities": similarity.values
        }


In [157]:
doge = GoldenRetrieval(db)

In [164]:
text = "Can I use paracetamol with pregnancy?"

query = Query(text).embed(emb_model)

res = doge.find_it(query, top_k=4)
res

{'answers': array(['Summary : You may need to take medicines every day, or only once in a while. Either way, you want to make sure that the medicines are safe and will help you get better. In the United States, the Food and Drug Administration is in charge of assuring the safety and effectiveness of both prescription and over-the-counter medicines.    Even safe drugs can cause unwanted side effects or interactions with food or other medicines you may be taking. They may not be safe during pregnancy. To reduce the risk of reactions and make sure that you get better, it is important for you to take your medicines correctly and be careful when giving medicines to children.',
        'When you are pregnant, you are not just "eating for two." You also breathe and drink for two, so it is important to carefully consider what you give to your baby. If you smoke, use alcohol or take illegal drugs, so does your unborn baby.    First, don\'t smoke. Smoking during pregnancy passes nicotine and can

In [163]:
text = "I broke my leg"

query = Query(text).embed(emb_model)

res = doge.find_it(query, top_k=4)
res

{'answers': array(["Mobility aids help you walk or move from place to place if you are disabled or have an injury. They include       - Crutches    - Canes    - Walkers    - Wheelchairs    - Motorized scooters       You may need a walker or cane if you are at risk of falling. If you need to keep your body weight off your foot, ankle or knee, you may need crutches. You may need a wheelchair or a scooter if an injury or disease has left you unable to walk.     Choosing these devices takes time and research. You should be fitted for crutches, canes and walkers. If they fit, these devices give you support, but if they don't fit, they can be uncomfortable and unsafe.",
        'Your legs are made up of bones, blood vessels, muscles, and other connective tissue. They are important for motion and standing. Playing sports, running, falling, or having an accident can damage your legs. Common leg injuries include sprains and strains, joint dislocations, and fractures.    These injuries can affec