In [None]:
%load_ext autoreload
%autoreload 2

from collections import defaultdict

import dvu
import matplotlib.pyplot as plt
import pandas as pd
from os.path import join
import os.path
from tqdm import tqdm
import pathlib
import imodelsx.llm
import json
import requests
import numpy as np
import openai
import pubmed
import prompts
openai.api_key = open('/home/chansingh/.OPENAI_KEY').read().strip()
plt.style.use('default')
dvu.set_style()

df = pd.read_csv('../data/main.csv')

# extract text from pdfs
ids_with_paper = df[df["found_paper (0=no, 1=yes)"] > 0].id.astype(int).values
# print(len(ids_with_paper), ids_with_paper.values)

# get papers
ids_found = sorted(
    [int(x.replace(".pdf", "")) for x in os.listdir("../papers") if x.endswith(".pdf")]
)

for paper_id in ids_with_paper:
    if paper_id in ids_found:
        continue
    else:
        print('should have paper', paper_id)

for paper_id in ids_found:
    if paper_id in ids_with_paper:
        continue
    else:
        print(paper_id, 'in local pdfs but not in main.csv')
        idx = df[df.id == paper_id].index[0]
        print(df.loc[idx, 'found_paper (0=no, 1=yes)'])
        df.loc[idx, 'found_paper (0=no, 1=yes)'] = 1

In [None]:
# # download papers
# refs = pubmed.get_updated_refs(df)
# all_ids = df.id
# ids_missing = [str(id) for id in all_ids if id not in ids_found]
# pmids_missing = {}
# for id in ids_missing:
#     ref = refs[df["id"] == int(id)][0]
    
#     if isinstance(ref, str) and 'pubmed' in ref:
#         paper_id = pubmed.get_paper_id(ref)
#         # print(id, ref, paper_id)
#         pmids_missing[paper_id] = id
# s = ",".join(list(pmids_missing.keys()))
# # !python -m pubmed2pdf pdf --pmids="{s}"

# # rename each pdf file in pubmed2pdf to its id
# pubmed_papers_dir = pathlib.Path("../pubmed2pdf")
# papers_downloaded = os.listdir(pubmed_papers_dir)
# for paper in papers_downloaded:
#     paper_id = paper.split(".")[0]
#     paper_id = pmids_missing[paper_id]
#     os.rename(
#         join(pubmed_papers_dir, paper),
#         join(pubmed_papers_dir, f"{paper_id}.pdf"),
#     )
pubmed.extract_texts_from_pdf(ids_with_paper)

### Ask questions about the text

In [None]:
# llm = imodelsx.llm.get_llm("gpt-3.5-turbo-0613")
llm = imodelsx.llm.get_llm("gpt-4-0613")
# llm = imodelsx.llm.get_llm("gpt-4-32k-0613")

# properties, functions, content_str = prompts.get_prompts_demographics()
properties, functions, content_str = prompts.get_prompts_gender()
messages = [
    {
        "role": "user",
        "content": content_str,
    }
]

# example with answer: One hundred and five patients, 55 males and 50 females
toy_input1 = """This study was about treating diabetes. It was a very difficult study.
One hundred and five patients, 55 males and 50 females were included.
The study took 200 days to complete. The study was conducted in the United States.
The study was conducted by the University of California, San Francisco."""

# example with answer: One hundred and five patients, 55 males and 50 females, 10 white, 75 black
toy_input2 = """This study was about treating diabetes. It was a very difficult study.
One hundred and five patients, 55 males and 50 females were included.
The study took 200 days to complete. The study was conducted in the United States.
Ten of the patients were white, 20 were asian, and the rest were black.
The study was conducted by the University of California, San Francisco."""

# messages[0]['content'] = content_str.format(input=toy_input1)
# msg = llm(messages, functions=functions, return_str=False, temperature=0.0)
# args = json.loads(msg.get('function_call')['arguments'])
# print(json.dumps(args, indent=2))

# messages[0]['content'] = content_str.format(input=toy_input2)
# msg = llm(messages, functions=functions, return_str=False, temperature=0.0)
# args = json.loads(msg.get('function_call')['arguments'])
# print(json.dumps(args, indent=2))

