In [None]:
import time
import pickle
import glob
import json
import pandas as pd
import seaborn as sns
from langchain.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from matplotlib import rcParams
import matplotlib.pyplot as plt

import os
os.environ['OPENAI_API_KEY'] = ''

# Open the manuscript dictionaries to get DOIs

In [None]:
with open('data_input/scipdf.pkl', 'rb') as handle:
    pdf_dict = pickle.load(handle)

pdf_doi_dict = {}
for pdf_name in pdf_dict.keys():
    pdf_doi_dict[pdf_name.replace(".pdf", "")] = pdf_dict[pdf_name]["doi"]

# Extract all MTA information from the 36 manuscripts

In [None]:
for model in ["gpt-3.5-turbo", "gpt-4"]:
    marker_results = pd.DataFrame()
    fail_list = []

    for key in pdf_dict.keys():
        
        time.sleep(2)
        embeddings = OpenAIEmbeddings()
        key = key.replace(".pdf", "")
        db = FAISS.load_local("data_faiss/faiss_chunked/faiss_db_1000/{}".format(key), embeddings, allow_dangerous_deserialization=True)
        retriever = db.as_retriever(search_type="similarity_score_threshold",
                                    search_kwargs={"score_threshold": 0.5, "k": 25})

        query = """Find all the names of significant QTLs, QTNs, MTAs, SNPs, SRR regions or GWAS marker names mentioned in the manuscript. \
                For each significant marker you find, find which trait or condition it is associated with, or which trait or condition was \
                used to genetically map to this trait. If there are multiple traits or conditions associated with a single marker, combine them. 

                Important - Respond in JSON format only, following the schema below:
                    ```json
                    {
                    "marker": string  // QTL, QTN, MTA, SNP, SRR regions or GWAS marker name
                    "trait_full": string  // full, non-abbreviated, names of all traits associated with the given marker
                    "trait_abv": string  // abbreviated names of all traits associated with the given marker
                    "chromosome": string // chromosome name and location information, if available. NaN if not found
                    "genomic_range": string // range of genomic region the marker was mapped to, if available. NaN if not found
                    }
                    ```
                """
        
        llm = ChatOpenAI(model_name=model) 
        qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, temperature=0.7)
        marker_response = qa(query)

        try:
            response_string = marker_response['result']
            response_string = response_string.split("```")[1] if "```" in response_string else response_string
            response_string = response_string.replace("```", "")
            response_string = response_string.replace("json\n", "")
            response_string = response_string.replace("}\n{", "},\n{")
            response_string = response_string.replace("\n}\n}\n", "\n}\n")
            response_string = response_string.replace("}\n", "}")

            if "[" not in response_string:
                response_string = "[" + response_string
            if "]" not in response_string:
                response_string = response_string + "]"

            # Parse the JSON format and turn it into a pandas dataframe
            response_json = json.loads(response_string)
            tmp_df = pd.DataFrame(response_json)

            # Add the PDF and DOI IDs so it is easier to keep track of
            tmp_df["pdf"] = key
            tmp_df["doi"] = pdf_doi_dict[key]
            marker_results = pd.concat([marker_results, tmp_df])
            print("success,"+key)
        except:
            print("fail,"+key)
            fail_list.append([key, marker_response['result']])
    marker_results.to_csv(f"data_output/markers/36.markers.{model}.tsv".format(key), sep="\t", index=None)

# Short results over all papers

In [None]:
true_markers = pd.read_csv("data_figures/36.markers.curated.tsv", sep="\t", encoding = "ISO-8859-1")
true_markers = true_markers.groupby("doi").count().reset_index()[["doi", "marker"]]

# Sum the 'correct' column because some predictions contain multiple markers that are all counted
pred_markers = pd.read_csv("data_figures/36.markers.tsv", sep="\t", encoding = "ISO-8859-1")
pred_markers = pred_markers.drop_duplicates(["Model", "doi", "marker"])
pred_markers = pred_markers.groupby(["Model", "doi"]).sum().reset_index()[["Model", "doi", "correct"]]
results = pred_markers.merge(true_markers, on="doi", how="outer")

results = results.fillna(0) # out join with missing results returns NaN
results['% correct'] = results["correct"] / results["marker"] * 100

plt.rcParams.update({'font.weight': 'bold', 'font.size': 13, 
                     'axes.labelweight': 'bold', 'axes.titleweight': 'bold'})
fig, ax = plt.subplots(figsize=(3, 4))

print(results[["Model", "% correct"]].groupby("Model").mean())
sns.barplot(data=results, x="Model", y= "% correct", color="black")
sns.swarmplot(data=results, x="Model", y= "% correct", color="orange")

# Analyze a single paper with different chunk sizes and K

