In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import os
import os.path
from os.path import join
import numpy as np
import imodelsx
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import data
import sys
files_dict = data.load_files_dict_single_site()

In [None]:
# site = 'Atlanta'
# site = 'Columbus'
site = 'WashingtonDC'
df = files_dict[site]
qs, responses_df, themes_df = data.split_single_site_df(df)

In [None]:
def numbered_list(responses):
    return '\n'.join([f'{i+1}. {c.strip()}' for i, c in enumerate(responses)])


themes_prompt = '''### You are given a question and a set of responses below.

**Question**: {question}

**Responses**:
{response_list}

### Group all responses into 2 or more non-overlapping themes.
### Return a comma-separated list, where each element is a theme, followed by the numbers of the responses that fall into that theme in brackets.
### **Example answer**: Theme 1: Negative responses [1, 2, 5], Theme 2: Positive responses [3, 4]

**Answer**: Theme 1:'''

llm = imodelsx.llm.get_llm('gpt-4', repeat_delay=3)

**Run single example**

In [None]:
# question, responses, theme_dict = data.get_data_for_question_single_site(
#     question_num=2, qs=qs, responses_df=responses_df, themes_df=themes_df)

# resps = responses[pd.notna(responses)]
# prompt = themes_prompt.format(
#     question=question,
#     response_list=numbered_list(resps)
# )
# print(prompt)
# llm(prompt)

### Screen valid questions
Valid questions have multiple unique responses.

In [None]:
def count_unique(resps):
    resps_match = resps.apply(str.lower)
    resps_match = resps_match.str.replace('[^\w\s]', '')
    # print(set(resps_match))
    return len(set(resps_match))


# screen valid questions
def get_valid_question_nums(qs, responses_df, themes_df):
    valid_question_nums = []
    for question_num in tqdm(range(len(qs)), position=0):

        question, responses, theme_dict = data.get_data_for_question_single_site(
            question_num=question_num, qs=qs, responses_df=responses_df, themes_df=themes_df)
        resps = responses[pd.notna(responses)]

        # valid only if there are multiple unique responses
        if count_unique(resps) > 3:
            # print(resps)
            valid_question_nums.append(question_num)
    return valid_question_nums


valid_question_nums = get_valid_question_nums(qs, responses_df, themes_df)
print('num valid qs', len(valid_question_nums), 'of', len(qs))

### Run generating themes

In [None]:
def get_themes_and_resps(valid_question_nums, qs, responses_df, themes_df):
    themes_generated = {}
    resps_list = {}
    for question_num in tqdm(valid_question_nums, position=0):

        question, responses, theme_dict = data.get_data_for_question_single_site(
            question_num=question_num, qs=qs, responses_df=responses_df, themes_df=themes_df)
        resps = responses[pd.notna(responses)]

        prompt = themes_prompt.format(
            question=question,
            response_list=numbered_list(resps)
        )
        ans = llm(prompt)
        themes_generated[question_num] = [s.strip(' ,:1234567890')
                                          for s in ans.split('Theme')]
        resps.index = np.arange(len(resps)) + 1
        resps_list[question_num] = resps
    return themes_generated, resps_list


themes_generated, resps_list = get_themes_and_resps(
    valid_question_nums, qs, responses_df, themes_df)

In [None]:
def dprint(*args, f):
    # print(*args, file=sys.stdout)
    print(*args, file=f)


with open(f'../figs/themes/themes_generated_{site}.md', 'w') as f:
    for question_num in valid_question_nums:
        dprint('### Question:', qs[question_num], f=f)
        dprint('\nResponses', f=f)
        for i in range(len(resps_list[question_num])):
            dprint(f'{i+1}. {resps_list[question_num].iloc[i]}', f=f)
        dprint('\nThemes', f=f)
        # print(themes_generated[question_num])
        themes = themes_generated[question_num]
        for i, theme in enumerate(themes):
            dprint(f'- Theme {i + 1}:', theme, f=f)
        dprint('', f=f)

# Generate human experiment

In [None]:
SITES = ['Atlanta', 'Columbus', 'WashingtonDC']
vset = None
for site in SITES:
    df = files_dict[site]
    qs, responses_df, themes_df = data.split_single_site_df(df)
    valid_question_nums = get_valid_question_nums(qs, responses_df, themes_df)
    if vset is None:
        vset = set(valid_question_nums)
    else:
        vset = vset.intersection(valid_question_nums)
