# Using LLMs and Knowledge graphs to search for PFAS Alternatives

## Project with Saint Gobain

#### Yu-Chuan (Michael) Hsu, Isabella Stewart, Tarjei Hage, Wei Lu, and Markus J. Buehler, MIT, 2025 
mkychsu@MIT.EDU, istewart@MIT.EDU, tphage@MIT.EDU, wl7@MIT.EDU, mbuehler@MIT.EDU
#### LAMM, Massachusetts Institute of Technology


# Allows for distributed or parallel processing of a dataset

In [1]:
import sys, os

try:
    thread_i = int(sys.argv[1]) #which thread number this process is (e.g., in multi-threaded runs)
    total_threads = int(sys.argv[2]) #how many total threads are running

except: 
    thread_i = 0 #If no arguments are provided (e.g. during a notebook run), it defaults to a single-threaded run
    total_threads = 1
    merge_every = 100

In [2]:
#%env TOGETHER_API_KEY

In [3]:
#print(os.getenv("TOGETHER_API_KEY"))

In [4]:
config_list = [
    {
        "model":"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",     
        "api_key":os.getenv("TOGETHER_API_KEY"),
        "max_tokens": 20000
    },
]

# Client Initiation with Together

In [5]:
from together import Together
client = Together(api_key=config_list[0]["api_key"])


In [6]:
#import os
#from GraphReasoning import *

import sys
sys.path.insert(0, '/orcd/home/002/istewart/orcd/pool/hypergraph/GraphReasoning_SG') #change functions here. 


In [7]:
verbatim=False

In [8]:
doc_data_dir = './CompositePDFs_marker' #place where you keep your markdown files
data_dir='./GRAPHDATA_paper'    
data_dir_output='./GRAPHDATA_OUTPUT_paper'

max_tokens = config_list[0]['max_tokens']

embedding_file='composite_LLAMA4_70b.pkl' #what your embedding file will be 


# Embedding the graph with Nomic

In [None]:
if total_threads == 1: ##merging mode: This block only runs if you're not distributing work across multiple threads.

    from transformers import AutoModelForCausalLM, AutoTokenizer
    from sentence_transformers import SentenceTransformer
    embedding_tokenizer =''
    embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

    from GraphReasoning import load_embeddings, save_embeddings, generate_hypernode_embeddings
    
    import hypernetx as hnx

    import torch
    generate_new_embeddings=True

    if os.path.exists(f'{data_dir}/{embedding_file}'):
        print('Found existing embedding file')
        generate_new_embeddings=False
    
    # generate_new_embeddings=True

    with torch.no_grad():
        if generate_new_embeddings:
            H = hnx.Hypergraph({})
            # Extract node IDs (will be empty here)
            nodes = list(H.nodes)
            # Generate embeddings for (new) nodes
            node_embeddings = generate_hypernode_embeddings(
                nodes,
                embedding_tokenizer,
                embedding_model,
            )
            # Save them
            save_embeddings(node_embeddings, f'{data_dir}/{embedding_file}')
        else:
            # Load previously computed embeddings
            node_embeddings = load_embeddings(f'{data_dir}/{embedding_file}')                 

### Load dataset

In [None]:
### Load dataset of papers

import pandas as pd
import glob

### Load all markdown files (dataset of papers) --- assumes each subfolder in doc_data_dir is named like a paper ID
#doc_list becomes a list of full file paths to .md files

doc_list=[]
with os.scandir(f'{doc_data_dir}') as folders:
    for folder in folders:
        doc_list.append(f'{doc_data_dir}/{folder.name}/{folder.name}.md')

doc_list=sorted(doc_list)

### Set up LLM client:

In [None]:

import instructor
from typing import List
from PIL import Image
import base64
from pydantic import BaseModel
from instructor import patch


#new for hypergraph
class Event(BaseModel):
    source: List[str]
    # target: str
    target: List[str]
    relation: str 

class HypergraphJSON(BaseModel):
    events: List[Event]


