Skip to content

Commit

Permalink
added functionality for splitting sentences
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-Stewart-Webdev committed Apr 5, 2024
1 parent 7e8bf11 commit be58cac
Show file tree
Hide file tree
Showing 15 changed files with 1,029 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ data-out.json
data-out-random.json
data-out-smart.json
data-out-manip.json
htmlcov/*
htmlcov/*
example_3.py
temp_output.json
1 change: 1 addition & 0 deletions docs/basic_functionality.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ Puggle also contains functions for manipulating a dataset, for example:
* :func:`~puggle.Dataset.Dataset.drop_relation_class` removes all instances of the given relation class from a dataset.
* :func:`~puggle.Dataset.Dataset.convert_entity_class` converts all entities with the given class to another class.
* :func:`~puggle.Dataset.Dataset.convert_relation_class` converts all relations with the given class to another class.
* :func:`~puggle.Dataset.Dataset.split_sentences` creates a new `Dataset` by splitting the sentences of the given `Dataset` based on a delimiter (such as a full stop).

For more info, see the :doc:`puggle`.

Expand Down
22 changes: 22 additions & 0 deletions puggle/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,24 @@ def _load_annotations(self, filename: os.path, anns_format: str):
logger.debug(f"Loaded {len(annotations)} annotations from {filename}.")
return annotations

def get_stats(self):
"""Return a string of some useful stats of this dataset.
Returns:
str: Stats (num docs, mentions, rels)
"""
num_mentions = sum(
[len(doc.annotation.mentions) for doc in self.documents]
)
num_relations = sum(
[len(doc.annotation.relations) for doc in self.documents]
)

return (
f"Dataset containing {len(self.documents)} documents, "
f" {num_mentions} mentions, and {num_relations} relations."
)

def __repr__(self):
"""String representation of the dataset.
Expand Down Expand Up @@ -538,6 +556,10 @@ def _to_spert(dataset: Dataset) -> List[Dict]:
"entities": entities,
"relations": relations,
}
# If the document has a document index (after sentence splitting),
# carry that through to the output
if doc.document_index is not None:
sd["document_index"] = doc.document_index
spert_docs.append(sd)

return spert_docs
98 changes: 98 additions & 0 deletions puggle/Document.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ def __init__(
self,
structured_fields: List[Dict] = None,
annotation: Annotation = None,
document_index: int = None,
):
"""Create a new document.
Args:
structured_fields (List[Dict], optional): List of fields.
annotation (Annotation, optional): The Annotation of the textual
part of this document (such as annotations over the short text)
document_index (None): When set, this is useful when splitting the
documents into sentences. The document_index is the index of the
original document that this sentence came from.
"""
super().__init__()
self.fields = structured_fields
self.annotation = annotation
self.document_index = document_index

def to_dict(self):
"""Return a dict of this Document.
Expand All @@ -41,5 +46,98 @@ def to_dict(self):
else None,
}

def split_sentences(self, delimiter):
"""Split this document into sentences, i.e. a list of Documents
that have been split by the given delimiter.
Args:
delimiter (str): The delimiter to use.
Returns:
List[Document]: List of documents.
List[Relation]: List of relations that were removed due to being
across multiple sentences.
"""

new_docs = []
sent_tokens = []

sent_start_idx = 0
sent_end_idx = 0

removed_relations = []
seen_rels = set()

for i, token in enumerate(self.annotation.tokens):
if token != delimiter:
sent_tokens.append(token)

if token == delimiter or i == (len(self.annotation.tokens) - 1):
sent_end_idx = i
if i == len(self.annotation.tokens) - 1:
sent_end_idx = i + 1

sent_mentions = []
sent_mention_ids = {}

# Rebuild the list of mentions in this sentence
for m in self.annotation.mentions:
if (m.start < sent_start_idx) or (m.end > sent_end_idx):
continue

m_dict = m.to_dict()
m_dict["start"] = m_dict["start"] - sent_start_idx
m_dict["end"] = m_dict["end"] - sent_start_idx
sent_mentions.append(m_dict)
sent_mention_ids[m] = len(sent_mentions) - 1

# Rebuild the list of relations in this sentence
# Discard any cross-sentence relations (whose start or end do
# not lie within this sentence)
# print('-------')
sent_relations = []
for r in self.annotation.relations:
# print(
# r,
# r.start.start,
# r.end.start,
# sent_start_idx,
# sent_end_idx,
# )
if (
(r.start.start < sent_start_idx)
or (r.end.start > sent_end_idx)
or (r.start.start > sent_end_idx)
or (r.end.start < sent_start_idx)
):
continue

seen_rels.add(r)
r_dict = r.to_dict()
# print(sent_mention_ids)
r_dict["start"] = sent_mention_ids[r.start]
r_dict["end"] = sent_mention_ids[r.end]
sent_relations.append(r_dict)

new_ann = Annotation(
tokens=sent_tokens,
mentions=sent_mentions,
relations=sent_relations,
)

new_docs.append(Document(self.fields, new_ann))

sent_tokens = []
sent_mentions = []
sent_relations = []

sent_start_idx = i + 1

for r in self.annotation.relations:
if r not in seen_rels:
removed_relations.append(r)

return new_docs, removed_relations

def __str__(self):
return str(self.to_dict())
40 changes: 40 additions & 0 deletions puggle/data_utils/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,49 @@ def flatten_all_relations(self: Dataset):
)


def split_sentences(self: Dataset, delimiter="."):
"""Split each document of this Dataset into sentences.
Args:
delimiter (str, optional): The delimiter to use for splitting.
Returns:
Dataset: A new dataset, where each document is a sentence. Each doc
also has a document_index, allowing the user to know which doc
the sentence originally came from.
"""
new_dataset = Dataset()
all_relations_removed = []

for i, d in enumerate(self.documents):
sents, relations_removed = d.split_sentences(delimiter=delimiter)
for s in sents:
s.document_index = i
new_dataset.add_document(s)
all_relations_removed += relations_removed

results = {"relations_removed": len(all_relations_removed)}

logger.info("Original dataset: %s" % self.get_stats())

logger.info(
"Removed %d relations that spanned multiple sentences."
% len(all_relations_removed)
)
logger.info(
"(an average of %.2f relations per document)"
% (len(all_relations_removed) / len(self.documents))
)

logger.info("New dataset: %s" % new_dataset.get_stats())

return new_dataset, results


Dataset.drop_entity_class = drop_entity_class
Dataset.drop_relation_class = drop_relation_class
Dataset.convert_entity_class = convert_entity_class
Dataset.convert_relation_class = convert_relation_class
Dataset.flatten_all_entities = flatten_all_entities
Dataset.flatten_all_relations = flatten_all_relations
Dataset.split_sentences = split_sentences
25 changes: 24 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def dataset_json_path(request):
return os.path.join(FIXTURE_DIR, f"{name}.json")


@pytest.fixture
def dataset_json_path_2(request):
name = request.param
return os.path.join(FIXTURE_DIR, f"{name}.json")


@pytest.fixture
def dataset_csv_path(request):
name = request.param
Expand All @@ -36,13 +42,30 @@ def dataset_untyped_path(request):


def _get_dataset(request):
"""Retrieve the given dataset. The format it will load (quickgraph or
spert) will depend on the name of the dataset to load (if it contains
"quickgraph", it will be loaded as quickgraph, otherwise it'll be loaded
as spert).
If name is empty, it will return an empty dataset.
Args:
request (TYPE): The request, i.e. the name of the dataset to load.
Returns:
Dataset: The dataset.
"""
name = request.param

data_format = "spert"
if "quickgraph" in name:
data_format = "quickgraph"

if name == "empty":
return Dataset()
else:
d = Dataset()
d.load_documents(
anns_filename=(os.path.join(FIXTURE_DIR, f"{name}.json")),
anns_format="spert",
anns_format=data_format,
)
return d
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
[
{
"original": "one three two",
"tokens": ["one", "three", "two"],
"document_index": 0,
"entities": [
{
"id": "1",
"start": 0,
"end": 0,
"label": "number"
},
{
"id": "2",
"start": 1,
"end": 1,
"label": "number"
},
{
"id": "3",
"start": 2,
"end": 2,
"label": "number"
}
],
"relations": [
{
"source_id": "2",
"target_id": "1",
"label": "bigger_than"
},
{
"source_id": "2",
"target_id": "3",
"label": "bigger_than"
}
]
},
{
"original": "six two",
"tokens": ["six", "two"],
"document_index": 0,
"entities": [
{
"id": "1",
"start": 0,
"end": 0,
"label": "number"
},
{
"id": "2",
"start": 1,
"end": 1,
"label": "number"
}
],
"relations": [
{
"source_id": "1",
"target_id": "2",
"label": "bigger_than"
}
]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
[
{
"tokens": ["one", "three", "two"],
"document_index": 0,
"entities": [
{
"start": 0,
"end": 1,
"type": "number"
},
{
"start": 1,
"end": 2,
"type": "number"
},
{
"start": 2,
"end": 3,
"type": "number"
}
],
"relations": [
{
"head": 1,
"tail": 0,
"type": "bigger_than"
},
{
"head": 1,
"tail": 2,
"type": "bigger_than"
}
]
},
{
"tokens": ["six", "two"],
"document_index": 0,
"entities": [
{
"start": 0,
"end": 1,
"type": "number"
},
{
"start": 1,
"end": 2,
"type": "number"
}
],
"relations": [
{
"head": 0,
"tail": 1,
"type": "bigger_than"
}
]
}
]
Loading

0 comments on commit be58cac

Please sign in to comment.