In [1]:
import os 
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="1, 2"

import torch
print(torch.backends.cuda.is_built())
print(torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

torch.cuda.empty_cache()
    

  from .autonotebook import tqdm as notebook_tqdm


True
True
cuda


In [2]:
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
import json

# model_name = 'google/flan-t5-small'
# model_name = 'google/flan-t5-base'
# model_name = 'google/flan-t5-small'
# model_name = 'google/flan-t5-xl'
# model_name = 'google/flan-t5-large'
# model_name = 'google/flan-t5-xxl'
# model_name = "google/pegasus-pubmed"
# model_name = 'microsoft/BioGPT-Large'

model_name = "stanford-crfm/BioMedLM"


# config.n_positions = 2048


if model_name == "microsoft/BioGPT-Large":
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
elif model_name == "stanford-crfm/BioMedLM":
    model = GPT2LMHeadModel.from_pretrained(model_name)
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)

else:
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
model.to(device)

Downloading (…)lve/main/config.json: 100%|██████████| 876/876 [00:00<00:00, 90.0kB/s]
Downloading (…)"pytorch_model.bin";:   3%|▎         | 304M/10.7G [00:27<15:28, 11.2MB/s] 

KeyboardInterrupt: 

Downloading (…)"pytorch_model.bin";:   3%|▎         | 304M/10.7G [00:38<15:28, 11.2MB/s]

In [7]:
## Models to run
# bioGPT
# bioGPT large
# Flan-T5
# google/pegasus-pubmed

# ['base', 'FG_template', 'evidence-only']

num_shots = 1
mode = 'base'
instruction = None
# instruction = "Given the evidence, answer the query provided: "
# instruction = "Based on the evidence provided, what is your conclusion regarding the query? "
# instruction = "Write a summary of the evidence and query in your own words: "
# instruction = "What are the potential implications of the given evidences in relation to the query? "

In [8]:
# outputs propt template with tags to be replaced
# could avoid the problem of overlapping with query by splitting contexts before passing it to the two functions 
import random
# (contexts, queries, mode, num_shots, instruction=None):
class Prompt:
    def __init__(self, contexts, queries):
        self.contexts = contexts
        self.queries = queries
        # self.mode = mode 
        # self.num_shots = num_shots
            
    def get_context_template(self, contexts_subset, num_examples): # or you could return all possible templates as a list 
        
        if num_examples==0:
            return None
            
        # prefix on every example, suffix only at the very end of context template
        example_part = ""
        for idx_example in range(num_examples): # change this part for random sampling / all permutation / etc
            example = contexts_subset[idx_example]
            example_part += "<EX_PREFIX>"
            for i, ev in enumerate(example['evidences']):
                example_part += f"evidence {i}: {ev} \n"
            
            example_part += f"query: {example['query']} \n"
            example_part += "<EX_OUTPUT>" # either Output: or Instruction or both?
            example_part += f"{example['ground_truth']} \n"
        example_part += "<EX_SUFFIX>"
            
        return example_part

    def get_query_templates(self, queries_subset):	# returns a list of all possible query templates
        
        templates = []
        # dict_3 = {**dict_1,**dict_2}
        
        for query_idx in range(len(queries_subset)):
            query = queries_subset[query_idx]

            query_part = "<QR_PREFIX>\n" # probs will just get rid of it 
            for i, ev in enumerate(query['evidences']):
                query_part += f"evidence {i}: {ev} \n"

            query_part += f"query: {query['query']} \n"
            query_part += "<QR_SUFFIX>" # probs instruction
            
            templates.append(query_part)
            
        return templates

    def transform_templates(self, context_template, query_templates, mode, instruction):

        ## FG_template이면 query가 없고 evidence-only면 ... 뭐가 없지? 얘는 context를 날리는게 맞는 듯. 
        
        assert (mode in ['base', 'FG_template', 'evidence-only']), "invalid mode"
        
        instruction = ("Summarize given evidences and query: " if instruction==None else instruction)
        preface = "I gave a friend an instruction and five inputs. The friend read the instruction and wrote an output for every one of the inputs. Here are the input-output pairs: \n"
        
        if mode=="base":
            ex_prefix, ex_output, ex_suffix, qr_prefix, qr_suffix = "", instruction, "", "", instruction
        elif mode=="FG_template": # query part irrelevant
            ex_prefix, ex_output, ex_suffix = (preface + "\n" + "Input: \n"), "Output: \n", "\nThe instruction was: "
        elif mode=="evidence-only": # context part irrelevant
            context_template = None
            qr_prefix, qr_suffix = "", "",
            
        # 
        context_transformed = context_template.replace('<EX_PREFIX>', ex_prefix).replace('<EX_OUTPUT>', ex_output).replace('<EX_SUFFIX>', ex_suffix) if (context_template!=None) else ""
        
        if mode=='FG_template':
            queries_transformed = None
        else:
            queries_transformed = []
            for query_template in query_templates:
                temp = query_template.replace('<QR_PREFIX>', qr_prefix).replace('<QR_SUFFIX>', qr_suffix)
                queries_transformed.append(temp)
            
        return context_transformed, queries_transformed
        


    def get_prompt(self, mode, num_shots, instruction=None):
        

        # randomly select from contexts num_shots contexts
        indices = list(range(len(self.contexts)))  # create a list of indices from 0 to 9
        context_indices = random.sample(indices, num_shots)
        contexts = [self.contexts[i] for i in context_indices]
        as_queries = [self.contexts[i] for i in range(len(self.contexts)) if i not in context_indices]
        queries = self.queries + as_queries
        
        context_template = self.get_context_template(contexts, num_shots)
        query_templates = self.get_query_templates(queries)
        
        context_transformed, queries_transformed = self.transform_templates(context_template, query_templates, mode, instruction)
        
        prompts = []
        if mode != 'FG_template':
            for query_transformed in queries_transformed:
                prompt = context_transformed + query_transformed
                prompts.append(prompt)
        else:
            prompts.append(context_transformed) # if mode is FG_template
        
        return prompts
        