#Identifies phrases or terms that are potential nodes --> Decide the type of each node --> Determine which nodes are related, and what the relation is 

response_model = HypergraphJSON
system_prompt = '''
                 (
                    "You are a network ontology graph maker who extracts terms and their relations from a given context, using principles from category theory.\n\n"

                    "You are provided with a context chunk (delimited by triple backticks: ```). Your task is to extract an ontology of terms mentioned in the context, representing key scientific concepts, systems, materials, and methods using well-defined, technical, and widely accepted terminology.\n\n"

                    "Proceed step by step:\n"
                    "Thought 1: Traverse the text sentence by sentence. Identify key scientific terms, such as materials, methods, entities, systems, conditions, or acronyms. \n"
                    "    - Focus on extracting terms that are atomistic and domain-relevant.\n"
                    "    - Group modifiers (e.g., 'collagen scaffold') as one term if they form a recognized concept.\n\n"

                    "Thought 2: Determine which terms are related to each other based on their co-occurrence in a sentence or paragraph.\n"
                    "    - A term may relate to multiple other terms.\n"
                    "    - Look for structural, functional, or procedural relationships.\n\n"

                    "Thought 3: For each related group of terms, infer the scientific relationship between them.\n"
                    "    - Use category-theoretic relation names when possible, such as: 'is', 'has', 'acts on', 'used for', 'composed of', 'leads to'.\n"
                    "    - If 3 or more co-dependent entities relate to a shared target, use an n-ary relation with the source as a list.\n"
                    "    - If only 2 entities are involved, use a binary relation.\n\n"

                    "Output Specification:\n"
                    "Return a JSON object with a single field: 'events'. Each event must contain:\n"
                    "- 'source': a string (for binary) or a list of entities (for n-ary)\n"
                    "- 'target': the main concept or object being acted upon or described\n"
                    "- 'relation': a concise, meaningful phrase describing the relation between source and target\n\n"

                    "Important:\n"
                    "- Always preserve the original wording for technical terms.\n"
                    "- Do not invent entities or relations that are not implied in the text.\n"
                    "- Do not include any additional fields beyond 'source', 'target', and 'relation'.\n\n"

                    "Examples:\n\n"

                    "Binary relation:\n"
                    "{\n"
                    "  \"source\": \"hydrangea\",\n"
                    "  \"target\": \"flower\",\n"
                    "  \"relation\": \"is a type of\"\n"
                    "}\n\n"

                    "N-ary relation:\n"
                    "{\n"
                    "  \"source\": [\"Sally\", \"Bob\", \"Julia\"],\n"
                    "  \"target\": \"paper 1\",\n"
                    "  \"relation\": \"are equal co-authors of\"\n"
                    "}\n\n"

                    "Return a JSON object with this structure:\n"
                    "{\n"
                    "  \"events\": [\n"
                    "     {\"source\": ..., \"target\": ..., \"relation\": ...},\n"
                    "     {...},\n"
                    "     ...\n"
                    "  ]\n"
                    "}"
                    )
'''

def generate(system_prompt=system_prompt, 
             prompt="",temperature=0.333,
             max_tokens=config_list[0]['max_tokens'], response_model=HypergraphJSON, 
            ):     

    if system_prompt==None:
        messages=[
            {"role": "user", "content": f"{prompt}"},
        ]

    else:
        messages=[
            {"role": "system",  "content": f"{system_prompt}"},
            {"role": "user", "content": f"{prompt}"},
        ]

    # monkey patching: replacing or enhancing client.chat.completions.create
    create = instructor.patch(     #instructor is a python library that wraps LLM responses and validates LLM output against a schema you define (via pydantic)
        create=client.chat.completions.create,    #Automatically converts the response into a real Python object (not just a raw string or dictionary)
        mode=instructor.Mode.JSON_SCHEMA,
    ) 
    

    return create(messages=messages,   
                    model=config_list[0]["model"],
                    max_tokens=max_tokens,
                    temperature=0.333,
                    response_model=response_model,
                   )

