In [1]:
import json
import numpy as np
import pandas as pd
from utils import get_llm_response, get_file, get_topics, list_s3
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass


@dataclass
class PersonaDimension:
    name: str  # a concise name of the persona aspect
    description: str  # a detailed description of the persona aspect
    level: str  # the abstractness level of this persona dimension
    candidate_values: Optional[str]  # the candidate values of this persona dimension


get_llm_response('hi')

'Hello!'

In [2]:
mapping = np.load(get_file('human_resp/topic_mapping.npy'), allow_pickle=True)
mapping = mapping.item()

In [3]:
surveys = set()
for path in list_s3("human_resp/"):
    if path.startswith("human_resp/American_Trends_Panel"):
        # Extract the folder name
        folder = path.split("/")[1]
        surveys.add(folder)
surveys = sorted(list(surveys))
surveys

['American_Trends_Panel_W26',
 'American_Trends_Panel_W27',
 'American_Trends_Panel_W29',
 'American_Trends_Panel_W32',
 'American_Trends_Panel_W34',
 'American_Trends_Panel_W36',
 'American_Trends_Panel_W41',
 'American_Trends_Panel_W42',
 'American_Trends_Panel_W43',
 'American_Trends_Panel_W45',
 'American_Trends_Panel_W49',
 'American_Trends_Panel_W50',
 'American_Trends_Panel_W54',
 'American_Trends_Panel_W82',
 'American_Trends_Panel_W92']

In [4]:
def get_personas_from_survey(info_df, survey):
    with open('prompts/get_personas_from_questions_simple.txt') as f:
        prompt_template = f.read()

    res = []
    logs = {}
    valid_cnt = 0
    for idx, row in tqdm(info_df.iterrows(), total=len(info_df)):
        topics = get_topics(mapping, row['question'])
        input_dict = {
            "topic_fg": topics['fg'],
            "topic_cg": topics['cg'],
            "question": row['question'],
            "options": row['references'],
        }
        prompt = prompt_template.format(**input_dict)
        response = get_llm_response(prompt, model_id='anthropic.claude-3-haiku-20240307-v1:0', prefill='[')
        response = '[' + response
        valid = None
        error_msg = None
        try:
            eval(response)
            valid = True
            valid_cnt += 1
        except Exception as e:
            print(e)
            valid = False
            error_msg = str(e)

        res.append({
            'valid': valid,
            'error_msg': error_msg,
            'input_dict': str(input_dict),
            'response': response,
        })

        logs['res_len'] = len(res)
        logs['valid_ratio'] = valid_cnt /  len(res)

        with open(f'outputs_simple/personas_extracted_from_question_{survey}.json', 'w') as f:
            json.dump(res, f, indent=4)
        with open(f'outputs_simple/logs_{survey}.json', 'w') as f:
            json.dump(logs, f, indent=4)

    print(survey)
    print(logs)

In [5]:
for survey in surveys[:1]:
    # survey = "American_Trends_Panel_W26"
    file_key = f"human_resp/{survey}/info.csv"
    info_df = pd.read_csv(get_file(file_key))
    get_personas_from_survey(info_df, survey)


100%|██████████| 78/78 [05:49<00:00,  4.48s/it]

American_Trends_Panel_W26
{'res_len': 78, 'valid_ratio': 1.0}