In [9]:
## for Google Sheets 

item1 = '''Here we report the characterization of the human Notch3 gene which we mapped to the CADASIL critical region. We have identified mutations in CADASIL patients that cause serious disruption of this gene, indicating that Notch3 could be the defective protein in CADASIL patients. All these missense mutations may result in severe disruption of the Notch3 protein, as suggested by the highly conserved nature of the aminoacid residues involved, particularly the cysteines that are key features of EGF likedomains24. These results indicate that these nucleotide substitutions are pathogenic mutations rather than rare polymorphisms.'''
item2 = '''Linkage studies in other families enabled further refinement of this genetic interval16,17 and identification of the mutated gene as NOTCH3 (Notch homolog 3).18'''
item3 = '''CADASIL, a hereditary vascular dementia, suggesting a role for Notch3 in vessel homeostasis (Joutel et al. 1996). CADASIL is a late-onset disorder, and neurological symptoms arise from a slowly developing systemic vasculopathy, characterized ultimately by degeneration of vSMC . maturation of vSMC, and ends around P28 when the artery acquires its final shape. We identify Notch3 to be the first key player of this process, by regulating cell-autonomously the arterial differentiation and maturation of vSMC.'''

item3 = "There is currently no consensus on whether CADASIL mutations generate hyperactive or hypoactive Notch3 proteins with regard to downstream signaling or whether CADASIL mutations are neutral in terms of Notch signaling. Whereas the R169C mutation appeared to lead to hyperactive Notch signaling (see above), the R1031C or C455R mutations were instead shown to be hypoactive,36 "
item4 = "We believe that increased TGFb3 reflects an inflammatory condition and even an involvement of TGFb in fibrosis in CADASIL. Support for our hypothesis is the outcome of several studies that showed NOTCH3 and TGFb1 signalling play a key role in the pathogenesis and progression of chronic cardiovascular disease. "
item5 = "Over time, VSMCs apoptosis leads to fibrosis and thickening of the arterial wall, progressive lumen stenosis and vascular insufficiency that makes the already poorly perfused terminal regions particularly susceptible to infarcts. SMCs degeneration is followed by the emergence of large"
item6 = "Previously published studies have shown that TGF-b signaling is closely associated with the activity of SYK, and the kinase activity of SYK is essential for the activation of some signaling receptor downstream effector molecules. In the present study, the activation of SYK was shown to increase the progression of peritoneal fibrosis through activation of the TGF-b1/Smad3 signaling pathway, and inhibition of TGF-b1 also resulted in down-regulation of SYK."
item7 = "Our work confirms and highlights the relevance of NOTCH3 expression and signaling in pro-inflammatory macrophage activation and identified its prominent and specific role in the activation of NF-κB. A positive regulation between NOTCH and NF-κB signaling pathway has been described previously in macrophages isolated from patients with atherosclerosis. In those patients, and in contrast with our results, "

gt = "Syk is a therapeutic target for CADASIL. In CADASIL, proinflammatory signaling is activated through the increase of TGFb1 signaling. The apoptosis of vascular smooth muscle cell (VSMC) and increased TGFb1 signaling cause fibrosis and vascular abnormalities in vascular epithelial cells. Syk increases fibrosis through activation of TGF-b1 signaling."

