Skip to content

Commit

Permalink
#8 related refactoring [using dict for rows representation]. Fixed lo…
Browse files Browse the repository at this point in the history
…gging bug in unit tests.
  • Loading branch information
nicolay-r committed Dec 22, 2022
1 parent 645c070 commit d5e438c
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 48 deletions.
6 changes: 0 additions & 6 deletions arenets/arekit/common/data/storages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ def _iter_rows(self):
def _get_rows_count(self):
raise NotImplemented()

def iter_column_values(self, column_name, dtype=None):
raise NotImplemented()

def get_row(self, row_index):
raise NotImplemented()

# endregion

# endregion
Expand Down
7 changes: 4 additions & 3 deletions arenets/arekit/common/data/views/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def iter_from_storage(self, storage):

linked = []
current_opinion_id = undefined
for row_index, sample_id in enumerate(storage.iter_column_values(const.ID)):
sample_id = str(sample_id)

for _, row_dict in storage:
sample_id = str(row_dict[const.ID])
opinion_id = self.__row_ids_provider.parse_opinion_in_sample_id(sample_id)
if current_opinion_id != undefined:
if opinion_id != current_opinion_id:
Expand All @@ -26,7 +27,7 @@ def iter_from_storage(self, storage):
else:
current_opinion_id = opinion_id

linked.append(storage.get_row(row_index))
linked.append(row_dict)

if len(linked) > 0:
yield linked
24 changes: 3 additions & 21 deletions arenets/arekit/contrib/utils/data/storages/pandas_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,14 @@ def __init__(self, df=None):
assert(isinstance(df, pd.DataFrame) or df is None)
self._df = df

@staticmethod
def __iter_rows_core(df):
assert(isinstance(df, pd.DataFrame))
for row_index, row in df.iterrows():
yield row_index, row

# region protected methods

def _iter_rows(self):
for row_index, row in self.__iter_rows_core(self._df):
yield row_index, row
assert(isinstance(self._df, pd.DataFrame))
for row_index, row in self._df.iterrows():
yield row_index, row.to_dict()

def _get_rows_count(self):
return len(self._df)

# endregion

# region public methods

def get_row(self, row_index):
return self._df.iloc[row_index]

def iter_column_values(self, column_name, dtype=None):
values = self._df[column_name]
if dtype is None:
return values
return values.astype(dtype)

# endregion
11 changes: 5 additions & 6 deletions arenets/core/input/rows_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pandas as pd

from arenets.arekit.common.data import const
from . import const as network_input_const
from ...arekit.common.utils import filter_whitespaces, split_by_whitespaces
Expand Down Expand Up @@ -40,13 +38,14 @@ def __process_int_values_list(value):


class ParsedSampleRow(object):
"""
Provides a parsed information for a sample row.
TODO. Use this class as API
""" Provides a parsed information for a sample row.
"""

def __init__(self, row):
assert(isinstance(row, pd.Series))
""" row: dict
dict of the pairs ("field_name", value)
"""
assert(isinstance(row, dict))

self.__uint_label = None
self.__params = {}
Expand Down
18 changes: 6 additions & 12 deletions tests/test_samples_iter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from os.path import dirname, join

import pandas as pd
import gzip
import sys
import unittest


sys.path.append('../')

from arenets.arekit.common.data import const
from arenets.arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader
from arenets.core.input.rows_parser import ParsedSampleRow
from arenets.context.configurations.base.base import DefaultNetworkConfig
from arenets.sample import InputSample
Expand Down Expand Up @@ -48,13 +46,9 @@ def test_show_shifted_examples_only(self):
@staticmethod
def __iter_tsv_gzip(input_file):
"""Reads a tab separated value file."""
df = pd.read_csv(input_file,
compression='gzip',
sep='\t',
encoding='utf-8')

for row_index, _ in enumerate(df[const.ID]):
yield df.iloc[row_index]
reader = PandasCsvReader(compression='gzip', sep='\t', encoding='utf-8')
for _, row in reader.read(input_file):
yield row

@staticmethod
def __read_vocab(input_file):
Expand Down Expand Up @@ -126,8 +120,8 @@ def __test_core(self, terms_vocab, config, samples_filepath):
print("frame_connots_uint: {}".format(row.TextFrameConnotations))
print("syn_obj: {}".format(row.SynonymObjectInds))
print("syn_subj: {}".format(row.SynonymSubjectInds))
print("terms:".format(row.Terms))
print("pos_tags:".format(row.PartOfSpeechTags))
print("terms: {}".format(row.Terms))
print("pos_tags: {}".format(row.PartOfSpeechTags))

print(self.__terms_to_text_line(terms=row.Terms, frame_inds_set=set(row.TextFrameVariantIndices)))

Expand Down
21 changes: 21 additions & 0 deletions tests/test_storage_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
from os.path import join, dirname

from arenets.arekit.common.data.row_ids.base import BaseIDProvider
from arenets.arekit.common.data.views.samples import LinkedSamplesStorageView
from arenets.arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader


class TestSamplesStorageView(unittest.TestCase):

def __get_local_dir(self, local_filepath):
return join(dirname(__file__), local_filepath)

def test(self):
samples_filepath = self.__get_local_dir("test_data/sample-train.tsv.gz")
reader = PandasCsvReader()
storage = reader.read(samples_filepath)
samples_view = LinkedSamplesStorageView(row_ids_provider=BaseIDProvider())
for data in samples_view.iter_from_storage(storage):
print(type(data))
print(len(storage))

0 comments on commit d5e438c

Please sign in to comment.