In [None]:
# !pip install aiohttp
# !pip install pandas
# !pip install numpy
# !pip install python-dotenv
# !pip install tqdm
#!pip install pyyaml

In [None]:
import sys
import os
import pandas as pd 
import numpy as np

sys.path.append('./..')
from py_helpers.gpt import get_prompts 
from dotenv import load_dotenv
from py_helpers.sqlite import SQLiteConn

sqlite = SQLiteConn('data.db')
load_dotenv('./.env')

## Testing

In [None]:
import yaml
import json 

base_prompts = {}
for v in ['v1', 'v2']:
    with open('prompts/' + v + '.yaml') as f:
        base_prompts[v] = [
            {'role': 'assistant', 'content': json.dumps(json.loads(p['content']))} if p['role'] == 'assistant' else p 
            for p in yaml.safe_load(f)
        ]

base_prompts

In [None]:
topics_to_avoid = ['Exploring Tokyo', 'Brain Structure']

prompts_list = base_prompts['v2'] + [{'role': 'user', 'content': '## Conversations Pairs: 2\n## Avoid: ' + json.dumps(topics_to_avoid)}]

res = await get_prompts(
    [prompts_list],
    {'model': 'gpt-4o', 'temperature': 1.0, 'response_format': {'type': 'json_object'}}, 
    api_key = os.environ.get('OPENAI_API_KEY')
)
res

## Run

In [None]:
sqlite.execute(
    """
    CREATE TABLE IF NOT EXISTS pairs (
        id INTEGER PRIMARY KEY,
        prompt STRING NOT NULL,
        topic STRING NOT NULL,
        generic STRING NOT NULL, 
        dog STRING NOT NULL,
        added_at STRING NOT NULL 
    )
    """
)

display(sqlite.get_query('SELECT * FROM pairs ORDER BY added_at DESC'))

In [None]:
from datetime import datetime

def parse_response(r, version):
    try:
        parsed = json.loads(r['choices'][0]['message']['content'])
        pairs = parsed['conversation_pairs']
        cleaned = []
        for pair in pairs:
            try:
                topic = pair['topic']
                generic = pair['generic']
                dog = pair['dog_related']
                cleaned.append({
                    'prompt': version,
                    'topic': topic,
                    'generic': json.dumps(generic),
                    'dog': json.dumps(dog),
                    'added_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                })
            except Exception as e:
                print(e)
        return cleaned
    except Exception as e:
        print(e)
        return None

# parsed = [x for xs in [parse_response(r, 'v2') for r in res] for x in xs if xs is not None]
# pd.DataFrame(parsed)

In [None]:
sqlite.get_query(
    """ 
    WITH t0 AS (SELECT topic FROM pairs ORDER BY added_at DESC LIMIT 50),
    t1 AS (SELECT topic FROM pairs ORDER BY RANDOM() LIMIT 50)
    SELECT DISTINCT(topic) 
    FROM (SELECT * FROM t0 UNION ALL SELECT * FROM t1)
    ORDER BY RANDOM() LIMIT 50
    """)['topic'].tolist()

In [None]:
from tqdm import tqdm 

async def pull_data(base_prompt_v1: str, base_prompt_v2: str, existing_topics: list):
    """
    Send two prompts concurrently
    """
    prompts_list_1 = base_prompt_v1 + [{'role': 'user', 'content': '## Conversations Pairs: 3\n## Avoid: ' + json.dumps(existing_topics) + ''}]
    prompts_list_2 = base_prompt_v2 + [{'role': 'user', 'content': '## Conversations Pairs: 3\n## Avoid: ' + json.dumps(existing_topics) + ''}]

    res = await get_prompts(
        [prompts_list_1, prompts_list_2],
        {'model': 'gpt-4o', 'temperature': 1.0, 'response_format': {'type': 'json_object'}}, 
        api_key = os.environ.get('OPENAI_API_KEY'),
        verbose = False
    )

    parsed_1 = pd.DataFrame(parse_response(res[0], 'v1'))
    parsed_2 = pd.DataFrame(parse_response(res[1], 'v2'))
    
    return pd.concat([parsed_1, parsed_2])


for i in tqdm(range(0, 1000)):
    topics = sqlite.get_query(
        """ 
        WITH t0 AS (SELECT topic FROM pairs ORDER BY added_at DESC LIMIT 30),
        t1 AS (SELECT topic FROM pairs ORDER BY RANDOM() LIMIT 30)
        SELECT DISTINCT(topic) 
        FROM (SELECT * FROM t0 UNION ALL SELECT * FROM t1)
        ORDER BY RANDOM() LIMIT 40
        """)['topic'].tolist()
    
    pulled_data = await pull_data(base_prompts['v1'], base_prompts['v2'], topics)
    display(pulled_data)
    sqlite.write_df('pairs', pulled_data)
    