def image_to_base64_data_uri(file_path):
    with open(file_path, "rb") as image_file:
        base64_data = base64.b64encode(image_file.read()).decode("utf-8")
        return f"data:image/png;base64,{base64_data}"

def generate_figure(image, system_prompt=system_prompt, 
                prompt="", temperature=0,
                ):

    pwd = os.getcwd()
    image = image.split(pwd)[-1]
    image=Path('.').glob(f'**/{image}', case_sensitive=False)
    image = list(image)[0]

    image_uri = image_to_base64_data_uri(image)
    messages = [
        {"role": "system", "content": "You are an assistant who perfectly describes images."},
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": image_uri}},
                {"type": "text", "text": "Describe this image in detail please."},
            ],
        },
    ]
        
    return create(messages=messages,   
                    model=config_list[0]["model"],
                    max_tokens=max_tokens,
                    temperature=0.333,
                    response_model=response_model,
                   ).choices[0].message.content

In [None]:
print(f'running on {thread_i}-th thread in totally {total_threads} threads') #double check threading 

### Checkpoint Cell: Finds where merging left off with current_merged_i. Handles resuming graph merging in single-threaded mode

#### 1. Finds the latest merged graph from previous runs.

#### 2. Checks if it's valid by trying to load it.

#### 3. If the graph or its parts are corrupted, it deletes them and rolls back to the previous merge index.

#### 4. Sets current_merged_i so merging can resume cleanly.

In [None]:
import os, re, glob, pickle

INT_PREFIX = re.compile(r'^(\d+)_')

def extract_idx(path: str) -> int:
    """Return leading integer prefix before first underscore, else -1."""
    name = os.path.basename(path)
    m = INT_PREFIX.match(name)
    return int(m.group(1)) if m else -1

# Directories
doc_data_dir = './CompositePDFs_marker'
data_dir = './GRAPHDATA_paper'
data_dir_output = './GRAPHDATA_OUTPUT_paper'

if total_threads == 1:  # merging mode
    # 1) find all "*_integrated.pkl" files
    merged_graph_list = sorted(
        glob.glob(f'{data_dir_output}/*_integrated.pkl'),
        reverse=True,
        key=extract_idx
    )

    # 2) determine current merged index; fall back to 0 if none found
    current_merged_i = extract_idx(merged_graph_list[0]) if merged_graph_list else 0

    # 3) find any per-session files for the current index
    last_graph = sorted(
        glob.glob(f'{data_dir}/{current_merged_i}_*.pkl'),
        reverse=True
    )

    # 4) validate heads if both lists non-empty
    if merged_graph_list and last_graph:
        try:
            with open(merged_graph_list[0], 'rb') as f:
                _ = pickle.load(f)
            with open(last_graph[0], 'rb') as f:
                _ = pickle.load(f)
        except Exception:
            print("Validation failed; cleaning up corrupt files.")
            try:
                os.remove(merged_graph_list[0])
            except Exception:
                pass
            for fn in last_graph:
                try:
                    os.remove(fn)
                except Exception:
                    pass
            current_merged_i = max(0, current_merged_i - 1)
    else:
        if not merged_graph_list:
            print("No integrated pickles; starting at 0.")
        else:
            print("No per-session pickles to validate; skipping load check.")

    print(f"Start merging from No. {current_merged_i}")

    # === Filter doc_list to remove missing markdown files ===
    missing_log = os.path.join(data_dir_output, "missing_markdown_folders.txt")
    
    # clear old log once at the start of run
    open(missing_log, "w").close()
    
    valid_doc_list = []
    
    for doc_path in doc_list:
        # Check if the markdown file actually exists
        if not os.path.exists(doc_path):
            folder_path = os.path.dirname(doc_path)
            print(f"⚠️ Missing markdown for: {folder_path}, skipping.")
            with open(missing_log, "a") as logf:
                logf.write(folder_path + "\n")
            continue
        
        valid_doc_list.append(doc_path)
    
    # Replace doc_list with filtered version
    doc_list = valid_doc_list
    print(f"✓ Filtered doc_list: {len(doc_list)} valid documents (excluded missing markdown files)")

