## Settings

In [None]:
from transformers import ( 
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    AutoModel,
    Trainer,
    pipeline,
    DataCollatorForSeq2Seq,
    AutoModelForSeq2SeqLM
)
import polars as pl
import warnings
from tqdm import tqdm
import glob, os
from PyPDF2 import PdfReader
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import torch
import torch._dynamo
from datasets import Dataset
import json
import faiss
import re
from peft import LoraConfig, get_peft_model, TaskType, PeftModel, prepare_model_for_kbit_training
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from pathlib import Path
from sentence_transformers import CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity as skl_cosine
from typing import List, Tuple

def chunk_text(text, max_length=500, overlap=50):
    sentences = text.split(". ")
    chunks = []
    current_chunk = ""

    for sent in sentences:
        if len(current_chunk) + len(sent) + 2 <= max_length:
            current_chunk += sent + ". "
        else:
            chunks.append(current_chunk.strip())
            overlap_words = current_chunk.split()[-overlap:]
            current_chunk = " ".join(overlap_words) + " " + sent + ". "
    if current_chunk:
        chunks.append(current_chunk.strip())
    return chunks

def cosine_similarity(a, b):
    a_norm = a / (np.linalg.norm(a) + 1e-8)
    b_norm = b / (np.linalg.norm(b) + 1e-8)
    return float(np.dot(a_norm, b_norm))
    
warnings.filterwarnings("ignore")
torch._dynamo.config.suppress_errors = True
MODEL_NAME ="Intelligent-Internet/II-Medical-8B-1706"

## Finetuning

In [None]:
'''
doc_type = pl.read_excel('datasets_folder/SNUHNOTE/1.0.1/document_type_level_mst.xlsx')
doc_type = doc_type.filter(
    (pl.col("doc_type_id") == "D008") & 
    (pl.col("mdfm_name").str.contains("감염"))
)
# mdfm_id = 41813, 41814, 42345 for 타과의뢰회신
'''
dfs = []
for i in tqdm(range(1, 8)):
    df = pl.read_excel(f'datasets_folder/SNUHNOTE/1.0.1/4_DGNS/DGNS_{i}.xlsx')
    df1 = df.filter(
        ((pl.col("level_path").str.contains("41813"))|(pl.col("level_path").str.contains("41814"))|(pl.col("level_path").str.contains("42345"))) &
        ~(pl.col("content")=='승인하였습니다.\n') & ~(pl.col("content")=='승인하였습니다. \n') &  ~(pl.col("content")=='승인하였습니다. \n\n') & ~(pl.col("content")=='승인하였습니다.\n\n') &
        ~(pl.col("content")=='승인하였습니다.\n\n\n') & ~(pl.col("content")=='승인하였습니다. \n\n\n') &
        ~(pl.col("content")=='승인하겠습니다.\n') & ~(pl.col("content")=='승인하겠습니다. \n') &  ~(pl.col("content")=='승인하겠습니다. \n\n') & ~(pl.col("content")=='승인하겠습니다.\n\n') &
        ~(pl.col("content")=='승인하겠습니다.\n\n\n') & ~(pl.col("content")=='승인하겠습니다. \n\n\n') &
        ~(pl.col("content")=='투약 승인하였습니다.\n') & ~(pl.col("content")=='투약 승인하였습니다. \n') & ~(pl.col("content")=='투약 승인하였습니다. \n\n') &~(pl.col("content")=='투약 승인하였습니다.\n\n') &
        ~(pl.col("content").str.contains("문의")) & ~(pl.col("content").str.contains("의뢰")) &~(pl.col("content").str.contains("취소된 처방")) &
        ((pl.col("content").str.contains("추천"))|(pl.col("content").str.contains("권장"))|(pl.col("content").str.contains("승인"))|(pl.col("content").str.contains("고려"))|
        (pl.col("content").str.contains("니다")|(pl.col("content").str.contains("권고"))))
    )
    df1 = df1.sort(["nid", "rec_dt_offset"])
    df1 = df1.group_by("nid").agg([
        pl.col("rec_dt_offset").first().alias("rec_dt_offset"),
        pl.col("level_path").first().alias("level_path"),
        pl.col("content").implode().alias("content_list")
    ])
    df1 = df1.with_columns(
        pl.col("content_list").list.join(" ").alias("content")
    ).select(["nid", "rec_dt_offset", "content"])

    df2 = df.filter(
        ((pl.col("level_path").str.contains("41813"))|(pl.col("level_path").str.contains("41814"))|(pl.col("level_path").str.contains("42345"))) &
        ((pl.col("content").str.contains("문의"))|(pl.col("content").str.contains("상의"))|(pl.col("content").str.contains("의뢰"))) &
        ~((pl.col("content").str.contains("추천"))|(pl.col("content").str.contains("권장"))|(pl.col("content").str.contains("승인"))|(pl.col("content").str.contains("고려"))|
        (pl.col("content").str.contains("바랍니다"))|(pl.col("content").str.contains("의뢰주셔서"))|(pl.col("content").str.contains("의뢰 감사"))|(pl.col("content").str.contains("권고"))|
        (pl.col("content").str.contains("분 이내"))|(pl.col("content").str.contains("해당없음"))|(pl.col("content").str.contains("니다")) )
    )

    df2 = df2.sort(["nid", "rec_dt_offset"])
    df2 = df2.group_by("nid").agg([
        pl.col("rec_dt_offset").alias("rec_dt_offset_list"),
        pl.col("content").alias("content_list")
    ])
    
    df2_long = (
        df2
        .select(["nid", "rec_dt_offset_list", "content_list"])
        .explode(["rec_dt_offset_list", "content_list"])
        .rename({"rec_dt_offset_list": "df2_rec_dt_offset",
                 "content_list":        "cst"})          
        .sort(["nid", "df2_rec_dt_offset"])
    )

    final_df = (
        df1.sort(["nid", "rec_dt_offset"])               
           .join_asof(                                   
                df2_long,
                left_on="rec_dt_offset",
                right_on="df2_rec_dt_offset",
                by="nid",
                strategy="backward"                      
           )
           .select(["nid", "rec_dt_offset", "content", "cst"])
    )
    dfs.append(final_df)

