Skip to content

Commit

Permalink
#19 related refactoring. Using providers for candidates selection
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 25, 2023
1 parent 5fb29bf commit 6ae8c81
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 68 deletions.
Empty file added core/candidates/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions core/candidates/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class CandidatesProvider(object):

def provide(self, speaker_id, label):
pass
111 changes: 111 additions & 0 deletions core/candidates/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
import math

from tqdm import tqdm

from utils_my import MyAPI
from utils import CACHE_DIR

from sentence_transformers import SentenceTransformer
from core.candidates.base import CandidatesProvider


class ALOHANegBasedClusteringProvider(CandidatesProvider):
""" Cluster based approach.
Every cluster provides list of positive and negative characters.
For candidates we consider utterances from "Negative" characters.
We also select the most relevant, which makes our interest in embedded vectors for the related utterances.
We consider the same "random" selection approach from the ALOHA paper:
https://arxiv.org/pdf/1910.08293.pdf
"""

def __init__(self, dataset_filepath, cluster_filepath, limit_per_char,
candidates_limit, closest_candidates_limit=100):
assert(isinstance(dataset_filepath, str))
assert(isinstance(cluster_filepath, str))
self.__candidates_limit = candidates_limit
self.__closest_candidates_limit = closest_candidates_limit
self.__neg_clusters_per_speaker = self.__read_cluster(cluster_filepath)
self.__candidates_per_speaker = self.__create_dict(
dataset_filepath=dataset_filepath, limit_per_char=limit_per_char)

self.__model = SentenceTransformer('all-mpnet-base-v2', cache_folder=CACHE_DIR)

@staticmethod
def __read_cluster(cluster_filepath):
neg_speakers = {}
with open(cluster_filepath, "r") as f:
for line in f.readlines():
data = json.loads(line)
speaker_id = data["speaker_id"]
ids = data["neg"]
neg_speakers[speaker_id] = ids
return neg_speakers

@staticmethod
def __create_dict(dataset_filepath=None, limit_per_char=100):
assert (isinstance(limit_per_char, int) and limit_per_char > 0)

lines = []

candidates = {}
for args in MyAPI.read_dataset(keep_usep=False, split_meta=True, dataset_filepath=dataset_filepath):
if args is None:
lines.clear()
continue

lines.append(args)

if len(lines) < 2:
continue

# Here is type of data we interested in.
speaker = args[0]

if speaker not in candidates:
candidates[speaker] = []
target = candidates[speaker]

if len(target) == limit_per_char:
# Do not register the candidate.
continue

# Consider the potential candidate.
target.append(args[1])

return candidates

@staticmethod
def cosine_similarity(v1, v2):
sumxx, sumxy, sumyy = 0, 0, 0
for i in range(len(v1)):
x = v1[i];
y = v2[i]
sumxx += x * x
sumyy += y * y
sumxy += x * y
return sumxy / math.sqrt(sumxx * sumyy)

def provide(self, speaker_id, label):

# Compose list of the NON-relevant candidates.
neg_candidates = []
for s_id in self.__neg_clusters_per_speaker[speaker_id][:10]:
neg_candidates.extend(self.__candidates_per_speaker[s_id])

# Calculate embedding vectors.
vectors = []
for c in neg_candidates:
vectors.append(self.__model.encode(c))

label_vector = self.__model.encode(label)
vvv = [(i, self.cosine_similarity(label_vector, v)) for i, v in enumerate(vectors)]
most_similar_first = sorted(vvv, key=lambda item: item[1], reverse=True)

ordered_neg_candidates = [neg_candidates[i] for i, _ in most_similar_first]
selected = ordered_neg_candidates[:self.__closest_candidates_limit]

if label in selected:
selected.remove(label)

return selected[:self.__candidates_limit]
67 changes: 67 additions & 0 deletions core/candidates/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import random
from core.candidates.base import CandidatesProvider
from utils_my import MyAPI


class SameBookRandomCandidatesProvider(CandidatesProvider):
""" Random candidates selection from the dataset.
We consider the same "random" selection approach from the ALOHA paper:
https://arxiv.org/pdf/1910.08293.pdf
"""

def __init__(self, dataset_filepath, candidates_limit, candidates_per_book):
self.__candidates_limit = candidates_limit
self.__candidates_per_book = self.__create_dict(dataset_filepath=dataset_filepath,
limit_per_book=candidates_per_book)

@staticmethod
def speaker_to_book_id(speaker_id):
return int(speaker_id.split('_')[0])

@staticmethod
def __create_dict(dataset_filepath, limit_per_book):
assert (isinstance(limit_per_book, int) and limit_per_book > 0)

lines = []

candidates = {}
for args in MyAPI.read_dataset(keep_usep=False, split_meta=True, dataset_filepath=dataset_filepath):
if args is None:
lines.clear()
continue

lines.append(args)

if len(lines) < 2:
continue

# Here is type of data we interested in.
speaker_id = args[0]
book_id = SameBookRandomCandidatesProvider.speaker_to_book_id(speaker_id)
if book_id not in candidates:
candidates[book_id] = []

target = candidates[book_id]

if len(target) == limit_per_book:
# Do not register the candidate.
continue

# Consider the potential candidate.
target.append(args[1])

return candidates

