In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
import os
from transformers import AutoTokenizer, AutoModel, BartForConditionalGeneration, BartTokenizer
import numpy as np
import torch.nn as nn
import pytorch_lightning as pl
import HQS_test_biobart
import argparse
import time
from rank_bm25 import BM25Okapi
import rouge155
import json
import rag_retrieval
from HQS_RAG_train_biobart_end2end_real import BARTTuner
num_gpus = torch.cuda.device_count()

INFO:nlp.utils.file_utils:PyTorch version 1.10.0+cu111 available.
INFO:faiss.loader:Loading faiss with AVX2 support.
INFO:faiss.loader:Successfully loaded faiss with AVX2 support.
INFO:datasets:PyTorch version 1.10.0+cu111 available.


CUDA_VISIBLE_DEVICES: 0
CUDA is available with 1 GPU(s)!


In [2]:
checkpoint_path = "biobartcheckpoint"
test_list = HQS_test_biobart.getDatalist("test_dataset")

In [None]:
args_dict = dict(
    output_dir="...",
    model_name_or_path='GanjinZero/biobart-large',
    tokenizer_name_or_path='GanjinZero/biobart-large',
    max_input_length=1024,
    max_output_length=50,
    freeze_encoder=False,
    freeze_embeds=False,
    num_train_epochs=20,
    eval_batch_size=8,
    learning_rate=0.00006,
    weight_decay=0.0,
    adam_epsilon=1e-7,
    warmup_steps=0,
    train_batch_size=2,
    gradient_accumulation_steps=16,
    n_gpu=-1,
    resume_from_checkpoint=None, 
    val_check_interval = 0.05, 
    n_val=1000,
    n_train=-1,
    n_test=-1,
    early_stop_callback=False,
    fp_16=True,
    opt_level='O1',
    max_grad_norm=1.0,
    seed=42,
    tau=1.0,
    lambda_CL=1.0,
    lambda_medical=0.0021,
    lambda_negation=0.0021,
)
args = argparse.Namespace(**args_dict)

bge_model_name = "BAAI/bge-m3"
bge_tokenizer = AutoTokenizer.from_pretrained(bge_model_name)
bge_model = AutoModel.from_pretrained(bge_model_name).to(device)

bart_model_name = "GanjinZero/biobart-large"
bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name)
bart_model = BartForConditionalGeneration.from_pretrained(bart_model_name)

