# Debating medAlpaca

## Clone repo

In [1]:
%%capture
!git clone https://github.com/kbressem/medAlpaca.git
%cd medAlpaca
!pip install -r requirements.txt

# The latest package versions (installed from requirements) breaks things.
# Install package versions that were available when the repo was developed.
!pip install peft==0.2.0
!pip install transformers==4.28.0

In [2]:
# Clear GPU memory
import torch
import gc
gc.collect()
torch.cuda.empty_cache()
# del model
!nvidia-smi

Wed Jun 14 09:48:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P8    12W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

Force kill python processes on gpu

# Debating helper functions

In [12]:
import re

def construct_message(agents, question):

    # Use introspection in the case in which there are no other agents.
    if len(agents) == 0:
        return "Can you verify that your answer is correct. Please reiterate your answer."

    prefix_string = "These are the recent/updated opinions from other experts: "

    for i, agent_response in enumerate(agents):
        pure_agent_response = remove_question(agent_response[-1], question)
        response = "\n\n Another expert's response: ```{}```".format(pure_agent_response)

        prefix_string += response

    prefix_string += "\n\n Use these opinions carefully as additional advice, can you provide an updated answer?"
    return prefix_string

def parse_answer(sentence):
    parts = sentence.split(" ")

    for part in parts:
        if re.match(r"\([A-Z]\)", part):
          answer = part[1]
          return answer
        else:
          continue

def remove_question(string, question):
    pattern = f'(?=({re.escape(question)}))'
    matches = re.findall(pattern, string)

    for match in matches:
        string = string.replace(match, '', 1)

    return string


def most_frequent(List):
    counter = 0
    num = List[0]

    for i in List:
        current_frequency = List.count(i)
        if current_frequency > counter:
            counter = current_frequency
            num = i

    return num

### Eval helper functions

In [18]:
import string

def strip_special_chars(input_str):
    "Remove special characters from string start/end"
    if not input_str:
        return input_str

    start_index = 0
    end_index = len(input_str) - 1

    while start_index < len(input_str) and input_str[start_index] not in string.ascii_letters + string.digits:
        start_index += 1

    while end_index >= 0 and input_str[end_index] not in string.ascii_letters + string.digits:
        end_index -= 1

    if start_index <= end_index:
        return input_str[start_index:end_index + 1]
    else:
        return ""

def starts_with_capital_letter(input_str):
    """
    The answers should start like this:
        'A: '
        'A. '
        'A '
    """
    pattern = r'^[A-Z](:|\.|) .+'
    return bool(re.match(pattern, input_str))



def extract_letter_from_answer(input_str):
    """Extracts letter from answer.

    Args:
        input_str : answer string - answers should start as mentioned in starts_with_capital_letter.

    Returns:
        letter or "-1".
    """
    pattern = r'^([A-Za-z])(?=[\s:.])'
    match = re.search(pattern, input_str)
    if match:
        extracted_text = match.group(1)
    else:
        extracted_text = "-1"
    return extracted_text

### Debating script

In [26]:
def debate(model, question, num_agents, num_rounds, num_tries=5, verbose=False):
    agent_contexts = [[question] for agent in range(num_agents)]
    agent_answers = {f"Agent_{agent}": {} for agent in range(num_agents)}

    for round in range(num_rounds):
        if verbose:
          print("#######################")
          print(f"DEBATING ROUND {round}")
          print("#######################")
          print("")

        for i, agent_context in enumerate(agent_contexts):
            if round != 0:
                agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
                message = construct_message(agent_contexts_other, question)
                agent_context[-1] += "\n\n" + message

            if verbose:
              print(f"---- AGENT {i} CONTEXT ----")
              print(agent_context[-1])
              print("---------------------------")
              print("")


            for j in range(num_tries):
                response = model(
                    instruction="Answer this multiple choice question.",
                    input=agent_context[-1],
                    output="The Answer to the question is:",
                    **sampling,
                )
                response = strip_special_chars(response)
                if starts_with_capital_letter(response):
                    break


            # only the letter e.g. A, B, C, D
            letter_answer = extract_letter_from_answer(response)
            agent_answers[f"Agent_{i}"][f"Round_{round}"] = letter_answer

            if verbose:
              print(f"---- AGENT {i} PROPOSED ANSWER ----")
              print(response)
              print("-----------------------------------")
              print("")

            agent_context.append(agent_context[0] + "\n\n" + response)

    return agent_contexts, agent_answers

