In [None]:
!pip install quaterion pytorch_lightning sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


#### The framework used for finetuning the embedding is [quaterion](https://github.com/qdrant/quaterion), you can find detailed information on how to utilize the framework using this [link](https://quaterion.qdrant.tech/getting_started/quick_start.html). This framework is built on-top of the pytorch lightning framework which is optimized for fast training and run time memory management.

In [None]:
import argparse
import json
import os
import random
from typing import Any, Dict, List, Union

import pytorch_lightning as pl
import torch

from quaterion import Quaterion, TrainableModel
from quaterion.loss import TripletLoss
from quaterion.dataset.similarity_data_loader import (
    GroupSimilarityDataLoader,
    SimilarityGroupSample,
)

from quaterion_models.encoders import Encoder
from quaterion_models.heads import EncoderHead, GatedHead, SkipConnectionHead
from quaterion_models.types import CollateFnType
from torch.utils.data import Dataset


import torch
import torch.nn as nn

from quaterion.loss import SimilarityLoss, SoftmaxLoss

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling

import pandas as pd

#### THE IDEA HERE IS TO TRAIN THE MODEL SUCH THAT THE EMBEDDING LAYER RESTRUCTURES ITSELF IN SUCH A WAY IT BECOMES VERY SPECIFIC TO THE DATASET

In [None]:
# IMPORT TRAIN AND TEST SET
with open("/content/drive/MyDrive/Lights-on-heights/data/labeled_data_train.json", "r") as f:
    train_data = json.loads(f.read())

with open("/content/drive/MyDrive/Lights-on-heights/data/labeled_data_val.json", "r") as f:
    val_data = json.loads(f.read())

In [None]:
class JsonDataset(Dataset):
    def __init__(self, path: str):
        super().__init__()
        self.translation_dict = {
            "acne" : 1,
            "adhd" : 2,
            "aids" : 3,
            "allergies" : 4,
            "alzheimer":5,
            "angina":6,
            "anxiety":7,
            "asthma":8,
            "bipolar":9,
            "bronchitis":10,
            "cancer":11,
            "cholesterol":12, 
            "cold":13
        }
        with open(path, "r") as f:
            self.data = json.loads(f.read())
    

    def __getitem__(self, index: int) -> SimilarityGroupSample:
        item = self.data[index]
        return SimilarityGroupSample(obj=item, group=self.translation_dict[item["label"]])

    def __len__(self) -> int:
        return len(self.data)

In [None]:
class SemanticSearchEncoder(Encoder):
    def __init__(self, transformer: Transformer, pooling: Pooling):
        super().__init__()
        self.transformer = transformer
        self.pooling = pooling
        self.encoder = nn.Sequential(self.transformer, self.pooling)

    @property
    def trainable(self) -> bool:
      # this is set to false because, we want to update the embedding layer not the model itself.
        return False

    @property
    def embedding_size(self) -> int:
        return self.transformer.get_word_embedding_dimension()

    def forward(self, batch) -> torch.Tensor:
        return self.encoder(batch)["sentence_embedding"]

    def collate_drug_information(self, batch: List[Any]) -> torch.Tensor:
        drug_information = [drug["drug_information"] for drug in batch]
        return self.transformer.tokenize(drug_information)

    def get_collate_fn(self) -> CollateFnType:
        return self.collate_drug_information

    @staticmethod
    def _transformer_path(path: str) -> str:
        return os.path.join(path, "transformer")

    @staticmethod
    def _pooling_path(path: str) -> str:
        return os.path.join(path, "pooling")

    def save(self, output_path: str):

        transformer_path = self._transformer_path(output_path)
        os.makedirs(transformer_path, exist_ok=True)

        pooling_path = self._pooling_path(output_path)
        os.makedirs(pooling_path, exist_ok=True)

        self.transformer.save(transformer_path)
        self.pooling.save(pooling_path)

    @classmethod
    def load(cls, input_path: str) -> Encoder:
        transformer = Transformer.load(os.path.join(input_path, 'transformer'))
        pooling = Pooling.load(os.path.join(input_path, 'pooling'))
        return cls(transformer=transformer, pooling=pooling)