class BARTTuner_e2e(pl.LightningModule):
    def __init__(self,batchsize, model, tokenizer, max_input_length, top_n, kb, p_tokenizer, pass_encoder):
        super(BARTTuner_e2e, self).__init__()
        self.batch_size = batchsize
        self.model = model
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.top_n = top_n
        self.knowledge_base = kb
        self.passage_tokenizer = p_tokenizer
        self.passage_encoder = pass_encoder
        self.bm25 = None
        self.retrieval = {}
    
    def bm25_retrieve_batch(self, queries, knowledge_base, top_n):
        tokenized_corpus = [doc.split() for doc in knowledge_base]
        bm25 = BM25Okapi(tokenized_corpus)
        top_n_indices_list = []
        top_n_scores_list = []
        num_batches = len(queries) // self.batch_size
        if len(queries) % self.batch_size != 0:
            num_batches += 1

        for batch_idx in range(num_batches):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, len(queries))

            batch_queries = queries[start_idx:end_idx]
            tokenized_queries = [query.split() for query in batch_queries]

            batch_doc_scores = []
            for tokenized_query in tokenized_queries:
                doc_scores = bm25.get_scores(tokenized_query)
                batch_doc_scores.append(doc_scores)

            batch_top_n_indices = []
            batch_top_n_scores = []
            for doc_scores in batch_doc_scores:
                top_n_indices = np.argsort(doc_scores)[::-1][:top_n]
                top_n_scores = [doc_scores[i] for i in top_n_indices]
                batch_top_n_indices.append(top_n_indices.tolist())
                batch_top_n_scores.append(top_n_scores)

            top_n_indices_list.extend(batch_top_n_indices)
            top_n_scores_list.extend(batch_top_n_scores)

        return top_n_indices_list, top_n_scores_list
    
    def get_embeddings(self, text, model, mode):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
        outputs = model(**inputs)
        if mode == "mean":
            embeddings = outputs.last_hidden_state.mean(dim=1).squeeze()
        elif mode == "cls":
            embeddings = outputs.last_hidden_state[:, 0, :].squeeze()
        elif mode == "wopool":
            embeddings = outputs.encoder_last_hidden_state.squeeze()
        else:
            raise ValueError("Invalid mode. Choose between 'mean' and 'cls' and 'wopool'.")
        
        return embeddings

    def pad_to_length(self, tensor, max_length):
        pad_size = max_length - tensor.shape[1]
        padding = (0, 0, 0, pad_size)
        return torch.nn.functional.pad(tensor, padding)

    def bm25_retrieve(self, query, knowledge_base, top_n):
        if query in self.retrieval.keys():
            return self.retrieval[query]
        tokenized_query = query.split()
        doc_scores = self.bm25.get_scores(tokenized_query)
        top_n_indices = np.argsort(doc_scores)[::-1][:top_n]
        top_n_knowledge = [knowledge_base[i] for i in top_n_indices]
        top_n_scores = [doc_scores[i] for i in top_n_indices]
        top_n_indices_list = top_n_indices.tolist()
        self.retrieval[query] = top_n_indices_list, top_n_scores
        return top_n_indices_list, top_n_scores

    def forward(self, input_embeds, attention_mask=None,
                decoder_attention_mask=None,
                lm_labels=None):
        outputs = self.model(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            labels=lm_labels
        )
        return outputs

    def training_step(self, batch, batch_idx):
        idx_batch = batch['index']
        query_batch = batch['query']
        doc_batch = batch['doc']
        target_batch = batch['target']
        kb = self.knowledge_base["validation"]
        input_embeds_list=[]
        mask_list=[]
        batch_l1_norm = torch.tensor(0.0, dtype=torch.float32, device=self.model.device)
        for q, doc in zip(query_batch, doc_batch):
            input_ids = self.tokenizer(q, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
            input_embeds = self.model.model.shared(input_ids.input_ids)
            question_length = input_embeds.shape[1]
            knowledge_embeds_list = [self.get_embeddings(doc, self.model, 'wopool')]
            knowledge_embeds_for_expanded = torch.stack(knowledge_embeds_list, dim=0)
            knowledge_embeds_for_expanded = knowledge_embeds_for_expanded.expand(input_embeds.size(0), -1, -1)
            input_embeds = torch.cat([input_embeds, knowledge_embeds_for_expanded], dim=1)
            current_length = input_embeds.shape[1]
            if current_length>args.max_input_length:
                print("current_length: ", current_length)
            concatenate_length = input_embeds.shape[1]
            input_embeds = self.pad_to_length(input_embeds, self.max_input_length)
            input_embeds_list.append(input_embeds.squeeze(0))
            concat_attention_mask = torch.zeros(self.max_input_length, dtype=torch.long)
            concat_attention_mask[:concatenate_length] = 1
            mask_list.append(concat_attention_mask)
        input_embeds_batch=torch.stack(input_embeds_list).to(self.model.device)
        mask_batch=torch.stack(mask_list).to(self.model.device)

        outputs = self.forward(
            input_embeds=input_embeds_batch,
            decoder_attention_mask=batch['decoder_attention_mask'],
            lm_labels=batch['labels']
        )
        loss = outputs[0]
        tensorboard_logs = {"train_loss": loss}
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        idx_batch = batch['index']
        query_batch = batch['query']
        doc_batch = batch['doc']
        target_batch = batch['target']
        kb = self.knowledge_base["validation"]
        input_embeds_list=[]
        mask_list=[]
        batch_l1_norm = torch.tensor(0.0, dtype=torch.float32, device=self.model.device)
        for q, doc in zip(query_batch, doc_batch):
            input_ids = self.tokenizer(q, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
            input_embeds = self.model.model.shared(input_ids.input_ids)
            question_length = input_embeds.shape[1]
            knowledge_embeds_list = [self.get_embeddings(doc, self.model, 'wopool')]
            knowledge_embeds_for_expanded = torch.stack(knowledge_embeds_list, dim=0)
            knowledge_embeds_for_expanded = knowledge_embeds_for_expanded.expand(input_embeds.size(0), -1, -1)
            input_embeds = torch.cat([input_embeds, knowledge_embeds_for_expanded], dim=1)
            concatenate_length = input_embeds.shape[1]
            input_embeds = self.pad_to_length(input_embeds, self.max_input_length)
            input_embeds_list.append(input_embeds.squeeze(0))
            concat_attention_mask = torch.zeros(self.max_input_length, dtype=torch.long)
            concat_attention_mask[:concatenate_length] = 1
            mask_list.append(concat_attention_mask)
        input_embeds_batch=torch.stack(input_embeds_list).to(self.model.device)
        mask_batch=torch.stack(mask_list).to(self.model.device)
        
        target=batch["target"]
        outs = self.model.generate(
                inputs_embeds=input_embeds_batch,
                max_new_tokens=args.max_output_length,
                use_cache=True,
                num_beams=5,
                repetition_penalty=1, 
                length_penalty=2, 
                early_stopping=True,
                )
        preds = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in outs]
        scores = rouge155.rouge_eval(target, preds)
        loss = scores['rouge_l_f_score']
        tensorboard_logs = {"val_rouge": loss}
        return {"val_rouge": loss, "log": tensorboard_logs}
    
    def validation_epoch_end(self, outputs):
        avg_val_loss = sum(x["val_rouge"] for x in outputs) / len(outputs)
        tensorboard_logs = {'avg_val_rouge': avg_val_loss}
        return {'avg_val_rouge': avg_val_loss, "log": tensorboard_logs, 'progress_bar': tensorboard_logs}

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size,
                          num_workers=4)

    def val_dataloader(self):
        return DataLoader(validation_dataset,
                          batch_size=self.batch_size,
                          num_workers=4)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
        return optimizer

