In [1]:
from pydantic import BaseModel, Field
from typing import List, Optional
from openai import OpenAI
import dotenv
import json

In [2]:
dotenv.load_dotenv()

True

In [3]:
client = OpenAI()

In [11]:
class Person(BaseModel):
    """
    Represents a person.
    """
    firstname: str = Field(..., title="First name", description="The first name of the person.")
    middle_initial: Optional[str] = Field(None, title="Middle initial", description="The middle initial of the person if present.")
    surname: str = Field(..., title="Surname", description="The surname of the person.")
    affiliation: str = Field(..., title="Affiliation", description="The affiliation of the person.")

class Metabolite(BaseModel):
    """
    Represents a metabolite.
    """
    name: str = Field(..., title="Name", description="The name of the metabolite.")
    chebi_id: str = Field(..., title="ChEBI ID", description="The ChEBI ID of the metabolite.")

class Protein(BaseModel):
    """
    Represents a protein.
    """
    name: str = Field(..., title="Name", description="The name of the protein.")
    uniprot_id: str = Field(..., title="UniProt ID", description="The UniProt ID of the protein.")

class Gene(BaseModel):
    """
    Represents a gene.
    """
    name: str = Field(..., title="Name", description="The name of the gene.")
    entrez_id: str = Field(..., title="Entrez ID", description="The Entrez ID of the gene.")

class Pathway(BaseModel):
    """
    Represents a pathway.
    """
    name: str = Field(..., title="Name", description="The name of the pathway.")
    kegg_id: str = Field(..., title="KEGG ID", description="The KEGG ID of the pathway.")

class Drug(BaseModel):
    """
    Represents a drug.
    """
    name: str = Field(..., title="Name", description="The name of the drug.")
    guide_to_pharmacology_id: str = Field(..., title="Guide to Pharmacology ID", description="The Guide to Pharmacology ID of the drug.")

class Disease(BaseModel):
    """
    Represents a disease.
    """
    name: str = Field(..., title="Name", description="The name of the disease.")
    doid_id: str = Field(..., title="DOID ID", description="The DOID ID of the disease.")
    inhibiting_drugs: List[Drug] = Field(..., title="Inhibiting drugs", description="The drugs inhibiting the disease.")
    associated_genes: List[Gene] = Field(..., title="Associated genes", description="The genes associated with the disease.")
    associated_proteins: List[Protein] = Field(..., title="Associated proteins", description="The proteins associated with the disease.")
    associated_metabolites: List[Metabolite] = Field(..., title="Associated metabolites", description="The metabolites associated with the disease.")
    associated_pathways: List[Pathway] = Field(..., title="Associated pathways", description="The pathways associated with the disease.")

# class Regulation(BaseModel):
#     """
#     Represents a regulation.
#     """
#     regulator: str = Field(..., title="Regulator", description="The regulator of the regulation.")
#     target: str = Field(..., title="Target", description="The target of the regulation.")
#     type: str = Field(..., title="Type", description="The type of the regulation.")
#     evidence: str = Field(..., title="Evidence", description="The evidence of the regulation.")

class Drug(BaseModel):
    """
    Represents a drug.
    """
    name: str = Field(..., title="Name", description="The name of the drug.")
    pubchem_id: str = Field(..., title="PubChem ID", description="The PubChem ID of the drug.")

class Article(BaseModel):
    """
    Represents an article.
    """
    title: str = Field(..., title="Title", description="The title of the article.")
    journal: str = Field(..., title="Journal", description="The journal of the article.")
    year: int = Field(..., title="Year", description="The year of the article.")
    volume: str = Field(..., title="Volume", description="The volume of the article.")
    pubmed_id: str = Field(..., title="PubMed ID", description="The PubMed ID of the article.")
    authors: List[Person] = Field(..., title="Authors", description="The authors of the article.")
    mentioned_metabolites: List[Metabolite] = Field(..., title="Mentioned metabolites", description="The metabolites mentioned in the article.")
    mentioned_proteins: List[Protein] = Field(..., title="Mentioned proteins", description="The proteins mentioned in the article.")
    mentioned_genes: List[Gene] = Field(..., title="Mentioned genes", description="The genes mentioned in the article.")
    mentioned_pathways: List[Pathway] = Field(..., title="Mentioned pathways", description="The pathways mentioned in the article.")
    mentioned_drugs: List[Drug] = Field(..., title="Mentioned drugs", description="The drugs mentioned in the article.")
    mentioned_diseases: List[Disease] = Field(..., title="Mentioned diseases", description="The diseases mentioned in the article.")


