Skip to content

Commit

Permalink
#388 related
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 22, 2022
1 parent 8890b15 commit 2c6d55f
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 24 deletions.
6 changes: 0 additions & 6 deletions arekit/contrib/networks/core/embedding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@ class BaseEmbeddingIO(object):
""" API for loading and saving embedding and vocabulary related data.
"""

def save_vocab(self, data, data_folding):
raise NotImplementedError()

def load_vocab(self, data_folding):
raise NotImplementedError()

def save_embedding(self, data, data_folding):
raise NotImplementedError()

Expand Down
5 changes: 3 additions & 2 deletions arekit/contrib/utils/io_utils/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from arekit.contrib.networks.core.embedding_io import BaseEmbeddingIO
from arekit.contrib.utils.io_utils.utils import check_targets_existence
from arekit.contrib.utils.np_utils.embedding import NpzEmbeddingHelper
from arekit.contrib.utils.np_utils.vocab import VocabRepositoryUtils
from arekit.contrib.utils.utils_folding import experiment_iter_index


Expand All @@ -28,11 +29,11 @@ def __init__(self, target_dir):
def save_vocab(self, data, data_folding):
assert(isinstance(data_folding, BaseDataFolding))
target = self.__get_default_vocab_filepath(data_folding)
return NpzEmbeddingHelper.save_vocab(data=data, target=target)
return VocabRepositoryUtils.save(data=data, target=target)

def load_vocab(self, data_folding):
source = self.___get_vocab_source(data_folding)
return NpzEmbeddingHelper.load_vocab(source)
return dict(VocabRepositoryUtils.load(source))

def save_embedding(self, data, data_folding):
assert(isinstance(data_folding, BaseDataFolding))
Expand Down
15 changes: 0 additions & 15 deletions arekit/contrib/utils/np_utils/embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import logging

import numpy as np

from arekit.contrib.utils.np_utils.npz_utils import NpzRepositoryUtils

logger = logging.getLogger(__name__)
Expand All @@ -16,22 +14,9 @@ def save_embedding(data, target):
logger.info("Saving embedding [size={shape}]: {filepath}".format(shape=data.shape,
filepath=target))

@staticmethod
def save_vocab(data, target):
logger.info("Saving vocabulary [size={size}]: {filepath}".format(size=len(data),
filepath=target))
np.savez(target, data)

@staticmethod
def load_embedding(source):
embedding = NpzRepositoryUtils.load(source)
logger.info("Embedding read [size={size}]: {filepath}".format(size=embedding.shape,
filepath=source))
return embedding

@staticmethod
def load_vocab(source):
vocab = dict(NpzRepositoryUtils.load(source))
logger.info("Vocabulary read [size={size}]: {filepath}".format(size=len(vocab),
filepath=source))
return vocab
10 changes: 9 additions & 1 deletion arekit/contrib/utils/np_utils/vocab.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import logging

import numpy as np

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class VocabRepositoryUtils(object):

@staticmethod
def save(data, target):
logger.info("Saving vocabulary [size={size}]: {filepath}".format(size=len(data), filepath=target))
np.savetxt(target, data, fmt='%s')

@staticmethod
def load(source):
return np.loadtxt(source, dtype=str)
vocab = np.loadtxt(source, dtype=str)
logger.info("Loading vocabulary [size={size}]: {filepath}".format(size=len(vocab), filepath=source))
return vocab

0 comments on commit 2c6d55f

Please sign in to comment.