Skip to content

Commit

Permalink
Adding generate-extract command, 158. Add cell type templates #159 (#162
Browse files Browse the repository at this point in the history
)

This PR does two things:

- Add a combined generate-extract command, fixes #158
- Adds cell type templates, fixes #159

## Generate-Extract

`ontogpt generate-extract -m gpt-4 -t cell_type "Acinar Cell Of Salivary
Gland"`

This does two things

1. asks GPT to generate a summary of the cell type
2. parses/extracts knowledge from that cell type

This rescuscitates the original HALO idea. We could in principle
**directly generate an entire knowledgebase in structured form from the
latent GPT KB**

Example output:

```yaml
extracted_object:
  cell_type: Acinar cell of a salivary gland
  parents:
    - CL:0000066
  subtypes:
    - CL:0000313
    - CL:0000319
  localizations:
    - UBERON:0001044
    - UBERON:0009842
  diseases:
    - AUTO:Sj%C3%B6gren%27s%20syndrome
    - MONDO:0021357
named_entities:
  - id: CL:0000066
    label: Epithelial cell
  - id: CL:0000313
    label: Serous cells
  - id: CL:0000319
    label: Mucous cells
  - id: UBERON:0001044
    label: Salivary gland
  - id: UBERON:0009842
    label: Acinus
  - id: AUTO:Sj%C3%B6gren%27s%20syndrome
    label: Sjögren's syndrome
  - id: MONDO:0021357
    label: Salivary gland tumors
```

## Cell Type Templates

This PR also demonstrates using subclasses for more refined subtypes

Compare the two:

1. `ontogpt generate-extract -m gpt-4 -t cell_type "L2/3
Intratelencephalic Projecting Glutamatergic Neuron Of The Primary Motor
Cortex"`
2. 1ontogpt generate-extract -m gpt-4 -t cell_type.InterneuronDocument
"L2/3 Intratelencephalic Projecting Glutamatergic Neuron Of The Primary
Motor Cortex"`

The first uses the generic base class. the second uses a subclass
designed for interneurons, which has an extra slot for projection fields

Example output:

```yaml
extracted_object:
  cell_type: L2/3 Intratelencephalic Projecting Glutamatergic Neuron of the Primary
    Motor Cortex
  range: Not mentioned
  parents:
    - AUTO:excitatory%20neuron
  subtypes:
    - AUTO:Not%20mentioned
  localizations:
    - UBERON:0000956
    - UBERON:0001384
  genes:
    - AUTO:Not%20mentioned
  diseases:
    - MONDO:0005180
    - MONDO:0020128
  projects_to_or_from:
    - UBERON:0001893
named_entities:
  - id: UBERON:0001893
    label: telencephalon
  - id: AUTO:excitatory%20neuron
    label: excitatory neuron
  - id: AUTO:Not%20mentioned
    label: Not mentioned
  - id: UBERON:0000956
    label: cerebral cortex
  - id: UBERON:0001384
    label: primary motor cortex
  - id: MONDO:0005180
    label: Parkinson's disease
  - id: MONDO:0020128
    label: motor neuron disease
```
  • Loading branch information
cmungall committed Jul 31, 2023
2 parents 40a6645 + 3774dec commit 90d3eaa
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 6 deletions.
28 changes: 23 additions & 5 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ def write_extraction(
default="yaml",
help="Output format.",
)
auto_prefix_option = click.option(
"--auto-prefix",
default="AUTO",
help="Prefix to use for auto-generated classes. Default is AUTO.",
)