In [None]:
system_message = """You are an expert in extracting structured information from medical journal articles.
Identify key details such as title and authors, plus mentioned entites such as metabolites and pathways.
Present the extracted information in a clear, structured format. Be consise, focusing on essential
content and ignoring unnecessary boilerplate language."""

In [52]:
system_message = """You are an expert in extracting structured information from medical journal articles.
Identify key details such as title and authors, plus mentioned entites such as metabolites and pathways.
Present the extracted information in a clear, structured format. Be comprehensive and extract every single
mentioned entity. You will be evaluated on the quality and completeness of the extracted information.

If you are not confident in the identifier for an entity, you can specify it as "unknown". It is better
to include an entity with an "unknown" identified than to omit it entirely."""

In [53]:
def extract(document, model="gpt-4o-2024-08-06", temperature=0):
    response = client.beta.chat.completions.parse(
        model=model,
        temperature=temperature,
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": document},
        ],
        response_format=Article,
    )
    return json.loads(response.choices[0].message.content)

In [54]:
with open('../data/MolGenetMetab_136_306_2022.txt', 'r') as file:
    contents = file.read()

In [55]:
data = extract(contents)

In [61]:
print(json.dumps(data, indent=4))
Article.parse_obj(data)

{
    "title": "Metabolomics analysis reveals dysregulation in one carbon metabolism in Friedreich Ataxia",
    "journal": "Molecular Genetics and Metabolism",
    "year": 2022,
    "volume": "136",
    "pubmed_id": "unknown",
    "authors": [
        {
            "firstname": "Thomas",
            "middle_initial": "M.",
            "surname": "O'Connell",
            "affiliation": "Department of Otolaryngology-Head & Neck Surgery, Indiana University School of Medicine, Indianapolis, IN, United States of America"
        },
        {
            "firstname": "David",
            "middle_initial": "L.",
            "surname": "Logsdon",
            "affiliation": "Department of Anatomy, Cell Biology & Physiology, Indiana University School of Medicine, Indianapolis, IN, United States of America"
        },
        {
            "firstname": "R. Mark",
            "middle_initial": null,
            "surname": "Payne",
            "affiliation": "Department of Pediatrics, Division of C

Article(title='Metabolomics analysis reveals dysregulation in one carbon metabolism in Friedreich Ataxia', journal='Molecular Genetics and Metabolism', year=2022, volume='136', pubmed_id='unknown', authors=[Person(firstname='Thomas', middle_initial='M.', surname="O'Connell", affiliation='Department of Otolaryngology-Head & Neck Surgery, Indiana University School of Medicine, Indianapolis, IN, United States of America'), Person(firstname='David', middle_initial='L.', surname='Logsdon', affiliation='Department of Anatomy, Cell Biology & Physiology, Indiana University School of Medicine, Indianapolis, IN, United States of America'), Person(firstname='R. Mark', middle_initial=None, surname='Payne', affiliation='Department of Pediatrics, Division of Cardiology, and Herman B Wells Center for Pediatric Research, Indiana University School of Medicine, Indianapolis, IN, United States of America')], mentioned_metabolites=[Metabolite(name='formate', chebi_id='unknown'), Metabolite(name='sarcosine

In [None]:
ground_truth_metabolites = """formate
sarcosine
hypoxanthine
homocysteine
ATP
NAD
NADH
fatty acids
dihydroceramide
unsaturated fatty acid
phosphocholines
cholesterol ester
hydroxypropionylcarnitine
acetate
dodecanedioic acid
indoxylsulfate
arginine
glucose
carnitine
glutamate
lysine
histidine
branched chain amino acids
leucine
valine
choline
threonine
ornithine
lactate
succinate
cholesterol ester 20:0
triglycerides
phosphatidylcholine
betaine
methionine
s-adenosylmethionine
s-adenosylhomocysteine
dimethylglycine
iron-sulfur clusters
phosphocreatine
nicotinamide mononucleotide
oxylipin
reactive oxygen species
taurine
β-alanine
pyruvate
""".split("\n")

ground_truth_pathways = """iron-sulfur cluster biogenesis
cellular energy metabolism
mitochondrial electron transport chain
krebs cycle
electron transfer flavoprotein
nuclear gene expression
glycolysis
carbohydrate and fatty acid metabolism
energy metabolism
one-carbon metabolism
folate cycle
methionine salvage
purine nucleotide salvage and synthesis
pyruvate metabolism
""".split("\n")

ground_truth_proteins = [
    "frataxin",
]

ground_truth_drugs = [
    "Etravirine",
    "Resveratrol",
    "SS-31",
    "deferoxamine",
    "BAPTA-AM",
    "antioxidants",
]

ground_truth_diseases = [
    "friedreich ataxia",
    "dyslipidemia",
    "pre-diabetic state",
    "diabetes"
]

In [60]:
article = Article.parse_obj(data)

false_positives = 0
false_negatives = 0
true_positives = 0
for metabolite in article.mentioned_metabolites:
    if metabolite.name.lower() in ground_truth_metabolites:
        print(f"True positive: {metabolite.name}")
        true_positives += 1
    else:
        print(f"False positive: {metabolite.name}")
        false_positives += 1

for metabolite in ground_truth_metabolites:
    found = False
    for mentioned_metabolite in article.mentioned_metabolites:
        if metabolite.lower() == mentioned_metabolite.name.lower():
            found = True
            break
    if not found:
        print(f"False negative: {metabolite}")
        false_negatives += 1

print(f"True positives: {true_positives}")
print(f"False positives: {false_positives}")
print(f"False negatives: {false_negatives}")

True positive: formate
True positive: sarcosine
True positive: hypoxanthine
True positive: homocysteine
True positive: choline
True positive: threonine
True positive: ornithine
True positive: lactate
True positive: hydroxypropionylcarnitine
True positive: succinate
True positive: acetate
True positive: dodecanedioic acid
True positive: indoxylsulfate
True positive: arginine
True positive: glucose
True positive: carnitine
True positive: glutamate
True positive: lysine
True positive: histidine
True positive: leucine
True positive: valine
False negative: ATP
False negative: NAD
False negative: NADH
False negative: fatty acids
False negative: dihydroceramide
False negative: unsaturated fatty acid
False negative: phosphocholines
False negative: cholesterol ester
False negative: branched chain amino acids
False negative: cholesterol ester 20:0
False negative: triglycerides
False negative: phosphatidylcholine
False negative: betaine
False negative: methionine
False negative: s-adenosylmethion

In [63]:
import networkx as nx
from ipycanvas import Canvas
from ipywidgets import VBox, widgets

class GraphCanvas:
    def __init__(self, graph):
        self.graph = graph
        self.canvas = Canvas(width=600, height=400)
        self.node_color = "blue"
        self.edge_color = "black"
        self.node_radius = 5
        self.draw_graph()

    def draw_graph(self):
        # Clear canvas
        self.canvas.clear()

        # Draw edges
        for edge in self.graph.edges:
            u, v = edge
            x1, y1 = self.graph.nodes[u]['x'], self.graph.nodes[u]['y']
            x2, y2 = self.graph.nodes[v]['x'], self.graph.nodes[v]['y']
            self.canvas.stroke_style = self.edge_color
            self.canvas.stroke_line(x1, y1, x2, y2)

        # Draw nodes
        for node in self.graph.nodes:
            x, y = self.graph.nodes[node]['x'], self.graph.nodes[node]['y']
            self.canvas.fill_style = self.node_color
            self.canvas.fill_circle(x, y, self.node_radius)

    def get_widget(self):
        return VBox([self.canvas])

# Create a sample NetworkX graph
G = nx.Graph()
G.add_nodes_from([
    (1, {"x": 100, "y": 150}),
    (2, {"x": 300, "y": 100}),
    (3, {"x": 500, "y": 200}),
])
G.add_edges_from([(1, 2), (2, 3), (3, 1)])

# Render the graph in a Jupyter notebook
graph_widget = GraphCanvas(G)
graph_widget.get_widget()


VBox(children=(Canvas(height=400, width=600),))

In [64]:
import networkx as nx
from ipywidgets import HTML, VBox

class GraphSVG:
    def __init__(self, graph):
        self.graph = graph
        self.node_color = "blue"
        self.edge_color = "black"
        self.node_radius = 5
        self.width = 600
        self.height = 400
        self.svg_template = """
        <svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
            {edges}
            {nodes}
        </svg>
        """
        self.svg_widget = HTML(self.generate_svg())

    def generate_edges(self):
        edges = []
        for u, v in self.graph.edges:
            x1, y1 = self.graph.nodes[u]['x'], self.graph.nodes[u]['y']
            x2, y2 = self.graph.nodes[v]['x'], self.graph.nodes[v]['y']
            edges.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{self.edge_color}" stroke-width="2"/>')
        return "\n".join(edges)

    def generate_nodes(self):
        nodes = []
        for node, attr in self.graph.nodes(data=True):
            x, y = attr['x'], attr['y']
            nodes.append(f'<circle cx="{x}" cy="{y}" r="{self.node_radius}" fill="{self.node_color}" />')
        return "\n".join(nodes)

    def generate_svg(self):
        edges = self.generate_edges()
        nodes = self.generate_nodes()
        return self.svg_template.format(width=self.width, height=self.height, edges=edges, nodes=nodes)

    def get_widget(self):
        return VBox([self.svg_widget])

# Create a sample NetworkX graph
G = nx.Graph()
G.add_nodes_from([
    (1, {"x": 100, "y": 150}),
    (2, {"x": 300, "y": 100}),
    (3, {"x": 500, "y": 200}),
])
G.add_edges_from([(1, 2), (2, 3), (3, 1)])

# Render the graph in a Jupyter notebook
graph_svg = GraphSVG(G)
graph_svg.get_widget()


VBox(children=(HTML(value='\n        <svg width="600" height="400" xmlns="http://www.w3.org/2000/svg" xmlns:xl…

In [110]:
import networkx as nx
import uuid
import json
from IPython.display import HTML

def networkx_to_d3(graph):
    """Convert a NetworkX graph to a D3.js-friendly JSON format."""
    nodes = [{"id": str(n), "x": graph.nodes[n].get('x', 0), "y": graph.nodes[n].get('y', 0)} for n in graph.nodes]
    links = [{"source": str(u), "target": str(v)} for u, v in graph.edges]
    return {"nodes": nodes, "links": links}

def create_d3_graph(graph):
    """Generate an HTML and JS script for rendering a graph with D3.js."""
    graph_data = networkx_to_d3(graph)
    graph_json = json.dumps(graph_data)
    container_id = "graph-" + str(uuid.uuid4())

    html_template = f"""
    <div id="{container_id}" style="margin:0;padding:0"></div>
    <script type="module">
        import * as d3 from "https://cdn.jsdelivr.net/npm/d3@7/+esm";

        const graph = {graph_json};
        const width = d3.select("#{container_id}").node().getBoundingClientRect().width;
        const height = 400;

        window.addEventListener("resize", () => {{
            const width = d3.select("#{container_id}").node().getBoundingClientRect().width;
            svg.attr("width", width);
        }});

        const svg = d3.select("#{container_id}")
            .append("svg")
            .attr("width", width)
            .attr("height", height)
            .style("padding", 0)
            .style("margin", 0);

        const simulation = d3.forceSimulation(graph.nodes)
            .force("link", d3.forceLink(graph.links).id(d => d.id).distance(100))
            .force("charge", d3.forceManyBody().strength(-300))
            .force("center", d3.forceCenter(width / 2, height / 2));

        // Draw links
        const link = svg.append("g")
            .selectAll("line")
            .data(graph.links)
            .enter()
            .append("line")
            .attr("stroke", "#999")
            .attr("stroke-width", 2);

        // Draw nodes
        const node = svg.append("g")
            .selectAll("circle")
            .data(graph.nodes)
            .enter()
            .append("circle")
            .attr("r", 8)
            .attr("fill", "blue")
            .call(d3.drag()
                .on("start", (event, d) => {{
                    if (!event.active) simulation.alphaTarget(0.3).restart();
                    d.fx = d.x;
                    d.fy = d.y;
                }})
                .on("drag", (event, d) => {{
                    d.fx = event.x;
                    d.fy = event.y;
                }})
                .on("end", (event, d) => {{
                    if (!event.active) simulation.alphaTarget(0);
                    d.fx = null;
                    d.fy = null;
                }})
            );

        // Update positions
        simulation.on("tick", () => {{
            link
                .attr("x1", d => d.source.x)
                .attr("y1", d => d.source.y)
                .attr("x2", d => d.target.x)
                .attr("y2", d => d.target.y);

            node
                .attr("cx", d => d.x)
                .attr("cy", d => d.y);
        }});
    </script>
    """
    return HTML(html_template)


In [111]:
# Create a sample NetworkX graph
G = nx.Graph()
G.add_nodes_from([
    (1, {"x": 100, "y": 150}),
    (2, {"x": 300, "y": 100}),
    (3, {"x": 500, "y": 200}),
])
G.add_edges_from([(1, 2), (2, 3), (3, 1)])

# Render the graph with D3.js
create_d3_graph(G)