In [1]:
from pathlib import Path
import pandas as pd
import numpy as np
import time
import ast

from model_client import ModelClient, get_prompt
from pdf_extraction import extract_text_from_pdf

In [2]:
def screen_paper(pdf_path, model="chatgpt"):
    client = ModelClient()

    titles = []
    responses = []
    times = []

    for pdf in pdf_path.glob("*.pdf"):
        titles.append(pdf.stem)
        paper_text = extract_text_from_pdf(pdf)
        prompt = get_prompt(paper_text)

        start_time = time.time()
        
        if model == "chatgpt":
            response = client.ask_chatgpt(prompt)
        elif model == "claude":
            response = client.ask_claude(prompt)
        elif model == "llama":
            response = client.ask_llama(prompt)
        else:
            raise "Model Not Found"
        
        end_time = time.time()

        responses.append(response)
        times.append(end_time - start_time)
    
    return titles, responses, times

In [3]:
def compile_result(titles, responses, times, save_path):
    results = []
    for title, response in zip(titles, responses):
        result = ast.literal_eval(response)
        result["Title"] = title
        results.append(result)
    
    mean_time = np.mean(times)
    std_dev_time = np.std(times)
    print("Mean Time: ", mean_time)
    print("Standard Deviation: ", std_dev_time)

    df = pd.DataFrame(results)
    df.to_csv(save_path, index=False)

In [4]:
pdf_path = Path("./data/pdfs")
save_root = Path("./data/results")

In [5]:
model = "chatgpt"
save_path = save_root / f"{model}_results.csv"
titles, responses, times = screen_paper(pdf_path, model=model)
compile_result(titles, responses, times, save_path)

Mean Time:  2.7227685332298277
Standard Deviation:  2.0831197106799415


In [6]:
model = "claude"
save_path = save_root / f"{model}_results.csv"
titles, responses, times = screen_paper(pdf_path, model=model)
compile_result(titles, responses, times, save_path)

Mean Time:  18.040290200710295
Standard Deviation:  14.781996082857068


In [7]:
model = "llama"
save_path = save_root / f"{model}_results.csv"
titles, responses, times = screen_paper(pdf_path, model=model)
compile_result(titles, responses, times, save_path)

Mean Time:  0.29912456274032595
Standard Deviation:  0.27146651580838643