def provide(self, speaker_id, label):
# pick a copy of the candidates.
book_id = SameBookRandomCandidatesProvider.speaker_to_book_id(speaker_id)
related = list(iter(self.__candidates_per_book[book_id]))
# remove already labeled candidate.
if label in related:
related.remove(label)
# shuffle candidates.
random.shuffle(related)
# select the top of the shuffled.
return related[:self.__candidates_limit]


39 changes: 0 additions & 39 deletions core/utils_parlai_facebook_formatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
from utils_my import MyAPI


def format_episode(request, response, candidates, resp_persona_traits=None, resp_persona_prefix="", seed=42):
Expand Down Expand Up @@ -43,41 +42,3 @@ def __fn(a):
lines.append("\t".join([text, labels, reward, label_candidates]))

return "\n".join(["{} {}".format(i+1, l) for i, l in enumerate(lines)])


def create_candidates_dict(dataset_filepath=None, limit_per_book=1000):
""" Random candidates selection from the dataset.
We consider the same "random" selection approach from the ALOHA paper:
https://arxiv.org/pdf/1910.08293.pdf
"""
assert(isinstance(limit_per_book, int) and limit_per_book > 0)

lines = []

candidates = {}
for args in MyAPI.read_dataset(keep_usep=False, split_meta=True, dataset_filepath=dataset_filepath):
if args is None:
lines.clear()
continue

lines.append(args)

if len(lines) < 2:
continue

# Here is type of data we interested in.
speaker = args[0]
book_id = int(speaker.split('_')[0])
if book_id not in candidates:
candidates[book_id] = []

target = candidates[book_id]

if len(target) == limit_per_book:
# Do not register the candidate.
continue

# Consider the potential candidate.
target.append(args[1])

return candidates
9 changes: 0 additions & 9 deletions my_s5_parlai_dataset_build_candidates.py

This file was deleted.

39 changes: 19 additions & 20 deletions my_s5_parlai_dataset_convert.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import random

import zipstream

from core.utils_parlai_facebook_formatter import create_candidates_dict, format_episode
from core.candidates.base import CandidatesProvider
from core.candidates.clustering import ALOHANegBasedClusteringProvider
from core.candidates.default import SameBookRandomCandidatesProvider
from core.utils_parlai_facebook_formatter import format_episode
from utils_ceb import CEBApi
from utils_my import MyAPI


def iter_dataset_lines(dataset_source, traits_func, candidates_dict, candidates_limit, desc=None):
def iter_dataset_lines(dataset_source, traits_func, candidates_provider, candidates_limit, desc=None):
assert(isinstance(dataset_source, str))
assert(callable(traits_func))
assert(isinstance(candidates_dict, dict) or candidates_dict is None)
assert(isinstance(candidates_provider, CandidatesProvider) or candidates_provider is None)
assert(isinstance(candidates_limit, int))

dialog = []
Expand All @@ -33,21 +34,13 @@ def iter_dataset_lines(dataset_source, traits_func, candidates_dict, candidates_
if len(dialog) < 2:
continue

book_id = int(speaker_id.split('_')[0])

assert(len(dialog) == len(speaker_ids) == 2)

candidates = [dialog[1]]
if candidates_dict is not None:
# pick a copy of the candidates.
related = list(iter(candidates_dict[book_id]))
# remove already labeled candidate.
if candidates[0] in related:
related.remove(candidates[0])
# shuffle candidates.
random.shuffle(related)
# select the top of the shuffled.
candidates.extend(related[:candidates_limit])
label = dialog[1]
if candidates_provider is not None:
candidates = candidates_provider.provide(speaker_id=speaker_id, label=label)
else:
candidates = [label]

yield format_episode(request=dialog[0],
response=dialog[1],
Expand Down Expand Up @@ -82,7 +75,13 @@ def iter_dataset_lines(dataset_source, traits_func, candidates_dict, candidates_

candidates_provider = {
#"_no-cands": None,
"": create_candidates_dict(dataset_filepath=my_api.dataset_filepath, limit_per_book=1000),
"": SameBookRandomCandidatesProvider(candidates_per_book=1000,
candidates_limit=MyAPI.dataset_candidates_limit,
dataset_filepath=MyAPI.dataset_filepath),
# "clustered": ALOHANegBasedClusteringProvider(limit_per_char=100,
# candidates_limit=MyAPI.dataset_candidates_limit,
# dataset_filepath=MyAPI.dataset_filepath,
# cluster_filepath=MyAPI.hla_cluster_config)
}

for data_fold_type, data_fold_source in dataset_filepaths.items():
Expand All @@ -93,7 +92,7 @@ def iter_dataset_lines(dataset_source, traits_func, candidates_dict, candidates_
data_it = iter_dataset_lines(
dataset_source=data_fold_source,
traits_func=traits_func,
candidates_dict=candidates_dict,
candidates_provider=candidates_dict,
candidates_limit=MyAPI.dataset_candidates_limit,
desc=filename)

Expand Down
11 changes: 11 additions & 0 deletions test/test_candidate_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from core.candidates.clustering import ALOHANegBasedClusteringProvider
from utils_my import MyAPI


provider = ALOHANegBasedClusteringProvider(
limit_per_char=100,
candidates_limit=MyAPI.dataset_candidates_limit,
dataset_filepath=MyAPI.dataset_filepath,
cluster_filepath=MyAPI.speaker_clusters_path)

r = provider.provide(speaker_id="55122_0", label="Don't want to talk")

0 comments on commit 6ae8c81

Please sign in to comment.