Skip to content

Commit

Permalink
fix: bug during w2v training with utf8 characters (#76)
Browse files Browse the repository at this point in the history
* Update base.py

* Update stream.py

* change single quote to double quote

* apply iid_max_col to byte length

* add test of stream for unicode case

* apply utf-8 to uid

* Remove unused numpy
  • Loading branch information
pakhy2380 committed Nov 22, 2023
1 parent 75c77f5 commit 3ddd9bf
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
4 changes: 2 additions & 2 deletions buffalo/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def _create_database(self, path, **kwargs):
iid_max_col = kwargs["iid_max_col"]
uid_max_col = kwargs["uid_max_col"]
idmap = f.create_group("idmap")
idmap.create_dataset("rows", (num_users,), dtype="S%s" % uid_max_col,
idmap.create_dataset("rows", (num_users,), dtype=h5py.string_dtype("utf-8", length=uid_max_col),
maxshape=(num_users,))
idmap.create_dataset("cols", (num_items,), dtype="S%s" % iid_max_col,
idmap.create_dataset("cols", (num_items,), dtype=h5py.string_dtype("utf-8", length=iid_max_col),
maxshape=(num_items,))
return f

Expand Down
16 changes: 8 additions & 8 deletions buffalo/data/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections import Counter

import h5py
import numpy as np
import psutil

from buffalo.data.base import Data, DataOption
Expand Down Expand Up @@ -84,7 +83,7 @@ def get_max_column_length(fname):
with open(fname) as fin:
max_col = 0
for l in fin:
max_col = max(max_col, len(l))
max_col = max(max_col, len(l.encode()))
return max_col
uid_path, iid_path, main_path = P["uid_path"], P["iid_path"], P["main_path"]
if uid_path:
Expand Down Expand Up @@ -121,7 +120,7 @@ def get_max_column_length(fname):
itemids = {iid.strip(): idx + 1 for idx, iid in enumerate(fin)}
else: # in case of item information is not given
itemids = {i: idx + 1 for idx, i in enumerate(itemids)}
iid_max_col = max(len(k) + 1 for k in itemids.keys())
iid_max_col = max(len(k.encode()) + 1 for k in itemids.keys())
num_items = len(itemids)

self.logger.info("Found %d unique itemids" % len(itemids))
Expand All @@ -138,17 +137,18 @@ def get_max_column_length(fname):
# if not given, assume id as is
if uid_path:
with open(uid_path) as fin:
idmap["rows"][:] = np.loadtxt(fin, dtype=f"S{uid_max_col}")
rows = [line.strip() for line in fin.readlines()]
idmap["rows"][:] = rows
else:
idmap["rows"][:] = np.array([str(i) for i in range(1, num_users + 1)],
dtype=f"S{uid_max_col}")
idmap["rows"][:] = [str(i) for i in range(1, num_users + 1)]
if iid_path:
with open(iid_path) as fin:
idmap["cols"][:] = np.loadtxt(fin, dtype=f"S{iid_max_col}")
cols = [line.strip() for line in fin.readlines()]
idmap["cols"][:] = cols
else:
cols = sorted(itemids.items(), key=lambda x: x[1])
cols = [k for k, _ in cols]
idmap["cols"][:] = np.array(cols, dtype=f"S{iid_max_col}")
idmap["cols"][:] = cols
except Exception as e:
self.logger.error("Cannot create db: %s" % (str(e)))
self.logger.error(traceback.format_exc())
Expand Down
27 changes: 27 additions & 0 deletions tests/data/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def setUpClass(cls):
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("""kim\nlee\npark""")
cls.uid_path = f.name
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("""사과 망고 망고 사과 파이 주스 콜라\n파이\n주스 콜라 포도""")
cls.unicode_main_path = f.name
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("""김씨\n이씨\n박씨""")
cls.unicode_uid_path = f.name
cls.temp_files = []

@classmethod
Expand Down Expand Up @@ -84,6 +90,27 @@ def test3_to_matrix(self):
data.sort()
self.assertEqual([uu for uu, _, _ in data], ["apple", "coke", "juice", "juice", "mango", "pie", "pie"])

def test4_unicode(self):
opt = StreamOptions().get_default_option()
opt.input.main = self.unicode_main_path
opt.input.uid = self.unicode_uid_path
mm = Stream(opt)
mm.create()
self.assertTrue(True)
db = mm.handle
if opt.data.sppmi:
self.assertEqual(sorted(db.keys()), sorted(["idmap", "rowwise", "colwise", "vali", "sppmi"]))
else:
self.assertEqual(sorted(db.keys()), sorted(["idmap", "rowwise", "colwise", "vali"]))
header = mm.get_header()
self.assertEqual(header["num_nnz"], 9) # due to validation samples
self.assertEqual(header["num_users"], 3)
self.assertEqual(header["num_items"], 6)

data = [(u, kk) for u, kk in mm.iterate(use_repr_name=True)]
self.assertEqual(len(data), 9)
self.assertEqual([kk for _, kk in data], ["사과", "망고", "망고", "사과", "파이", "주스", "파이", "주스", "콜라"])


if __name__ == "__main__":
unittest.main()

0 comments on commit 3ddd9bf

Please sign in to comment.