## Data Preparation

In [1]:
import xmltodict
import json
import numpy as np
from scispacy.abbreviation import AbbreviationDetector
import spacy
import scispacy
import en_core_sci_sm

nlp = en_core_sci_sm.load()
nlp.add_pipe("abbreviation_detector")

def extract_clinical_terms(text):
    doc = nlp(text)
    return [ent.text for ent in doc.ents]


with open('./data/dev/archehr-qa_key.json', 'r', encoding='utf-8') as f:
    ee_labels = json.load(f)
with open('./data/dev/archehr-qa.xml', 'r', encoding='utf-8') as f:
    dev_data = xmltodict.parse(f.read())

processed_dev_data = []
for case_id in range(20):
    essential_sentences = []
    for i, ans in enumerate(ee_labels[case_id]["answers"]):
        if ans["relevance"] == "essential":
            essential_sentences.append(i)

    question_en = extract_clinical_terms(dev_data["annotations"]["case"][case_id]["clinician_question"]) #patient_narrative  #clinician_question
    question = dev_data["annotations"]["case"][case_id]["clinician_question"]
    context = dev_data["annotations"]["case"][case_id]["note_excerpt"]
    sentences = []
    sentences_en = []
    sentence_offsets = [case_id]
    sentences_ids = []
    for sentence in dev_data["annotations"]["case"][case_id]["note_excerpt_sentences"]["sentence"]:
        sentences_ids.append({
            "id": sentence["@id"],
            "text": sentence["#text"]
        })
        sentence_offsets.append(context.find(sentence["#text"], sentence_offsets[-1]))
        sentences_en.append(extract_clinical_terms(sentence["#text"]))
        sentences.append(sentence["#text"])
    processed_dev_data.append({
        "question": question,
        "question_ents": question_en,
        "question_full": question,
        "sentences_ents": sentences_en,
        "sentences_full": sentences,
        "sentences_ids": sentences_ids,
        "gold_relevant_ids": essential_sentences,
        "case_id": dev_data["annotations"]["case"][case_id]["@id"]
    })

  global_matches = self.global_matcher(doc)


# Experiments

## Essential Sentence Extraction

In [2]:
from sklearn.metrics import precision_recall_fscore_support
import numpy as np

def compute_micro_scores(y_true, y_pred):
    flat_true = [item for sublist in y_true for item in sublist]
    flat_pred = [item for sublist in y_pred for item in sublist]
    p, r, f1, _ = precision_recall_fscore_support(flat_true, flat_pred, average='binary')
    return p, r, f1

def compute_macro_scores(y_true, y_pred):
    ps, rs, f1s = [], [], []
    for true, pred in zip(y_true, y_pred):
        p, r, f1, _ = precision_recall_fscore_support(true, pred, average='binary')
        ps.append(p)
        rs.append(r)
        f1s.append(f1)
    return np.mean(ps), np.mean(rs), np.mean(f1s)
    

y_true = []
for ex in processed_dev_data:
    y_true.append([1 if i in ex["gold_relevant_ids"] else 0 for i in range(len(ex["sentences_full"]))])

### Max Cosine Entity Similarity Score

In [5]:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import itertools
import numpy as np
from sklearn.metrics import precision_recall_fscore_support



def get_max_cosine_score(question_ents, sentences_ents):
    scores = []
    for i, sentence_ents in enumerate(sentences_ents):
        q_scores = []
        for q_ent in question_ents:
            sen_scores = []
            for s_ent in sentence_ents:
                emb_q = get_embedding(q_ent, mces_model, mces_tokenizer)
                emb_s = get_embedding(s_ent, mces_model, mces_tokenizer)
                cos_sim = F.cosine_similarity(emb_q, emb_s, dim=0).item()
                sen_scores.append(cos_sim)
            q_scores.append(max(sen_scores) if sen_scores else 0.0)
        scores.append(np.mean(sorted(q_scores, reverse=True)[:1]) if q_scores else 0.0)
    return scores

    

def get_embedding(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        token_embeddings = outputs.last_hidden_state  # shape: (1, seq_len, hidden_size)
        attention_mask = inputs['attention_mask']

        # Mean pooling
        mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        masked_embeddings = token_embeddings * mask
        summed = masked_embeddings.sum(1)
        counts = mask.sum(1)
        mean_pooled = summed / counts
        return mean_pooled.squeeze()

#mces_model_path = "../../resq/saved_models/evidence_extraction"
#mces_model_path = "../../models/Bio_ClinicalBERT"
mces_model_path = "../../models/bert-base-multilingual-cased"
mces_tokenizer = AutoTokenizer.from_pretrained(mces_model_path)
mces_model = AutoModel.from_pretrained(mces_model_path)

# Collect scores
mces_scores = []
for i, ex in enumerate(processed_dev_data):
    mces_scores.append(get_max_cosine_score(ex["question_ents"], ex["sentences_ents"]))
    print(f"{i}. sample processed")

0. sample processed
1. sample processed
2. sample processed
3. sample processed
4. sample processed
5. sample processed
6. sample processed
7. sample processed
8. sample processed
9. sample processed
10. sample processed
11. sample processed
12. sample processed
13. sample processed
14. sample processed
15. sample processed
16. sample processed
17. sample processed
18. sample processed
19. sample processed


#### Threshold Experiment

In [7]:
# Do the threshold experiment, mBERT
for threshold in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]:
    y_pred = []
    for i, sample_scores in enumerate(mces_scores):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    print(f"\tThreshold {threshold}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f} \t\t macro   F1={macro_f1*100.0:.3f} | P={macro_precision*100.0:.3f}, R={macro_recall*100.0:.3f}")

	Threshold 0.0:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.1:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.2:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.3:	 micro   F1=48.913 | P=32.609, R=97.826 		 macro   F1=48.551 | P=33.313, R=97.500
	Threshold 0.4:	 micro   F1=47.794 | P=32.020, R=94.203 		 macro   F1=48.026 | P=33.245, R=95.389
	Threshold 0.5:	 micro   F1=48.276 | P=33.521, R=86.232 		 macro   F1=49.539 | P=35.865, R=90.583
	Threshold 0.6:	 micro   F1=46.512 | P=36.145, R=65.217 		 macro   F1=47.558 | P=40.576, R=72.667
	Threshold 0.7:	 micro   F1=42.918 | P=52.632, R=36.232 		 macro   F1=40.815 | P=58.551, R=40.667
	Threshold 0.8:	 micro   F1=19.753 | P=66.667, R=11.594 		 macro   F1=20.689 | P=38.750, R=18.208
	Threshold 0.81:	 micro   F1=18.750 | P=68.182, R=10.870 		 macro   F1=19.578 | P=37.500, R=17.208
	Threshold 0.82:	 m

In [30]:
# Do the threshold experiment, ClinicalBERT
for threshold in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]:
    y_pred = []
    for i, sample_scores in enumerate(mces_scores):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    print(f"\tThreshold {threshold}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f} \t\t macro   F1={macro_f1*100.0:.3f} | P={macro_precision*100.0:.3f}, R={macro_recall*100.0:.3f}")

	Threshold 0.0:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.1:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.2:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.3:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.4:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.5:	 micro   F1=48.825 | P=32.530, R=97.826 		 macro   F1=48.455 | P=33.188, R=97.500
	Threshold 0.6:	 micro   F1=48.913 | P=32.609, R=97.826 		 macro   F1=48.606 | P=33.299, R=97.500
	Threshold 0.7:	 micro   F1=48.995 | P=32.763, R=97.101 		 macro   F1=48.792 | P=33.489, R=96.944
	Threshold 0.8:	 micro   F1=50.602 | P=37.906, R=76.087 		 macro   F1=48.878 | P=37.926, R=77.847
	Threshold 0.81:	 micro   F1=50.510 | P=38.976, R=71.739 		 macro   F1=48.147 | P=38.706, R=71.931
	Threshold 0.82:	 m

In [4]:
mces_scores_cb = mces_scores
f1 = []
p = []
r = []
for threshold in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
    y_pred = []
    for i, sample_scores in enumerate(mces_scores_cb):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    f1.append(micro_f1)
    p.append(micro_precision)
    r.append(micro_recall)
