Skip to content

Commit

Permalink
#503 Remove target extension parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 15, 2023
1 parent 75631b2 commit d33aa51
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
5 changes: 2 additions & 3 deletions arekit/contrib/utils/io_utils/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@

class OpinionsIO(BaseSamplesIO):

def __init__(self, target_dir, reader=None, prefix="opinion", target_extension=".tsv.gz"):
def __init__(self, target_dir, reader=None, prefix="opinion"):
assert(isinstance(reader, BaseReader))
self.__target_dir = target_dir
self.__prefix = prefix
self.__reader = reader
self.__target_extension = reader.extension() \
if target_extension is None else target_extension
self.__target_extension = reader.extension()

@property
def Reader(self):
Expand Down
14 changes: 6 additions & 8 deletions arekit/contrib/utils/io_utils/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,21 @@ class SamplesIO(BaseSamplesIO):
Samples required for machine learning training/inferring.
"""

def __init__(self, target_dir, writer=None, reader=None, prefix="sample", target_extension=None):
def __init__(self, target_dir, writer=None, reader=None, prefix="sample"):
assert(isinstance(target_dir, str))
assert(isinstance(prefix, str))
assert(isinstance(writer, BaseWriter) or writer is None)
assert(isinstance(reader, BaseReader) or reader is None)
assert(isinstance(target_extension, str) or target_extension is None)
self.__target_dir = target_dir
self.__prefix = prefix
self.__writer = writer
self.__reader = reader
self.__target_extension = target_extension

if target_extension is None:
if writer is not None:
self.__target_extension = writer.extension()
elif reader is not None:
self.__target_extension = reader.extension()
self.__target_extension = None
if writer is not None:
self.__target_extension = writer.extension()
elif reader is not None:
self.__target_extension = reader.extension()

# region public methods

Expand Down
10 changes: 4 additions & 6 deletions tests/contrib/utils/test_csv_stream_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ class TestStreamWriters(unittest.TestCase):

__output_dir = join(dirname(__file__), "out")

def __launch(self, writer, target_extention):
assert(isinstance(target_extention, str))
def __launch(self, writer):

text_b_template = '{subject} к {object} в контексте : << {context} >>'

Expand All @@ -49,7 +48,7 @@ def __launch(self, writer, target_extention):
label_provider=MultipleLabelProvider(SentimentLabelScaler()),
text_provider=text_provider)

samples_io = SamplesIO(self.__output_dir, writer, target_extension=target_extention)
samples_io = SamplesIO(self.__output_dir, writer)

pipeline_item = BertExperimentInputSerializerPipelineItem(
rows_provider=sample_rows_provider,
Expand Down Expand Up @@ -90,10 +89,9 @@ def __launch(self, writer, target_extention):
def test_csv_native(self):
""" Testing writing into CSV format
"""
return self.__launch(writer=NativeCsvWriter(), target_extention=".csv")
return self.__launch(writer=NativeCsvWriter())

def test_json_native(self):
""" Testing writing into CSV format
"""
return self.__launch(writer=OpenNREJsonWriter(text_columns=[BaseSingleTextProvider.TEXT_A]),
target_extention=".jsonl")
return self.__launch(writer=OpenNREJsonWriter(text_columns=[BaseSingleTextProvider.TEXT_A]))

0 comments on commit d33aa51

Please sign in to comment.