Skip to content

Commit

Permalink
#503 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 14, 2023
1 parent 29de625 commit 219eeca
Show file tree
Hide file tree
Showing 14 changed files with 35 additions and 39 deletions.
28 changes: 0 additions & 28 deletions arekit/contrib/utils/data/ext.py

This file was deleted.

3 changes: 3 additions & 0 deletions arekit/contrib/utils/data/readers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
class BaseReader(object):

def extension(self):
raise NotImplementedError()

def read(self, target):
raise NotImplementedError()
3 changes: 3 additions & 0 deletions arekit/contrib/utils/data/readers/csv_pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, sep='\t', header='infer', compression='infer', encoding='utf-
if self.__col_types is None:
self.__col_types = dict()

def extension(self):
return ".tsv.gz"

def __from_csv(self, filepath):
pd = importlib.import_module("pandas")
return pd.read_csv(filepath,
Expand Down
3 changes: 3 additions & 0 deletions arekit/contrib/utils/data/readers/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

class JsonlReader(BaseReader):

def extension(self):
return ".jsonl"

def read(self, target):
rows = []
with open(target, "r") as f:
Expand Down
5 changes: 5 additions & 0 deletions arekit/contrib/utils/data/writers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
class BaseWriter(object):

def extension(self):
""" Expected output extension type.
"""
raise NotImplementedError()

def open_target(self, target):
pass

Expand Down
3 changes: 3 additions & 0 deletions arekit/contrib/utils/data/writers/csv_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def __init__(self, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL, hea
self.__header = header
self.__header_written = None

def extension(self):
return ".csv"

@staticmethod
def __iter_storage_column_names(storage):
""" Iter only those columns that existed in storage.
Expand Down
3 changes: 3 additions & 0 deletions arekit/contrib/utils/data/writers/csv_pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def __init__(self, write_header):
super(PandasCsvWriter, self).__init__()
self.__write_header = write_header

def extension(self):
return ".tsv.gz"

def write_all(self, storage, target):
assert(isinstance(storage, PandasBasedRowsStorage))
assert(isinstance(target, str))
Expand Down
3 changes: 3 additions & 0 deletions arekit/contrib/utils/data/writers/json_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self, text_columns, encoding="utf-8"):
self.__encoding = encoding
self.__target_f = None

def extension(self):
return ".jsonl"

@staticmethod
def __format_row(row, text_columns):
""" Formatting that is compatible with the OpenNRE.
Expand Down
5 changes: 4 additions & 1 deletion arekit/contrib/utils/data/writers/sqlite_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def __init__(self, table_name="contents"):
self.__need_init_table = True
self.__column_names = None

def extension(self):
return ".sqlite"

@staticmethod
def __iter_storage_column_names(storage):
""" Iter only those columns that existed in storage.
Expand All @@ -22,7 +25,7 @@ def __iter_storage_column_names(storage):
yield col_name

def open_target(self, target):
self.__conn = sqlite3.connect(target + ".sqlite")
self.__conn = sqlite3.connect(target)
self.__cur = self.__conn.cursor()

def commit_line(self, storage):
Expand Down
3 changes: 1 addition & 2 deletions arekit/contrib/utils/io_utils/opinions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from os.path import join

from arekit.contrib.utils.data.ext import create_reader_extension
from arekit.contrib.utils.data.readers.base import BaseReader
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.contrib.utils.io_utils.utils import filename_template
Expand All @@ -13,7 +12,7 @@ def __init__(self, target_dir, reader=None, prefix="opinion", target_extension="
self.__target_dir = target_dir
self.__prefix = prefix
self.__reader = reader
self.__target_extension = create_reader_extension(reader) \
self.__target_extension = reader.extension() \
if target_extension is None else target_extension

@property
Expand Down
5 changes: 2 additions & 3 deletions arekit/contrib/utils/io_utils/samples.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from os.path import join

from arekit.contrib.utils.data.ext import create_writer_extension, create_reader_extension
from arekit.contrib.utils.data.readers.base import BaseReader
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.contrib.utils.data.writers.base import BaseWriter
Expand Down Expand Up @@ -32,9 +31,9 @@ def __init__(self, target_dir, writer=None, reader=None, prefix="sample", target

if target_extension is None:
if writer is not None:
self.__target_extension = create_writer_extension(writer)
self.__target_extension = writer.extension()
elif reader is not None:
self.__target_extension = create_reader_extension(reader)
self.__target_extension = reader.extension()

# region public methods

Expand Down
4 changes: 2 additions & 2 deletions tests/tutorials/test_tutorial_pipeline_sampling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test(self):
text_provider=text_provider)

writer = PandasCsvWriter(write_header=True)
samples_io = SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz")
samples_io = SamplesIO(self.__output_dir, writer)

pipeline_item = BertExperimentInputSerializerPipelineItem(
rows_provider=rows_provider,
Expand Down Expand Up @@ -127,6 +127,6 @@ def test(self):
})

reader = PandasCsvReader()
source = join(self.__output_dir, "sample-train-0.tsv.gz")
source = join(self.__output_dir, "sample-train-0" + writer.extension())
storage = reader.read(source)
self.assertEqual(20, len(storage), "Amount of rows is non equal!")
4 changes: 2 additions & 2 deletions tests/tutorials/test_tutorial_pipeline_sampling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test(self):
ctx=ctx)

pipeline_item = NetworksInputSerializerPipelineItem(
samples_io=SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz"),
samples_io=SamplesIO(self.__output_dir, writer),
emb_io=NpEmbeddingIO(target_dir=self.__output_dir),
rows_provider=rows_provider,
save_labels_func=lambda data_type: data_type != DataType.Test,
Expand Down Expand Up @@ -131,6 +131,6 @@ def test(self):
})

reader = PandasCsvReader()
source = join(self.__output_dir, "sample-train-0.tsv.gz")
source = join(self.__output_dir, "sample-train-0" + writer.extension())
storage = reader.read(source)
self.assertEqual(20, len(storage), "Amount of rows is non equal!")
2 changes: 1 addition & 1 deletion tests/tutorials/test_tutorial_pipeline_sampling_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,6 @@ def test(self):
})

reader = PandasCsvReader()
source = join(self.__output_dir, "prompt-sample-train-0.tsv.gz")
source = join(self.__output_dir, "prompt-sample-train-0" + writer.extension())
storage = reader.read(source)
self.assertEqual(20, len(storage), "Amount of rows is non equal!")

0 comments on commit 219eeca

Please sign in to comment.