print(f1)
print(p)
print(r)

[0.4876325088339223, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.4891304347826087, 0.48727272727272725, 0.489945155393053, 0.49224806201550386, 0.5060240963855421, 0.41353383458646614, 0.2717391304347826, 0.15286624203821655, 0.08163265306122448]
[0.32242990654205606, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.32608695652173914, 0.32524271844660196, 0.3276283618581907, 0.335978835978836, 0.37906137184115524, 0.4296875, 0.5434782608695652, 0.631578947368421, 0.6666666666666666]
[1.0, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782

In [6]:
mces_scores_mb = mces_scores
f1 = []
p = []
r = []
for threshold in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
    y_pred = []
    for i, sample_scores in enumerate(mces_scores_mb):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    f1.append(micro_f1)
    p.append(micro_precision)
    r.append(micro_recall)
print(f1)
print(p)
print(r)

[0.4876325088339223, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.488245931283906, 0.4891304347826087, 0.48363636363636364, 0.47794117647058826, 0.4714828897338403, 0.4827586206896552, 0.47702407002188185, 0.46511627906976744, 0.46357615894039733, 0.4291845493562232, 0.2857142857142857, 0.19753086419753085, 0.14193548387096774, 0.11920529801324503, 0.11920529801324503, 0.09523809523809523]
[0.32242990654205606, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.3253012048192771, 0.32608695652173914, 0.32281553398058255, 0.32019704433497537, 0.31958762886597936, 0.3352112676056338, 0.34169278996865204, 0.3614457831325301, 0.4268292682926829, 0.5263157894736842, 0.5909090909090909, 0.6666666666666666, 0.6470588235294118, 0.6923076923076923, 0.6923076923076923, 0.7777777777777778]
[1.0, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9782608695652174, 0.9637681159420289, 0

### MedCPT Cross Encoder Score

In [10]:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import itertools
import numpy as np

def get_medcpt_embedding(pairs): 
    with torch.no_grad():
        encoded = medcptce_tokenizer(
            pairs,
            truncation=True,
            padding=True,
            return_tensors="pt",
            max_length=512,
        )
        logits = medcptce_model(**encoded).logits.squeeze(dim=1)
        probs = (logits - logits.min()) / (logits.max() - logits.min())
    return probs

def rank_sentences_by_similarity(question, sentences):
    pairs = [[question, sentence] for sentence in sentences]
    return get_medcpt_embedding(pairs)


medcptce_model_path = "../../models/MedCPT-Cross-Encoder"
medcptce_tokenizer = AutoTokenizer.from_pretrained(medcptce_model_path)
medcptce_model = AutoModelForSequenceClassification.from_pretrained(medcptce_model_path)

# Collect scores
medcptce_scores = []
for i, ex in enumerate(processed_dev_data):
    #medcptce_scores.append(rank_sentences_by_similarity(ex["question_full"], ex["sentences_full"]))
    medcptce_scores.append(rank_sentences_by_similarity(", ".join(ex["question_ents"]), [", ".join(sen_en) for sen_en in ex["sentences_ents"]]))
    print(f"{i}. sample processed")

0. sample processed
1. sample processed
2. sample processed
3. sample processed
4. sample processed
5. sample processed
6. sample processed
7. sample processed
8. sample processed
9. sample processed
10. sample processed
11. sample processed
12. sample processed
13. sample processed
14. sample processed
15. sample processed
16. sample processed
17. sample processed
18. sample processed
19. sample processed


#### Threshold Experiment

In [9]:
# Do the threshold experiment
for threshold in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]:
    y_pred = []
    for i, sample_scores in enumerate(medcptce_scores):
        preds = []
        for sc in sample_scores:
            if sc > threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    print(f"\tThreshold {threshold}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f} \t\t macro   F1={macro_f1*100.0:.3f} | P={macro_precision*100.0:.3f}, R={macro_recall*100.0:.3f}")

	Threshold 0.0:	 micro   F1=49.446 | P=33.168, R=97.101 		 macro   F1=49.785 | P=34.477, R=97.708
	Threshold 0.1:	 micro   F1=49.794 | P=34.770, R=87.681 		 macro   F1=50.577 | P=37.197, R=89.569
	Threshold 0.2:	 micro   F1=50.360 | P=37.634, R=76.087 		 macro   F1=51.752 | P=41.167, R=80.083
	Threshold 0.3:	 micro   F1=50.556 | P=40.991, R=65.942 		 macro   F1=51.384 | P=45.102, R=70.236
	Threshold 0.4:	 micro   F1=51.466 | P=46.746, R=57.246 		 macro   F1=52.188 | P=52.217, R=60.764
	Threshold 0.5:	 micro   F1=51.163 | P=55.000, R=47.826 		 macro   F1=52.864 | P=64.458, R=53.750
	Threshold 0.55:	 micro   F1=50.407 | P=57.407, R=44.928 		 macro   F1=52.118 | P=65.578, R=51.889
	Threshold 0.6:	 micro   F1=46.018 | P=59.091, R=37.681 		 macro   F1=47.558 | P=67.313, R=43.083
	Threshold 0.65:	 micro   F1=40.187 | P=56.579, R=31.159 		 macro   F1=42.334 | P=63.750, R=37.278
	Threshold 0.7:	 micro   F1=36.181 | P=59.016, R=26.087 		 macro   F1=38.706 | P=67.500, R=30.653
	Threshold 0.75:	 

In [9]:
medcptce_scores_full = medcptce_scores
f1 = []
p = []
r = []
for threshold in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
    y_pred = []
    for i, sample_scores in enumerate(medcptce_scores_full):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    f1.append(micro_f1)
    p.append(micro_precision)
    r.append(micro_recall)
print(f1)
print(p)
print(r)

[0.4876325088339223, 0.4934086629001883, 0.5058823529411764, 0.5112474437627812, 0.5086206896551724, 0.5162790697674419, 0.51010101010101, 0.5096952908587258, 0.49557522123893805, 0.46540880503144655, 0.4689655172413793, 0.45255474452554745, 0.42857142857142855, 0.4, 0.3490566037735849, 0.32323232323232326, 0.2751322751322751, 0.2247191011235955, 0.17647058823529413, 0.14545454545454545, 0.11392405063291139]
[0.32242990654205606, 0.3333333333333333, 0.3467741935483871, 0.3561253561253561, 0.3619631901840491, 0.3801369863013699, 0.39147286821705424, 0.4125560538116592, 0.417910447761194, 0.4111111111111111, 0.4473684210526316, 0.45588235294117646, 0.47368421052631576, 0.4845360824742268, 0.5, 0.5333333333333333, 0.5098039215686274, 0.5, 0.46875, 0.4444444444444444, 0.45]
[1.0, 0.9492753623188406, 0.9347826086956522, 0.9057971014492754, 0.855072463768116, 0.8043478260869565, 0.7318840579710145, 0.6666666666666666, 0.6086956521739131, 0.5362318840579711, 0.4927536231884058, 0.449275362318

In [11]:
medcptce_scores = medcptce_scores
f1 = []
p = []
r = []
for threshold in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
    y_pred = []
    for i, sample_scores in enumerate(medcptce_scores):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    f1.append(micro_f1)
    p.append(micro_precision)
    r.append(micro_recall)
print(f1)
print(p)
print(r)

[0.4876325088339223, 0.48828125, 0.49794238683127573, 0.5067264573991032, 0.5035971223021583, 0.5103092783505154, 0.5055555555555555, 0.5180722891566265, 0.5146579804560261, 0.49640287769784175, 0.5116279069767442, 0.5040650406504065, 0.46017699115044247, 0.40186915887850466, 0.36180904522613067, 0.27956989247311825, 0.22598870056497175, 0.18823529411764706, 0.16666666666666666, 0.14545454545454545, 0.1518987341772152]
[0.32242990654205606, 0.3342245989304813, 0.34770114942528735, 0.36688311688311687, 0.3763440860215054, 0.396, 0.4099099099099099, 0.44329896907216493, 0.46745562130177515, 0.4928571428571429, 0.55, 0.5740740740740741, 0.5909090909090909, 0.5657894736842105, 0.5901639344262295, 0.5416666666666666, 0.5128205128205128, 0.5, 0.4666666666666667, 0.4444444444444444, 0.6]
[1.0, 0.9057971014492754, 0.8768115942028986, 0.8188405797101449, 0.7608695652173914, 0.717391304347826, 0.6594202898550725, 0.6231884057971014, 0.572463768115942, 0.5, 0.4782608695652174, 0.4492753623188406,

### Context Based Med42 Few Shot Scores

In [None]:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
import torch
import torch.nn.functional as F
import itertools
import numpy as np
import math



def cb_compute_response_score(question, sentence, context):
    few_shot = """You are a medical assistant helping a patient's family member understand the discharge summary. The family member asks a general question about the patient’s condition or expected recovery. From the discharge summary, you are evaluating whether a specific sentence is essential to help them understand what they truly need to know — even if they didn't ask about it directly.

For each example, decide:
- Is the sentence important for answering the underlying concern in the question? ("Yes" or "No")
- Briefly explain why or why not.

### Example 1
Context:
The patient was admitted with signs of dehydration and electrolyte imbalance following several days of vomiting and diarrhea. Intravenous fluids and potassium replacement were administered. He gradually regained strength and tolerated oral intake by day 3. There were no signs of infection. Electrolyte levels normalized. He was encouraged to maintain oral hydration and avoid NSAIDs. Discharge instructions included dietary recommendations. He is to follow up with his primary care physician in one week. The patient lives alone and has limited mobility. Transportation services were arranged for follow-up.

Patient’s Question: How long will it take for him to fully recover?
Sentence: "He is to follow up with his primary care physician in one week."
Answer: Yes
Reason: The scheduled follow-up provides insight into the expected timeline of recovery and monitoring, even though the patient didn't explicitly ask about appointments.

### Example 2
Context:
The patient presented with acute asthma exacerbation. She received nebulized albuterol and corticosteroids in the emergency department. Oxygen saturation improved over 24 hours. There were no signs of pneumonia. She was discharged with a prescription for inhaled corticosteroids and a tapering dose of prednisone. She was advised to avoid known triggers such as smoke or allergens. Patient reported improved breathing at rest but slight shortness of breath during activity. No further imaging was ordered. The pulmonologist will review her progress in 10 days.

Patient’s Question: Is she okay to go back to work next week?
Sentence: "The pulmonologist will review her progress in 10 days."
Answer: Yes
Reason: The timing of the specialist review is crucial for determining readiness to return to work, even though the patient didn't mention the appointment.

### Example 3
Context:
The patient was admitted for routine laparoscopic cholecystectomy. The surgery was uncomplicated. Minimal intraoperative bleeding was noted. Postoperative pain was managed with oral analgesics. Bowel function resumed within 24 hours. She ambulated independently on post-op day 2. The surgical wound was clean and dry. Discharge instructions advised avoiding heavy lifting for two weeks. Follow-up scheduled with surgery clinic in 14 days. Patient was in good spirits and eager to return to normal activities.

Patient’s Question: What should her recovery look like?
Sentence: "Discharge instructions advised avoiding heavy lifting for two weeks."
Answer: Yes
Reason: The lifting restriction is an essential part of understanding the expected recovery process, even if not directly requested.

### Example 4
"""
    # Build full prompt
    #prompt = few_shot.strip() + "\n\n" + f"Context: {sentence}\nQuestion: What information is essential from this context for answering the question \"{question}\"\nAnswer:"
    prompt = few_shot.strip() + "\n" + f"Context:\n{context}\n\nPatient’s Question: {question}\nSentence: \"{sentence}\"\nAnswer:"
    input_ids = med42_tokenizer(prompt, return_tensors="pt").input_ids.to(med42_model.device)
    with torch.no_grad():
        outputs = med42_model(input_ids)
        logits = outputs.logits[:, -1, :]  # only the next token

    # Tokenize input
    inputs = med42_tokenizer(prompt, return_tensors="pt").to(med42_model.device)
    
    # Generate model output
    with torch.no_grad():
        outputs = med42_model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            return_dict_in_generate=True,
            output_scores=True,
            pad_token_id=med42_tokenizer.pad_token_id
        )
    
    # Decode generated answer
    generated_tokens = outputs.sequences[0][inputs['input_ids'].shape[-1]:]
    generated_answer = med42_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

    
    # Compute log-probability score
    answer_ids = med42_tokenizer(" " + generated_answer, return_tensors="pt", add_special_tokens=False).input_ids.to(med42_model.device)
    logits = torch.stack(outputs.scores, dim=1)[0]  # shape: [tokens, vocab]
    probs = torch.softmax(logits, dim=-1)
    log_probs = torch.log(probs + 1e-12)
    
    # Sum log-probs for each token
    score = sum(log_probs[i, token_id].item() for i, token_id in enumerate(answer_ids[0]))
    
    if generated_answer[:3] == "Yes":
        exp_score = math.exp(score)
        if math.isnan(exp_score):
            return 0.0
        return exp_score
    else:
        return 0.0


def cbmed42_predict(question, sentences):
    scores = []
    for i, sen in enumerate(sentences):
        scores.append(cb_compute_response_score(question, " ".join(sen.split()), " ".join("\n".join(sentences).split())))
    if sum(scores) == 0.0:
        return scores
    return list(np.array(scores)/sum(scores))

# Load model
med42_model_path = "../../models/Llama3-Med42-8B"
med42_tokenizer = AutoTokenizer.from_pretrained(med42_model_path)
med42_model = AutoModelForCausalLM.from_pretrained(med42_model_path, torch_dtype=torch.float16)
med42_model.eval()

med42_model = med42_model.to('cuda:0')
print(med42_model.device)

# Collect scores
cbmed42_scores = []
for i, ex in enumerate(processed_dev_data):
    cbmed42_scores.append(cbmed42_predict(ex["question_full"], ex["sentences_full"]))
    print(f"{i}. sample processed")

#### Threshold Experiment

In [13]:
# Do the threshold experiment
for threshold in [0.0, 0.02, 0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]:
    y_pred = []
    for i, sample_scores in enumerate(cbmed42_scores):
        preds = []
        for sc in sample_scores:
            if sc > threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    print(f"\tThreshold {threshold}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f} \t\t macro   F1={macro_f1*100.0:.3f} | P={macro_precision*100.0:.3f}, R={macro_recall*100.0:.3f}")

	Threshold 0.0:	 micro   F1=33.663 | P=53.125, R=24.638 		 macro   F1=32.630 | P=51.079, R=30.333
	Threshold 0.02:	 micro   F1=25.434 | P=62.857, R=15.942 		 macro   F1=27.408 | P=53.583, R=21.292
	Threshold 0.05:	 micro   F1=23.529 | P=62.500, R=14.493 		 macro   F1=25.277 | P=52.500, R=18.167
	Threshold 0.08:	 micro   F1=22.892 | P=67.857, R=13.768 		 macro   F1=25.620 | P=55.000, R=17.667
	Threshold 0.1:	 micro   F1=18.634 | P=65.217, R=10.870 		 macro   F1=21.722 | P=52.500, R=14.708
	Threshold 0.15:	 micro   F1=16.352 | P=61.905, R=9.420 		 macro   F1=20.588 | P=51.667, R=13.958
	Threshold 0.2:	 micro   F1=16.352 | P=61.905, R=9.420 		 macro   F1=20.588 | P=51.667, R=13.958
	Threshold 0.25:	 micro   F1=16.456 | P=65.000, R=9.420 		 macro   F1=20.635 | P=52.500, R=13.958
	Threshold 0.3:	 micro   F1=14.103 | P=61.111, R=7.971 		 macro   F1=18.302 | P=50.000, R=12.083
	Threshold 0.35:	 micro   F1=13.072 | P=66.667, R=7.246 		 macro   F1=17.663 | P=50.000, R=11.667
	Threshold 0.4:	 mi

In [37]:
f1 = []
p = []
r = []
for threshold in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
    y_pred = []
    for i, sample_scores in enumerate(cbmed42_scores):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    f1.append(micro_f1)
    p.append(micro_precision)
    r.append(micro_recall)
print(f1)
print(p)
print(r)

[0.4876325088339223, 0.23529411764705882, 0.18633540372670807, 0.16352201257861634, 0.16352201257861634, 0.16455696202531644, 0.14102564102564102, 0.13071895424836602, 0.13157894736842105, 0.13245033112582782, 0.13513513513513514, 0.13513513513513514, 0.12244897959183673, 0.12244897959183673, 0.12244897959183673, 0.1095890410958904, 0.1095890410958904, 0.1095890410958904, 0.08333333333333333, 0.056338028169014086, 0.0425531914893617]
[0.32242990654205606, 0.625, 0.6521739130434783, 0.6190476190476191, 0.6190476190476191, 0.65, 0.6111111111111112, 0.6666666666666666, 0.7142857142857143, 0.7692307692307693, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
[1.0, 0.14492753623188406, 0.10869565217391304, 0.09420289855072464, 0.09420289855072464, 0.09420289855072464, 0.07971014492753623, 0.07246376811594203, 0.07246376811594203, 0.07246376811594203, 0.07246376811594203, 0.07246376811594203, 0.06521739130434782, 0.06521739130434782, 0.06521739130434782, 0.057971014492753624, 0.05797101

### Sentence Relevant Med42 Few Shot Scores

In [None]:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
import torch
import torch.nn.functional as F
import itertools
import numpy as np
import math




def sr_compute_response_score(question, sentence):
    few_shot = """You are a clinical assistant. Given a context and a question, extract only the essential information from the context that is necessary to answer the question. If no information is relevant, respond with "None". Also provide a short explanation for your answer.

Context: The patient has a history of hypertension and presents with progressive shortness of breath. BNP levels are elevated. Physical examination reveals bilateral rales and mild pedal edema.
Question: What information is essential from this context for answering the question "What is causing the patient's breathing difficulty?"
Answer: elevated BNP, bilateral rales, mild pedal edema
Reason: These are all indicators of congestive heart failure, which likely explains the breathing difficulty.

Context: The patient is a 45-year-old male with a history of allergic rhinitis. He was seen in allergy clinic and placed on a regimen of nasal corticosteroids and antihistamines. No new triggers identified. Symptoms are seasonal and well-controlled.
Question: What information is essential from this context for answering the question "What is the most likely cause of the patient's anemia?"
Answer: None
Reason: The context is entirely focused on allergic rhinitis, with no hematologic data or symptoms of anemia.

Context: The patient completed a dental cleaning and X-rays showed mild periodontal disease. Oral hygiene habits were discussed, and the patient agreed to floss daily. No pain or bleeding reported. No antibiotics were prescribed.
Question: What information is essential from this context for answering the question "What medications are responsible for the patient's elevated INR?"
Answer: None
Reason: There is no mention of any anticoagulants or medications that affect coagulation in the context.

Context: Patient underwent knee replacement two years ago. Reports occasional clicking sensation but no pain. X-ray shows proper implant positioning.
Question: What information is essential from this context for answering the question "Is the knee replacement causing complications?"
Answer: occasional clicking sensation, no pain, proper implant positioning
Reason: Clicking may suggest minor mechanical noise but no signs of complications given the lack of pain and good positioning.

Context: Complains of weight loss and fatigue over the past 3 months. Labs show iron deficiency anemia. Colonoscopy reveals a 2 cm mass in the ascending colon.
Question: What information is essential from this context for answering the question "What might explain the patient’s fatigue?"
Answer: iron deficiency anemia, 2 cm mass in ascending colon
Reason: Chronic blood loss from the mass could explain anemia and fatigue.

Context: The patient underwent cataract surgery on the right eye and reports improved vision. Post-op evaluation showed clear lens placement and normal intraocular pressure. No inflammation noted. Scheduled for left eye surgery in two months.
Question: What information is essential from this context for answering the question "Why did the patient develop shortness of breath?"
Answer: None
Reason: The context is limited to ophthalmologic findings and does not mention any pulmonary or cardiovascular symptoms.

Context: Denies smoking, alcohol, or drug use. Family history positive for lung cancer in both parents. Works in construction for 25 years without respiratory protection.
Question: What information is essential from this context for answering the question "What are the patient's risk factors for lung cancer?"
Answer: family history of lung cancer, 25 years in construction without respiratory protection
Reason: Occupational exposure and genetics increase risk even without smoking.

Context: Admitted for severe epigastric pain. Has history of NSAID use for chronic back pain. Labs show decreased hemoglobin. Endoscopy confirms a gastric ulcer.
Question: What information is essential from this context for answering the question "What is the likely cause of the gastrointestinal bleeding?"
Answer: NSAID use, gastric ulcer, decreased hemoglobin
Reason: NSAIDs are known to cause gastric ulcers, which can lead to bleeding.

Context: No prior psychiatric history. The patient has been irritable and withdrawn for the past month. Sleep has decreased to 3 hours/night. Appetite remains normal.
Question: What information is essential from this context for answering the question "Are there signs of depression?"
Answer: irritability, social withdrawal, decreased sleep
Reason: These are common symptoms associated with depressive disorders.

Context: The patient had a colonoscopy last week, which revealed three polyps that were removed. Pathology is pending. The patient denies abdominal pain, nausea, or changes in bowel habits. Family history is negative for colorectal cancer.
Question: What information is essential from this context for answering the question "Why is the patient experiencing chronic fatigue?"
Answer: None
Reason: The context is focused on GI screening and doesn't include symptoms, labs, or findings that would explain fatigue.

Context: Presents with left arm weakness and facial droop for 45 minutes. Symptoms resolved prior to arrival. CT scan shows no acute infarct. History of atrial fibrillation.
Question: What information is essential from this context for answering the question "What might have caused the neurological symptoms?"
Answer: transient symptoms, atrial fibrillation
Reason: AFib can cause transient ischemic attacks, which present with stroke-like symptoms that resolve.

Context: Mother reports that her child, aged 3, has not yet started speaking in full sentences. Hearing test is normal. No social interaction issues observed. Growth chart is appropriate.
Question: What information is essential from this context for answering the question "Is there concern for developmental delay?"
Answer: 3-year-old not speaking in full sentences
Reason: While social and hearing are normal, speech delay is suggestive of possible developmental delay.

Context: Recent travel to sub-Saharan Africa. Developed intermittent fever and chills on return. Blood smear reveals Plasmodium falciparum.
Question: What information is essential from this context for answering the question "What is the likely cause of the patient’s fever?"
Answer: travel to sub-Saharan Africa, Plasmodium falciparum
Reason: These findings point to malaria as the likely cause of the fever.

Context: Complains of morning stiffness lasting more than 1 hour. Joints in both hands are swollen and tender. Positive rheumatoid factor and anti-CCP antibodies.
Question: What information is essential from this context for answering the question "Is this likely to be rheumatoid arthritis?"
Answer: morning stiffness >1 hour, swollen/tender hand joints, positive RF and anti-CCP
Reason: These clinical and serological findings are diagnostic of RA.

Context: A 65-year-old woman was referred to audiology due to recent hearing difficulties. Audiogram showed moderate bilateral sensorineural hearing loss. Hearing aids were recommended. No signs of vertigo or tinnitus were reported.
Question: What information is essential from this context for answering the question "What led to the patient's episodes of syncope?"
Answer: None
Reason: The context only contains auditory assessment and does not address cardiovascular or neurologic causes.

Context: On insulin therapy. Skipped lunch due to meetings. Found diaphoretic and confused. Glucose 42 mg/dL.
Question: What information is essential from this context for answering the question "What explains the patient’s confusion?"
Answer: skipped lunch, insulin therapy, glucose 42 mg/dL
Reason: Hypoglycemia is likely due to missed meal with insulin use.

Context: Reports worsening shortness of breath over 2 weeks. Has COPD. Oxygen saturation drops to 89% on ambulation. Chest X-ray shows no infiltrates.
Question: What information is essential from this context for answering the question "What is likely contributing to the patient’s shortness of breath?"
Answer: COPD history, desaturation with ambulation
Reason: COPD with exertional desaturation is a common cause of dyspnea in such patients.

Context: Diagnosed with hypothyroidism last year. Currently on levothyroxine. Complains of fatigue and cold intolerance. TSH 9.2.
Question: What information is essential from this context for answering the question "Why is the patient still symptomatic?"
Answer: hypothyroidism, TSH 9.2
Reason: Elevated TSH indicates under-replacement with levothyroxine.

Context: Denies any chest pain. Takes beta-blocker for hypertension. EKG reveals bradycardia (HR 48 bpm). Patient feels fatigued.
Question: What information is essential from this context for answering the question "What could explain the fatigue?"
Answer: beta-blocker use, bradycardia
Reason: Bradycardia from beta-blockers may result in reduced cardiac output and fatigue.

Context: The patient was evaluated in the ophthalmology clinic due to complaints of blurry vision. Examination showed no signs of diabetic retinopathy. Blood pressure was within normal range. There were no neurological deficits noted. Follow-up was scheduled in six months.
Question: What information is essential from this context for answering the question "What is the underlying cause of the patient's persistent headaches?"
Answer: None
Reason: The context only discusses ophthalmological findings and vision-related complaints but contains no information about the cause of headaches.

Context: Patient presented for a follow-up regarding their post-operative shoulder surgery. Physical therapy was recommended and patient reports improvement in range of motion. There are no signs of infection or complications. Sleep has improved as well.
Question: What information is essential from this context for answering the question "What factors contributed to the patient's recent weight loss?"
Answer: None
Reason: The context only discusses orthopedic recovery and makes no mention of diet, metabolism, or weight.

Context: The patient was brought in for confusion. No focal neurological deficits noted. BUN and creatinine significantly elevated. Recently started lisinopril.
Question: What information is essential from this context for answering the question "What could explain the altered mental status?"
Answer: elevated BUN/creatinine, started lisinopril
Reason: Acute kidney injury from ACE inhibitors may lead to uremic encephalopathy.

Context: 65-year-old with chronic low back pain. MRI shows mild degenerative disc disease. No nerve compression.
Question: What information is essential from this context for answering the question "Is surgery indicated?"
Answer: mild degenerative disc disease, no nerve compression
Reason: Conservative treatment is favored as no surgical lesion is present.

Context: During the dermatology consultation, the patient described new-onset skin lesions. The rash appeared on the arms and back, non-pruritic and non-painful. No signs of infection were noted. Biopsy was scheduled.
Question: What information is essential from this context for answering the question "Why has the patient developed elevated liver enzymes?"
Answer: None
Reason: The context centers around dermatological symptoms with no hepatic or metabolic findings provided.

Context: History of mechanical heart valve replacement. INR today is 5.2. No active bleeding reported.
Question: What information is essential from this context for answering the question "What explains the elevated INR?"
Answer: mechanical valve replacement
Reason: Patients require anticoagulation for valves, which can overshoot and elevate INR.

Context: Breast mass noted on exam. Mammogram shows suspicious lesion. Biopsy confirms ductal carcinoma in situ.
Question: What information is essential from this context for answering the question "What is the diagnosis?"
Answer: ductal carcinoma in situ
Reason: Biopsy provides definitive diagnosis.

Context: Patient with ESRD on dialysis. Missed last two sessions. Complains of generalized weakness. Potassium level is 6.8.
Question: What information is essential from this context for answering the question "What is the likely cause of weakness?"
Answer: missed dialysis sessions, potassium 6.8
Reason: Hyperkalemia and uremia due to missed dialysis likely explain weakness.
"""
    # Build full prompt
    prompt = few_shot.strip() + "\n\n" + f"Context: {sentence}\nQuestion: What information is essential from this context for answering the question \"{question}\"\nAnswer:"
    input_ids = med42_tokenizer(prompt, return_tensors="pt").input_ids.to(med42_model.device)
    with torch.no_grad():
        outputs = med42_model(input_ids)
        logits = outputs.logits[:, -1, :]  # only the next token

    # Tokenize input
    inputs = med42_tokenizer(prompt, return_tensors="pt").to(med42_model.device)
    
    # Generate model output
    with torch.no_grad():
        outputs = med42_model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            return_dict_in_generate=True,
            output_scores=True,
            pad_token_id=med42_tokenizer.pad_token_id
        )
    
    # Decode generated answer
    generated_tokens = outputs.sequences[0][inputs['input_ids'].shape[-1]:]
    generated_answer = med42_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

    
    # Compute log-probability score
    answer_ids = med42_tokenizer(" " + generated_answer, return_tensors="pt", add_special_tokens=False).input_ids.to(med42_model.device)
    logits = torch.stack(outputs.scores, dim=1)[0]  # shape: [tokens, vocab]
    probs = torch.softmax(logits, dim=-1)
    log_probs = torch.log(probs + 1e-12)
    
    # Sum log-probs for each token
    score = sum(log_probs[i, token_id].item() for i, token_id in enumerate(answer_ids[0]))
    
    if generated_answer[:4] != "None":
        return math.exp(score)
    else:
        return 0.0

def srmed42_predict(question, sentences):
    scores = []
    for i, sen in enumerate(sentences):
        scores.append(sr_compute_response_score(question, " ".join(sen.split())))
    if sum(scores) == 0.0:
        return scores
    return list(np.array(scores)/sum(scores))

# Load model
med42_model_path = "../../models/Llama3-Med42-8B"
#med42_tokenizer = AutoTokenizer.from_pretrained(med42_model_path)
#med42_model = AutoModelForCausalLM.from_pretrained(med42_model_path, device_map="auto", torch_dtype=torch.float16)
#med42_model.eval()


# Collect scores
srmed42_scores = []
for i, ex in enumerate(processed_dev_data):
    srmed42_scores.append(srmed42_predict(ex["question_full"], ex["sentences_full"]))
    print(f"{i}. sample processed")

#### Threshold Experiment

In [15]:
# Do the threshold experiment
for threshold in [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]:
    y_pred = []
    for i, sample_scores in enumerate(srmed42_scores):
        preds = []
        for sc in sample_scores:
            if sc > threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    print(f"\tThreshold {threshold}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f} \t\t macro   F1={macro_f1*100.0:.3f} | P={macro_precision*100.0:.3f}, R={macro_recall*100.0:.3f}")

	Threshold 0.0:	 micro   F1=21.053 | P=54.545, R=13.043 		 macro   F1=19.550 | P=39.417, R=17.375
	Threshold 0.01:	 micro   F1=15.385 | P=66.667, R=8.696 		 macro   F1=15.840 | P=42.500, R=11.750
	Threshold 0.02:	 micro   F1=14.194 | P=64.706, R=7.971 		 macro   F1=14.650 | P=41.667, R=10.083
	Threshold 0.03:	 micro   F1=14.194 | P=64.706, R=7.971 		 macro   F1=14.650 | P=41.667, R=10.083
	Threshold 0.04:	 micro   F1=14.194 | P=64.706, R=7.971 		 macro   F1=14.650 | P=41.667, R=10.083
	Threshold 0.05:	 micro   F1=14.379 | P=73.333, R=7.971 		 macro   F1=14.983 | P=42.500, R=10.083
	Threshold 0.06:	 micro   F1=14.379 | P=73.333, R=7.971 		 macro   F1=14.983 | P=42.500, R=10.083
	Threshold 0.07:	 micro   F1=14.379 | P=73.333, R=7.971 		 macro   F1=14.983 | P=42.500, R=10.083
	Threshold 0.08:	 micro   F1=14.379 | P=73.333, R=7.971 		 macro   F1=14.983 | P=42.500, R=10.083
	Threshold 0.09:	 micro   F1=14.379 | P=73.333, R=7.971 		 macro   F1=14.983 | P=42.500, R=10.083
	Threshold 0.1:	 mic

In [40]:
f1 = []
p = []
r = []
for threshold in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
    y_pred = []
    for i, sample_scores in enumerate(srmed42_scores):
        preds = []
        for sc in sample_scores:
            if sc >= threshold:
                preds.append(1)
            else:
                preds.append(0)
        y_pred.append(preds)
    micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
    macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
    f1.append(micro_f1)
    p.append(micro_precision)
    r.append(micro_recall)
print(f1)
print(p)
print(r)

[0.4876325088339223, 0.1437908496732026, 0.13245033112582782, 0.13245033112582782, 0.13245033112582782, 0.13245033112582782, 0.13245033112582782, 0.13245033112582782, 0.13245033112582782, 0.13245033112582782, 0.12080536912751678, 0.09523809523809523, 0.09523809523809523, 0.09523809523809523, 0.09523809523809523, 0.09523809523809523, 0.09523809523809523, 0.09523809523809523, 0.08275862068965517, 0.08275862068965517, 0.06993006993006994]
[0.32242990654205606, 0.7333333333333333, 0.7692307692307693, 0.7692307692307693, 0.7692307692307693, 0.7692307692307693, 0.7692307692307693, 0.7692307692307693, 0.7692307692307693, 0.7692307692307693, 0.8181818181818182, 0.7777777777777778, 0.7777777777777778, 0.7777777777777778, 0.7777777777777778, 0.7777777777777778, 0.7777777777777778, 0.7777777777777778, 0.8571428571428571, 0.8571428571428571, 1.0]
[1.0, 0.07971014492753623, 0.07246376811594203, 0.07246376811594203, 0.07246376811594203, 0.07246376811594203, 0.07246376811594203, 0.07246376811594203, 

### Combination

In [16]:
# Do the threshold experiment, max entity score with mBERT
print("Grid Search Experiment for finding the best thresholds for combination of previous precomputed scores")
#for srmed_eps, cbmed_eps, mces_eps, cptce_eps in zip([0.0], [0.0], [0.9], [0.5]):
f1s = []
for srmed_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
    for cbmed_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
        for mces_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
            for cptce_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
                y_pred = []
                for srmed_sample_scs, cbmed_sample_scs, mces_sample_scs, cptce_sample_scs in zip(srmed42_scores, cbmed42_scores, mces_scores, medcptce_scores):
                    srmed_preds = [1 if sc > srmed_eps else 0 for sc in srmed_sample_scs]
                    cbmed_preds = [1 if sc > cbmed_eps else 0 for sc in cbmed_sample_scs]
                    mces_preds = [1 if sc > mces_eps else 0 for sc in mces_sample_scs]
                    cptce_preds = [1 if sc > cptce_eps else 0 for sc in cptce_sample_scs]
                    preds = [a | b | c | d for a, b, c, d in zip(srmed_preds, cbmed_preds, mces_preds, cptce_preds)]
                    y_pred.append(preds)
                micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
                macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
                if micro_f1 > 0.56:
                    f1s.append(micro_f1)
                    print(f"\tThreshold {srmed_eps}, {cbmed_eps}, {mces_eps}, {cptce_eps}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f} \t\t macro   F1={macro_f1*100.0:.3f} | P={macro_precision*100.0:.3f}, R={macro_recall*100.0:.3f}")

print(max(f1s))

Grid Search Experiment for finding the best thresholds for combination of previous precomputed scores
	Threshold 0.0, 0.0, 0.7, 0.35:	 micro   F1=56.316 | P=44.215, R=77.536 		 macro   F1=56.669 | P=47.672, R=82.056
	Threshold 0.0, 0.0, 0.7, 0.4:	 micro   F1=56.044 | P=45.133, R=73.913 		 macro   F1=56.098 | P=48.238, R=77.167
	Threshold 0.0, 0.0, 0.7, 0.45:	 micro   F1=56.395 | P=47.087, R=70.290 		 macro   F1=56.445 | P=50.153, R=74.778
	Threshold 0.0, 0.0, 0.7, 0.5:	 micro   F1=57.751 | P=49.738, R=68.841 		 macro   F1=57.982 | P=53.356, R=74.028
	Threshold 0.0, 0.0, 0.7, 0.55:	 micro   F1=56.875 | P=50.000, R=65.942 		 macro   F1=57.274 | P=53.811, R=72.167
	Threshold 0.0, 0.05, 0.7, 0.35:	 micro   F1=56.684 | P=44.915, R=76.812 		 macro   F1=56.925 | P=48.158, R=81.431
	Threshold 0.0, 0.05, 0.7, 0.4:	 micro   F1=56.583 | P=46.119, R=73.188 		 macro   F1=56.475 | P=48.855, R=76.542
	Threshold 0.0, 0.05, 0.7, 0.45:	 micro   F1=57.143 | P=48.485, R=69.565 		 macro   F1=56.909 | P=51.

In [18]:
final_pred = []
srmed_eps, cbmed_eps, mces_eps, cptce_eps = 0.9, 0.05, 0.7, 0.5
for srmed_sample_scs, cbmed_sample_scs, mces_sample_scs, cptce_sample_scs in zip(srmed42_scores, cbmed42_scores, mces_scores, medcptce_scores):
    srmed_preds = [1 if sc > srmed_eps else 0 for sc in srmed_sample_scs]
    cbmed_preds = [1 if sc > cbmed_eps else 0 for sc in cbmed_sample_scs]
    mces_preds = [1 if sc > mces_eps else 0 for sc in mces_sample_scs]
    cptce_preds = [1 if sc > cptce_eps else 0 for sc in cptce_sample_scs]
    preds = [a | b | c | d for a, b, c, d in zip(srmed_preds, cbmed_preds, mces_preds, cptce_preds)]
    final_pred.append(preds)
micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, final_pred)
macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, final_pred)
print(micro_f1)
print(final_pred)

0.5859872611464968
[[0, 1, 0, 1, 0, 1, 1, 1, 1], [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0], [1, 1, 0, 1, 1, 1, 1, 0, 1, 0], [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1], [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0], [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1], [0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1], [1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0], [0, 1, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], [0, 0, 0, 0, 1, 0, 1

In [123]:
# Do the threshold experiment, max entity score with ClinicalBERT
print("Grid Search Experiment for finding the best thresholds for combination of previous precomputed scores")
#for srmed_eps, cbmed_eps, mces_eps, cptce_eps in zip([0.0], [0.0], [0.9], [0.5]):
f1s = []
for srmed_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
    for cbmed_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
        for mces_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
            for cptce_eps in [-1.0] + [round(x * 0.05, 2) for x in range(21)]:
                y_pred = []
                for srmed_sample_scs, cbmed_sample_scs, mces_sample_scs, cptce_sample_scs in zip(srmed42_scores, cbmed42_scores, mces_scores, medcptce_scores):
                    srmed_preds = [1 if sc > srmed_eps else 0 for sc in srmed_sample_scs]
                    cbmed_preds = [1 if sc > cbmed_eps else 0 for sc in cbmed_sample_scs]
                    mces_preds = [1 if sc > mces_eps else 0 for sc in mces_sample_scs]
                    cptce_preds = [1 if sc > cptce_eps else 0 for sc in cptce_sample_scs]
                    preds = [a | b | c | d for a, b, c, d in zip(srmed_preds, cbmed_preds, mces_preds, cptce_preds)]
                    y_pred.append(preds)
                micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
                macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
                if micro_f1 > 0.56:
                    f1s.append(micro_f1)
                    print(f"\tThreshold {srmed_eps}, {cbmed_eps}, {mces_eps}, {cptce_eps}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f} \t\t macro   F1={macro_f1*100.0:.3f} | P={macro_precision*100.0:.3f}, R={macro_recall*100.0:.3f}")

print(max(f1s))

Grid Search Experiment for finding the best thresholds for combination of previous precomputed scores
	Threshold 0.0, 0.0, 0.9, 0.5:	 micro   F1=56.410 | P=50.575, R=63.768 		 macro   F1=57.942 | P=56.350, R=69.278
	Threshold 0.0, 0.0, 0.9, 0.55:	 micro   F1=56.106 | P=51.515, R=61.594 		 macro   F1=57.299 | P=56.749, R=67.833
	Threshold 0.0, 0.05, 0.9, 0.5:	 micro   F1=56.667 | P=52.469, R=61.594 		 macro   F1=57.899 | P=57.290, R=67.403
	Threshold 0.0, 0.05, 0.9, 0.55:	 micro   F1=56.357 | P=53.595, R=59.420 		 macro   F1=57.272 | P=57.832, R=65.958
	Threshold 0.0, 0.1, 0.9, 0.5:	 micro   F1=56.376 | P=52.500, R=60.870 		 macro   F1=57.529 | P=57.123, R=66.569
	Threshold 0.0, 0.1, 0.9, 0.55:	 micro   F1=56.055 | P=53.642, R=58.696 		 macro   F1=56.902 | P=57.666, R=65.125
	Threshold 0.0, 0.15, 0.9, 0.5:	 micro   F1=56.376 | P=52.500, R=60.870 		 macro   F1=57.529 | P=57.123, R=66.569
	Threshold 0.0, 0.15, 0.9, 0.55:	 micro   F1=56.055 | P=53.642, R=58.696 		 macro   F1=56.902 | P=57.

In [126]:
final_pred = []
srmed_eps, cbmed_eps, mces_eps, cptce_eps = 0.0, 0.4, 0.9, 0.5
for srmed_sample_scs, cbmed_sample_scs, mces_sample_scs, cptce_sample_scs in zip(srmed42_scores, cbmed42_scores, mces_scores, medcptce_scores):
    srmed_preds = [1 if sc > srmed_eps else 0 for sc in srmed_sample_scs]
    cbmed_preds = [1 if sc > cbmed_eps else 0 for sc in cbmed_sample_scs]
    mces_preds = [1 if sc > mces_eps else 0 for sc in mces_sample_scs]
    cptce_preds = [1 if sc > cptce_eps else 0 for sc in cptce_sample_scs]
    preds = [a | b | c | d for a, b, c, d in zip(srmed_preds, cbmed_preds, mces_preds, cptce_preds)]
    final_pred.append(preds)
micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, final_pred)
macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, final_pred)
print(micro_f1)
print(final_pred)

0.5675675675675675
[[0, 1, 0, 1, 0, 1, 1, 1, 1], [0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1], [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0], [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1], [0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1], [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0], [0, 1, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], [0, 1, 0, 0, 1, 0, 1

### Combination for paper

In [47]:
# Do the threshold experiment, max entity score with ClinicalBERT
print("Grid Search Experiment for finding the best thresholds for combination of previous precomputed scores")
#for srmed_eps, cbmed_eps, mces_eps, cptce_eps in zip([0.0], [0.0], [0.9], [0.5]):
f1s = []
for srmed_eps in [round(x * 0.025, 3) for x in range(41)]:
    print(srmed_eps)
    for cbmed_eps in [round(x * 0.025, 3) for x in range(41)]:
        for mces_eps in [round(x * 0.025, 3) for x in range(41)]:
            for cptce_eps in [round(x * 0.025, 3) for x in range(41)]:
                y_pred = []
                for srmed_sample_scs, cbmed_sample_scs, mces_sample_scs, cptce_sample_scs in zip(srmed42_scores, cbmed42_scores, mces_scores_cb, medcptce_scores):
                    srmed_preds = [1 if sc >= srmed_eps else 0 for sc in srmed_sample_scs]
                    cbmed_preds = [1 if sc >= cbmed_eps else 0 for sc in cbmed_sample_scs]
                    mces_preds = [1 if sc >= mces_eps else 0 for sc in mces_sample_scs]
                    cptce_preds = [1 if sc >= cptce_eps else 0 for sc in cptce_sample_scs]
                    preds = [a | b | c | d for a, b, c, d in zip(srmed_preds, cbmed_preds, mces_preds, cptce_preds)]
                    y_pred.append(preds)
                micro_precision, micro_recall, micro_f1 = compute_micro_scores(y_true, y_pred)
                #macro_precision, macro_recall, macro_f1 = compute_macro_scores(y_true, y_pred)
                if micro_f1 > 0.56:
                    f1s.append(micro_f1)
                    print(f"\tThreshold {srmed_eps}, {cbmed_eps}, {mces_eps}, {cptce_eps}:\t micro   F1={micro_f1*100.0:.3f} | P={micro_precision*100.0:.3f}, R={micro_recall*100.0:.3f}")

print(max(f1s))

Grid Search Experiment for finding the best thresholds for combination of previous precomputed scores
0.0
0.025
	Threshold 0.025, 0.025, 0.825, 0.4:	 micro   F1=56.057 | P=41.696, R=85.507
	Threshold 0.025, 0.025, 0.825, 0.5:	 micro   F1=56.423 | P=43.243, R=81.159
	Threshold 0.025, 0.025, 0.825, 0.525:	 micro   F1=56.709 | P=43.580, R=81.159
	Threshold 0.025, 0.025, 0.825, 0.55:	 micro   F1=56.122 | P=43.307, R=79.710
	Threshold 0.025, 0.025, 0.825, 0.6:	 micro   F1=56.104 | P=43.725, R=78.261
	Threshold 0.025, 0.025, 0.9, 0.525:	 micro   F1=56.164 | P=53.247, R=59.420
	Threshold 0.025, 0.05, 0.825, 0.5:	 micro   F1=56.061 | P=43.023, R=80.435
	Threshold 0.025, 0.05, 0.825, 0.525:	 micro   F1=56.345 | P=43.359, R=80.435
	Threshold 0.025, 0.075, 0.825, 0.5:	 micro   F1=56.203 | P=43.191, R=80.435
	Threshold 0.025, 0.075, 0.825, 0.525:	 micro   F1=56.489 | P=43.529, R=80.435
	Threshold 0.025, 0.075, 0.9, 0.525:	 micro   F1=56.055 | P=53.642, R=58.696
	Threshold 0.025, 0.1, 0.825, 0.525:

## Answer Generation

In [19]:
# Score combination with mBERT (58.6)
final_pred = [[0, 1, 0, 1, 0, 1, 1, 1, 1], [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0], [1, 1, 0, 1, 1, 1, 1, 0, 1, 0], [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1], [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0], [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1], [0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1], [1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0], [0, 1, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], [0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], [0, 1, 1, 0, 0, 1, 0, 0, 1]]

In [30]:
# Score combination with ClinicalBERT (56.76)
final_pred = [[0, 1, 0, 1, 0, 1, 1, 1, 1], [0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1], [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0], [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1], [0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1], [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0], [0, 1, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], [0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], [0, 1, 1, 0, 0, 1, 0, 0, 1]]

In [31]:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
import torch
import torch.nn.functional as F
import itertools
import numpy as np

def concisely_paraphrase(sentence):
    few_shot = """You are a clinical assistant specialized in simplifying discharge summaries.
Your task is to take a long clinical sentence and rewrite it as a shorter, natural, and concise sentence that preserves the essential clinical information.
Do not copy the entire sentence or use unnecessary detail. Keep it factual, clear, and brief.

Sentence: The patient was admitted to the hospital due to a sudden episode of chest pain that occurred while he was gardening.
Compressed: Admitted for sudden chest pain during gardening.

Sentence: Following the MRI scan, the patient was found to have a small herniated disc at the L4-L5 level.
Compressed: MRI showed a small herniated disc at L4-L5.

Sentence: The patient has a medical history of hypertension, type 2 diabetes, and chronic kidney disease stage 3.
Compressed: History includes hypertension, diabetes, and stage 3 kidney disease.

Sentence: She was prescribed albuterol inhaler to be used as needed for episodes of shortness of breath.
Compressed: Prescribed albuterol for shortness of breath as needed.

Sentence: During his hospital stay, the patient developed a mild skin rash likely due to a reaction to antibiotics.
Compressed: Developed mild rash from antibiotics.

Sentence: The patient was advised to follow a low-sodium diet and monitor blood pressure regularly at home.
Compressed: Advised low-sodium diet and home blood pressure monitoring.

Sentence: He lives alone but receives weekly assistance from his daughter with groceries and medication management.
Compressed: Lives alone with weekly help from daughter.

Sentence: The patient’s vaccination record was updated during the follow-up visit, including influenza and tetanus boosters.
Compressed: Received flu and tetanus boosters at follow-up.
"""
    prompt = few_shot + f"\n\nSentence: {sentence}\nCompressed:"

    # Tokenization and generation
    inputs = med42_tokenizer(prompt, return_tensors="pt").to(med42_model.device)
    max_tokens = 50

    with torch.no_grad():
        outputs = med42_model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.7,
            return_dict_in_generate=True,
            output_scores=True,
            pad_token_id=med42_tokenizer.pad_token_id
        )

    # Decode
    generated_text = med42_tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    answer = generated_text[len(prompt):].split("\n")[0].strip()
    return answer

def compress_sentence(question, sentence, max_answer_words=15):
    few_shot = """You are a clinical assistant helping family members understand discharge summaries.
Your task is to answer questions based on long clinical sentences, which may include irrelevant information.
Always provide a direct, natural answer that is as concise as possible.
Do not repeat or copy any part of the question in your answer.
Do not begin the answer with phrases like “Because...” or “XYZ was recommended because...”.
If no clear answer is possible, reply with: None

Question: What treatment did the patient receive for pneumonia?
Sentence: The patient was diagnosed with pneumonia and treated with intravenous antibiotics and oxygen therapy.
Answer: He was treated with antibiotics and oxygen therapy.

Question: Why is the patient taking insulin?
Sentence: Due to a recent diagnosis of type 2 diabetes, the patient was prescribed insulin to manage blood sugar levels.
Answer: He was diagnosed with type 2 diabetes.

Question: What caused the patient's shortness of breath?
Sentence: The patient's shortness of breath was likely due to fluid accumulation in the lungs caused by heart failure.
Answer: He had lung fluid from heart failure.

Question: What mobility assistance does the patient need?
Sentence: After hip surgery, the patient requires a walker and supervision while moving.
Answer: He requires a walker and supervision.

Question: Why was a walking cane recommended to the patient?
Sentence: The patient’s vaccination record was updated during the follow-up visit, including influenza and tetanus boosters.
Answer: None

Question: What complications occurred during the patient's hospital stay?
Sentence: The patient experienced atrial fibrillation, transient confusion, and a mild allergic reaction to antibiotics during admission.
Answer: He experienced atrial fibrillation, confusion, and an allergic reaction.
"""
    prompt = few_shot + f"\n\nQuestion: {question}\nSentence: {sentence}\nAnswer:"

    # Tokenization and generation
    inputs = med42_tokenizer(prompt, return_tensors="pt").to(med42_model.device)
    max_tokens = 50

    with torch.no_grad():
        outputs = med42_model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            return_dict_in_generate=True,
            output_scores=True,
            pad_token_id=med42_tokenizer.pad_token_id
        )

    # Decode
    generated_text = med42_tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    answer = generated_text[len(prompt):].split("\n")[0].strip()
    if answer[:4] == "None":
        answer = concisely_paraphrase(sentence)
    return answer




med42_model_path = "../../models/Llama3-Med42-8B"
#med42_tokenizer = AutoTokenizer.from_pretrained(med42_model_path)
#med42_model = AutoModelForCausalLM.from_pretrained(med42_model_path, device_map="auto", torch_dtype=torch.float16)
med42_model.eval()

# GENERATE SUBMISSION ANSWERS

max_answer_words = 75
final_submission = []
for ex_id, (ex, pred) in enumerate(zip(processed_dev_data, final_pred)):
    sentence_ids = []
    compressed_sentences = []
    full_sentences = []
    one_sentence_limit = int(max_answer_words/sum(pred))
    
    # Collect compressed sentences
    for i, label in enumerate(pred):
        if label == 1:
            compressed_sentences.append(compress_sentence(ex["question_full"], " ".join(ex["sentences_ids"][i]["text"].split()), one_sentence_limit-1))
            sentence_ids.append(ex["sentences_ids"][i]["id"])
            full_sentences.append(" ".join(ex["sentences_ids"][i]["text"].split()))

    # Ensure the max answer words limit
    current_longest = 0
    prev_length = -1
    current_length = sum([len(sen.split()) for sen in compressed_sentences])
    while sum([len(sen.split()) for sen in compressed_sentences]) > max_answer_words:
        long_sen_id = np.argsort([len(sen.split()) for sen in compressed_sentences])[::-1][current_longest]
        new_sen = concisely_paraphrase(compressed_sentences[long_sen_id])
        if len(new_sen.split()) < len(compressed_sentences[long_sen_id].split()):
            compressed_sentences[long_sen_id] = new_sen
        if current_length == sum([len(sen.split()) for sen in compressed_sentences]):
            current_longest += 1
        else:
            current_length = sum([len(sen.split()) for sen in compressed_sentences])
        if current_longest == len(compressed_sentences):
            current_longest = 0
            if prev_length == current_length:
                # Remove last word from the longest sentence
                compressed_sentences[long_sen_id] = " ".join(compressed_sentences[long_sen_id].split(" ")[:-1]) + "."
                print("Warning: Last word of the longest sentence was removed")
            prev_length = current_length
            current_length = sum([len(sen.split()) for sen in compressed_sentences])
            print(current_length)
            print("Warning: Unable to shorten sentences. Retrying...")
    
    
    # If the length is small, use original sentences - start with the smallest ones
    sorted_indices = sorted(range(len(full_sentences)), key=lambda i: len(full_sentences[i]))
    for short_sentence_i in sorted_indices:
        if sum([len(sen.split()) for sen in compressed_sentences]) - len(compressed_sentences[short_sentence_i].split()) + len(full_sentences[short_sentence_i].split()) <= max_answer_words:
            compressed_sentences[short_sentence_i] = full_sentences[short_sentence_i]
        else:
            break

    # Collect compressed sentences to one answer string with cites
    answer = [com_sen + " |" + sen_id + "|" for com_sen, sen_id in zip(compressed_sentences, sentence_ids)]
    sample_submission = {
        "case_id" : ex["case_id"],
        "answer": "\n".join(answer)
    }
    
    print(f"\t{ex_id}. submission answer generated")
    print(len(sample_submission["answer"].split(' ')))
    final_submission.append(sample_submission)

print(final_submission)

	0. submission answer generated
73
	1. submission answer generated
70
	2. submission answer generated
75
	3. submission answer generated
61
	4. submission answer generated
73
	5. submission answer generated
76
	6. submission answer generated
74
	7. submission answer generated
72
94
87
83
81
80
79
78
76
	8. submission answer generated
76
	9. submission answer generated
76
78
76
76
	10. submission answer generated
75
	11. submission answer generated
66
	12. submission answer generated
66
	13. submission answer generated
33
	14. submission answer generated
71
102
91
85
81
79
77
77
	15. submission answer generated
77
	16. submission answer generated
45
	17. submission answer generated
75
	18. submission answer generated
75
	19. submission answer generated
67
[{'case_id': '1', 'answer': 'ERCP was recommended over medication because of stones and sludge in the bile duct. |2|\nHe had pus draining from the bile duct. |4|\nReturned to ERCP on day 4 post-procedure for biliary stent re-evaluation

## Evaluation

### Generate Submission File

In [32]:
import json

with open("./data/dev/cb_submission.json", "w", encoding="utf-8") as f:
    json.dump(final_submission, f, ensure_ascii=False, indent=2)

### Evaluate

In [33]:
!python scoring.py \
            --submission_path ./data/dev/cb_submission.json \
            --key_path ./data/dev/archehr-qa_key.json \
            --data_path ./data/dev/archehr-qa.xml \
            --out_file_path scores.json

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[nltk_data] Downloading package punkt to /home/lanz/nltk_data...
[nltk_data]   Package punkt is already up-to-date!

Loading submission
Number of cases in submission: 20

Validating submission

Computing factuality scores

Computing relevance scores
------------ BLEU ------------
--------- BERTSCORE ----------
--------- ALIGNSCORE ---------
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Lightning automatically upgraded your loaded checkpoint from v1.7.7 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file https:/huggingface.co/yzha/AlignScore/resolve/main/AlignScore-base.ckpt`
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and a