Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions to parse YAML output to KGX-compliant CSV #349

Merged
merged 21 commits into from
Mar 22, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 160 additions & 9 deletions src/ontogpt/io/csv_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import logging
from pathlib import Path
from typing import Any, List

import yaml
import pandas as pd
import uuid
from tqdm import tqdm
from oaklib import get_adapter
from pydantic import BaseModel

Expand Down Expand Up @@ -36,7 +39,8 @@ def output_parser(obj: Any, file) -> List[str]:
# Extract all 'subject', 'predicate', and 'object' output lines
with open(output_file, "r") as file:
to_print = False
perpetuators = tuple([" subject:", " predicate:", " object:", " "])
perpetuators = tuple(
[" subject:", " predicate:", " object:", " "])
for line in file:
line = line.strip("\n")
if line.startswith("extracted_object:"):
Expand All @@ -62,15 +66,17 @@ def output_parser(obj: Any, file) -> List[str]:
i = 0
while i < len(cleaned_lines):
if cleaned_lines[i].startswith("extracted_object"):
for index, elem in enumerate(cleaned_lines[i + 1 : i + 4]):
for index, elem in enumerate(cleaned_lines[i + 1:i + 4]):
if elem.startswith("extracted_object"):
next_index = i + 1 + index
del cleaned_lines[i:next_index]
i -= 1
i += 1

# Separate extracted values into indexed items in dictionary of lists
grouped_lines = [cleaned_lines[n : n + 4] for n in range(0, len(cleaned_lines), 4)]
grouped_lines = [
cleaned_lines[n:n + 4] for n in range(0, len(cleaned_lines), 4)
]
trimmed_dict: dict = {"genes": [], "relationships": [], "exposures": []}
for group in grouped_lines:
group.pop(0)
Expand All @@ -94,7 +100,7 @@ def output_parser(obj: Any, file) -> List[str]:
for index, elem in enumerate(value):
if ":" in elem:
try:
prefix = elem[: (elem.index(":"))]
prefix = elem[:(elem.index(":"))]
adapter_str = "sqlite:obo:" + str(prefix)
curr_adapter = get_adapter(adapter_str)
trimmed_dict[key][index] = curr_adapter.label(elem)
Expand All @@ -106,12 +112,14 @@ def output_parser(obj: Any, file) -> List[str]:

def write_obj_as_csv(obj: Any, file, minimize=True, index_field=None) -> None:
if isinstance(obj, BaseModel):
obj = obj.model_dump()
obj = obj.dict()
if isinstance(obj, list):
rows = obj
elif not isinstance(obj, dict):
if not index_field:
index_fields = [k for k, v in obj.items() if v and isinstance(v, list)]
index_fields = [
k for k, v in obj.items() if v and isinstance(v, list)
]
if len(index_fields) == 1:
index_field = index_fields[0]
logger.warning(f"Using {index_field} as index field")
Expand All @@ -120,7 +128,7 @@ def write_obj_as_csv(obj: Any, file, minimize=True, index_field=None) -> None:
raise ValueError(f"Cannot dump {obj} as CSV")
if isinstance(file, Path) or isinstance(file, str):
file = open(file, "w", encoding="utf-8")
rows = [row.model_dump() if isinstance(row, BaseModel) else row for row in rows]
rows = [row.dict() if isinstance(row, BaseModel) else row for row in rows]
writer = csv.DictWriter(file, fieldnames=rows[0].keys(), delimiter="\t")
writer.writeheader()
for row in rows:
Expand All @@ -131,5 +139,148 @@ def _str(s):
return str(s)

# row = {k: v for k, v in row.items() if "\n" not in str(v)}
row = {k: _str(v).replace("\n", r"\n").replace("\t", " ") for k, v in row.items()}
row = {
k: _str(v).replace("\n", r"\n").replace("\t", " ")
for k, v in row.items()
}
writer.writerow(row)


def schema_plurals_to_camelcase(schema_path):
"""
Returns a dictionary to map the underscored plural names to the
schema-defined entity and relation types. Assumes that the user follows the
convention that a type defined in singular with camel case is defined as
part of EntityContainingDocument pluralized with underscores; e.g.
GeneProteinInteraction --> gene_protein_interactions.

For issues, tag @serenalotreck

parameters:
schema_path, str: path to schema YAML file

returns:
schema_types, dict: keys are underscored names, values are camelcase
names
"""
# Read in the schema
with open(schema_path) as stream:
schema = yaml.load(stream, yaml.FullLoader)

# Get underscore names
underscore_names = [
name
for name in schema['classes']['EntityContainingDocument']['attributes']
]

# Convert to camelcase names
camelcase_map = {
name: ''.join([part.capitalize() for part in name.split('_')])[:-1]
for name in underscore_names
}

# Confirm that the camelcase names exist
for name in camelcase_map.values():
assert name in schema['classes'].keys(
), f'Name {name} does not appear in classes'

return camelcase_map


def parse_yaml_predictions(yaml_path, schema_path):
"""
Parse named entities and relations from the YAML output of OntoGPT.
Currently only supports binary relations.

For issues, tag @serenalotreck

parameters:
yaml_path, str: path to YAML file to parse. Can contain multiple
YAML documents.
schema_path, str: path to schema YAML file

returns:
ent_df, pandas df: dataframe with entities from YAML output
rel_df, pandas df: dataframe with relations from YAML output
"""
# Read in the YAML file
with open(yaml_path) as stream:
output_docs = list(yaml.safe_load_all(stream))

# Get type map
type_map = schema_plurals_to_camelcase(schema_path)

# Initialize objects to store data
ent_rows = []
rel_rows = []

# Format entity label to type dict
# Note: Have to do this here because the named entity list gets added
# to with every doc, instead of just containing entities for one doc at a
# time
ent_types = {}
for doc in tqdm(output_docs):
for typ, ent_list in doc['extracted_object'].items():
for ent in ent_list:
if isinstance(ent, str):
ent_types[ent] = typ

# Parse documents
# Note: assumes that in extracted_object, types with strings in a list are
# entities, and types with dicts in a list are relations.
for doc in output_docs:

# Get the elements we need
obj = doc['extracted_object']
ents = doc['named_entities']

# Index entities by ID
ent_labels = {ent['id']: ent['label'] for ent in ents}

# Format relations
rel_types = {
k: v
for k, v in obj.items() if all([isinstance(rl, dict) for rl in v])
}
for rel_type, rels in rel_types.items():
for rel in rels:
row = {}
for i, pair in enumerate(
rel.items()
): # Allows parsing without needing component entity types
# (relies on preservation of insertion order)
# Enforce binary relations
assert len(rel) == 2, 'At least one relation is n-ary'
# Get subject and predicate
if i == 0:
row['subject'] = pair[1]
elif i == 1:
row['object'] = pair[1]
# Get other relation data
row['predicate'] = type_map[rel_type]
row['category'] = rel_type
row['provided_by'] = obj['id']
row['id'] = str(uuid.uuid4())
rel_rows.append(row)

# Format entities
for ent, lab in ent_labels.items():
row = {}
row['id'] = ent
try:
row['category'] = ent_types[ent]
except KeyError:
row['category'] = 'UNKNOWN'
row['name'] = lab
row['provided_by'] = obj['id']
ent_rows.append(row)

# Make dataframes
ent_df = pd.DataFrame(ent_rows)
rel_df = pd.DataFrame(rel_rows)

# Drop repeated entities
ent_df = ent_df.drop_duplicates()
rel_df = rel_df.drop_duplicates()

return ent_df, rel_df
Loading