@click.group()
Expand Down Expand Up @@ -201,11 +206,7 @@ def main(verbose: int, quiet: bool, cache_db: str, skip_annotator):
@click.option("--dictionary")
@output_format_options
@use_textract_options
@click.option(
"--auto-prefix",
default="AUTO",
help="Prefix to use for auto-generated classes. Default is AUTO.",
)
@auto_prefix_option
@click.option(
"--set-slot-value",
"-S",
Expand Down Expand Up @@ -304,6 +305,23 @@ def extract(
write_extraction(results, output, output_format, ke)


@main.command()
@template_option
@model_option
@recurse_option
@output_option_wb
@output_format_options
@auto_prefix_option
@click.argument("entity")
def generate_extract(entity, template, output, output_format, **kwargs):
"""Generate text using GPT and then extract knowledge from it."""
logging.info(f"Creating for {template}")
ke = SPIRESEngine(template, **kwargs)
logging.debug(f"Input entity: {entity}")
results = ke.generate_and_extract(entity)
write_extraction(results, output, output_format)


@main.command()
@template_option
@model_option
Expand Down
6 changes: 5 additions & 1 deletion src/ontogpt/engines/knowledge_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def __post_init__(self):
self.mappers = [get_adapter("translator:")]

self.set_up_client()
self.encoding = tiktoken.encoding_for_model(self.client.model)
try:
self.encoding = tiktoken.encoding_for_model(self.client.model)
except KeyError:
self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
logger.error(f"Could not find encoding for model {self.client.model}")

def set_api_key(self, key: str):
self.api_key = key
Expand Down
8 changes: 8 additions & 0 deletions src/ontogpt/engines/spires_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,18 @@ def extract_from_text(
named_entities=self.named_entities,
)


def _extract_from_text_to_dict(self, text: str, cls: ClassDefinition = None) -> RESPONSE_DICT:
raw_text = self._raw_extract(text, cls)
return self._parse_response_to_dict(raw_text, cls)

def generate_and_extract(
self, entity: str, **kwargs
) -> ExtractionResult:
prompt = f"Generate a comprehensive description of {entity}.\n"
payload = self.client.complete(prompt)
return self.extract_from_text(payload, **kwargs)

def generalize(
self, object: Union[pydantic.BaseModel, dict], examples: List[EXAMPLE]
) -> ExtractionResult:
Expand Down
203 changes: 203 additions & 0 deletions src/ontogpt/templates/cell_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from __future__ import annotations
from datetime import datetime, date
from enum import Enum
from typing import List, Dict, Optional, Any, Union, Literal
from pydantic import BaseModel as BaseModel, Field
from linkml_runtime.linkml_model import Decimal

metamodel_version = "None"
version = "None"

class WeakRefShimBaseModel(BaseModel):
__slots__ = '__weakref__'

class ConfiguredBaseModel(WeakRefShimBaseModel,
validate_assignment = True,
validate_all = True,
underscore_attrs_are_private = True,
extra = 'forbid',
arbitrary_types_allowed = True):
pass


class BrainRegionIdentifier(str, Enum):


dummy = "dummy"


class NullDataOptions(str, Enum):

UNSPECIFIED_METHOD_OF_ADMINISTRATION = "UNSPECIFIED_METHOD_OF_ADMINISTRATION"
NOT_APPLICABLE = "NOT_APPLICABLE"
NOT_MENTIONED = "NOT_MENTIONED"



class CellTypeDocument(ConfiguredBaseModel):

cell_type: Optional[str] = Field(None, description="""the name of the cell type described""")
range: Optional[str] = Field(None)
parents: Optional[List[str]] = Field(default_factory=list, description="""categorization""")
subtypes: Optional[List[str]] = Field(default_factory=list)
localizations: Optional[List[str]] = Field(default_factory=list)
genes: Optional[List[str]] = Field(default_factory=list)
diseases: Optional[List[str]] = Field(default_factory=list)



class InterneuronDocument(CellTypeDocument):

projects_to_or_from: Optional[List[str]] = Field(default_factory=list, description="""Brain structures from which this cell type projects into or receives projections from""")
cell_type: Optional[str] = Field(None, description="""the name of the cell type described""")
range: Optional[str] = Field(None)
parents: Optional[List[str]] = Field(default_factory=list, description="""categorization""")
subtypes: Optional[List[str]] = Field(default_factory=list)
localizations: Optional[List[str]] = Field(default_factory=list)
genes: Optional[List[str]] = Field(default_factory=list)
diseases: Optional[List[str]] = Field(default_factory=list)



class ExtractionResult(ConfiguredBaseModel):
"""
A result of extracting knowledge on text
"""
input_id: Optional[str] = Field(None)
input_title: Optional[str] = Field(None)
input_text: Optional[str] = Field(None)
raw_completion_output: Optional[str] = Field(None)
prompt: Optional[str] = Field(None)
extracted_object: Optional[Any] = Field(None, description="""The complex objects extracted from the text""")
named_entities: Optional[List[Any]] = Field(default_factory=list, description="""Named entities extracted from the text""")



class NamedEntity(ConfiguredBaseModel):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class Gene(NamedEntity):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class Pathway(NamedEntity):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class AnatomicalStructure(NamedEntity):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class BrainRegion(AnatomicalStructure):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class CellType(NamedEntity):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class Disease(NamedEntity):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class Drug(NamedEntity):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class CompoundExpression(ConfiguredBaseModel):

None



class Triple(CompoundExpression):
"""
Abstract parent for Relation Extraction tasks
"""
subject: Optional[str] = Field(None)
predicate: Optional[str] = Field(None)
object: Optional[str] = Field(None)
qualifier: Optional[str] = Field(None, description="""A qualifier for the statements, e.g. \"NOT\" for negation""")
subject_qualifier: Optional[str] = Field(None, description="""An optional qualifier or modifier for the subject of the statement, e.g. \"high dose\" or \"intravenously administered\"""")
object_qualifier: Optional[str] = Field(None, description="""An optional qualifier or modifier for the object of the statement, e.g. \"severe\" or \"with additional complications\"""")



class TextWithTriples(ConfiguredBaseModel):

publication: Optional[Publication] = Field(None)
triples: Optional[List[Triple]] = Field(default_factory=list)



class RelationshipType(NamedEntity):

id: str = Field(None, description="""A unique identifier for the named entity""")
label: Optional[str] = Field(None, description="""The label (name) of the named thing""")



class Publication(ConfiguredBaseModel):

id: Optional[str] = Field(None, description="""The publication identifier""")
title: Optional[str] = Field(None, description="""The title of the publication""")
abstract: Optional[str] = Field(None, description="""The abstract of the publication""")
combined_text: Optional[str] = Field(None)
full_text: Optional[str] = Field(None, description="""The full text of the publication""")



class AnnotatorResult(ConfiguredBaseModel):

subject_text: Optional[str] = Field(None)
object_id: Optional[str] = Field(None)
object_text: Optional[str] = Field(None)




# Update forward refs
# see https://pydantic-docs.helpmanual.io/usage/postponed_annotations/
CellTypeDocument.update_forward_refs()
InterneuronDocument.update_forward_refs()
ExtractionResult.update_forward_refs()
NamedEntity.update_forward_refs()
Gene.update_forward_refs()
Pathway.update_forward_refs()
AnatomicalStructure.update_forward_refs()
BrainRegion.update_forward_refs()
CellType.update_forward_refs()
Disease.update_forward_refs()
Drug.update_forward_refs()
CompoundExpression.update_forward_refs()
Triple.update_forward_refs()
TextWithTriples.update_forward_refs()
RelationshipType.update_forward_refs()
Publication.update_forward_refs()
AnnotatorResult.update_forward_refs()

0 comments on commit 90d3eaa

Please sign in to comment.