x1 = pl.concat(dfs)
x1 = x1.sort("rec_dt_offset").group_by("nid").first()
x1 = x1.filter(pl.col("cst").is_not_null())
def remove_im_to_baesang(text):
    return re.sub(r'IM((?:(?!배상).)*?)배상', '', text, flags=re.DOTALL)
x1 = x1.with_columns(
    pl.col('content').map_elements(remove_im_to_baesang).alias('content')
)

qa_pairs = []
for row in tqdm(x1.iter_rows(named=True), total=x1.height, desc="Generating QA pairs"):  # progress bar
    question = (
        "@@@ Task: \n"
        "You are an infectious-disease consultant.  Compose ONE brief consultation note as the [@@@ Answer] that obeys **all** rules 1~4 below:\n"
        "1. Respond including 3 categories below:\n"
        "   • (1) suspected / confirmed pathogen (if any)\n"
        "   • (2) recommended medications (if any)\n"
        "   • (3) additional labs / cultures (if any)\n"
        "2. You must NOT repeat any phrase or idea in each sentence.\n"
        "3. Do not copy ANY text from [@@@ Attending physician’s message].\n"
        "4. Write in complete sentences and ALWAYS finish each sentence with a period.\n\n"
        "@@@ Attending physician’s message\n"
        f"{row['cst']}\n\n"
        "@@@ Answer:"
    )
    answer = row["content"]
    qa_pairs.append({
        "question": question,
        "answer": answer
    })
with open("qa_pairs.json", "w", encoding="utf-8") as f:
    json.dump(qa_pairs, f, ensure_ascii=False, indent=2)

In [None]:
MODEL_NAME = "facebook/nllb-200-3.3B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, src_lang="kor_Hang")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def translate_ko2en(text: str, max_len: int = 512, num_beams: int = 6) -> str:
    """한국어 → 영어 단일 문장/문단 번역."""
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_len,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"), 
            max_length=max_len,
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    
SRC_JSON = Path("qa_pairs.json")
DST_JSON = Path("qa_pairs_en.json")