questions_selected = list(vset)
pd.Series(questions_selected).to_csv(
    '../figs/human/themes_questions_selected.csv', index=False, header=False)

In [None]:
def remove_brackets(s):
    return s.split('[')[0].strip()


MAX_THEMES = 5
ddf = defaultdict(list)
for i, q in enumerate(questions_selected):
    site = SITES[i % len(SITES)]

    df = files_dict[site]
    qs, responses_df, themes_df = data.split_single_site_df(df)
    themes_generated, resps_list = get_themes_and_resps(
        vset, qs, responses_df, themes_df)
    ts = [remove_brackets(s) for s in themes_generated[q]]
    resps = resps_list[q]
    n_resps = len(resps)
    for resp in resps:
        ddf['responses'].append(resp)
        ddf['site'].append(site)
        ddf['question'].append(qs[q])

        # Themes without citations
        for i in range(len(ts)):
            ddf[f'Theme {i+1}'].append(ts[i])
        for i in range(len(ts), MAX_THEMES + 1):
            ddf[f'Theme {i+1}'].append('')
        ddf['themes'].append(ts)

    # answers (map citations to answers)
    ans_vec = np.zeros((n_resps, MAX_THEMES))
    for i, t in enumerate(themes_generated[q]):
        try:
            nums = t.split('[')[1].split(']')[0].split(',')
        except:
            print(t)
        nums = np.array([int(n.strip()) for n in nums]) - 1
        ans_vec[nums, i] = 1
    for i in range(MAX_THEMES):
        ddf[f'ans {i+1}'] += ans_vec[:, i].tolist()

In [None]:
dx = pd.DataFrame(ddf)
dx.to_csv('../figs/human/themes_template.csv', index=False)

### Analyze responses

In [27]:
template = pd.read_csv('../figs/human/themes_template.csv')

annots = {
    'hum1': 'human1',
    'hum2': 'human2',
    'hum3': 'human3',
}

for k, v in annots.items():
    hum = pd.read_csv(f'../figs/human/collected/themes_{v}.csv', skiprows=1)

    # check for matching index
    def remove_all_whitespace(s):
        return ''.join(s.split())

    assert np.all(hum['response'].apply(remove_all_whitespace).values ==
                  template['responses'].apply(remove_all_whitespace).values)

    # load answer
    def get_clean_annotation(s):
        return s.split(',')[-1].strip()

    template[k] = hum['annotation'].astype(str).apply(
        get_clean_annotation).values.astype(float).astype(int)
    # check that values are in range 1-5
    assert np.all(template[k].values >= 1), np.unique(template[k])
    assert np.all(template[k].values <= 5)
    ans = template[['ans 1', 'ans 2', 'ans 3', 'ans 4', 'ans 5']].values

    # compute mean
    template[f'{k}_correct'] = False
    for i in range(len(template)):
        if ans[i, template[k][i] - 1] == 1:
            template.at[i, f'{k}_correct'] = True
    print(k, template[f'{k}_correct'].mean().round(
        3), template[f'{k}_correct'].sem().round(4))

hum1 0.862 0.0252
hum2 0.867 0.0248
hum3 0.878 0.024


In [28]:
# def agreement(k1, k2):
#     return np.mean(template[k1] == template[k2])
# agreements = [agreement('hum1', 'hum2'), agreement(
#     'hum1', 'hum3'), agreement('hum2', 'hum3')]
x = np.concatenate([template['hum1'], template['hum2'], template['hum3']])
y = np.concatenate([template['hum2'], template['hum3'], template['hum1']])
print('inter-annotator agreement', np.mean(x == y).round(3),
      (np.std(x == y) / np.sqrt(len(x))).round(4))

inter-annotator agreement 0.867 0.0143


In [17]:
themes = template[['Theme 1', 'Theme 2', 'Theme 3', 'Theme 4', 'Theme 5']]
num_non_nan_themes = (~themes.isna()).sum(axis=1)
print('avg num classes', num_non_nan_themes.mean(),
      'random acc', 1 / num_non_nan_themes.mean())

avg num classes 3.2393617021276597 random acc 0.30870279146141216
