In [None]:
!nvidia-smi

## Setup

In [1]:
import json
import torch
import scispacy
import spacy
import evaluate
from matplotlib import pyplot
import transformers
import numpy as np
import pandas as pd
import seaborn as sns
from datasets import load_dataset
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig

%matplotlib inline
sns.set(rc={'figure.figsize':(8, 6)})
sns.set(rc={'figure.dpi':100})
sns.set(style='white', palette='muted', font_scale=1.2)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

## Dataset


In [2]:
train_path = r"C:\Users\16462\Desktop\Research\dataset\medquad.csv"
data = pd.read_csv(train_path)

data
# test_path = r"C:\Users\16462\Desktop\Research\dataset\All-2479-Answers-retrieved-from-MedQuAD.csv"
# df = pd.read_csv(test_path)


Unnamed: 0,type,Qtype,Q,A,qlen,alen,token_qlen,token_alen
0,CancerGov,information,What is (are) Adult Acute Lymphoblastic Leukem...,Key Points\n - Adult acute ...,8,436,21,794
1,CancerGov,symptoms,What are the symptoms of Adult Acute Lymphobla...,"Signs and symptoms of adult ALL include fever,...",10,127,22,213
2,CancerGov,exams and tests,How to diagnose Adult Acute Lymphoblastic Leuk...,Tests that examine the blood and bone marrow a...,8,445,20,611
3,CancerGov,outlook,What is the outlook for Adult Acute Lymphoblas...,Certain factors affect prognosis (chance of re...,10,66,22,97
4,CancerGov,susceptibility,Who is at risk for Adult Acute Lymphoblastic L...,Previous chemotherapy and exposure to radiatio...,10,114,22,151
...,...,...,...,...,...,...,...,...
45093,MPlusHerbsSuppls,information,What is 5-HTP ?,5-hydroxytryptophan (5-HTP) can be converted t...,4,119,10,209
45094,MPlusHerbsSuppls,how effective is it,How effective is 5-HTP ?,Natural Medicines Comprehensive Database rate...,5,119,11,209
45095,MPlusHerbsSuppls,interactions with medications,Are there interactions between 5-HTP and other...,Moderate Be cautious with this combination. Ca...,9,113,16,217
45096,MPlusHerbsSuppls,interactions with herbs and supplements,Are there interactions between 5-HTP and herbs...,Herbs and supplements with sedative properties...,10,95,19,180


In [None]:
# def question(s):
#     q = s[9:s.find("URL")] 
#     return q
# def answer(s):
#     a = s[s.find("Answer")+8:]
#     return a
# def get_category(s):
#     return s.split('_')[0]

In [None]:
# df["type"] = df.AnswerID.apply(get_category)
# df["Q"] = df.Answer.apply(question)
# df["A"] = df.Answer.apply(answer)
# df = df.drop(["AnswerID","Answer"],axis=1)

## Load Model



In [3]:
BASE_MODEL = "meta-llama/Llama-2-13b-chat-hf"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

tokenizer.pad_token_id = (
    0  # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left"


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
binary_path: C:\Users\16462\anaconda3\envs\alpaca\lib\site-packages\bitsandbytes\cuda_setup\libbitsandbytes_cuda116.dll
CUDA SETUP: Loading binary C:\Users\16462\anaconda3\envs\alpaca\lib\site-packages\bitsandbytes\cuda_setup\libbitsandbytes_cuda116.dll...


The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
CUTOFF_LEN = 2048

def generate_prompt(data_point):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
### Instruction:
{data_point["Q"]}
### Input:
{data_point["input"]}
### Response:
{data_point["A"]}"""

def tokenize(prompt, add_eos_token=True):

    result = tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding=False,
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < CUTOFF_LEN
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    result["labels"] = result["input_ids"].copy()

    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt = generate_prompt(data_point)
    tokenized_full_prompt = tokenize(full_prompt)
    return tokenized_full_prompt

In [5]:
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

## Inference

In [6]:
def generate_prompt_inference(instruction, input=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501

    ### Instruction:
    {instruction}

    ### Input:
    {input}

    ### Response:
    """
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  # noqa: E501

    ### Instruction:
    {instruction}

    ### Response:
    """
def customize_generate_prompt(instruction, input=None):
    if input:
        return f"""  # noqa: E501

    ### Instruction:
    {instruction}

    ### Input:
    {input}

    ### Response:
    """
    else:
        return f""" Please use the knowledge from the clinical domain such as National Institutes of Health (NIH) to answer the following question in the Response section.  # noqa: E501
    {instruction}
    ### Response: 
    """


In [7]:
def inference(
        instruction,
        input=None,
        temperature=0.1,
        top_p=0.75,
        top_k=40,
        num_beams=4,
        max_new_tokens=128,
        **kwargs,
):
        
    prompt = generate_prompt_inference(instruction)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to("cuda")
    
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        **kwargs,
    )
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
        )
        
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)

    return output.split("### Response:")[1].strip()


In [None]:
bleu = evaluate.load("bleu")

In [19]:
# evaluate by overflap words

nlp = spacy.load("en_core_sci_lg")

# ground truth as denominator 

def common_entities(pred,label):
    doc1, doc2 = nlp(pred), nlp(label)
    x, y = [str(d) for d in doc1.ents], [str(d) for d in doc2.ents]
    if min(len(x),len(y))!=0:
        return (len(list(set(list(x))&set(list(y))))/min(len(x),len(y)))
    elif len(y) == 0:
        return 1
    else:
        return 0

In [None]:
# Inference for each question

types = data.type.unique()
predictions = []
references = []

for i in [data[data.type==t].sample().index[0] for t in types]:
    pred = inference(data.Q[i])
    predictions.append(pred)
    references.append([data.A[i]])
    
results = bleu.compute(predictions=predictions, references=references)
print(results)

In [None]:
# Visualize Result for each type

bleus = [bleu.compute(predictions=[predictions[i]], references=[references[i]])["bleu"] for i in range(len(predictions))]

evaluation_table = pd.DataFrame(
{"Type":types,
 "Prediction":predictions,
 "Ground_Truth":references,
 "BLEU":bleus}
)

evaluation_table

In [None]:
# Inference for each category
# 12 categories, 50 questions for each

### Evaluate BLEU scores

types = data.Qtype.unique()
num_questions = 5
eval_result = []

for j,t in enumerate(types):

    predictions = []
    references = []
    print(t)
    
    for i in data[data.Qtype==t].sample(num_questions,replace=True).index:
        pred = inference(data.Q[i])
        predictions.append(pred)
        references.append([data.A[i]])

    results = bleu.compute(predictions=predictions, references=references)
    eval_result.append(results)

bleus_Qtypes_sample = pd.DataFrame(
{"Type":types,
 "BLEU":eval_result,
 "bleus":[e["bleu"] for e in eval_result]}
)

In [27]:
# Inference for each category

### Evaluate common entities 

types = data.Qtype.unique()
num_questions = 10
eval_result = []

for j,t in enumerate(types):

    overlap_score = []
    print(t)
    
    for i in data[data.Qtype==t].sample(num_questions,replace=True).index:
        pred = inference(data.Q[i])
        label = data.A[i]
        overlap_score.append(common_entities(pred,label))

    eval_result.append(np.mean(overlap_score))

bleus_Qtypes_sample = pd.DataFrame(
{"Type":types,
 "Scores":eval_result}
)

information
symptoms
exams and tests
outlook
susceptibility
stages
treatment
research
genetic changes
prevention
causes
inheritance
frequency
considerations
complications
support groups
when to contact a medical professional
indication
usage
precautions
dietary
forget a dose
side effects
storage and disposal
emergency or overdose
other information
brand names
brand names of combination products
contraindication
severe reaction
how can i learn more
dose
why get vaccinated
how effective is it
interactions with medications
interactions with herbs and supplements
interactions with foods
how does it work


In [29]:
bleus_Qtypes_sample.to_csv(r"C:/Users/16462/Desktop/Research/Clinical_LLM_Evaluation/evaluation_csv/common_entity_question_type_5_sample.csv",index=False)

In [30]:
bleus_Qtypes_sample

Unnamed: 0,Type,Scores
0,information,0.299013
1,symptoms,0.147693
2,exams and tests,0.146859
3,outlook,0.133737
4,susceptibility,0.219584
5,stages,0.360815
6,treatment,0.199954
7,research,0.12928
8,genetic changes,0.279627
9,prevention,0.190115


In [26]:
pd.read_csv(r"C:/Users/16462/Desktop/Research/Clinical_LLM_Evaluation/evaluation_csv/̧bleu_question_type_1_sample.csv")

Unnamed: 0,Type,BLEU,bleus
0,information,"{'bleu': 0.01114735729087514, 'precisions': [0...",0.011147
1,symptoms,"{'bleu': 0.004669027861147354, 'precisions': [...",0.004669
2,exams and tests,"{'bleu': 0.04530888858973332, 'precisions': [0...",0.045309
3,outlook,"{'bleu': 0.0, 'precisions': [0.422680412371134...",0.0
4,susceptibility,"{'bleu': 0.0, 'precisions': [0.608108108108108...",0.0
5,stages,"{'bleu': 0.07008008045353659, 'precisions': [0...",0.07008
6,treatment,"{'bleu': 0.0, 'precisions': [0.123456790123456...",0.0
7,research,"{'bleu': 0.0, 'precisions': [0.586956521739130...",0.0
8,genetic changes,"{'bleu': 0.02818715470469387, 'precisions': [0...",0.028187
9,prevention,"{'bleu': 0.00047713692503738016, 'precisions':...",0.000477


In [None]:
pyplot.rcParams["figure.figsize"] = (7,4)

x = data["qlen"].apply(np.log)
y = data["alen"].apply(np.log)

bins = numpy.linspace(0, 10, 100)

pyplot.hist([x,y], alpha=0.6, bins=bins, label=['question_word_length','answer_word_length'])
pyplot.legend(loc='upper right')


pyplot.show()

data[["qlen","alen"]].describe()

In [None]:
sum(data.alen>1024)

In [None]:
x = data["token_qlen"].apply(np.log)
y = data["token_alen"].apply(np.log)

bins = numpy.linspace(0, 10, 100)

pyplot.hist([x,y], alpha=0.6, bins=bins, label=['question_token_length','answer_token_length'])
pyplot.legend(loc='upper right')
pyplot.show()

data[["token_qlen","token_alen"]].describe()

In [None]:
d = data.sample()

In [None]:
d.Q.item()


In [None]:
d.A.item()

In [None]:
d.Qtype.item()

In [None]:
inference(d.Q)

In [None]:
# - extract entity & compare 
# - 


In [None]:
df.drop()

In [None]:
i = 173

In [None]:
predictions[i]

In [None]:
references[i][0]

In [None]:
for i in range(500):
    print(i)
    print(bleu.compute(predictions=[predictions[i]],references=[references[i]])["bleu"])

In [None]:
for i in range(500):
    print(i)
    print(bleu.compute(predictions=[predictions[i]],references=[references[i]])["bleu"])

In [None]:
bleus_types_sample.to_csv(r"C:/Users/16462/Desktop/Research/Clinical_LLM_Evaluation/evaluation_csv/̧bleu_data_source_5_sample.csv",index=False)

In [None]:
# 1. num of data length
# 2. other stats
# 3. average number of words in each question & answers (tokens)
# 

def spl(s):
    return len(s.split())

def token_spl(s):
    return len(tokenize(s)["input_ids"])

data["qlen"]=data.Q.apply(spl)
data["alen"]=data.A.apply(spl)

data["token_qlen"]=data.Q.apply(token_spl)
data["token_alen"]=data.A.apply(token_spl)