with SRC_JSON.open("r", encoding="utf-8") as f:
    data = json.load(f)
    
translated = []
for item in tqdm(data, desc="Translating 'question' and 'answer' fields"):
    new_item = item.copy()  # 원본 유지
    if isinstance(item.get("answer"), str):
        translated_answer = translate_ko2en(item["answer"])
        translated_answer = re.sub(r'(?<=\S)investment(?=\S)', 'administration', translated_answer, flags=re.IGNORECASE)
        new_item["answer"] = translated_answer
    translated.append(new_item)

with DST_JSON.open("w", encoding="utf-8") as f:
    json.dump(translated, f, ensure_ascii=False, indent=2)
print(f"Saved translated data ➜ {DST_JSON.resolve()}")

In [None]:
# 1. LoRA 설정 --------------------------------------------------------------
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,            # <- real 4-bit weights (nf4)
    bnb_4bit_quant_type="nf4",    # Normal-Float-4 (best accuracy)
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",            
    quantization_config=bnb_cfg,
    token = "YOUR_API_KEY"
)
model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj","k_proj","v_proj","o_proj"],  # change to model’s names
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_cfg)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,
                                         token = "YOUR_API_KEY")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.print_trainable_parameters()
if torch.cuda.is_available() and hasattr(torch, "compile"):
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
# 2. 데이터 로드 및 분할 -------------------------------------------------------
with open("qa_pairs_en.json", "r", encoding="utf-8") as f:
    raw_data = json.load(f)
prompts, labels = [], []
for sample in raw_data:
    prompts.append(f"Question: {sample['question']}\n")
    labels.append(" " + str(sample["answer"]))
train_p, test_p, train_l, test_l = train_test_split(prompts, labels, test_size=0.1, random_state=42)
train_dataset = Dataset.from_dict({"prompt": train_p, "completion": train_l})
# 3. 토크나이저 --------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
# 4. Tokenize (tip 2) --------------------------------------------------------
MAX_LEN = 256
def tokenize_fn(example):
    full_text = example["prompt"] + example["completion"]
    tok_full = tokenizer(full_text, truncation=True, max_length=MAX_LEN, padding=False)
    input_ids = tok_full["input_ids"]
    labels = input_ids.copy()
    prompt_len = len(tokenizer(example["prompt"], truncation=True, max_length=MAX_LEN)["input_ids"])
    for i in range(min(prompt_len, len(labels))):
        labels[i] = -100
    return {"input_ids": input_ids, "attention_mask": tok_full["attention_mask"], "labels": labels}
train_tok = train_dataset.map(tokenize_fn, remove_columns=["prompt", "completion"])
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
# 5. Training arguments (tips 4 & 6) ----------------------------------------
training_args = TrainingArguments(
    output_dir="lora_rag",
    num_train_epochs=2,
    per_device_train_batch_size=2,  
    gradient_accumulation_steps=1,  
    learning_rate=1e-4,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    save_strategy="epoch",  # checkpoint once per epoch
    save_total_limit=2,
    lr_scheduler_type="cosine",
    report_to="none",
    label_names=[],
    remove_unused_columns=False,  
)
# 6. Trainer -----------------------------------------------------------------
trainer = Trainer(model=model, args=training_args, train_dataset=train_tok, data_collator=data_collator)
trainer.train()
trainer.save_model("lora_rag_weight")

'''
Step	Training Loss
50	3.748400
100	1.965000
150	2.850100
200	1.830300
250	0.645100
'''

## RAG

In [None]:
rag_folder = "rag/"
pdf_files = glob.glob(os.path.join(rag_folder, "*.pdf"))
records = []

for pdf_path in tqdm(pdf_files, desc="Parsing PDFs"):
    try:
        reader = PdfReader(pdf_path)
        text = ""
        for page in reader.pages:
            text += page.extract_text() or ""
        text = text.strip()
        if text:
            records.append({
                "title": os.path.basename(pdf_path),
                "main_text": text
            })
    except Exception as e:
        print(f"Failed to parse {pdf_path}: {e}")