else:
    current_merged_i = 0

### Generate a Knowledge Graph (KG) from each document using an LLM + embedding pipeline, then (optionally) merge it into a global graph in merging mode.

In [None]:
import os
import pickle
import hypernetx as hnx
from GraphReasoning import make_hypergraph_from_text, add_new_hypersubgraph_from_text, update_hypernode_embeddings
from datetime import datetime
import time
import torch
import shutil
import traceback


# Initialize the "global" graph
G = hnx.Hypergraph({})

with torch.no_grad():
    for i, doc in enumerate(doc_list):
        # only process docs for this thread
        if i % total_threads != thread_i:
            continue
        # skip already-merged docs
        if i < current_merged_i:
            continue

        # extract title/doi and text
        title = os.path.basename(doc).rsplit('.md', 1)[0]
        doi = title
        with open(doc, 'r') as f:
            txt = f.read()

        # define where this doc's subgraph lives
        graph_root = f'{i}_{title[:100]}'
        current_graph = os.path.join(data_dir, f'{graph_root}.pkl')

        # Variable to store current document's sub_dfs
        current_doc_sub_dfs = None
        
        # generate until the subgraph file appears
        while not os.path.exists(current_graph):
            print(f"Generating KG for {i}: {title}")
            try:
                if not isinstance(txt, str):
                    print("Text is not a string:", txt)
                    break
                now = datetime.now()
                # FIX: Capture all return values including sub_dfs
                current_graph, _, sub_dfs_pkl_path, current_doc_sub_dfs = make_hypergraph_from_text(
                    txt, generate, generate_figure, image_list='',
                    graph_root=graph_root,
                    do_distill=False,
                    do_relabel=False,
                    chunk_size=10000, chunk_overlap=0,
                    repeat_refine=0, verbatim=False,
                    data_dir=data_dir,
                )
                print("Time:", datetime.now() - now)
                print(f"[generation] Captured sub_dfs with {len(current_doc_sub_dfs) if current_doc_sub_dfs else 0} chunks")
            except Exception as e:
                print("Error during KG generation:", repr(e))
                time.sleep(60)

        # ── MERGING MODE ──
        if total_threads == 1:

            if i % merge_every == 0:
                do_simplify_graph = True
                size_threshold = 10
            else:
                do_simplify_graph = False
                size_threshold = 0
                
            _hypergraph_pkl = os.path.join(data_dir_output, f'{graph_root}_integrated.pkl')
            # skip if already merged
            if os.path.exists(_hypergraph_pkl):
                with open(_hypergraph_pkl, 'rb') as f:
                    G = pickle.load(f)  
                print(f"[merge] already have integrated at {_hypergraph_pkl}")
                continue            
            
            now = datetime.now()
            print(f"[merge] about to merge paper {i}: will write {_hypergraph_pkl!r}")

            try:
                graph_path = current_graph  # save the path
                with open(graph_path, 'rb') as f:
                    H0 = pickle.load(f)
                current_graph = hnx.Hypergraph(
                    H0.incidence_dict,
                    edge_attr={'DOI': {eid: doi for eid in H0.incidence_dict}}
                )
                print(f"[merge] loaded subgraph with {len(current_graph.edges)} edges")
            except Exception as e:
                print(f"[merge] failed loading/annotating {graph_path!r}: {e!r}")
                continue

            ### FIX: Load cumulative sub_dfs
            updated_path = os.path.join(data_dir_output, "updated_sub_dfs.pkl")
            if os.path.exists(updated_path):
                with open(updated_path, "rb") as f:
                    sub_dfs = pickle.load(f)
                print(f"[merge] Loaded cumulative sub_dfs from {updated_path} with {len(sub_dfs)} existing chunks")
            else:
                sub_dfs = []
                print(f"[merge] Starting fresh sub_dfs list")

            ### FIX: If generation happened (and wasn't skipped), add current doc's chunks
            if current_doc_sub_dfs is None:
                # Generation was skipped (file already existed), load from the saved pkl
                sub_dfs_pkl_path = os.path.join(data_dir, "original_sub_dfs.pkl")
                if os.path.exists(sub_dfs_pkl_path):
                    with open(sub_dfs_pkl_path, "rb") as f:
                        current_doc_sub_dfs = pickle.load(f)
                    print(f"[merge] Loaded current doc's sub_dfs from {sub_dfs_pkl_path}")
            
            ### FIX: Append current document's chunks to cumulative list
            if current_doc_sub_dfs:
                if isinstance(current_doc_sub_dfs, list):
                    sub_dfs.extend(current_doc_sub_dfs)
                    print(f"[merge] Added {len(current_doc_sub_dfs)} chunks from doc {i}. Total: {len(sub_dfs)} chunks")
                else:
                    sub_dfs.append(current_doc_sub_dfs)
                    print(f"[merge] Added 1 chunk from doc {i}. Total: {len(sub_dfs)} chunks")
            else:
                print(f"[merge] WARNING: No sub_dfs found for doc {i}")

            
            # perform the merge
            integrated_path, G, _, node_embeddings, sub_dfs = add_new_hypersubgraph_from_text(
                txt='',
                node_embeddings=node_embeddings,
                tokenizer=embedding_tokenizer,
                model=embedding_model,
                original_graph=G,
                data_dir_output=data_dir_output,
                graph_root=graph_root,
                do_simplify_graph=do_simplify_graph,
                do_relabel=False,
                size_threshold=size_threshold,
                do_update_node_embeddings=do_simplify_graph,
                repeat_refine=0,
                similarity_threshold=0.9,
                do_Louvain_on_new_graph=False,
                return_only_giant_component=False,
                save_common_graph=False,
                G_to_add=current_graph,
                graph_pkl_to_add=None, 
                sub_dfs=sub_dfs,  # Now contains all previous + current doc's chunks
                verbatim=True,
            )

            print(f"[merge] After merge, sub_dfs contains {len(sub_dfs)} total chunks")

            # ── CONDITIONAL EMBEDDING UPDATE ──
            doc_count = len(doc_list)
            is_last_group = (doc_count < 100) or (i >= (doc_count // 100) * 100)

            if is_last_group:
                try:
                    print(f"[update] Performing embedding update for doc {i} (triggered by small or remainder group)")
                    node_embeddings = update_hypernode_embeddings(node_embeddings, G, embedding_tokenizer, embedding_model)
                except Exception as e:
                    print(f"[update] Failed to update embeddings: {e!r}")

            # ── check consistency ──
            graph_nodes = set(str(n) for n in G.nodes)
            embedding_keys = set(str(k) for k in node_embeddings)
        
            missing = graph_nodes - embedding_keys
            extra = embedding_keys - graph_nodes
        
            if not missing and not extra:
                print(f"[check] Embeddings are aligned with the graph nodes. Count = {len(graph_nodes)}")
            else:
                print(f"[check] Embedding mismatch detected:")
                if missing:
                    print(f"  - Missing embeddings for {len(missing)} nodes: {list(missing)[:5]}")
                if extra:
                    print(f"  - Extra embeddings for {len(extra)} nodes not in graph: {list(extra)[:5]}")
                        
            #rename the very final last saved clean graph 
            is_last_doc = (i == len(doc_list) - 1)
                        
            if is_last_doc:
                final_path = os.path.join(data_dir_output, "final_graph.pkl")
                try:
                    shutil.copyfile(integrated_path, final_path)
                    print(f"[merge] Final graph saved as: {final_path}")
                except Exception as e:
                    print(f"[merge] Failed to rename final graph: {e!r}")
            
            try:
                save_embeddings(node_embeddings, os.path.join(data_dir, embedding_file))
                print("[merge] embeddings saved")
            except Exception as e:
                print(f"[merge] failed to save embeddings: {e!r}")
            
            print("Merge elapsed time:", datetime.now() - now)