Skip to content

Commit

Permalink
#15 added
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 26, 2023
1 parent aae86f7 commit c51d66f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
10 changes: 5 additions & 5 deletions core/candidates/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ class ALOHANegBasedClusteringProvider(CandidatesProvider):
"""

def __init__(self, dataset_filepath, cluster_filepath, vectorized_utterances_filepath,
embedding_model_name='all-mpnet-base-v2', candidates_limit=20, closest_candidates_limit=100):
embedding_model_name='all-mpnet-base-v2', neg_speakers_limit=20, candidates_limit=20):
assert(isinstance(dataset_filepath, str))
assert(isinstance(cluster_filepath, str))
assert(isinstance(vectorized_utterances_filepath, str))

self.__candidates_limit = candidates_limit
self.__closest_candidates_limit = closest_candidates_limit
self.__neg_speakers_limit = neg_speakers_limit

self.__neg_clusters_per_speaker = self.__read_cluster(cluster_filepath)
self.__model = SentenceTransformer(embedding_model_name, cache_folder=CACHE_DIR)

Expand Down Expand Up @@ -56,7 +57,7 @@ def cosine_similarity(v1, v2):
def provide(self, speaker_id, label):

# Compose a SQL-request to obtain vectors and utterances.
neg_speakers = self.__neg_clusters_per_speaker[speaker_id]
neg_speakers = self.__neg_clusters_per_speaker[speaker_id][self.__neg_speakers_limit]

# Compose WHERE clause that filters the relevant speakers.
where_clause = 'speakerid in ({})'.format(",".join(['"{}"'.format(s) for s in neg_speakers]))
Expand All @@ -72,8 +73,7 @@ def provide(self, speaker_id, label):
vvv = [(i, self.cosine_similarity(label_vector, v)) for i, v in enumerate(vectors)]
most_similar_first = sorted(vvv, key=lambda item: item[1], reverse=True)

ordered_neg_candidates = [neg_candidates[i] for i, _ in most_similar_first]
selected = ordered_neg_candidates[:self.__closest_candidates_limit]
selected = [neg_candidates[i] for i, _ in most_similar_first]

if label in selected:
selected.remove(label)
Expand Down
10 changes: 6 additions & 4 deletions my_s5_parlai_dataset_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ def iter_dataset_lines(dataset_source, traits_func, candidates_provider, candida
"": SameBookRandomCandidatesProvider(candidates_per_book=1000,
candidates_limit=MyAPI.dataset_candidates_limit,
dataset_filepath=MyAPI.dataset_filepath),
# "clustered": ALOHANegBasedClusteringProvider(limit_per_char=100,
# candidates_limit=MyAPI.dataset_candidates_limit,
# dataset_filepath=MyAPI.dataset_filepath,
# cluster_filepath=MyAPI.hla_cluster_config)
"clustered": ALOHANegBasedClusteringProvider(
candidates_limit=MyAPI.dataset_candidates_limit,
neg_speakers_limit=MyAPI.neg_set_speakers_limit,
dataset_filepath=MyAPI.dataset_filepath,
cluster_filepath=MyAPI.speaker_clusters_path,
vectorized_utterances_filepath=MyAPI.dataset_responses_data_path)
}

for data_fold_type, data_fold_source in dataset_filepaths.items():
Expand Down
2 changes: 2 additions & 0 deletions utils_my.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class MyAPI:
hla_cluster_config = ClusterConfig(perc_cutoff=10, level2_limit=30, acceptable_overlap=10, weighted=False)
speaker_clusters_path = join(books_storage, "clusters.jsonl")
dataset_responses_data_path = join(__current_dir, "./data/ceb_books_annot/dataset_responses_data.sqlite")
neg_set_speakers_limit = 20 # The overall process might take so much time is what becomes a reason
# of this limit.

prefixes_storage = join(__current_dir, "./data/ceb_books_annot/prefixes")
# Dialogs with recognized speakers.
Expand Down

0 comments on commit c51d66f

Please sign in to comment.