df = pd.DataFrame(records, columns=["title","main_text"])
df.to_csv('review_articles.csv', index=False)

In [None]:
df = pd.read_csv('review_articles.csv')
all_chunks = []
for example in tqdm(df['main_text']):
    for c in chunk_text(example, max_length=500, overlap=10):
        all_chunks.append(c)
print(f"chunk total: {len(all_chunks)}")

# Use MedCPT as embedder
MODEL_NAME = "ncbi/MedCPT-Query-Encoder"
medcpt_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
medcpt_model = AutoModel.from_pretrained(MODEL_NAME)
medcpt_model.eval()
medcpt_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
medcpt_model = medcpt_model.to(medcpt_device)
BATCH_SIZE = 16

with torch.no_grad():
    dummy = medcpt_tokenizer("dummy", return_tensors="pt", max_length=64, truncation=True)
    dummy = {k: v.to(medcpt_device) for k, v in dummy.items()}
    dummy_emb = medcpt_model(**dummy).last_hidden_state[:, 0, :]
    embedding_dim = dummy_emb.size(-1)

index = faiss.IndexFlatL2(embedding_dim)
chunk_embeddings_list = []
for i in tqdm(range(0, len(all_chunks), BATCH_SIZE)):
    batch_chunks = all_chunks[i:i+BATCH_SIZE]
    with torch.no_grad():
        encoded = medcpt_tokenizer(
            batch_chunks,
            truncation=True,
            padding=True,
            return_tensors='pt',
            max_length= 512
        )
        encoded = {k: v.to(medcpt_device) for k, v in encoded.items()}
        emb = medcpt_model(**encoded).last_hidden_state[:, 0, :]  # [BATCH_SIZE, 768]
        emb = emb.cpu().numpy().astype("float32")
        chunk_embeddings_list.append(emb)
chunk_embeddings = np.concatenate(chunk_embeddings_list, axis=0)
index.add(chunk_embeddings)

faiss.write_index(index, "faiss_index.index")

## Evaluation

In [None]:
torch.cuda.empty_cache()

index = faiss.read_index("faiss_index.index")
df = pd.read_csv('review_articles.csv')
all_chunks = []
for example in tqdm(df['main_text']):
    for c in chunk_text(example, max_length=500, overlap=50):
        all_chunks.append(c)

MEDCPT_MODEL_NAME = "ncbi/MedCPT-Query-Encoder"
medcpt_tokenizer = AutoTokenizer.from_pretrained(MEDCPT_MODEL_NAME)
medcpt_model = AutoModel.from_pretrained(MEDCPT_MODEL_NAME)
medcpt_model.eval()
medcpt_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
medcpt_model = medcpt_model.to(medcpt_device)

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,            # <- real 4-bit weights (nf4)
    bnb_4bit_quant_type="nf4",    # Normal-Float-4 (best accuracy)
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",            
    quantization_config=bnb_cfg,
    token = "YOUR_API_KEY",
    
)
model = PeftModel.from_pretrained(model, "lora_rag_weight")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,
                                         token = "YOUR_API_KEY")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def medcpt_embed(texts):
    embs = []
    with torch.no_grad():
        for t in texts:
            encoded = medcpt_tokenizer(
                t,
                truncation=True,
                padding=True,
                return_tensors='pt',
                max_length=512,
            )
            encoded = {k: v.to(medcpt_device) for k, v in encoded.items()}
            emb = medcpt_model(**encoded).last_hidden_state[:, 0, :].cpu().numpy().astype("float32")[0]
            embs.append(emb)
    embs = np.stack(embs)
    return embs