In [97]:
def rename_to_none(x: str):
    if x in {"", "unknown", "N/A"}:
        return None
    else:
        return x


def call_on_subsets(x: str, subset_len_tokens=4750, max_calls=3):
    subset_len_chars = subset_len_tokens * 4

    args = None
    subset_num = 0

    while args is None and subset_num < max_calls:
        subset = x[subset_num * subset_len_chars : (subset_num + 1) * subset_len_chars]

        # if approx_tokens < 6000:
        messages[0]["content"] = content_str.format(input=subset)
        msg = llm(messages, functions=functions, return_str=False, temperature=0.0)
        if msg is not None and msg.get("function_call") is not None:
            args = json.loads(msg.get("function_call")["arguments"])
            return args

        subset_num += 1

        # next segment should have atleast 0.5 * subset_len_chars_left
        if len(x) < (subset_num + 0.5) * subset_len_chars:
            break

    return None


def check_evidence(ev: str, real_input: str):
    if ev is not None:
        # remove all whitespace
        ev = "".join(ev.split())
        real_input = "".join(real_input.split())
        return ev.lower() in real_input.lower()
    return False


# initialize
for k in properties.keys():
    df.loc[:, k] = None
# df["approx_tokens"] = None

# run loop
for id in tqdm(ids_with_paper):
    i = df[df.id == id].index[0]
    row = df.iloc[i]
    paper_file = join("../papers", str(int(row.id)) + ".txt")

    try:
        real_input = pathlib.Path(paper_file).read_text()
        # gpt4 has 8k token window (some of it is functions, etc.)
        # approx_tokens = (len(real_input) / 4)
        # df.loc[i, "approx_tokens"] = approx_tokens
        args = call_on_subsets(real_input)

        # print(json.dumps(args, indent=2))
        if args is not None:
            for k in properties.keys():
                if k in args:
                    # set the value at row number i and column k to the value of args[k]
                    df.loc[i, k] = rename_to_none(args[k])

                    # remove spans if they are not actually contained in the text
                    if k in ["num_male_evidence_span", "num_female_evidence_span"]:
                        if not check_evidence(args[k], real_input):
                            df.loc[i, k] = None
    except Exception as e:
        print(row.id, e)
print("completed!")

 17%|█▋        | 31/184 [02:43<18:46,  7.37s/it]

cached!
'choices'


In [None]:
def cast_int(x):
    try:
        return int(x)
    except:
        return -1


for k in ['num_male', 'num_female']:
    idxs = (df[k + '_corrected'].notnull() & ~(df[k + '_corrected'] == 'Unk'))
    gt = df[k + '_corrected'][idxs].astype(int)
    pred = df[k].apply(cast_int)[idxs].astype(int)
    acc = (gt == pred).mean()
    print(f'{k} acc={acc:0.2f} n={len(gt)}')

In [None]:
df.to_csv('../data/main.csv', index=False)

# Look at gender ratios

In [None]:
idxs = (df['num_male_corrected'].notnull() & ~(df['num_male_corrected'] == 'Unk')) & (df['num_female_corrected'].notnull() & ~(df['num_female_corrected'] == 'Unk'))
male = df['num_male_corrected'][idxs].astype(int)
female = df['num_female_corrected'][idxs].astype(int)

In [None]:
ratios = (male / female).values
# drop inf
print(sorted(ratios))

In [None]:
plt.figure(figsize=(4, 2), dpi=300)
r = ratios[~np.isinf(ratios)]
logr = np.log10(r)
print('mean', r.mean(), 'frac>0', (r > 1).sum(), '/', len(r))
plt.hist(logr[logr < 0], color='pink') #, bins=100)
plt.hist(logr[logr >= 0], color='C0') #, bins=100)
plt.axvline(0, color='black', ls='--')
ticks = plt.xticks()[0]
plt.xticks(ticks, [f'$10^{{{t}}}$' for t in ticks])
plt.xlabel('Ratio (male / female)')
plt.ylabel('Count')
plt.show()