#### THE TRIPLET LOSS FUNCTION IS USED WHEN OPTIMIZING BASED ON SIMILARITY

In [None]:
class Model(TrainableModel):
    def __init__(self, lr: float):
        self._lr = lr
        super().__init__()

    def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]:
        pre_trained = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")  # multi-qa-MiniLM-L6-cos-v1
        transformer, pooling = pre_trained[0], pre_trained[1]
        return SemanticSearchEncoder(transformer, pooling)

    def configure_head(self, input_embedding_size) -> EncoderHead:
        return SkipConnectionHead(input_embedding_size)

    def configure_loss(self) -> SimilarityLoss:
        return TripletLoss()

    def configure_optimizers(self):
        return torch.optim.Adam( self.model.parameters(), lr=self._lr)

In [None]:
model = Model(lr=0.001)

train_dataloader = GroupSimilarityDataLoader(JsonDataset("/content/drive/MyDrive/Lights-on-heights/data/labeled_data_train.json"), batch_size=128)
val_dataloader = GroupSimilarityDataLoader(JsonDataset("/content/drive/MyDrive/Lights-on-heights/data/labeled_data_val.json"), batch_size=128)

Downloading (…)5fedf/.gitattributes:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)2cb455fedf/README.md:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

Downloading (…)b455fedf/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)edf/data_config.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)5fedf/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

Downloading (…)fedf/train_script.py:   0%|          | 0.00/13.8k [00:00<?, ?B/s]

Downloading (…)2cb455fedf/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)455fedf/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [None]:
# FINE TUNE THE EMBEDDING MATRIX
Quaterion.fit(
    trainable_model=model,
    trainer=None, # Use default trainer
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader
)

In [None]:
# SAVE FINETUNED MODEL
model.save_servable("/content/drive/MyDrive/Lights-on-heights/artefacts/finetuned_model")



### The model achieved a val loss of 0.0483 using a batch size of 128

### Evaluate Model Performance

In [None]:
import re
import nltk
import string
import pandas as pd
nltk.download('stopwords')
stopwords_lst = nltk.corpus.stopwords.words('english')
from sentence_transformers import util

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [None]:
def clean_input_qst(qst):
  string.punctation = string.punctuation
  qst = ' '.join([word for word in qst.split() if word not in (stopwords_lst)])
  qst = re.sub('[%s]' % re.escape(string.punctuation), ' ' , qst)
  qst = qst.lower()
  return qst.strip()

In [None]:
def query_shuffler(drug_name):
  query = [f"what do you have on {drug_name}?",f"i need information on {drug_name}.",f"{drug_name}?", f"{drug_name}"]
  query = random.choice(query)
  return query

In [None]:
from quaterion_models.model import SimilarityModel
finetuned_model = SimilarityModel.load("/content/drive/MyDrive/Lights-on-heights/artefacts/finetuned_model")

In [None]:
feat_store = pd.read_json("/content/drive/MyDrive/Lights-on-heights/data/feature_store.json")

In [None]:
feat_stores = [{"drug_information":info}for info in feat_store["drug_information"].to_list()]

In [None]:
feat_stores[:6]

