-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#19 related refactoring. Using providers for candidates selection
- Loading branch information
Showing
8 changed files
with
212 additions
and
68 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
class CandidatesProvider(object): | ||
|
||
def provide(self, speaker_id, label): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import json | ||
import math | ||
|
||
from tqdm import tqdm | ||
|
||
from utils_my import MyAPI | ||
from utils import CACHE_DIR | ||
|
||
from sentence_transformers import SentenceTransformer | ||
from core.candidates.base import CandidatesProvider | ||
|
||
|
||
class ALOHANegBasedClusteringProvider(CandidatesProvider): | ||
""" Cluster based approach. | ||
Every cluster provides list of positive and negative characters. | ||
For candidates we consider utterances from "Negative" characters. | ||
We also select the most relevant, which makes our interest in embedded vectors for the related utterances. | ||
We consider the same "random" selection approach from the ALOHA paper: | ||
https://arxiv.org/pdf/1910.08293.pdf | ||
""" | ||
|
||
def __init__(self, dataset_filepath, cluster_filepath, limit_per_char, | ||
candidates_limit, closest_candidates_limit=100): | ||
assert(isinstance(dataset_filepath, str)) | ||
assert(isinstance(cluster_filepath, str)) | ||
self.__candidates_limit = candidates_limit | ||
self.__closest_candidates_limit = closest_candidates_limit | ||
self.__neg_clusters_per_speaker = self.__read_cluster(cluster_filepath) | ||
self.__candidates_per_speaker = self.__create_dict( | ||
dataset_filepath=dataset_filepath, limit_per_char=limit_per_char) | ||
|
||
self.__model = SentenceTransformer('all-mpnet-base-v2', cache_folder=CACHE_DIR) | ||
|
||
@staticmethod | ||
def __read_cluster(cluster_filepath): | ||
neg_speakers = {} | ||
with open(cluster_filepath, "r") as f: | ||
for line in f.readlines(): | ||
data = json.loads(line) | ||
speaker_id = data["speaker_id"] | ||
ids = data["neg"] | ||
neg_speakers[speaker_id] = ids | ||
return neg_speakers | ||
|
||
@staticmethod | ||
def __create_dict(dataset_filepath=None, limit_per_char=100): | ||
assert (isinstance(limit_per_char, int) and limit_per_char > 0) | ||
|
||
lines = [] | ||
|
||
candidates = {} | ||
for args in MyAPI.read_dataset(keep_usep=False, split_meta=True, dataset_filepath=dataset_filepath): | ||
if args is None: | ||
lines.clear() | ||
continue | ||
|
||
lines.append(args) | ||
|
||
if len(lines) < 2: | ||
continue | ||
|
||
# Here is type of data we interested in. | ||
speaker = args[0] | ||
|
||
if speaker not in candidates: | ||
candidates[speaker] = [] | ||
target = candidates[speaker] | ||
|
||
if len(target) == limit_per_char: | ||
# Do not register the candidate. | ||
continue | ||
|
||
# Consider the potential candidate. | ||
target.append(args[1]) | ||
|
||
return candidates | ||
|
||
@staticmethod | ||
def cosine_similarity(v1, v2): | ||
sumxx, sumxy, sumyy = 0, 0, 0 | ||
for i in range(len(v1)): | ||
x = v1[i]; | ||
y = v2[i] | ||
sumxx += x * x | ||
sumyy += y * y | ||
sumxy += x * y | ||
return sumxy / math.sqrt(sumxx * sumyy) | ||
|
||
def provide(self, speaker_id, label): | ||
|
||
# Compose list of the NON-relevant candidates. | ||
neg_candidates = [] | ||
for s_id in self.__neg_clusters_per_speaker[speaker_id][:10]: | ||
neg_candidates.extend(self.__candidates_per_speaker[s_id]) | ||
|
||
# Calculate embedding vectors. | ||
vectors = [] | ||
for c in neg_candidates: | ||
vectors.append(self.__model.encode(c)) | ||
|
||
label_vector = self.__model.encode(label) | ||
vvv = [(i, self.cosine_similarity(label_vector, v)) for i, v in enumerate(vectors)] | ||
most_similar_first = sorted(vvv, key=lambda item: item[1], reverse=True) | ||
|
||
ordered_neg_candidates = [neg_candidates[i] for i, _ in most_similar_first] | ||
selected = ordered_neg_candidates[:self.__closest_candidates_limit] | ||
|
||
if label in selected: | ||
selected.remove(label) | ||
|
||
return selected[:self.__candidates_limit] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import random | ||
from core.candidates.base import CandidatesProvider | ||
from utils_my import MyAPI | ||
|
||
|
||
class SameBookRandomCandidatesProvider(CandidatesProvider): | ||
""" Random candidates selection from the dataset. | ||
We consider the same "random" selection approach from the ALOHA paper: | ||
https://arxiv.org/pdf/1910.08293.pdf | ||
""" | ||
|
||
def __init__(self, dataset_filepath, candidates_limit, candidates_per_book): | ||
self.__candidates_limit = candidates_limit | ||
self.__candidates_per_book = self.__create_dict(dataset_filepath=dataset_filepath, | ||
limit_per_book=candidates_per_book) | ||
|
||
@staticmethod | ||
def speaker_to_book_id(speaker_id): | ||
return int(speaker_id.split('_')[0]) | ||
|
||
@staticmethod | ||
def __create_dict(dataset_filepath, limit_per_book): | ||
assert (isinstance(limit_per_book, int) and limit_per_book > 0) | ||
|
||
lines = [] | ||
|
||
candidates = {} | ||
for args in MyAPI.read_dataset(keep_usep=False, split_meta=True, dataset_filepath=dataset_filepath): | ||
if args is None: | ||
lines.clear() | ||
continue | ||
|
||
lines.append(args) | ||
|
||
if len(lines) < 2: | ||
continue | ||
|
||
# Here is type of data we interested in. | ||
speaker_id = args[0] | ||
book_id = SameBookRandomCandidatesProvider.speaker_to_book_id(speaker_id) | ||
if book_id not in candidates: | ||
candidates[book_id] = [] | ||
|
||
target = candidates[book_id] | ||
|
||
if len(target) == limit_per_book: | ||
# Do not register the candidate. | ||
continue | ||
|
||
# Consider the potential candidate. | ||
target.append(args[1]) | ||
|
||
return candidates | ||
|
||
def provide(self, speaker_id, label): | ||
# pick a copy of the candidates. | ||
book_id = SameBookRandomCandidatesProvider.speaker_to_book_id(speaker_id) | ||
related = list(iter(self.__candidates_per_book[book_id])) | ||
# remove already labeled candidate. | ||
if label in related: | ||
related.remove(label) | ||
# shuffle candidates. | ||
random.shuffle(related) | ||
# select the top of the shuffled. | ||
return related[:self.__candidates_limit] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from core.candidates.clustering import ALOHANegBasedClusteringProvider | ||
from utils_my import MyAPI | ||
|
||
|
||
provider = ALOHANegBasedClusteringProvider( | ||
limit_per_char=100, | ||
candidates_limit=MyAPI.dataset_candidates_limit, | ||
dataset_filepath=MyAPI.dataset_filepath, | ||
cluster_filepath=MyAPI.speaker_clusters_path) | ||
|
||
r = provider.provide(speaker_id="55122_0", label="Don't want to talk") |