class AdvancedRetriever:
    """
    • Stage-1  : dense search with MedCPT + FAISS
    • Stage-2  : Cross-Encoder re-ranking (point-wise)
    • Optional : MMR diversification
    """
    def __init__(
        self,
        faiss_index: faiss.Index,
        all_chunks: List[str],
        embed_fn,                         
        cross_encoder_name: str = "cross-encoder/ms-marco-MiniLM-L6-v2",
        device: str | None = None,
    ):
        self.index        = faiss_index
        self.all_chunks   = all_chunks
        self.embed_fn     = embed_fn
        self.cross_enc    = CrossEncoder(
            cross_encoder_name,
            device=device or ("cuda" if torch.cuda.is_available() else "cpu")
        )
    def retrieve(
        self,
        query: str,
        recall_k: int = 50,
        rerank_k: int = 8,
        use_mmr: bool = True,
        lambda_mmr: float = 0.6,
    ) -> Tuple[List[str], List[int]]:
        """
        Returns:
            texts   – final context strings
            indices – their original indices in `all_chunks`
        """
        # 1️⃣
        q_emb = self.embed_fn([query]).astype("float32")
        _, I  = self.index.search(q_emb, recall_k)
        cand_idx   = I[0].tolist()
        cand_texts = [self.all_chunks[i] for i in cand_idx]
        # 2️⃣
        pairs       = list(zip([query] * len(cand_texts), cand_texts))
        ce_scores   = self.cross_enc.predict(pairs, convert_to_numpy=True)
        ranked      = sorted(
            zip(cand_idx, cand_texts, ce_scores),
            key=lambda x: x[2],
            reverse=True
        )[: rerank_k]
        idxs, texts, _ = zip(*ranked)
        # 3️⃣
        if use_mmr and rerank_k > 1:
            embs      = self.embed_fn(list(texts))
            order     = self._mmr(embs, q_emb[0], k=rerank_k, lamb=lambda_mmr)
            idxs      = [idxs[i]  for i in order]
            texts     = [texts[i] for i in order]
        return list(texts), list(idxs)
    @staticmethod
    def _mmr(doc_embs: np.ndarray, query_emb: np.ndarray,
             k: int = 20, lamb: float = 0.6) -> List[int]: # Selecting Top-K
        """
        Maximal Marginal Relevance.
        """
        selected, cand = [], list(range(len(doc_embs)))

        # relevance of each doc to the query
        rel = skl_cosine(doc_embs, query_emb.reshape(1, -1)).flatten()

        while len(selected) < min(k, len(doc_embs)):
            if not selected:
                sel = int(np.argmax(rel))
            else:
                redund = skl_cosine(
                    doc_embs[cand],        
                    doc_embs[selected]     
                ).max(axis=1)
                mmr = lamb * rel[cand] - (1 - lamb) * redund
                sel = cand[int(np.argmax(mmr))]
            selected.append(sel)
            cand.remove(sel)
        return selected

class RAGGenerator:
    def __init__(self, retriever: AdvancedRetriever, llm, tokenizer):
        self.ret   = retriever
        self.llm   = llm
        self.tok   = tokenizer
        if self.tok.pad_token is None:
            self.tok.pad_token = self.tok.eos_token
    @torch.inference_mode()
    def __call__(
        self,
        query: str,
        recall_k: int = 50,
        rerank_k: int = 8,
        max_new_tokens: int = 128,
        top_p: float = 0.4,
        temperature: float = 0.1,
    ) -> str:
        ctx_texts, _ = self.ret.retrieve(query, recall_k, rerank_k)
        evidence = "\n\n".join(f"[Doc {i+1}]\n{txt}" for i, txt in enumerate(ctx_texts))
        prompt   = (
            "@@@ Evidence\n"
            f"{evidence}\n\n"
            "@@@ Instruction: Base your [@@@ Answer] ONLY on the evidence above. "
            "Do NOT copy phrases verbatim and keep it under 200 characters.\n\n"
            f"Question: {query}\n@@@ Answer"
        )
        inputs = self.tok(
            prompt, return_tensors="pt", truncation=True, max_length=4096
        ).to(self.llm.device)
        output = self.llm.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            eos_token_id=self.tok.eos_token_id,
            pad_token_id=self.tok.eos_token_id,
            no_repeat_ngram_size=3,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=1.15,
            do_sample=True,
        )
        answer = self.tok.decode(output[0], skip_special_tokens=True)
        return answer.split("@@@ Answer")[-1].strip()