In [None]:
for model in ["gpt-3.5-turbo", "gpt-4"]:
    marker_results = pd.DataFrame()
    fail_list = []
    key = "s00122-022-04109-9.pdf"

    for size in [250,500,750,1000]:
        for k in [5,10,15,20,25]:
            for rep in [1,2,3,4]:
                time.sleep(2)
                embeddings = OpenAIEmbeddings()
                key = key.replace(".pdf", "")
                db = FAISS.load_local("data_faiss/faiss_chunked/faiss_db_{}/{}".format(size, key), embeddings)
                retriever = db.as_retriever(search_type="similarity_score_threshold",
                                            search_kwargs={"score_threshold": 0.5, "k": k})
                
                query = """Find all the names of significant QTLs, QTNs, MTAs, SNPs, SRR regions or GWAS marker names mentioned in the manuscript. \
                        For each significant marker you find, find which trait or condition it is associated with, or which trait or condition was \
                        used to genetically map to this trait. If there are multiple traits or conditions associated with a single marker, combine them. 

                        Important - Respond in JSON format only, following the schema below:
                            ```json
                            {
                            "marker": string  // QTL, QTN, MTA, SNP, SRR regions or GWAS marker name
                            "trait_full": string  // full, non-abbreviated, names of all traits associated with the given marker
                            "trait_abv": string  // abbreviated names of all traits associated with the given marker
                            "chromosome": string // chromosome name and location information, if available. NaN if not found
                            "genomic_range": string // range of genomic region the marker was mapped to, if available. NaN if not found
                            }
                            ```
                        """

                # Uncommend the LLM model you want to use for the RAG chain
                #llm = ChatOpenAI(model_name='gpt-3.5-turbo-1106')
                llm = ChatOpenAI(model_name='gpt-4') 
                qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, temperature=0.7)
                marker_response = qa(query)
                
                # Correct the returned JSON format based on observed errors
                try:
                    response_string = marker_response['result']
                    response_string = response_string.split("```")[1] if "```" in response_string else response_string
                    response_string = response_string.replace("```", "")
                    response_string = response_string.replace("json\n", "")
                    response_string = response_string.replace("}\n{", "},\n{")
                    response_string = response_string.replace("\n}\n}\n", "\n}\n")
                    response_string = response_string.replace("}\n", "}")
                    
                    if "[" not in response_string:
                        response_string = "[" + response_string
                    if "]" not in response_string:
                        response_string = response_string + "]"
                    
                    # Parse the JSON format and turn it into a pandas dataframe
                    response_json = json.loads(response_string)
                    tmp_df = pd.DataFrame(response_json)
                    
                    # Add the PDF and DOI IDs so it is easier to keep track of
                    tmp_df["pdf"] = key
                    tmp_df["doi"] = pdf_doi_dict[key]
                    tmp_df["k"] = k
                    tmp_df["rep"] = rep
                    tmp_df["size"] = size
                    marker_results = pd.concat([marker_results, tmp_df])
                except:
                    fail_list.append([key, size, k, rep, marker_response['result']])
    marker_results.to_csv(f"data_output/markers/p10.markers.{model}.tsv", sep="\t", index=None)

# Make plots for the different k and chunk size comparisons

In [None]:
true_markers = pd.read_csv("data_figures/36.markers.curated.tsv", sep="\t", encoding = "ISO-8859-1")
true_markers = true_markers.groupby("doi").count().reset_index()[["doi", "marker"]]
n_markers = true_markers[true_markers["doi"]=="10.1007/s00122-022-04109-9"]["marker"].iloc[0]

In [None]:
plt.rcParams.update({'font.weight': 'bold', 'font.size': 13, 
                     'axes.labelweight': 'bold', 'axes.titleweight': 'bold'})
fig, ax = plt.subplots(figsize=(4.5, 3))

trait_results = pd.read_csv("data_figures/p10.marekrs.gpt-3.5-turbo.tsv" ,sep="\t")
trait_results = trait_results[trait_results["correct"]==1]
trait_results = trait_results[["k", "rep", "size", "correct"]].groupby(["k", "rep", "size"]).sum().reset_index()
trait_results["% correct"] = trait_results["correct"] / n_markers * 100

trait_results.columns = ["Top-k Size", "rep", "Chunk Size", "correct","% correct"]
ax = sns.barplot(data=trait_results, x="Chunk Size", y="% correct", hue="Top-k Size", errorbar="sd")
ax.set_title("GPT-3.5")
plt.legend(loc='upper right', title="Top-k Size", prop={'size': 10})
plt.ylim(0, 100)

In [None]:
plt.rcParams.update({'font.weight': 'bold', 'font.size': 13, 
                     'axes.labelweight': 'bold', 'axes.titleweight': 'bold'})
fig, ax = plt.subplots(figsize=(4.5, 3))

trait_results = pd.read_csv("data_figures/p10.markers.gpt-4.tsv" ,sep="\t")
trait_results = trait_results[trait_results["correct"]==1]
trait_results = trait_results[["k", "rep", "size", "correct"]].groupby(["k", "rep", "size"]).sum().reset_index()
trait_results["% correct"] = trait_results["correct"] / n_markers * 100

trait_results.columns = ["Top-k Size", "rep", "Chunk Size", "correct","% correct"]
ax = sns.barplot(data=trait_results, x="Chunk Size", y="% correct", hue="Top-k Size", errorbar="sd")
ax.set_title("GPT-4")
ax.get_legend().remove()
plt.ylim(0, 100)