In [1]:
%%time
import os, re
from time import ctime
import time, math
from pathlib import Path
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import bitsandbytes as bnb
from transformers import BitsAndBytesConfig
import pandas as pd
import numpy as np
import gc

def mm_fewshot(model_name, readfile_name, savefile_name): 
    qconfig = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16 #"float16"
    )
      
    tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
    
    gc.collect()
    torch.cuda.empty_cache()
    model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="cuda",
            # torch_dtype=torch.float16,
            quantization_config=qconfig,
        )

    prompt_template = """
    Your task is to determine if a diagnosis of multiple myeloma was explicitly stated in the given clinical note. The diagnosis must be unambiguous and must exactly mention "myeloma".
    Your responses should be either "Yes" or "No". Do not respond any texts other than "Yes" or "No".
    
    Follow these guidelines: 
    To identify the explicit diagnosis of multiple myeloma in clinical notes, you should look for phrases or terms exactly stating "myeloma". Mention of relevant lab results or treatments alone does not qualify as an explicit diagnosis. Additionally:
    Avoid mentions where multiple myeloma is considered only as a suspicion, a concern, or a differential diagnosis.
    A history of multiple myeloma should not be identified as a current diagnosis.
    A mention of conditions or symptoms associated with multiple myeloma should not be identified as a current diagnosis of multiple myeloma.
    A diagnosis of MGUS or monoclonal gammopathy should not be mistaken for multiple myeloma.
    
    Here are some examples: 
    Respond with "No" when the clinical note is "Patient has Igm monoclonal gammopathy, will repeat myeloma labs in 6 months."
    Respond with "Yes" when the clinical note is "Oncology diagnosis: IgG Kappa Multiple Myeloma."
    Respond with "No" when the clinical note is  "A/P: Patient with a hematological hx of MGUS as well as ASCVD. Etiologies of the MGUS include multiple myeloma (most likely), Amyloid and Lymphoma."
    
    Here is the clinical note: {document}
    """
    
    def llm_VAmodel(user_query):
        
        messages = [
            {"role": "system", "content": "You are an AI assistant."},
        ]
    
        messages.extend([{"role":"user", "content":user_query}])
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=5, 
            temperature=0.001
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
    
        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        del model_inputs 
        del generated_ids
        gc.collect()
        torch.cuda.empty_cache()
        return response
    
    df = pd.read_json(readfile_name) 
    Npatient = len(df)
    # df2 = df.sample(frac=1, random_state=85).reset_index(drop=False).head(Npatient)
    df2 = df #df.head(Npatient)
    
    resulttext = []
    Tstart = time.time() 
    for i in range(Npatient):
        x = df2.reportText.to_list()[i]
        user_query = prompt_template.format(document=x)
        answer = llm_VAmodel(user_query)
        resulttext.append(answer)
        print('Note '+str(i)+' done!')
    Tend = time.time() 
    print('==== %s second =====' %(Tend-Tstart)) #('Duration:{}'.format(Tend-Tstart))
    
    neg_words = ['no','not','cannot']#,'not','non','negative','suspect','might','likely']
    binaryresult = []
    for textid in range(len(resulttext)):
        if any(word in resulttext[textid][0:2].lower() for word in neg_words): #any(word in resulttext[textid].lower().split() for word in neg_words):
            binaryresult.append(0)
        else:
            binaryresult.append(1)
    
    dfsavefile = pd.concat([df2.PatientSSN, df2.EntryDate, pd.DataFrame({'Output':resulttext}), pd.DataFrame({'Label':binaryresult})], axis=1)
    dfsavefile.to_csv(savefile_name)
    # print(dfsavefile)

CPU times: total: 12.8 s
Wall time: 14.1 s


In [None]:
model_name = ".\Llama-3.1-8B-Instruct"
readfile_name = r".\testingnotes_final.json" 
savefile_name = ".\llama8b\Llama-8B-fewshot-MM-final.csv"

mm_fewshot(model_name, readfile_name, savefile_name)