adv_retriever = AdvancedRetriever(
    faiss_index=index,
    all_chunks=all_chunks,
    embed_fn=medcpt_embed,
)

rag_gen = RAGGenerator(
    retriever=adv_retriever,
    llm=model,
    tokenizer=tokenizer,
)

In [None]:
torch.cuda.empty_cache()

with open("qa_pairs_en.json", "r", encoding="utf-8") as f:
    raw_data = json.load(f)
prompts, labels = [], []
for sample in raw_data:
    prompts.append(f"Question: {sample['question']}")
    labels.append(" " + str(sample["answer"]))
train_p, test_p, train_l, test_l = train_test_split(prompts, labels, test_size=0.1, random_state=1)
train_dataset = pd.DataFrame({"prompt": train_p, "completion": train_l})

gen_ans, ref_ans, qs = [], [], []
for p, ref in tqdm(list(zip(test_p, test_l)), desc="Testing (RAG)"):
    ans =rag_gen(p)
    gen_ans.append(ans)
    try:
        ref_dict = eval(ref.strip())
        if isinstance(ref_dict, dict):
            last_key = list(ref_dict.keys())[-1]
            last_value = ref_dict[last_key]
            ref_ans.append(str(last_value))
        else:
            ref_ans.append(ref.strip())
    except Exception:
        ref_ans.append(ref.strip())
    qs.append(p)

def cut_to_last_period(text):
    idx = text.rfind('.')
    if idx != -1:
        return text[:idx+1].strip()
    else:
        return text.strip()

gen_ans = [cut_to_last_period(ans) for ans in gen_ans]

gen_embs = medcpt_embed(gen_ans)
ref_embs = medcpt_embed(ref_ans)

num   = (gen_embs * ref_embs).sum(axis=1)
denom = np.linalg.norm(gen_embs, axis=1) * np.linalg.norm(ref_embs, axis=1)
sim   = num / denom

bleu_refs = [[ref.split()] for ref in ref_ans]
bleu_hyps = [gen.split() for gen in gen_ans]
bleu_score = corpus_bleu(bleu_refs, bleu_hyps, smoothing_function=SmoothingFunction().method1)

rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
rouge_l_f1s = []
for ref, gen in zip(ref_ans, gen_ans):
    scores = rouge.score(ref, gen)
    rouge_l_f1s.append(scores['rougeL'].fmeasure)
rouge_l_f1 = sum(rouge_l_f1s) / len(rouge_l_f1s) if rouge_l_f1s else 0.0

print(f"\nTest set mean cosine similarity: {np.mean(sim):.4f}")
print(f"Test set BLEU: {bleu_score:.4f}")
print(f"Test set ROUGE-LCS F1: {rouge_l_f1:.4f}")

for i in range(min(3, len(qs))):
    print(f"\nQ: {qs[i]}\nReference: {ref_ans[i]}\nGenerated: {gen_ans[i]}\nCosine similarity: {sim[i]:.4f}") 

'''
<Example >
Reference: (1) Antibiotics are recommended to be maintained as they are, and adjusted according to the results of culture. (2) Esophageal candidasis is also suspected, and fluconazole dosage is recommended to be changed to 400 mg on day 1, then 200 to 400 mg once daily as a normal new functional baseline. Currently, after application of PTGBD, a reduction in inflammatory markers and a fever spike is not observed, but a suppressed RUQ tenderness is observed.
Generated: : #. There is no specific recommendation for empirical antibiotic therapy because there is limited information about the previous medical condition and current status of the patient who has already recovered from the asymptomatic stage of the past infection until now, but only empirically administered vancomycin and meropenem were given up to the time of confirmation of the BAL culture result, so please continue administration of vanco + mero according to the doctor's discretion until the recovery of the lung function is completed. If you want to change the medication, we recommend confirming the susceptibility test of the cultured bacteria before changing the drug.
'''