Skip to content

Commit

Permalink
#19 related. Refactoring dataset generation. Provide filtering rules …
Browse files Browse the repository at this point in the history
…for some combination modes of candidates and traits
  • Loading branch information
nicolay-r committed Jul 27, 2023
1 parent fa604e4 commit 3109ceb
Showing 1 changed file with 29 additions and 38 deletions.
67 changes: 29 additions & 38 deletions my_s5_parlai_dataset_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,29 @@
from utils_my import MyAPI


def iter_dataset_lines(dataset_source, traits_func, candidates_provider, candidates_limit, desc=None):
assert(isinstance(dataset_source, str))
def iter_formatted_dialog(dialogs_iter, traits_func, candidates_provider, candidates_limit):
assert(callable(traits_func))
assert(isinstance(candidates_provider, CandidatesProvider) or candidates_provider is None)
assert(isinstance(candidates_limit, int))

dialog = []
speaker_ids = []
for dialog in dialogs_iter:
assert(len(dialog) == 2)

read_dataset = MyAPI.read_dataset(
keep_usep=False, split_meta=True, dataset_filepath=dataset_source, desc=desc)
q_speaker_id, query = dialog[0]
r_speaker_id, label = dialog[1]

for args in read_dataset:

if args is None:
dialog.clear()
speaker_ids.clear()
continue

speaker_id = args[0]
speaker_ids.append(speaker_id)
dialog.append(args[1])

if len(dialog) < 2:
continue

assert(len(dialog) == len(speaker_ids) == 2)

label = dialog[1]
if candidates_provider is not None:
candidates = candidates_provider.provide_or_none(speaker_id=speaker_id, label=label)
candidates = candidates_provider.provide_or_none(speaker_id=r_speaker_id, label=label)
else:
candidates = [label]

if candidates is None:
continue

yield format_episode(request=dialog[0],
response=dialog[1],
yield format_episode(request=query,
response=label,
candidates=candidates,
resp_persona_traits=traits_func(speaker_ids[0], speaker_ids[1]),
resp_persona_traits=traits_func(q_speaker_id, r_speaker_id),
resp_persona_prefix=MyAPI.response_persona_prefix,
seed=MyAPI.candidates_and_traits_shuffle_seed).encode()
yield b"\n"
Expand All @@ -69,24 +51,29 @@ def iter_dataset_lines(dataset_source, traits_func, candidates_provider, candida
# genders = ceb_api.get_meta_gender()
speaker_spectrums = MyAPI.read_speaker_spectrums(MyAPI.spectrum_prompts_filepath)

TRAITS_NO = "original"
TRAITS_SPECTRUM = "spectrum"
traits_provider = {
"original": lambda your_id, partner_id: ["none"] * MyAPI.traits_per_character,
TRAITS_NO: lambda your_id, partner_id: ["none"] * MyAPI.traits_per_character,
# NOTE: In some cases (less than ~0.07%) speakers might be missed so we need to perform check.
"spectrums": lambda your_id, partner_id: speaker_spectrums[partner_id] if partner_id in speaker_spectrums
else traits_provider["original"](your_id, partner_id)
TRAITS_SPECTRUM: lambda your_id, partner_id: speaker_spectrums[partner_id] if partner_id in speaker_spectrums
else traits_provider[TRAITS_NO](your_id, partner_id)
}

CANDIDATES_UTTERANCE_ONLY = ""
CANDIDATES_HLA_CLUSTER = "clustered"
candidates_provider = {

#"_no-cands": None,

"": SameBookRandomCandidatesProvider(
CANDIDATES_UTTERANCE_ONLY: SameBookRandomCandidatesProvider(
iter_dialogs=MyAPI.iter_dataset_as_dialogs(
MyAPI.read_dataset(keep_usep=False, split_meta=True, dataset_filepath=MyAPI.dataset_filepath)
),
candidates_per_book=1000,
candidates_limit=MyAPI.dataset_candidates_limit),

"_clustered": ALOHANegBasedClusteringProvider(
"clustered": ALOHANegBasedClusteringProvider(
candidates_limit=MyAPI.dataset_candidates_limit,
neg_speakers_limit=MyAPI.neg_set_speakers_limit,
embedding_model_name=MyAPI.utterance_embedding_model_name,
Expand All @@ -100,17 +87,21 @@ def iter_dataset_lines(dataset_source, traits_func, candidates_provider, candida
for candidates_type, candidates_dict in candidates_provider.items():

# This type does not makes sense, so we skip such formatting.
if trait_type == "spectrum" and candidates_type == "":
if trait_type == TRAITS_NO and candidates_type == CANDIDATES_HLA_CLUSTER:
continue
if trait_type == TRAITS_SPECTRUM and candidates_type == CANDIDATES_UTTERANCE_ONLY:
continue

filename = '{}_{}{}.txt'.format(data_fold_type, trait_type, candidates_type)
filename = '{}.txt'.format("_".join([data_fold_type, trait_type, candidates_type]))

data_it = iter_dataset_lines(
dataset_source=data_fold_source,
data_it = iter_formatted_dialog(
dialogs_iter=MyAPI.iter_dataset_as_dialogs(
MyAPI.read_dataset(
keep_usep=False, split_meta=True, dataset_filepath=data_fold_source,
desc=filename)),
traits_func=traits_func,
candidates_provider=candidates_dict,
candidates_limit=MyAPI.dataset_candidates_limit,
desc=filename)
candidates_limit=MyAPI.dataset_candidates_limit)

z = zipstream.ZipFile()
z.write_iter(filename, data_it)
Expand Down

0 comments on commit 3109ceb

Please sign in to comment.