In [1]:
import os
import pandas as pd
import json 
import openai

In [2]:
from utils.openai_query import openai_chat
from utils.prompt_factory import make_user_prompt_with_score
from utils.llm_analysis_utils import process_analysis, save_progress

In [3]:
openai.api_key = os.environ["OPENAI_API_KEY"] # Environment variable

In [4]:
geneSep = " "
inputFilePath = "data/omics.txt"
jsonFilePath = "jsonFiles/OmicsRunLLM.json"
genesCol = "GeneList"
nameCol  = "GeneSetName"
outputFilePath = "data/omics_LLM_DF.tsv"

In [5]:
with open(jsonFilePath) as json_file:
    config = json.load(json_file)
    
context = config['CONTEXT']
gpt_model = config['GPT_MODEL']
temperature = config['TEMP']
max_tokens = config['MAX_TOKENS']
rate_per_token = config['RATE_PER_TOKEN']
LOG_FILE = config['LOG_NAME'] + 'log.json'
DOLLAR_LIMIT = config['DOLLAR_LIMIT']

In [6]:
SEED = 42

In [7]:
gpt_model

'gpt-4-1106-preview'

### Run GPT-4 query pipeline for NeST gene sets

In [8]:
df = pd.read_csv(inputFilePath, sep = "\t"); 

df['LLM Name'] = None
df['LLM Analysis'] = None
df['Score'] = None

In [None]:
for i, row in df.iterrows():
    term_genes = row[genesCol]
    genes = term_genes.split(geneSep) 
    
    prompt = make_user_prompt_with_score(genes)

    analysis, finger_print = openai_chat(context, prompt, gpt_model, temperature, max_tokens, rate_per_token, LOG_FILE, DOLLAR_LIMIT, SEED)

    if analysis:
        llm_name, llm_score, llm_analysis = process_analysis(analysis)
        df.loc[i, 'LLM Name'] = llm_name
        df.loc[i, 'LLM Analysis'] = llm_analysis
        df.loc[i, 'Score'] = float(llm_score)

    else:
        #go_term = row['GO']
        name = row[nameCol]
        print(f'No analysis for {name}')
        df.loc[i, 'LLM Name'] = None
        df.loc[i, 'LLM Analysis'] = None
        
    # Keep on saving to not loose data if something happens
    if (i%10 == 1):
        print(i)
        df.to_csv(outputFilePath, sep = "\t",  index=False)
    

2520
2373
1
1914
1612
2305
1521
2114
2158
1837
2130
2207
2524
11
2528
2288
2174
2191
2283
2651
2199
2050
2317
2561
21
2295
2127
2381
2040
2688
2331
2375
2413
2749
2744
31
2147
2355
1784
2431
2617
1962
1950
2095
2077
1882
41
2253
1821
2371
2038
2115
2304
2075
2593
1350
1630
51
1562
2415
1440
1712
1572
1676
1910
1591
1515
1402
61
1576
1354
1279
1622
1399
1876
1531
2219
1541
1471
71
1523
1682
1943
1478
1756
1583
1478
2335
1716
1707
81
1526
1343
1584
1567
1549
1651
1504
1509
1609
1560
91
1661
1505
1698
1996
1494
1426
1469
1294
1464
1656
101
1800
1409


In [None]:
df.to_csv(outputFilePath, sep= '\t', index=False)
