In [1]:
import os
import pandas as pd
import json 
import openai

In [2]:
from utils.openai_query import openai_chat
from utils.prompt_factory import make_user_prompt_with_score
from utils.llm_analysis_utils import process_analysis, save_progress

In [3]:
openai.api_key = os.environ["OPENAI_API_KEY"] # Environment variable

In [4]:
geneSep = " "
inputFilePath = "data/omics.txt"
jsonFilePath = "jsonFiles/OmicsRunLLM.json"
genesCol = "GeneList"
nameCol  = "GeneSetName"
outputFilePath = "data/omics_LLM_DF.tsv"

In [5]:
with open(jsonFilePath) as json_file:
    config = json.load(json_file)
    
context = config['CONTEXT']
gpt_model = config['GPT_MODEL']
temperature = config['TEMP']
max_tokens = config['MAX_TOKENS']
rate_per_token = config['RATE_PER_TOKEN']
LOG_FILE = config['LOG_NAME'] + 'log.json'
DOLLAR_LIMIT = config['DOLLAR_LIMIT']

In [6]:
SEED = 42

In [7]:
gpt_model

'gpt-4-1106-preview'

### Run GPT-4 query pipeline for NeST gene sets

In [8]:
df = pd.read_csv(inputFilePath, sep = "\t"); 

df['LLM Name'] = None
df['LLM Analysis'] = None
df['Score'] = None

In [None]:
for i, row in df.iterrows():
    term_genes = row[genesCol]
    genes = term_genes.split(geneSep) 
    
    prompt = make_user_prompt_with_score(genes)

    analysis, finger_print = openai_chat(context, prompt, gpt_model, temperature, max_tokens, rate_per_token, LOG_FILE, DOLLAR_LIMIT, SEED)

    if analysis:
        llm_name, llm_score, llm_analysis = process_analysis(analysis)
        df.loc[i, 'LLM Name'] = llm_name
        df.loc[i, 'LLM Analysis'] = llm_analysis
        df.loc[i, 'Score'] = float(llm_score)

    else:
        #go_term = row['GO']
        name = row[nameCol]
        print(f'No analysis for {name}')
        df.loc[i, 'LLM Name'] = None
        df.loc[i, 'LLM Analysis'] = None
        
    # Keep on saving to not loose data if something happens
    if (i%10 == 1):
        print(i)
        df.to_csv(outputFilePath, sep = "\t",  index=False)
    

2338
2678
1
1883
1793
2292
1715
2267
2098
1895
2143
2288
2252
11
2472
2236
2215
2087
2109
2683
2236
2050
2404
2540
21
2167
2264
2512
1997
2727
2346
2306
2274
2463
2744
31
2147
2496
1823
2492
2558
2066
1962
2253
2250
1869
41
2029
2083
2736
2199
2006
2214
2075
2392
1406
1415
51
1515
2335
1373
1780
2437
1430
1667
1793
2034
2362
61
1531
1536
1330
1472
1731
1626
1572
1629
1929
1591
71
1536
1676
1656
1315
1824
1665
1986
1719
1380
1384
81
1602
1475
1508
1436
1535
1729
1575
1464
2122
1447
91
1723
1543
2023
1996
2019
1845
2012


In [None]:
df.to_csv(outputFilePath, sep= '\t', index=False)
