Skip to content

Commit

Permalink
#26 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 1, 2023
1 parent 4695837 commit ab118d9
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 56 deletions.
9 changes: 9 additions & 0 deletions core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from math import ceil


def chunk_into_n(lst, n):
size = ceil(len(lst) / n)
return list(
map(lambda x: lst[x * size:x * size + size],
list(range(n)))
)
57 changes: 22 additions & 35 deletions my_s3_dataset_folding.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,42 @@
from collections import Counter

from core.utils import chunk_into_n
from utils import cat_files
from utils_my import MyAPI


def write_folded_dataset(k):
def write_folded_dataset(k, dialog_iter_func, speakers):
""" Folding with the even splits of the utterances.
"""
assert(isinstance(k, int))
assert(isinstance(speakers, list))
assert(callable(dialog_iter_func))

buffer = []
for fold_index in range(k):
speaker_ids_per_fold = chunk_into_n(speakers, n=k)

for fold_index in range(k):
fold_speakers_ids = set(speaker_ids_per_fold[fold_index])
with open(MyAPI.dataset_fold_filepath.format(fold_index=str(fold_index)), "w") as file:
for dialog in dialog_iter_func(fold_index):
speaker_id = dialog[1][0]
# Check whether it is a part of the current fold.
if speaker_id in fold_speakers_ids:
MyAPI.write_dataset_buffer(file=file, buffer=dialog)

partners_count = Counter()

lines_it = MyAPI.read_dataset(
dataset_filepath=MyAPI.dataset_filepath,
desc="Prepare for fold {}".format(fold_index))

for line in lines_it:

if line is None:
buffer.clear()
continue

s_name = MyAPI._get_meta(line)

buffer.append(line)

# response of the partner.
if len(buffer) == 2:

# Check whether it is a part of the current fold.
if partners_count[s_name] % k == fold_index:
MyAPI.write_dataset_buffer(file=file, buffer=buffer)

# Count the amount of partners.
partners_count[s_name] += 1
def dialog_iter_func(fold_index):
lines_it = MyAPI.read_dataset(
dataset_filepath=MyAPI.dataset_filepath, split_meta=True, desc="Prepare fold {}".format(fold_index))
return MyAPI.iter_dataset_as_dialogs(lines_it)


write_folded_dataset(k=MyAPI.dataset_folding_parts)
write_folded_dataset(k=MyAPI.dataset_folding_parts,
dialog_iter_func=dialog_iter_func,
speakers=MyAPI.read_speakers())

print("Original:")
print(MyAPI.check_speakers_count(dataset_filepath=MyAPI.dataset_filepath, pbar=False))
print(MyAPI.calc_speakers_count(dataset_filepath=MyAPI.dataset_filepath, pbar=False))
print("Folds:")
for i in range(MyAPI.dataset_folding_parts):
c = MyAPI.check_speakers_count(dataset_filepath=MyAPI.dataset_fold_filepath.format(fold_index=i), pbar=False)
print(c)
c = MyAPI.calc_speakers_count(dataset_filepath=MyAPI.dataset_fold_filepath.format(fold_index=i), pbar=False)
print(sum(c.values()))

# Merge foldings.
Expand All @@ -60,6 +48,5 @@ def write_folded_dataset(k):

print("---")
for i in ["train", "valid"]:
c = MyAPI.check_speakers_count(dataset_filepath=MyAPI.dataset_fold_filepath.format(fold_index=i), pbar=False)
print(c)
c = MyAPI.calc_speakers_count(dataset_filepath=MyAPI.dataset_fold_filepath.format(fold_index=i), pbar=False)
print(sum(c.values()))
8 changes: 8 additions & 0 deletions test/test_speakers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from core.utils import chunk_into_n
from utils_my import MyAPI

l = chunk_into_n(MyAPI.read_speakers(), n=5)
for i in l:
print(len(i))
print("----")
print(len(l))
36 changes: 15 additions & 21 deletions utils_my.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,17 @@ def write_speakers(self, speaker_names_list):
for speaker_name in speaker_names_list:
f.write("{}\n".format(speaker_name))

def read_speakers(self):
@staticmethod
def read_speakers(filepath=None):
assert(isinstance(filepath, str) or filepath is None)

filepath = MyAPI.filtered_speakers_filepath if filepath is None else filepath

speakers = []
with open(self.filtered_speakers_filepath, "r") as f:
with open(filepath, "r") as f:
for line in f.readlines():
speakers.append(line.strip())

return speakers

@staticmethod
Expand Down Expand Up @@ -328,30 +334,18 @@ def iter_dataset_as_dialogs(dataset_lines_iter):
lines.clear()

@staticmethod
def check_speakers_count(dataset_filepath, pbar=True):
def calc_speakers_count(dataset_filepath, pbar=True):
""" Folding with the even splits of the utterances.
"""
partners_count = Counter()

utt = []

args_it = MyAPI.read_dataset(
keep_usep=False, split_meta=True, dataset_filepath=dataset_filepath, pbar=pbar)

for args in args_it:

if args is None:
utt.clear()
continue

s_name, _ = args

utt.append(s_name)
dialogs_it = MyAPI.iter_dataset_as_dialogs(
MyAPI.read_dataset(keep_usep=False, split_meta=True,
dataset_filepath=dataset_filepath, pbar=pbar))

# response of the partner.
if len(utt) == 2:
# Count the amount of partners.
partners_count[s_name] += 1
for dialog in dialogs_it:
partner_id = dialog[1][0]
partners_count[partner_id] += 1

return partners_count

Expand Down

0 comments on commit ab118d9

Please sign in to comment.