[{'drug_information': 'doxycycline miscellaneous antimalarials tetracyclines acticlate adoxa ck adoxa pak adoxa tt alodox avidoxy doryx mondoxyne nl monodox morgidox okebo oracea oraxyl periostat targadox vibramycin calcium vibramycin hyclate vibramycin monohydrate vibra tabs acne amoxicillin prednisone albuterol ciprofloxacin azithromycin cephalexin metronidazole metronidazole topical clindamycin topical clindamycin'},
 {'drug_information': 'spironolactone aldosterone receptor antagonists potassium sparing diuretics aldactone carospir acne amlodipine lisinopril losartan metoprolol furosemide hydrochlorothiazide carvedilol warfarin lasix bumetanide'},
 {'drug_information': 'minocycline tetracyclines dynacin minocin minolira solodyn ximino vectrin myrac acne amoxicillin prednisone doxycycline ciprofloxacin azithromycin cephalexin metronidazole clindamycin topical augmentin dexamethasone'},
 {'drug_information': 'accutane isotretinoin oral miscellaneous antineoplastics miscellaneous unca

In [None]:
embeddings = finetuned_model.encode(feat_stores)

In [None]:
embeddings.shape

(999, 384)

#### VISUALIZE THE EMBEDDINGS

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import plotly.express as px

In [None]:
visualize_labels = feat_store["label"].to_list()
visualize_labels = np.asarray(visualize_labels)

In [None]:
# CREATE A TWO DIMENTIONAL T-SNE PROJECTION OF THE EMBEDDING
tsne = TSNE(2, verbose=1)
tsne_proj = tsne.fit_transform(embeddings)


fig = px.scatter(x=tsne_proj[:, 0], y=tsne_proj[:, 1], color=visualize_labels)
fig.update_layout(
    title="2D t-SNE visualization of embeddings",
    xaxis_title="First t-SNE",
    yaxis_title="Second t-SNE",
)
fig.show()

[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 999 samples in 0.001s...
[t-SNE] Computed neighbors for 999 samples in 0.120s...
[t-SNE] Computed conditional probabilities for sample 999 / 999
[t-SNE] Mean sigma: 1.272233
[t-SNE] KL divergence after 250 iterations with early exaggeration: 51.069466
[t-SNE] KL divergence after 1000 iterations: 0.429648


In [None]:
# CREATE A THREE DIMENTIONAL T-SNE PROJECTION OF THE EMBEDDING
tsne = TSNE(3, verbose=1)
tsne_proj = tsne.fit_transform(embeddings)

fig = px.scatter_3d(x=tsne_proj[:, 0], y=tsne_proj[:, 1], z=tsne_proj[:, 2], color=visualize_labels, opacity=0.8)
fig.update_layout(
    title="3D t-SNE visualization of embeddings",
)
fig.show()

[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 999 samples in 0.001s...
[t-SNE] Computed neighbors for 999 samples in 0.051s...
[t-SNE] Computed conditional probabilities for sample 999 / 999
[t-SNE] Mean sigma: 1.272233
[t-SNE] KL divergence after 250 iterations with early exaggeration: 51.088341
[t-SNE] KL divergence after 1000 iterations: 0.345077


###### This vector spaces shows how well the embedding layer was able to represent the vector space. This could mean that either, the model did a very good job, or the model over-fitted.

In [None]:
qst = "what do you have on cancer?"
qst = clean_input_qst(qst)
qst

'cancer'

#### PREPARING THE QUESTION TO SUITE THE INPUT FORMAT OF THE MODEL

In [None]:
qst_dict = {"drug_information":qst}

In [None]:
qst_embedding = finetuned_model.encode(qst_dict)

In [None]:
qst_embedding.shape

(1, 384)

In [None]:
resp = util.semantic_search(qst_embedding, embeddings)
resp  = resp[0]
resp = resp[0:10]

In [None]:
resp

[{'corpus_id': 733, 'score': 0.8086163997650146},
 {'corpus_id': 723, 'score': 0.8009374737739563},
 {'corpus_id': 729, 'score': 0.7964625954627991},
 {'corpus_id': 732, 'score': 0.7813810110092163},
 {'corpus_id': 739, 'score': 0.7315276265144348},
 {'corpus_id': 734, 'score': 0.7312262654304504},
 {'corpus_id': 736, 'score': 0.7265673279762268},
 {'corpus_id': 731, 'score': 0.7207668423652649},
 {'corpus_id': 727, 'score': 0.7156854867935181},
 {'corpus_id': 724, 'score': 0.7119912505149841}]

In [None]:
for i in range(len(resp)):
  print(feat_store._get_value(resp[i]['corpus_id'], 'label'))
  print(feat_store._get_value(resp[i]['corpus_id'], 'drug_name'))
  print("")

cancer
Toposar

cancer
Leukeran

cancer
vincristine

cancer
Ifex

cancer
Tepadina

cancer
vinblastine

cancer
ifosfamide

cancer
Etopophos

cancer
etoposide

cancer
chlorambucil



##### MINI-TEST SET

In [None]:
drug_names_condtions= ["gemfibrozil","Xyzal","amlodipine","Vicks VapoRub","triamcinolone","Equaline Sleep Aid","Lescol XL","Namenda","Monodox","Depakote Sprinkles","Bayer Aspirin","Truvada","bronchitis","Acticlate","cancer","Dxevo","Namzaric","Depakote ER","Benadryl","Tivicay"]
labels = ["cholesterol","allergies","angina","cold","asthma","allrgies","cholesterol","acne","Namenda","bipolar","angina","aids","bronchitis","acne","cancer","asthma","alzheimer","bipolar","cold","aids"]

In [None]:
def pre_proceess_text(df):
  string.punctation = string.punctuation + "——"
  df = re.sub('\w*\d\w*', '', str(df)) # remove numbers
  df = re.sub(' +', ' ', str(df)) # remove extra whitespaces
  df = re.sub('[%s]' % re.escape(string.punctuation), ' ' , str(df)) # remove punctuations
  df = re.sub('’s', '',str(df))
  df = df.lower() # convert to lower case
  df = df.strip()
  return df

In [None]:
save_test_result = []

In [None]:
for drug_name in drug_names_condtions:
  drug_name = pre_proceess_text(drug_name)
  question = query_shuffler(drug_name) # simulatr real queries
  print(question)
  qst = clean_input_qst(question)
  qst_dict = {"drug_information":qst}
  qst_embedding = finetuned_model.encode(qst_dict)
  resp = util.semantic_search(qst_embedding, embeddings)
  resp  = resp[0]
  resp = resp[:1]
  for i in range(len(resp)):
    df_result = feat_store._get_value(resp[i]['corpus_id'], 'drug_name')
    save_test_result.append(feat_store._get_value(resp[i]['corpus_id'], 'label'))
    print(df_result)
    print("")

gemfibrozil?
Equetro

what do you have on xyzal?
Qelbree

amlodipine
enoxaparin

vicks vaporub?
Vicks NyQuil Cold & Flu Nighttime Relief (Alcohol Free)

what do you have on triamcinolone?
vincristine

i need information on equaline sleep aid.
Equaline Sleep Aid

i need information on lescol xl.
Lescol XL

i need information on namenda.
aducanumab

monodox
Concerta

i need information on depakote sprinkles.
acetaminophen / dextromethorphan

i need information on bayer aspirin.
Bayer Aspirin Extra Strength Plus

truvada
Truvada

what do you have on bronchitis?
cefaclor

acticlate
Inderal LA

cancer?
Toposar

i need information on dxevo.
Emtriva

namzaric?
Angiomax

what do you have on depakote er?
Depakote Sprinkles

i need information on benadryl.
Benadryl

i need information on tivicay.
Tivicay



In [None]:
labels

['cholesterol',
 'allergies',
 'angina',
 'cold',
 'asthma',
 'allrgies',
 'cholesterol',
 'acne',
 'Namenda',
 'bipolar',
 'angina',
 'aids',
 'bronchitis',
 'acne',
 'cancer',
 'asthma',
 'alzheimer',
 'bipolar',
 'cold',
 'aids']

In [None]:
save_test_result

['bipolar',
 'adhd',
 'angina',
 'cold',
 'cancer',
 'allergies',
 'cholesterol',
 'adhd',
 'adhd',
 'cold',
 'angina',
 'aids',
 'bronchitis',
 'aids',
 'cancer',
 'angina',
 'angina',
 'bipolar',
 'cold',
 'aids']

In [None]:
assert len(save_test_result) == len(labels)

In [None]:
count = 0

for i in range(len(save_test_result)):
  if save_test_result[i] == labels[i]:
    count+=1
print(count)

10


In [None]:
accuracy = count/len(save_test_result)

In [None]:
print(f"Accuracy on curated dataset: {accuracy * 100} %")

Accuracy on curated dataset: 50.0 %


In [None]:
for i in range(len(resp)):
  print(f"DRUG NAME: {feat_store._get_value(resp[i]['corpus_id'], 'drug_name')}")
  print(f"MEDICAL CONDITION FOR DRUG: {feat_store._get_value(resp[i]['corpus_id'], 'label')}")
  print("")

DRUG NAME: Tivicay
MEDICAL CONDITION FOR DRUG: aids



### Although the embedding space seems to be well defined, it doesn't seem to be very generalizable. The major reason this approach performed averagely even with a really low validation loss, is due to not having sufficient high  quality data.