# Qualitative Debate test

In [6]:
from medalpaca.inferer import Inferer

model = Inferer(
    model_name="medalpaca/medalpaca-lora-7b-8bit",
    prompt_template="medalpaca/prompt_templates/medalpaca.json",
    base_model="decapoda-research/llama-7b-hf",
    peft=True,
    load_in_8bit=True,
)
print("Model loaded.")


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


Downloading (…)lve/main/config.json:   0%|          | 0.00/427 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/33 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00002-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00003-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00004-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00005-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00006-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00007-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00008-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00009-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00010-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00011-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00012-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00013-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00014-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00015-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00016-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00017-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00018-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00019-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00020-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00021-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00022-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00023-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00024-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00025-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00026-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00027-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00028-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00029-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00030-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00031-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00032-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00033-of-00033.bin:   0%|          | 0.00/524M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)/adapter_config.json:   0%|          | 0.00/398 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/67.2M [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


Model loaded.


In [15]:
sampling = {
    "do_sample": True,
    "top_k": 50,
    "num_beams": 1,
    "max_new_tokens": 128,
    "early_stopping": True,
    "temperature": 0.4,
    "top_p": 0.9,
}

In [16]:
question = (
    "Question: A 27-year-old woman comes to the office for counseling prior to "
"conception. She states that a friend recently delivered a newborn with a neural tube "
"defect and she wants to decrease her risk for having a child with this condition. She "
"has no history of major medical illness and takes no medications. Physical examination"
" shows no abnormalities. It is most appropriate to recommend that this patient begin "
"supplementation with a vitamin that is a cofactor in which of the following processes?"
" \n\n(A) Biosynthesis of nucleotides \n\n(B) Protein gamma glutamate carboxylation "
"\n\n(C) Scavenging of free radicals \n\n(D) Transketolation \n\n(E) Triglyceride "
"lipolysis \n\nWhat is the correct answer (A, B, C, D, or E)?")

print(question)

Question: A 27-year-old woman comes to the office for counseling prior to conception. She states that a friend recently delivered a newborn with a neural tube defect and she wants to decrease her risk for having a child with this condition. She has no history of major medical illness and takes no medications. Physical examination shows no abnormalities. It is most appropriate to recommend that this patient begin supplementation with a vitamin that is a cofactor in which of the following processes? 

(A) Biosynthesis of nucleotides 

(B) Protein gamma glutamate carboxylation 

(C) Scavenging of free radicals 

(D) Transketolation 

(E) Triglyceride lipolysis 

What is the correct answer (A, B, C, D, or E)?


### Simulate debate

A debate between 3 agents over 3 rounds.

-- For detailed output during debate, set `verbose = True` below.

In [27]:
%%time
agent_contexts, answers = debate(model, question, num_agents=3, num_rounds=3)

CPU times: user 6min 12s, sys: 805 ms, total: 6min 13s
Wall time: 6min 16s


In [28]:
answers

{'Agent_0': {'Round_0': 'C', 'Round_1': 'C', 'Round_2': 'C'},
 'Agent_1': {'Round_0': 'A', 'Round_1': 'C', 'Round_2': 'C'},
 'Agent_2': {'Round_0': 'C', 'Round_1': 'C', 'Round_2': 'C'}}

Notes:

*   Debating is expensive. For a single question, time is on the order of minutes.
*   The models don't seem to be able to handle long context well.
*   Results are very stochastic, and often incorrect.
*   At least it seems the agents often reach consensus.