p_tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-large")
p_encoder = AutoModel.from_pretrained("GanjinZero/biobart-large")
model = BARTTuner_e2e.load_from_checkpoint( 
    checkpoint_path,
    batchsize=4, 
    model=bart_model, 
    tokenizer=bart_tokenizer,
    max_input_length=1024,
    top_n=5,
    kb=[""],
    p_tokenizer=p_tokenizer,
    pass_encoder=p_encoder)
passage_tokenizer = model.passage_tokenizer
passage_encoder = model.passage_encoder
passage_encoder = passage_encoder.to(device)
model.eval()
model.to(device)
RESULTS=[]
output_path = "checkpointsavepath"+checkpoint_path.split(".ckpt")[0].split("/")[-1]+".txt"
def get_embeddings(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze()

def get_embeddings_cls(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].squeeze()

def bge_retrieve(top_n, input_text, db, db_embeds):
    instruction = input_text
    if db_embeds==None:
        knowledge_embeddings = torch.stack([get_embeddings(doc, bge_tokenizer, bge_model) for doc in db]).to(device)
    else:
        knowledge_embeddings = db_embeds
    instruction_embedding = get_embeddings(instruction, bge_tokenizer, bge_model).to(device)
    similarities = torch.matmul(knowledge_embeddings, instruction_embedding) / (
        torch.norm(knowledge_embeddings, dim=1) * torch.norm(instruction_embedding)
    )
    top_n_indices = torch.argsort(similarities, descending=True)[:top_n]
    top_n_knowledge = [db[i] for i in top_n_indices]
    top_n_scores = [similarities[i].item() for i in top_n_indices]
    top_n_indices_list = top_n_indices.cpu().numpy().tolist()
    tmp = {"idxs":top_n_indices_list, "scores":top_n_scores}
    RESULTS.append(tmp)
    return top_n_indices_list, top_n_scores

def bm25_retrieve(query, knowledge_base, top_n):
    tokenized_corpus = [doc.split() for doc in knowledge_base]
    bm25 = BM25Okapi(tokenized_corpus)
    tokenized_query = query.split()
    doc_scores = bm25.get_scores(tokenized_query)
    top_n_indices = np.argsort(doc_scores)[::-1][:top_n]
    top_n_knowledge = [knowledge_base[i] for i in top_n_indices]
    top_n_scores = [doc_scores[i] for i in top_n_indices]
    top_n_indices_list = top_n_indices.tolist()
    tmp = {"idxs":top_n_indices_list, "scores":top_n_scores}
    RESULTS.append(tmp)
    return top_n_indices_list, top_n_scores

def retrieve_definition(query, knowledge_base):
    for entry in knowledge_base:
        if entry["question"] == query:
            return entry.get("q_definition", " ")
    return " "

def bge_reranker(query, knowledge_base, top_re, top_n):
    top_re_indices_list, top_re_scores = bm25_retrieve(query, knowledge_base, top_re)
    kb_rerank = [knowledge_base[idx] for idx in top_re_indices_list]
    top_n_indices_list, top_n_scores = bge_retrieve(top_n, query, kb_rerank)
    return kb_rerank, top_n_indices_list, top_n_scores
NULL_COUNT=0
def inference(model, top_n, input_text, doc, retriever):
    bart_input_text = input_text
    bart_inputs = bart_tokenizer(bart_input_text, return_tensors="pt", padding=True, truncation=True)
    bart_inputs = bart_inputs.to(device)
    input_embeds = bart_model.model.shared(bart_inputs.input_ids)
    knowledge_embeds_list = model.get_embeddings([doc], model.model, 'wopool')
    knowledge_embeds_for_bart = torch.stack([knowledge_embeds_list], dim=0)
    input_length = bart_inputs.input_ids.shape[-1]
    knowledge_embeds_for_expanded = knowledge_embeds_for_bart.expand(input_embeds.size(0), -1, -1)
    if knowledge_embeds_for_expanded.shape[1]<4:
        NULL_COUNT+=1
    input_embeds = torch.cat([input_embeds, knowledge_embeds_for_expanded], dim=1)
    if input_embeds.shape[1] > args.max_input_length:
        input_embeds = input_embeds[:, :args.max_input_length, :]
    concatenate_length = input_embeds.shape[1]
    concat_attention_mask = torch.ones(concatenate_length, dtype=torch.long)
    concat_attention_mask = concat_attention_mask.unsqueeze(dim=0).to(device)
    input_embeds = input_embeds.to(device)
    outs = model.model.generate(
                inputs_embeds=input_embeds,
                attention_mask=concat_attention_mask,
                max_new_tokens=args.max_output_length,
                use_cache=True,
                num_beams=5,
                repetition_penalty=1, 
                length_penalty=2, 
                early_stopping=True,
                )
    sentence = bart_tokenizer.decode([int(id) for id in outs[0]], skip_special_tokens=True)
    return sentence

def infer_dataset(dataset, model, top_n, db, retriever, des):
    res = []
    db_embeds = 0
    for i, item in enumerate(dataset):
        instruction = item["question"]
        doc = item["retrieval"][0]["doc"]
        sentence = inference(model, top_n, instruction, doc, retriever)
        sentence = sentence.replace("\n", "")
        res.append(sentence)
    with open(des, "w")as f:
        for item in res:
            f.write(item+"\n")

def read_dict(path):
    with open(path, 'r') as file:
        data = file.read()
    dictionary = json.loads(data)
    return dictionary

dataset = "yahoo"
retriever_list=["bm25", "bge_rerank", "bge"]
retriever = retriever_list[0]
targets_dict = {
    "yahoo":".../test.target"}
targets_path = targets_dict[dataset]
train_list = HQS_test_biobart.getDatalist("train_dataset".format(dataset))
knowledge_base_v1 = [item["question"] for item in train_list]

def load_datalist(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

test_dataset = HQS_test_biobart.MyDataset(test_list, bart_tokenizer, args.max_input_length, args.max_output_length)
do_eva = True

def read_dict(path):
    with open(path, 'r') as file:
        data = file.read()
    dictionary = json.loads(data)
    return dictionary

query=train_list[0]["question"]
top_n_indices_list, top_n_scores = bm25_retrieve(query, knowledge_base_v1, 5)
bart_input_text = query
bart_inputs = bart_tokenizer(bart_input_text, return_tensors="pt", padding=True, truncation=True)
bart_inputs = bart_inputs.to(device)
input_embeds = bart_model.model.shared(bart_inputs.input_ids)
db=knowledge_base_v1
knowledge_embeds_list = [
get_embeddings(db[idx], bart_tokenizer, bart_model.model)
knowledge_embeds_for_bart = torch.stack(knowledge_embeds_list, dim=0)

In [None]:
NULL_COUNT=0
for i in [0]:
    RESULTS=[]
    top_n=i+1
    decodes_path = "savepath".format(dataset, top_n, retriever)
    start_time = time.time()
    infer_dataset(test_list, model, top_n, kb_test, retriever, decodes_path)
    end_time = time.time()
    elapsed_time = end_time - start_time
    if do_eva:
        print(rouge155.calculate_rouge155_md(targets_path, decodes_path))
    for i, item in enumerate(RESULTS):
        item['index'] = i
    with open('save', 'w', encoding='utf-8') as f:
        json.dump(RESULTS, f, ensure_ascii=False, indent=4)