# SQuaD QA with LLM RAG

## Imports & Inits

In [1]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

import sys, os, time, warnings, pdb, pickle, random, math, re, json
warnings.filterwarnings('ignore')
sys.path.insert(0, '../scripts')

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset

np.set_printoptions(precision=4)
sns.set_style("darkgrid")
%matplotlib inline

In [2]:
import torch
import chromadb

from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding 

from llama_index.core import PromptTemplate
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext

from auto_gptq import exllama_set_max_input_length

In [3]:
SYSTEM_PROMPT = """You are an AI assistant designed to answer queries about information contained in a set of documents. Here are guidelines you must adhere to:
- Provide the most shortest answer possible while maintaining accuracy
- Do not include pleasantries or filler phrases
- If you are not sure of anything, indicate that in your response.
"""
query_wrapper_prompt = PromptTemplate("[INST]<<SYS>>\n" + SYSTEM_PROMPT + "<</SYS>>\n\n{query_str}[/INST]")

In [4]:
project_dir = Path('/mnt/scratch')
proc_data_dir = project_dir/'llmqa/data/'
# os.mkdir(proc_data_dir)
db_dir = project_dir/'llmqa/db_dir'
db_dir.mkdir(parents=True, exist_ok=True)

In [5]:
db = chromadb.PersistentClient(path=db_dir.as_posix())

In [6]:
chroma_collection = db.get_or_create_collection('mimiciv')
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

In [7]:
model_name_or_path = 'mistralai/Mistral-7B-Instruct-v0.1'
# model_name_or_path = 'TheBloke/Mistral-7B-Instruct-v0.1-GPTQ'

In [8]:
%%time
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
llm = HuggingFaceLLM(
  context_window=4096,
  max_new_tokens=30,
  generate_kwargs={'temperature': 0.2, 'do_sample': True, 'top_p': 0.95, 'top_k': 40, 'repetition_penalty': 1.1},
  query_wrapper_prompt=query_wrapper_prompt,
  tokenizer_name=model_name_or_path,
  model_name=model_name_or_path,
  device_map='auto',
  model_kwargs={'torch_dtype': torch.float16},
)


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

CPU times: user 1min 28s, sys: 20.4 s, total: 1min 48s
Wall time: 3.44 s


In [9]:
# embedding
# embed_model = HuggingFaceEmbedding(model_name='BAAI/bge-large-en-v1.5')
embed_model = HuggingFaceEmbedding(model_name='exp_finetune')

service_context = ServiceContext.from_defaults(
  chunk_size=1024,
  chunk_overlap=256,
  llm=llm,
  embed_model=embed_model,
)

In [10]:
%%time
# index = VectorStoreIndex.from_documents(docs, storage_context=storage_context,
# service_context=service_context)
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)

CPU times: user 574 µs, sys: 0 ns, total: 574 µs
Wall time: 581 µs


In [11]:
query_engine = index.as_query_engine()

In [12]:
DATA_PATH = "/home/75y/data_ragMimic/data/"
VAL_DATASET_FPATH = DATA_PATH+'val_dataset.json'
with open(VAL_DATASET_FPATH, 'r+') as f:
    val_dataset = json.load(f)

corpus = val_dataset['corpus']
queries = val_dataset['queries']
relevant_docs = val_dataset['relevant_docs']


In [13]:
idx = np.random.randint(len(corpus))
q,c = queries[list(queries.keys())[idx]],corpus[relevant_docs[list(queries.keys())[idx]][0]]
r = query_engine.query(q)

print(f"Question:\n{q}")
print("-"*100)
print(f"Context:\n{c}")
print("-"*100)
print(f"Response:\n{r}")

Question:
What is the patient's name?
----------------------------------------------------------------------------------------------------
Context:
- the patient was given:  
 azithro 500mg  
 500cc normal saline (for brief period of relative hypotension  with sbp ___  
 vitals prior to transfer were:  
 64 105/61 20 99% ra  
 upon arrival to the floor, patient endorses no complaints.  

 
past medical history:  sjogren's syndrome  
 osteoarthritis  
 hearing loss - sensorineural, unspec  
 cancer - breast s/p partial mastectomy  
 hypercholesterolemia  
 spinal stenosis - lumbar  
 pulmonary embolism  
 dvts (bilateral), postphlebitic syndrome  
 history total knee replacement  
 hypertension - essential, unspec  
 diastolic heart failure  
 ckd stage iii  
 
social history: ___
family history: no family history of kidney or heart disease. patient's brother  with diabetes.
 
 vitals - 97.2 129/49 62 18 100ra  
 wt 80.3  
 general - pleasant, well-appearing, in no apparent distress  
 