In [8]:
import os
import json

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings

# Path to the vectorstore relative to this notebook
FAISS_DB_PATH = os.path.abspath(os.path.join("..", "her2_faiss_db"))

# Load FAISS vectorstore
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.load_local(FAISS_DB_PATH, embedding_model, allow_dangerous_deserialization=True)

# Load Flan-T5 model
model_id = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
llm_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)


Device set to use cuda:0


In [9]:
def build_prompt(context: str, question: str) -> str:
    return f"""You are a biomedical research assistant. Read the context and answer the question in a detailed, informative way suitable for a graduate-level researcher.

Context:
{context}

Question:
{question}

Answer:"""

def get_answer(query: str) -> str:
    docs = vectorstore.similarity_search(query, k=10)
    context = "\n\n".join(doc.page_content for doc in docs)
    prompt = build_prompt(context, query)
    result = llm_pipeline(prompt, max_new_tokens=512, temperature=0.2)
    return result[0]["generated_text"]


In [10]:
with open("../data/qa/her2_qa_dataset.json") as f:
    qa_dataset = json.load(f)


In [11]:
chatbot_predictions = {}

for qa in qa_dataset:
    qid = qa["id"]
    question = qa["question"]
    print(f"Answering: {question}")
    answer = get_answer(question)
    chatbot_predictions[qid] = answer

with open("../data/qa/her2_predictions.json", "w") as f:
    json.dump(chatbot_predictions, f, indent=2)


Answering: What gene is associated with poor prognosis in human breast cancer?


Token indices sequence length is longer than the specified maximum sequence length for this model (1221 > 512). Running this sequence through the model will result in indexing errors


Answering: In how many of the 103 tumors was HER-2/neu gene amplification found in the initial study?
Answering: Does HER-2/neu amplification correlate with hormone receptor status?
Answering: How does HER-2/neu amplification affect relapse and survival rates?
Answering: What statistical test was used to compare survival curves?
Answering: How many patients had HER-2/neu amplification in the second group of 86 node-positive samples?
Answering: Is HER-2/neu amplification an independent prognostic factor?
Answering: What was the median follow-up time in the node-positive patient study?
Answering: What methods were used to detect HER-2/neu amplification?
Answering: Which other gene was compared with HER-2/neu for amplification?


In [12]:
with open("../data/qa/her2_predictions.json", "w") as f:
    json.dump(chatbot_predictions, f, indent=2)


In [None]:
import pandas as pd
import re
import string
from IPython.display import display, HTML

# --- F1-only evaluation functions ---
def normalize_answer(s: str) -> str:
    def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text): return ' '.join(text.split())
    def remove_punc(text): return ''.join(ch for ch in text if ch not in set(string.punctuation))
    def lower(text): return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_f1(a_gold: str, a_pred: str) -> float:
    gold_tokens = normalize_answer(a_gold).split()
    pred_tokens = normalize_answer(a_pred).split()
    common = set(gold_tokens) & set(pred_tokens)
    if not common: return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)
    return 2 * (precision * recall) / (precision + recall)

# --- Build evaluation table ---
records = []

for qa in qa_dataset:
    qid = qa["id"]
    question = qa["question"]
    gold = qa["answer"]
    pred = chatbot_predictions.get(qid, "")
    f1 = compute_f1(gold, pred)
    flag = "⚠️ Weak" if f1 < 0.5 else ""
    
    records.append({
        "Question": question,
        "Gold Answer": gold,
        "Predicted Answer": pred,
        "F1 Score": round(f1, 2),
        "Flag": flag
    })

df_eval = pd.DataFrame(records)

# --- Display ---
display(HTML("<h3>F1 Score per Question (⚠️ = Flagged as Weak)</h3>"))
display(df_eval)

# --- Also return average F1 score ---
avg_f1 = df_eval["F1 Score"].mean()
{"Average F1 Score (%)": round(avg_f1 * 100, 2)}


Unnamed: 0,Question,Gold Answer,Predicted Answer,F1 Score,Flag
0,What gene is associated with poor prognosis in...,HER-2/neu gene amplification is associated wit...,HER-2/neu,0.15,⚠️ Weak
1,In how many of the 103 tumors was HER-2/neu ge...,In 19 out of 103 tumors (18%).,34/86,0.0,⚠️ Weak
2,Does HER-2/neu amplification correlate with ho...,No significant correlation was found between H...,The presence of gene amplification was correla...,0.39,⚠️ Weak
3,How does HER-2/neu amplification affect relaps...,Amplification significantly correlates with sh...,While there was a somewhat shortened time to r...,0.37,⚠️ Weak
4,What statistical test was used to compare surv...,The log-rank test was used to compare Kaplan-M...,log rank test,0.17,⚠️ Weak
5,How many patients had HER-2/neu amplification ...,34 out of 86 node-positive patients had HER-2/...,34/86,0.0,⚠️ Weak
6,Is HER-2/neu amplification an independent prog...,"Yes, multivariate analysis showed HER-2/neu am...",Amplification of the HER-2/neu gene is a signi...,0.13,⚠️ Weak
7,What was the median follow-up time in the node...,The median follow-up time was 46 months.,47 months,0.25,⚠️ Weak
8,What methods were used to detect HER-2/neu amp...,Southern blot analysis with a 32P-labeled HER-...,x2 test. P values werc computed after combinin...,0.09,⚠️ Weak
9,Which other gene was compared with HER-2/neu f...,"The EGFR gene was compared, and found to be am...",N-myc,0.0,⚠️ Weak


{'Average F1 Score (%)': np.float64(15.5)}