## Query GPT-4 for name and analysis using a toy example

In [1]:
import pandas as pd
import json 
from utils.openai_query import openai_chat
from utils.prompt_factory import make_user_prompt
from tqdm import tqdm
import openai
import os


## check example_config.json for the format of the config file
with open('./jsonFiles/GOLLMrun_config.json') as json_file:
    config = json.load(json_file)

context = config['CONTEXT']
gpt_model = config['GPT_MODEL']
temperature = config['TEMP']
max_tokens = config['MAX_TOKENS']
rate_per_token = config['RATE_PER_TOKEN']
LOG_FILE = config['LOG_NAME'] + '_log.json'
DOLLAR_LIMIT = config['DOLLAR_LIMIT']
openai.api_key = os.environ["OPENAI_API_KEY"] # set your openai api key in the environment variable or set in config
# Generate list of genes from file (file: data/go_terms_sample.csv) check notebook 0.[Prep GO] Download_and_parse_GO.ipynb
df = pd.read_csv('data/GO_term_analysis/toy_example.csv', sep = ',',index_col=0)

df['LLM Name'] = None
df['LLM Analysis'] = None
# print(df.head())

for i, row in tqdm(df.iterrows(), total=df.shape[0]):
    term_genes = row['Genes']
    genes = term_genes.split()
    prompt = make_user_prompt(genes)
    # print(prompt)
    analysis = openai_chat(context, prompt, gpt_model,temperature, max_tokens, rate_per_token, LOG_FILE, DOLLAR_LIMIT)
    
    llm_name = analysis.split("\n")[0].replace("Process: ", "")
    df.loc[i, 'LLM Name'] = llm_name
    
    llm_analysis = analysis.split('\n', 2)[2]
    df.loc[i, 'LLM Analysis'] = llm_analysis
    # go_name = row['Term_Description'].lower()
    # print(go_name)

df.to_csv('data/GO_term_analysis/LLM_processed_toy_example.tsv', index=True, sep='\t')


 10%|█         | 1/10 [00:16<02:24, 16.06s/it]

582


 20%|██        | 2/10 [00:36<02:27, 18.47s/it]

754


 30%|███       | 3/10 [00:55<02:13, 19.00s/it]

655


 40%|████      | 4/10 [01:11<01:46, 17.74s/it]

640


 50%|█████     | 5/10 [01:39<01:46, 21.40s/it]

786


 60%|██████    | 6/10 [01:56<01:19, 19.89s/it]

569


 70%|███████   | 7/10 [02:11<00:54, 18.24s/it]

560


 80%|████████  | 8/10 [02:35<00:40, 20.25s/it]

744


 90%|█████████ | 9/10 [02:53<00:19, 19.27s/it]

583


100%|██████████| 10/10 [03:09<00:00, 18.92s/it]

588





In [None]:
# test the script for batch run

input_file = 'data/GO_term_analysis/toy_example.csv'
config = './jsonFiles/GOLLMrun_config.json'
%run query_llm_for_analysis.py --input $input_file --start 0 --end 1 --config $config

## Checkout and combine the output from the batch run 

In [9]:
from glob import glob
import pandas as pd
import json

### sanity check code along the way
processed_files = glob('data/GO_term_analysis/LLM_processed_selected_go_terms*.tsv')

for file in processed_files:
    df = pd.read_csv(file, sep='\t')
    df.set_index('GO', inplace=True)
    ranges = file.split('/')[-1].split('.')[0].split('_')[5:7]
    with open(f'data/GO_term_analysis/LLM_response_go_terms_{ranges[0]}_{ranges[1]}.json') as fp:
        llm_response_dict = json.load(fp)
    for go_term, row in df.iterrows():
        if llm_response_dict[go_term] == 'NO ANALYSIS':
            print(file.split('/')[-1])
            print(f'No analysis for {go_term}')
            continue
        else:
            llm_analysis = llm_response_dict[go_term].split('\n', 2)[2]
            if df.loc[go_term, 'LLM Analysis'] != llm_analysis:
                print(f'LLM analysis for {go_term} is different')
            
    df.reset_index(inplace=True)
#     # print(ranges)
    print(df.shape)

    
combined_df = pd.concat([pd.read_csv(f, sep = '\t') for f in processed_files])
print(combined_df.shape)
print('Any duplicated GO: ',combined_df['GO'].duplicated().sum())
print('Any NAs in the LLM res: ', combined_df['LLM Name'].isna().sum())
print('Any duplicated LLM analysis: ', combined_df['LLM Analysis'].duplicated(keep=False).sum())

combined_df.to_csv('data/GO_term_analysis/LLM_processed_selected_1000_go_terms.tsv', index=False, sep='\t')

(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(50, 6)
(1000, 6)
Any duplicated GO:  0
Any NAs in the LLM res:  0
Any duplicated LLM analysis:  0