set1 = {"query": 'What is the cause of the CADASIL?', "evidences": [item1, item2, item3]}
set2 = {"query": 'What is the therapeutic target for CADASIL?', "evidences": [item3, item4, item5, item6, item7], "ground_truth": gt}

# test_args = [(set2, set1),]

contexts = [set2, ]
queries = [set1, ]

In [10]:
import json
with open('examples_20230207.json') as f:
    loaded_dic = json.load(f)
    contexts = loaded_dic['contexts']
    queries = loaded_dic['queries']
    
prompt_setup = Prompt(contexts, queries)


In [11]:
# assert (mode in ['base', 'FG_template', 'evidence-only']), "invalid mode"

# num_shots = 3
# mode = 'FG_template'
# instruction = ''

prompts = prompt_setup.get_prompt(mode=mode, num_shots=num_shots, instruction=instruction)

print(prompts)

['evidence 0: There is currently no consensus on whether CADASIL mutations generate hyperactive or hypoactive Notch3 proteins with regard to downstream signaling or whether CADASIL mutations are neutral in terms of Notch signaling. Whereas the R169C mutation appeared to lead to hyperactive Notch signaling (see above), the R1031C or C455R mutations were instead shown to be hypoactive,36 \nevidence 1: We believe that increased TGFb3 reflects an inflammatory condition and even an involvement of TGFb in fibrosis in CADASIL. Support for our hypothesis is the outcome of several studies that showed NOTCH3 and TGFb1 signalling play a key role in the pathogenesis and progression of chronic cardiovascular disease. \nevidence 2: Over time, VSMCs apoptosis leads to fibrosis and thickening of the arterial wall, progressive lumen stenosis and vascular insufficiency that makes the already poorly perfused terminal regions particularly susceptible to infarcts. SMCs degeneration is followed by the emerg

In [12]:
num_to_word = {}
num_to_word[0]='zero-shot'
num_to_word[1]='one-shot'
num_to_word[2]='two-shot'
num_to_word[3]='few-shot'

output_file_name = f"{model_name.replace('/', '-')}_{mode}-mode_{num_to_word[num_shots]}"
output_dic = {
    'output' : [],
    'config' : {'mode': mode, 'num_shots': num_shots}
}


# cwd = os.path.dirname(os.path.abspath(__file__))
cwd = '/home/chaeeun/workspace/biochatgpt_generation'
file_name_json = os.path.join(cwd, f"{output_file_name}.json")

for prompt in prompts:
    
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    inputs.to(device)
    outputs = model.generate(**inputs, max_length=2048, min_length=100, num_beams=5)
    generated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    print(generated)
    # generated = *generated
    
    output_dic['output'].append({'prompt': prompt, 'generated': generated})

with open(file_name_json, "w") as outfile: 
    json.dump(output_dic, outfile)
outfile.close()


    

['Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that is used to treat CADASIL. Nilvadipine is a non-steroidal anti-inflammatory drug (NSAID) that i

In [13]:
# outfile.close()

In [14]:
with open(f'{output_file_name}.txt', "w") as results: # mujeen_
    results.write(f'<< Model: {model_name} >>\n\n')
    
    for prompt in prompts:
        print('< PROMPT >\n')
        print(prompt)
        results.write('< PROMPT >\n\n')
        results.write(prompt)

        inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
        inputs.to(device)
        outputs = model.generate(**inputs, max_length=1024, min_length=100, num_beams=5)
        generated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # output_dic['output'].append({'prompt': prompt, 'generated': *generated})

        print('\n\n< GENERATED >\n')
        print(generated)
        print('\n###############################################################################################################################################################\n')
        results.write('\n\n< GENERATED >\n\n')
        results.write(*generated)
        results.write('\n\n###############################################################################################################################################################\n')
        
results.close()

< PROMPT >

evidence 0: There is currently no consensus on whether CADASIL mutations generate hyperactive or hypoactive Notch3 proteins with regard to downstream signaling or whether CADASIL mutations are neutral in terms of Notch signaling. Whereas the R169C mutation appeared to lead to hyperactive Notch signaling (see above), the R1031C or C455R mutations were instead shown to be hypoactive,36 
evidence 1: We believe that increased TGFb3 reflects an inflammatory condition and even an involvement of TGFb in fibrosis in CADASIL. Support for our hypothesis is the outcome of several studies that showed NOTCH3 and TGFb1 signalling play a key role in the pathogenesis and progression of chronic cardiovascular disease. 
evidence 2: Over time, VSMCs apoptosis leads to fibrosis and thickening of the arterial wall, progressive lumen stenosis and vascular insufficiency that makes the already poorly perfused terminal regions particularly susceptible to infarcts. SMCs degeneration is followed by t