In [None]:
import numpy as np
import pandas as pd

from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import re

from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import argparse
import os
import json


from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from accelerate.utils.modeling import get_balanced_memory
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model

import string
import sys
import gc
import inspect

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda")

In [None]:
q_config = {
    "zero_point": True,  # by default True
    "q_group_size": 128,  # whether to use group quantization
}
max_memory = []

In [None]:
def build_model_and_enc(model_path, quantized_file_path, load_quant = True, w_bit = 4):
    if not os.path.exists(model_path):  # look into ssd
        raise FileNotFoundError(f"{model_path} not found!")
    print(f"* Building model {model_path}")

    # all hf model
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    print(f"Config = {config}")
    if "mpt" in config.__class__.__name__.lower():
        enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
    else:
        enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

    if load_quant:  # directly load quantized weights
        print("Loading pre-computed quantized weights...")
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config=config,
                                                     torch_dtype=torch.float16, trust_remote_code=True)
        model.config.pretraining_tp = 1
        real_quantize_model_weight(
            model, w_bit=w_bit, q_config=q_config, init_only=True)
        
        model.tie_weights()
        
        # Infer device map
        kwargs = {"max_memory": max_memory} if len(max_memory) else {}
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
            **kwargs
        )
        # Load checkpoint in the model
        load_checkpoint_in_model(
            model,
            checkpoint= quantized_file_path,
            device_map=device_map,
            offload_state_dict=True,
        )
        # Dispatch model
        model = simple_dispatch_model(model, device_map=device_map)

        model.eval()
    else:  # fp16 to quantized
        args.run_awq &= not args.load_awq  # if load_awq, no need to run awq
        # Init model on CPU:
        kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
        model = AutoModelForCausalLM.from_pretrained(
            model_path, config=config, trust_remote_code=True, **kwargs)

        model.eval()

        if args.run_awq:
            assert args.dump_awq, "Please save the awq results with --dump_awq"
                        
            awq_results = run_awq(
                model, enc,
                w_bit=args.w_bit, q_config=q_config,
                n_samples=128, seqlen=512,
            )
            if args.dump_awq:
                dirpath = os.path.dirname(args.dump_awq)
                os.makedirs(dirpath, exist_ok=True)
                
                torch.save(awq_results, args.dump_awq)
                print("AWQ results saved at", args.dump_awq)
                
            exit(0)
                
        if args.load_awq:
            print("Loading pre-computed AWQ results from", args.load_awq)
            awq_results = torch.load(args.load_awq, map_location="cpu")
            apply_awq(model, awq_results)

        # weight quantization
        if args.w_bit is not None:
            if args.q_backend == "fake":
                assert args.dump_quant is None, \
                    "Need to use real quantization to dump quantized weights"
                pseudo_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
            elif args.q_backend == "real":  # real quantization
                real_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
                if args.dump_quant:
                    dirpath = os.path.dirname(args.dump_quant)
                    os.makedirs(dirpath, exist_ok=True)
                    
                    print(
                        f"Saving the quantized model at {args.dump_quant}...")
                    torch.save(model.cpu().state_dict(), args.dump_quant)
                    exit(0)
            else:
                raise NotImplementedError
            
        # Move the model to GPU (as much as possible) for LM evaluation
        kwargs = {"max_memory": get_balanced_memory(model, max_memory if len(max_memory) > 0 else None)}
        device_map = infer_auto_device_map(
            model,
            # TODO: can we remove this?
            no_split_module_classes=[
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
            **kwargs
        )
        model = dispatch_model(model, device_map=device_map)

    return model, enc

In [None]:
def generate_question_list(question_text_file_path):
    # Open the file in read mode
    with open(question_text_file_path, 'r') as file:
        # Read the content of the file
        text = file.read()
    question_list = text.split("\n\n")
    question_list = [input_string.split('. ', 1)[-1] for input_string in question_list]
    return question_list

In [None]:
question_header_list = ["the patient’s marital status information is:",
                "the patient's children information is:",
                "the patient's smoking or tobacco usage information is:",
                "the patient's drinking or alcohol usage information is:",
                "the patient's drug usage information is:",
                "the patient's living information is:",
                "the patient's employment information is:",
                "the patient's education information is:",
                "the patient's exercising information is:"]

In [None]:
def build_default_prompt(text, question):
    header = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: [PROMPT]'''
    prompt = '''Given the following text about a patient's social documentation.
                    [TEXT]
                    Please answer the question: [QUESTION]
                    ASSISTANT: '''
    result = header.replace("[PROMPT]", prompt)
    result = result.replace("[TEXT]", text)
    result = result.replace("[QUESTION]", question)
    return result

In [None]:
def build_Q6_prompt(text, question):
    header = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: [PROMPT]'''
    prompt = '''Given the following text about a patient's social documentation.
                    [TEXT]
                    Please answer the question: [QUESTION]
                    Please DO NOT infer the answer according to the text. ONLY give your answers based on the ACTUAL text.
                    Your answer should be "Yes" IF AND ONLY IF the text EXPLICITLY mentions that the patient lives alone.
                    Your answer should be "No" IF AND ONLY IF the text EXPLICITLY mentions that the patient lives with someone.
                    Otherwise, your answer MUST be "Not mentioned".
                    ASSISTANT: '''
    result = header.replace("[PROMPT]", prompt)
    result = result.replace("[TEXT]", text)
    result = result.replace("[QUESTION]", question)
    return result

In [None]:
def build_Q7_prompt(text, question):
    header = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: [PROMPT]'''
    prompt = '''Given the following text about a patient's social documentation.
                    [TEXT]
                    Please answer the question: [QUESTION]
                    Your answer should be "Employed" if the text mentions that the patient has a full-time or part-time job (even if the patient retired or was jobless before).
                    Your answer should be "Jobless" if the text mentions that the patient is not working, is a homemaker, or is on disability.
                    Your answer should be "Retired" if the text mentions that the patient is retired, is a former employee or used to be an employee.
                    ASSISTANT: '''
    result = header.replace("[PROMPT]", prompt)
    result = result.replace("[TEXT]", text)
    result = result.replace("[QUESTION]", question)
    return result

In [None]:
def build_second_Q7_prompt(text):
    header = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: [PROMPT]'''
    prompt = '''Given the following text about a patient's social documentation.
                    [TEXT]
                    Please answer the question: Does the text mention the patient's occupation?
                    Your answer must be one of the following:
                    - Yes
                    - No
                    ASSISTANT: '''
    result = header.replace("[PROMPT]", prompt)
    result = result.replace("[TEXT]", text)
    return result

In [None]:
def build_few_shot_learning_prompt(text, question, few_shot_learning_prompt_path):
    header = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: [PROMPT]'''
    with open(few_shot_learning_prompt_path, 'r') as file:
        # Read the content of the file
        prompt = file.read()
    result = header.replace("[PROMPT]", prompt)
    result = result.replace("[TEXT]", text)
    result = result.replace("[QUESTION]", question)
    return result

In [None]:
def build_default_secondary_prompt(text, question, question_id):
    header = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: [PROMPT]'''
    prompt = '''Suppose that [HEADER] [TEXT].
                [QUESTION]
                ASSISTANT: '''
    result = header.replace("[PROMPT]", prompt)
    result = result.replace("[HEADER]", question_header_list[question_id])
    result = result.replace("[TEXT]", text)
    result = result.replace("[QUESTION]", question)
    return result

In [None]:
def build_Q7_secondary_prompt(text, question):
    header = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: [PROMPT]'''
    prompt = '''Given the following text:
                [TEXT]
                [QUESTION]
                ASSISTANT: '''
    result = header.replace("[PROMPT]", prompt)
    result = result.replace("[TEXT]", text)
    result = result.replace("[QUESTION]", question)
    return result

In [None]:
def get_model(model_name):
    supported_models = ["openchat_3.5", 
                          "zephyr-7b-beta", 
                          "vicuna-7b-v1.5", 
                          "Llama-2-7b-chat-hf", 
                          "vicuna-13b-v1.5", 
                          "WizardLM-13B-V1.2", 
                          "Llama-2-13b-chat-hf", 
                        "vicuna-33b-v1.3", 
                        "WizardLM-70B-V1.0", 
                        "Llama-2-70b-chat-hf"]
    if (model_name.split("-w")[0] in supported_models):
        if ("-w" in model_name): # This is a quantized model
            model_path = f"../.cache/huggingface/transformers/{model_name.split('-w')[0]}"
            quantized_file_path = f"quant_cache/{model_name}.pt"
            model, tokenizer = build_model_and_enc(model_path, quantized_file_path = quantized_file_path)
        else:
            tokenizer = AutoTokenizer.from_pretrained(f"../.cache/huggingface/transformers/{model_name}")
            model = AutoModelForCausalLM.from_pretrained(f"../.cache/huggingface/transformers/{model_name}")
    else:
        raise NotImplementedError(f"[ERROR]: Model {model_name} not supported")
    return model, tokenizer

In [None]:
def baseline(text, question_id):

    def Q1_pattern_matching(text):
        """
        Determines the patient's marital status based on the given text.
        """
        if re.search(r'\b(widowed)\b', text, re.IGNORECASE):
            return 'Widowed'
        elif re.search(r'\b(divorced|separated)\b', text, re.IGNORECASE):
            return 'Divorced'
        elif re.search(r'\b(single|no spouse|partner|boyfriend|girlfriend|long-term partner)\b', text, re.IGNORECASE):
            return 'Single'
        elif re.search(r'\b(married)\b', text, re.IGNORECASE):
            return 'Married'
        else:
            return 'Not mentioned'

    def Q2_pattern_matching(text):
        """
        Determines the number of children the patient has based on the given text.
        """
        if re.search(r'\b(no children|0 children|zero children|no child|0 child|zero child)\b', text, re.IGNORECASE):
            return '0'
        elif re.search(r'\b(1 child|1 son|1 daughter|one child|one son|one daughter)\b', text, re.IGNORECASE):
            return '1'
        elif re.search(r'\b(2 children|2 sons|2 daughters|two children|two sons|two daughters)\b', text, re.IGNORECASE):
            return '2'
        elif re.search(r'\b(3 children|3 sons|3 daughters|three children|three sons|three daughters)\b', text, re.IGNORECASE):
            return '3'
        elif re.search(r'\b(4 children|4 sons|4 daughters|four children|four sons|four daughters)\b', text, re.IGNORECASE):
            return '4'
        elif re.search(r'\b(5 children|6 children|7 children|8 children|9 children|10 children|five children|six children|seven children|eight children|nine children|ten children)\b', text, re.IGNORECASE):
            return '5 or more'
        else:
            return 'Not mentioned'

    def Q3_pattern_matching(text):
        """
        Determines if the patient currently uses tobacco based on the given text.
        """
        if re.search(r'\b(never used tobacco|quit smoking|quit tobacco|quit smoke|past tobacco use|never smoke|no smoke)\b', text, re.IGNORECASE):
            return 'No'
        elif re.search(r'\b(currently smokes|cigarettes|cigars|smokeless tobacco|tobacco use|smoke)\b', text, re.IGNORECASE):
            return 'Yes'
        else:
            return 'Not mentioned'

    def Q4_pattern_matching(text):
        """
        Determines if the patient currently consumes alcohol based on the given text.
        """
        if re.search(r'\b(never consumed alcohol|sober|no ETOH)\b', text, re.IGNORECASE):
            return 'No'
        elif re.search(r'\b(consumes alcohol|drinks alcohol|ETOH)\b', text, re.IGNORECASE):
            return 'Yes'
        else:
            return 'Not mentioned'

    def Q5_pattern_matching(text):
        """
        Determines if the patient currently uses illicit drugs based on the given text.
        """
        if re.search(r'\b(never used drugs|sober from drugs|past drug use|deny illicit drugs)\b', text, re.IGNORECASE):
            return 'No'
        elif re.search(r'\b(illicit drugs|uses drugs|cocaine|marijuana|substance abuse)\b', text, re.IGNORECASE):
            return 'Yes'
        else:
            return 'Not mentioned'

    def Q6_pattern_matching(text):
        """
        Determines if the patient lives alone based on the given text.
        """
        if re.search(r'\b(alone)\b', text, re.IGNORECASE):
            return 'Yes'
        elif re.search(r'\b(with husband|with wife|with child|with children|with son|with daughter|with boyfriend|with girlfriend|with friend|with grandparents|with grandmother|with grandfather|with uncle|with aunt|with parents|with father|with mother|with dad|with mom)\b', text, re.IGNORECASE):
            return 'No'
        else:
            return 'Not mentioned'

    def Q7_pattern_matching(text):
        """
        Determines the patient's employment status based on the given text.
        """
        if re.search(r'\b(full-time|part-time|employed|work as|employ)\b', text, re.IGNORECASE):
            return 'Employed'
        elif re.search(r'\b(stay-at-home|at home|unemployed|on disability|homemaker|not work|no work|jobless)\b', text, re.IGNORECASE):
            return 'Jobless'
        elif re.search(r'\b(retire|former employee|worked|used to work)\b', text, re.IGNORECASE):
            return 'Retired'
        else:
            return 'Not mentioned'

    def Q8_pattern_matching(text):
        """
        Determines the patient's highest education level based on the given text.
        """
        if re.search(r'\b(elementary school|1st grade|2nd grade|3rd grade|4th grade|5th grade)\b', text, re.IGNORECASE):
            return 'Elementary school'
        elif re.search(r'\b(middle school|6th grade|7th grade|8th grade)\b', text, re.IGNORECASE):
            return 'Middle school'
        elif re.search(r'\b(high school|9th grade|10th grade|11th grade)\b', text, re.IGNORECASE):
            return 'High school'
        elif re.search(r'\b(college|Associates|Bachelors|BS|BA|freshman|sophomore|junior|senior|post-bacc)\b', text, re.IGNORECASE):
            return 'College'
        elif re.search(r'\b(graduate school|grad school|Masters degree|MS|MBA|MPH|MENG|Doctoral degree|PHD|MD)\b', text, re.IGNORECASE):
            return 'Graduate school'
        else:
            return 'Not mentioned'

    def Q9_pattern_matching(text):
        """
        Determines if the patient exercises based on the given text.
        """
        if re.search(r'\b(does not exercise|never exercise|not exercising)\b', text, re.IGNORECASE):
            return 'No'
        elif re.search(r'\b(used to exercise|used to walk)\b', text, re.IGNORECASE):
            return 'In the past'
        elif re.search(r'\b(exercise|walk|run|dance|gym|treadmill|yoga)\b', text, re.IGNORECASE):
            return 'Yes'
        else:
            return 'Not mentioned'

    if (question_id == 1):
        return Q1_pattern_matching(text)
    elif (question_id == 2):
        return Q2_pattern_matching(text)
    elif (question_id == 3):
        return Q3_pattern_matching(text)
    elif (question_id == 4):
        return Q4_pattern_matching(text)
    elif (question_id == 5):
        return Q5_pattern_matching(text)
    elif (question_id == 6):
        return Q6_pattern_matching(text)
    elif (question_id == 7):
        return Q7_pattern_matching(text)
    elif (question_id == 8):
        return Q8_pattern_matching(text)
    elif (question_id == 9):
        return Q9_pattern_matching(text)
    else:
        raise ValueError(f"[ERROR]: question_id {question_id} is invalid. question_id must range from 1 to 9.")

In [None]:
def LLM_text(model, tokenizer, prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    generate_ids = model.generate(inputs.input_ids, max_length=2048)
    question_and_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    response = question_and_response.split("ASSISTANT: ")[-1]
    return response

In [None]:
def generate_choices(question):
    choices = question.split("\n")[2:]
    choices = [choice.replace("- ", "") for choice in choices]
    return choices

In [None]:
def remove_punctuation(s):
    return ''.join(c if c not in string.punctuation else ' ' for c in s)

In [None]:
def match_response(response, choices_list):
    match = []
    for choice in choices_list:
        if (len(choice.split(" ")) == 1): # Choice is a single word
            if choice.lower() in remove_punctuation(response).lower().replace("\n", " ").split(" "):
                match.append(choice)
        else: # Choice has multiple words
            if choice.lower() in remove_punctuation(response).lower().replace("\n", " "):
                match.append(choice)
    return match

In [None]:
def refine_and_check(response, choices_list, question_list, question_id, info_header, level, bumpy_question_threshold):
    print(f"[INFO] (level {level}) {info_header}: Response before refinement: {response}")
    # Check if response is empty, contains only whitespaces, or any of the specified phrases or contains no numbers or English characters
    no_numbers_or_english_pattern = re.compile('^[^0-9A-Za-z]*$')
    if not response.strip() or re.match(no_numbers_or_english_pattern, response) or any(substring in response.lower() for substring in ["not mentioned", "unknown", "not available", "unavailable", "not provided", "none", "not include", "does not provide any information", "cannot determine"]):
        print(f"[INFO] (level {level}) {info_header}: Refinement SUCCESS! Response before refinement: Not mentioned")
        return "Not mentioned"

    # Regular expression patterns to match Chinese, Japanese, and Korean characters
    chinese_pattern = re.compile('[\u4e00-\u9fff]+')
    japanese_pattern = re.compile('[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF]+')
    korean_pattern = re.compile('[\uAC00-\uD7AF]+')

    # Search for Chinese, Japanese, or Korean characters in the string
    if re.search(chinese_pattern, response) or re.search(japanese_pattern, response) or re.search(korean_pattern, response):
        print(f"[INFO] (level {level}) {info_header}: Refinement SUCCESS! Response before refinement: Not mentioned")
        return "Not mentioned"
    
    if (len(response.split("\n")) > 1): # At least for openchat-3.5, most of its answers are on the first line
        match = match_response(response.split("\n")[0], choices_list)
        if (len(match) != 1):
            match = match_response(response, choices_list)
    else:
        match = match_response(response, choices_list)
        
    if (len(match) > 1):
        print(f"[INFO] (level {level}) {info_header}: Mutiple matches found! match = {match}")
    elif (len(match) == 1):
        print(f"[INFO] (level {level}) {info_header}: Single match found! match = {match}")
    else:
        print(f"[INFO] (level {level}) {info_header}: No match found!")
    
    if len(match) == 1:
        print(f"[INFO] (level {level}) {info_header}: Refinement SUCCESS! Response after refinement: {match[0]}")
        return match[0]
    else:
        if (level == 0):
            print(f"[INFO] (level {level}) {info_header}: Cannot resolve this invalid case using regular expression refinement. Retry with secondary prompt.")
            if (question_id + 1 == 7 and len(question_list) == 1):
                secondary_prompt = build_Q7_secondary_prompt(response, question_list[0])
            else:
                secondary_prompt = build_default_secondary_prompt(response, question_list[question_id], question_id)
            print(f"[INFO] (level {level}) {info_header}: The secondary prompt is\n{secondary_prompt}")
            question_start_time = time.time()
            response = LLM_text(model, tokenizer, secondary_prompt)
            question_end_time = time.time()
            question_time = question_end_time - question_start_time
            print(f"[INFO] (level {level}) {info_header}: Processing this question took {question_time} seconds")
            if (bumpy_question_threshold > 500):
                print(f"[INFO] (level {level}) {info_header}: This is a bumpy question, which takes {question_time} seconds to finish (Bumpy question threshold = {bumpy_question_threshold})")
            response = refine_and_check(response, choices_list, question_list, question_id, info_header, level + 1, bumpy_question_threshold)
            if ("Invalid response:" in response):
                print(f"[INFO] (level {level}) {info_header}: Refinement FAIL! Response after refinement: {response}")
            else:
                print(f"[INFO] (level {level}) {info_header}: Refinement SUCCESS! Response after refinement: {response}")
            return response
        else:
            response = "Invalid response: " + response
            print(f"[INFO] (level {level}) {info_header}: Refinement FAIL! Response after refinement: {response}")
            return response

In [None]:
def LLM_pipeline(model_name, question_text_file_path, COT_question_text_file_path, SDoH_file_path, first_n, specific_patients, few_shot_learning_question_list, few_shot_learning_prompt_path_list, specific_COT_questions_list, specific_questions_list, bumpy_question_threshold, test_set):
    if (test_set):
        sys.stdout = open(f'log ({model_name}) (Test).txt', 'a')
    else:
        sys.stdout = open(f'log ({model_name}).txt', 'a')
    print("****************************** LOG START ******************************")
    print(f"[INFO]: Pipeline input parameters:\nmodel_name = {model_name},\nquestion_text_file_path = {question_text_file_path},\nCOT_question_text_file_path = {COT_question_text_file_path},\nSDoH_file_path = {SDoH_file_path},\nfirst_n = {first_n},\nspecific_patients = {specific_patients},\nfew_shot_learning_question_list = {few_shot_learning_question_list},\nfew_shot_learning_prompt_path_list = {few_shot_learning_prompt_path_list},\nspecific_COT_questions_list = {specific_COT_questions_list},\nspecific_questions_list = {specific_questions_list},\nbumpy_question_threshold = {bumpy_question_threshold},\ntest_set = {test_set}")
    if (few_shot_learning_question_list is None and few_shot_learning_prompt_path_list is not None) or \
       (few_shot_learning_question_list is not None and few_shot_learning_prompt_path_list is None):
        raise ValueError("[ERROR]: Few_shot_learning_question_list and few_shot_learning_prompt_path_list must be both None or both not None")
        
    if (few_shot_learning_question_list is not None and few_shot_learning_prompt_path_list is not None):
        if (type(few_shot_learning_question_list) != list) or (type(few_shot_learning_prompt_path_list) != list):
            raise ValueError("[ERROR]: Few_shot_learning_questions and few_shot_learning_prompt_path_list must be lists")

        if len(few_shot_learning_question_list) != len(few_shot_learning_prompt_path_list):
            raise ValueError("[ERROR]: Few_shot_learning_questions and few_shot_learning_prompt_path_list must have the same length")
        
    if specific_questions_list is not None and type(specific_questions_list) != list:
        raise ValueError("[ERROR]: Specific_questions_list must be a list")

    if (first_n is None and specific_patients is None):
        raise ValueError("[WARNING]: The current setup will run ALL patients in the data set, which will take a significant amount of time for this program to finish. To reduce the number of patients, please set 'first_n' or 'specific_patients'")
    
    if (model_name == "baseline" and specific_COT_questions_list is not None):
        specific_COT_questions_list = None
        print("[WARNING]: The baseline does not support COT questions. Only LLMs have COT questions. Changing specific_COT_questions_list to None.")
        
    if (model_name == "baseline" and few_shot_learning_question_list is not None):
        few_shot_learning_question_list = None
        print("[WARNING]: The baseline does not support few-shot learning questions. Only LLMs have few-shot learning questions. Changing few_shot_learning_question_list to None.")
        
    print(f"[INFO]: Loading model {model_name}")
    #model, tokenizer = get_model(model_name)
    # The "result_df" is for autograder, the "time_df" is for computation resource calculation
    result_df = pd.DataFrame()
    time_df = pd.DataFrame()
    data = pd.read_csv(SDoH_file_path)
    if (specific_patients is not None and type(specific_patients) == list):
        data = data[data["PATIENT_NUM"].isin(specific_patients)]
        print(f"[INFO]: Found {len(data)} out of {len(specific_patients)} patients")
    if first_n is not None:
        if (first_n <= len(data)):
            data = data.head(first_n)
            print(f"[INFO] Processing the first {first_n} patients")
        else:
            print(f"[WARNING] Try to process the first {first_n} patients, but there are only {len(data)} patients in the current data set. Will process all {len(data)} patients instead.")

    MRN_list = data["PATIENT_NUM"].tolist()
    SDoH_text_list = data["OBSERVATION_BLOB"].tolist()
    update_date_list = data["UPDATE_DATE"].tolist()
    num_patients = len(data)
    question_list = generate_question_list(question_text_file_path)
    COT_question_list = generate_question_list(COT_question_text_file_path)
    if specific_questions_list is not None:
        num_questions = len(specific_questions_list)
    else:
        num_questions = len(question_list)
    question_display_list_detailed = [f"Q{i + 1}. {question}" for i, question in enumerate(question_list)] # If you want to show the choices, use this line of code
    simplified_question_list = [question.split('\n')[0] for question in question_list]
    COT_simplified_question_list = [question.split('\n')[0] for question in COT_question_list]
    question_display_list = [f"Q{i + 1}. {question}" for i, question in enumerate(simplified_question_list)] # If you only want the question, use this line of code
    COT_question_display_list = [f"Q{i + 1}. {question}" for i, question in enumerate(COT_simplified_question_list)]
    result_df[model_name] = ["Text"] + ([question_display_list_detailed[i - 1] for i in specific_questions_list] if specific_questions_list is not None else question_display_list_detailed)
    time_df[model_name] = ([question_display_list_detailed[i - 1] for i in specific_questions_list] if specific_questions_list is not None else question_display_list_detailed)
    for i, (MRN, text) in enumerate(zip(MRN_list, SDoH_text_list)):
        patient_start_time = time.time()
        print(f"[INFO]: Processing patient {i + 1} out of {num_patients} (MRN: {MRN})")
        responses = []
        time_usage = []
        counter = 0
        for j, (question, COT_question) in enumerate(zip(question_list, COT_question_list)):
            if (model_name != "baseline" and j in [1, 6]):
                refined_text = text.split("  Socioeconomic")[0] # The part after "  Socioeconomic  Occupational" will confuse the LLM on questions 2 and 7
                if (j == 6 and "  Socioeconomic" in text):
                    refined_text += text.split("  Socioeconomic", 1)[-1].replace("Occupational", "").replace("Occupation", "")
            else:
                refined_text = text
            if specific_questions_list is not None and (j + 1) not in specific_questions_list:
                continue
            question_start_time = time.time()
            print(f"[INFO]: Processing question {counter + 1} out of {num_questions} (Prompt: {'Few shot learning' if few_shot_learning_question_list is not None and (j + 1) in few_shot_learning_question_list else 'Default'}) (Question: {simplified_question_list[j]})")
            info_header = f"(Patient {i + 1} (MRN: {MRN}) Question {counter + 1} (Question: {question_display_list[j]}) (Prompt: {'Few shot learning' if few_shot_learning_question_list is not None and (j + 1) in few_shot_learning_question_list else 'Default'}))"
            COT_info_header = f"(Patient {i + 1} (MRN: {MRN}) Question {counter + 1} (Question: {COT_question_display_list[j]}) (Prompt: Default))"
            if (specific_COT_questions_list is not None and (j + 1) in specific_COT_questions_list):
                print(f"[INFO]: Question {counter + 1} is a COT question. Running the COT question.")
                COT_start_time = time.time()
                COT_prompt = build_default_prompt(refined_text, COT_question)
                print(f"[INFO] {COT_info_header}: The COT prompt is\n{COT_prompt}")
                COT_response = LLM_text(model, tokenizer, COT_prompt)
                COT_choices_list = generate_choices(COT_question)
                COT_refined_response = refine_and_check(COT_response, COT_choices_list, COT_question_list, j, COT_info_header, 0, bumpy_question_threshold)
                print(f"[INFO] {COT_info_header}: The refined model response for the COT question is {COT_refined_response}")
                COT_end_time = time.time()
                COT_time = COT_end_time - COT_start_time
                print(f"[INFO] {COT_info_header}: Running the COT question took {COT_time} seconds")
                if (COT_time > bumpy_question_threshold):
                    print(f"[INFO] {COT_info_header}: This is a bumpy question, which takes {COT_time} seconds to finish (Bumpy question threshold = {bumpy_question_threshold})")
                if (COT_refined_response == "No"):
                    print(f"[INFO] {COT_info_header}: Answer determined from the COT question to be 'Not mentioned'. No need to run the actual question")
                    refined_response = "Not mentioned"
                else:
                    print(f"[INFO] {COT_info_header}: Answer cannot be determined from the COT question. Running the actual question")
                    if (few_shot_learning_question_list is not None and (j + 1) in few_shot_learning_question_list):
                        prompt = build_few_shot_learning_prompt(refined_text, question, few_shot_learning_prompt_path_list[few_shot_learning_question_list.index(j + 1)])
                        print(f"[INFO] {COT_info_header}: The few shot learning prompt is\n{prompt}")
                    else:
                        if (j + 1 == 6):
                            prompt = build_Q6_prompt(refined_text, question)
                        elif (j + 1 == 7):
                            prompt = build_Q7_prompt(refined_text, question)
                        else:
                            prompt = build_default_prompt(refined_text, question)
                        print(f"[INFO] {COT_info_header}: The default prompt is\n{prompt}")

                    #print(f"The prompt for Q{(j + 1)} is {prompt}")
                    response = LLM_text(model, tokenizer, prompt)
                    choices_list = generate_choices(question)
                    refined_response = refine_and_check(response, choices_list, question_list, j, info_header, 0, bumpy_question_threshold)
            else:
                if (few_shot_learning_question_list is not None and (j + 1) in few_shot_learning_question_list):
                    prompt = build_few_shot_learning_prompt(refined_text, question, few_shot_learning_prompt_path_list[few_shot_learning_question_list.index(j + 1)])
                    print(f"[INFO] {info_header}: The few shot learning prompt is\n{prompt}")
                else:
                    if (model_name != "baseline"):
                        if (j + 1 == 6):
                            prompt = build_Q6_prompt(refined_text, question)
                        elif (j + 1 == 7):
                            prompt = build_Q7_prompt(refined_text, question)
                        else:
                            prompt = build_default_prompt(refined_text, question)
                        print(f"[INFO] {info_header}: The default prompt is\n{prompt}")

                #print(f"The prompt for Q{(j + 1)} is {prompt}")
                if (model_name != "baseline"):
                    response = LLM_text(model, tokenizer, prompt)
                    choices_list = generate_choices(question)
                    refined_response = refine_and_check(response, choices_list, question_list, j, info_header, 0, bumpy_question_threshold)
                else:
                    refined_response = baseline(refined_text, j + 1)
            if (model_name != "baseline"):
                print(f"[INFO] {info_header}: The refined model response is {refined_response}")
            else:
                print(f"[INFO] (Baseline): The baseline extraction result is {refined_response}")
            if (model_name != "baseline" and j + 1 == 7 and refined_response == "Not mentioned"):
                print(f"[INFO]: Running the second prompt on Q7.")
                second_prompt = build_second_Q7_prompt(refined_text)
                print(f"[INFO] {info_header}: The second Q7 prompt is\n{second_prompt}")
                second_response = LLM_text(model, tokenizer, second_prompt)
                second_choices_list = ["Yes", "No"]
                second_question_list = ['''Does the text mention the patient's occupation?
                    Your answer must be one of the following:
                    - Yes
                    - No''']
                second_refined_response = refine_and_check(second_response, second_choices_list, second_question_list, j, info_header, 0, bumpy_question_threshold)
                if (second_refined_response == "Yes"):
                    refined_response = "Employed"
            responses.append(refined_response)
            question_end_time = time.time()
            question_time = question_end_time - question_start_time
            print(f"[INFO]: Processing question {counter + 1} took {question_time} seconds")
            time_usage.append(question_time)
            if (question_time > bumpy_question_threshold):
                print(f"[INFO] {info_header}: This is a bumpy question, which takes {question_time} seconds to finish (Bumpy question threshold = {bumpy_question_threshold})")
            counter += 1
        result_df[MRN] = [text] + responses
        time_df[MRN] = time_usage
        patient_end_time = time.time()
        print(f"[INFO]: Processing patient {i + 1} took {patient_end_time - patient_start_time} seconds")
    if (test_set):
        result_df.to_csv(f"SDoH extraction performance ({model_name}) (Test).csv", index = False)
    else:
        result_df.to_csv(f"SDoH extraction performance ({model_name}).csv", index = False)
    # The "research_df" is for downstream tasks (count number of cases in each SDoH)
    research_df = result_df
    research_df.set_index(model_name, inplace=True)
    research_df = result_df.T
    research_df.reset_index(inplace=True)
    research_df.insert(2, "Date", update_date_list)
    if (test_set):
        research_df.to_csv(f"SDoH extraction result ({model_name}) (Test).csv", index = False)
        time_df.to_csv(f"SDoH extraction time ({model_name}) (Test).csv", index = False)
    else:
        research_df.to_csv(f"SDoH extraction result ({model_name}).csv", index = False)
        time_df.to_csv(f"SDoH extraction time ({model_name}).csv", index = False)
    print("****************************** LOG END ******************************\n")
    sys.stdout.close()
    sys.stdout = sys.__stdout__

In [None]:
# Supported model names (All non-quantized models can ONLY be run using CPUs. All quantized models MUST be run using GPUs):
# Unquantized models (FP32)
# openchat_3.5
# zephyr-7b-beta
# vicuna-7b-v1.5
# Llama-2-7b-chat-hf
# vicuna-13b-v1.5
# WizardLM-13B-V1.2
# Llama-2-13b-chat-hf
# vicuna-33b-v1.3
# WizardLM-70B-V1.0
# Llama-2-70b-chat-hf

# Quantized models (INT4, GEMV, AWQ, 4-bit, group size 128)
# openchat_3.5-w4-g128
# zephyr-7b-beta-w4-g128
# vicuna-7b-v1.5-w4-g128
# Llama-2-7b-chat-hf-w4-g128
# vicuna-13b-v1.5-w4-g128
# WizardLM-13B-V1.2-w4-g128
# Llama-2-13b-chat-hf-w4-g128
# vicuna-33b-v1.3-w4-g128
# WizardLM-70B-V1.0-w3-g128
# Llama-2-70b-chat-hf-w3-g128

# Baseline model (Regular expression pattern matching)
# baseline

question_text_file_path = "./LLM Questions.txt"
COT_question_text_file_path = "./LLM Questions (COT).txt"
SDoH_file_path = "./observation_fact_notes (Social Documentation) (Unique).csv"
first_n = None
data = pd.read_csv(SDoH_file_path)
test_set = True
if test_set:
    specific_patients = data["PATIENT_NUM"].iloc[100:200].tolist()
else:
    specific_patients = data["PATIENT_NUM"].iloc[0:100].tolist()
few_shot_learning_question_list = None
few_shot_learning_prompt_path_list = None
specific_COT_questions_list = [2]
specific_questions_list = None
bumpy_question_threshold = 500
model_name_list = ["Llama-2-70b-chat-hf"]
#model_name_list = ["WizardLM-13B-V1.2"]
for model_name in model_name_list:
    if (model_name != "baseline"):
        model, tokenizer = get_model(model_name)
        if ("-w4-g128" in model_name):
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        model.to(device)
        LLM_pipeline(model_name, question_text_file_path, COT_question_text_file_path, SDoH_file_path, first_n, specific_patients, few_shot_learning_question_list, few_shot_learning_prompt_path_list, specific_COT_questions_list, specific_questions_list, bumpy_question_threshold, test_set)
        # Unload the model from GPU and RAM
        del model          # Delete the model
        gc.collect()       # Collect garbage
        if device == 'cuda':
            torch.cuda.empty_cache()  # Clear CUDA cache
    else:
        LLM_pipeline(model_name, question_text_file_path, COT_question_text_file_path, SDoH_file_path, first_n, specific_patients, few_shot_learning_question_list, few_shot_learning_prompt_path_list, specific_COT_questions_list, specific_questions_list, bumpy_question_threshold, test_set)

In [None]:
check = pd.read_csv(f"SDoH extraction performance ({model_name}) (Test).csv")
check