Skip to content

Commit

Permalink
#429 sync. Fix #427
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 12, 2023
1 parent 35fd8ae commit ea26d16
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
1 change: 1 addition & 0 deletions arekit/contrib/networks/input/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
SynonymObject = "syn_objs"
SynonymSubject = "syn_subjs"
PosTags = "pos_tags"
Text = "text"

ArgsSep = ','
51 changes: 33 additions & 18 deletions arekit/contrib/networks/input/rows_parser.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
import pandas as pd

from arekit.common.data import const
from arekit.common.utils import filter_whitespaces, split_by_whitespaces
from . import const as network_input_const

import arekit.contrib.networks.input.const as network_input_const

empty_list = []


def no_value():
return None


def __process_values_list(value):
return value.split(network_input_const.ArgsSep)


def __process_indices_list(value):
return [int(v) for v in str(value).split(network_input_const.ArgsSep)]
return no_value() if not value else [int(v) for v in str(value).split(network_input_const.ArgsSep)]


def __process_int_values_list(value):
return __process_indices_list(value)


def __handle_text(value):
""" The core method of the input text processing.
"""
assert(isinstance(value, str) or isinstance(value, list))
return filter_whitespaces([term for term in split_by_whitespaces(value)]
if isinstance(value, str) else value)


parse_value = {
const.ID: lambda value: value,
const.DOC_ID: lambda value: int(value),
Expand All @@ -35,18 +46,19 @@ def __process_int_values_list(value):
network_input_const.SynonymObject: lambda value: __process_indices_list(value),
network_input_const.SynonymSubject: lambda value: __process_indices_list(value),
network_input_const.PosTags: lambda value: __process_int_values_list(value),
"text_a": lambda value: filter_whitespaces([term for term in split_by_whitespaces(value)])
network_input_const.Text: lambda value: __handle_text(value)
}


class ParsedSampleRow(object):
"""
Provides a parsed information for a sample row.
TODO. Use this class as API
""" Provides a parsed information for a sample row.
"""

def __init__(self, row):
assert(isinstance(row, pd.Series))
""" row: dict
dict of the pairs ("field_name", value)
"""
assert(isinstance(row, dict))

self.__uint_label = None
self.__params = {}
Expand All @@ -64,13 +76,16 @@ def __init__(self, row):

self.__params[key] = parse_value[key](value)

def __value_or_none(self, key):
return self.__params[key] if key in self.__params else no_value()

@property
def SampleID(self):
return self.__params[const.ID]

@property
def Terms(self):
return self.__params["text_a"]
return self.__params[network_input_const.Text]

@property
def SubjectIndex(self):
Expand All @@ -86,33 +101,33 @@ def UintLabel(self):

@property
def PartOfSpeechTags(self):
return self.__params[network_input_const.PosTags]
return self.__value_or_none(network_input_const.PosTags)

@property
def TextFrameVariantIndices(self):
return self.__params[network_input_const.FrameVariantIndices]
return self.__value_or_none(network_input_const.FrameVariantIndices)

@property
def TextFrameConnotations(self):
return self.__params[network_input_const.FrameConnotations]
return self.__value_or_none(network_input_const.FrameConnotations)

@property
def EntityInds(self):
return self.__params[const.ENTITIES]
return self.__value_or_none(const.ENTITIES)

@property
def SynonymObjectInds(self):
return self.__params[network_input_const.SynonymObject]
return self.__value_or_none(network_input_const.SynonymObject)

@property
def SynonymSubjectInds(self):
return self.__params[network_input_const.SynonymSubject]
return self.__value_or_none(network_input_const.SynonymSubject)

def __getitem__(self, item):
assert (isinstance(item, str) or item is None)
if item not in self.__params:
return None
return self.__params[item] if item is not None else None
return no_value()
return self.__params[item] if item is not None else no_value()

@classmethod
def parse(cls, row):
Expand Down

0 comments on commit ea26d16

Please sign in to comment.