Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[BB3] Option to ignore in session memories (#4753)
Browse files Browse the repository at this point in the history
* ignore in session memories

* tests

* update test

* remove commas
  • Loading branch information
klshuster committed Aug 18, 2022
1 parent e215048 commit e3eff5b
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 31 deletions.
3 changes: 2 additions & 1 deletion parlai/opt_presets/gen/opt_bb3.opt
Expand Up @@ -192,5 +192,6 @@
"include_prompt": false,
"knowledge_chunk_size": 100,
"max_prompt_len": 1912,
"all_vanilla_prompt": false
"all_vanilla_prompt": false,
"ignore_in_session_memories_mkm": false
}
3 changes: 2 additions & 1 deletion parlai/opt_presets/gen/opt_pt.opt
Expand Up @@ -192,5 +192,6 @@
"include_prompt": true,
"knowledge_chunk_size": 100,
"max_prompt_len": 1912,
"all_vanilla_prompt": false
"all_vanilla_prompt": false,
"ignore_in_session_memories_mkm": false
}
36 changes: 27 additions & 9 deletions projects/bb3/agents/opt_bb3_agent.py
Expand Up @@ -153,6 +153,13 @@ def add_cmdline_args(
help='Number of times to retry on API request failures (< 0 for unlimited retry).',
)
parser.add_argument('--metaseq-server-timeout', default=20.0, type=float)
parser.add_argument(
'--ignore-in-session-memories-mkm',
type='bool',
default=False,
help='If true, we do not look at the in-session memories when '
'generating from the MKM',
)
return parser

def __init__(self, opt, shared=None):
Expand All @@ -173,6 +180,9 @@ def __init__(self, opt, shared=None):
self.dictionary = top_agent.dictionary
# continue
self.max_prompt_len = opt.get('max_prompt_len', PROMPT.MAX_PROMPT_LEN)
self.ignore_in_session_memories_mkm = opt.get(
'ignore_in_session_memories_mkm', False
)
self.search_agent = SearchAgent(
{
'server': self.opt.get('search_server', 'default'),
Expand Down Expand Up @@ -415,10 +425,10 @@ def batch_act_knowledge(
for module in Module:
obs = all_obs[module]
if module is Module.MEMORY_KNOWLEDGE and i in memory_indices:
memories = MemoryUtils.get_available_memory(
all_obs['raw'], self.dictionary
memories = MemoryUtils.maybe_reduce_memories(
all_obs['raw']['text'], available_memory[i], self.dictionary
)
memories = '\n'.join(available_memory[i])
memories = '\n'.join(memories)
new_prompt = self._check_and_limit_len(
obs['prompt'].replace(module.opt_pre_context_tok(), memories)
)
Expand Down Expand Up @@ -772,7 +782,15 @@ def batch_act(
for _ in range(len(observations))
]
# Step 1: determine whether we're searching or accessing memory
available_memory = [o['raw']['memories'] for o in observations]
all_memory = [o['raw']['memories'] for o in observations]
available_memory = [
MemoryUtils.get_available_memories(
o['raw']['memories'],
o['raw']['in_session_memories'],
self.ignore_in_session_memories_mkm,
)
for o in observations
]

batch_reply_sdm, search_indices = self.batch_act_decision(
observations,
Expand Down Expand Up @@ -866,7 +884,7 @@ def batch_act(
batch_reply_mgm_partner,
batch_reply_knowledge,
batch_reply_dialogue,
available_memory,
all_memory,
)
for i, reply in enumerate(batch_reply_final):
reply.force_set('id', 'BlenderBot3')
Expand Down Expand Up @@ -900,8 +918,8 @@ def self_observe(self, self_message: Message):
memory_candidate,
MemoryUtils.get_memory_prefix(person, self.MODEL_TYPE),
):
self.memories.append(
MemoryUtils.add_memory_prefix(
memory_candidate, person, self.MODEL_TYPE
)
memory_to_add = MemoryUtils.add_memory_prefix(
memory_candidate, person, self.MODEL_TYPE
)
self.memories.append(memory_to_add)
self.in_session_memories.add(memory_to_add)
18 changes: 10 additions & 8 deletions projects/bb3/agents/r2c2_bb3_agent.py
Expand Up @@ -421,6 +421,7 @@ def __init__(self, opt, shared=None):
self.agents[Module.SEARCH_KNOWLEDGE] = agent

self.memories = []
self.in_session_memories = set()
self.search_knowledge_responses = ['__SILENCE__']
self.memory_knowledge_responses = ['__SILENCE__']
self.contextual_knowledge_responses = ['__SILENCE__']
Expand Down Expand Up @@ -529,6 +530,7 @@ def reset(self, clones_only: bool = False):
self.contextual_knowledge_responses = ['__SILENCE__']
self.memory_knowledge_responses = ['__SILENCE__']
self.memories = []
self.in_session_memories = set()

def _construct_subagent_opts(self, opt: Opt):
"""
Expand Down Expand Up @@ -657,6 +659,7 @@ def observe(self, observation: Message) -> Dict[Module, Message]:

raw_observation = copy.deepcopy(observation)
raw_observation['memories'] = self.memories
raw_observation['in_session_memories'] = self.in_session_memories
observations['raw'] = raw_observation

if observation.get('episode_done'):
Expand Down Expand Up @@ -1402,15 +1405,14 @@ def self_observe(self, self_message: Message):
),
MemoryUtils.get_memory_prefix(person, self.MODEL_TYPE),
):
self.memories.append(
MemoryUtils.add_memory_prefix(
self_message[
f'{Module.MEMORY_GENERATOR.message_name()}_{person}'
],
person,
self.MODEL_TYPE,
)
memory_to_add = MemoryUtils.add_memory_prefix(
self_message[f'{Module.MEMORY_GENERATOR.message_name()}_{person}'],
person,
self.MODEL_TYPE,
)

self.memories.append(memory_to_add)
self.in_session_memories.add(memory_to_add)
observation = {
'text': clean_text(
self.agents[Module.SEARCH_KNOWLEDGE].history.get_history_str() or ''
Expand Down
39 changes: 31 additions & 8 deletions projects/bb3/agents/utils.py
Expand Up @@ -13,7 +13,7 @@
import os
import string
import time
from typing import List, Tuple, Optional, Dict, Any
from typing import List, Tuple, Optional, Dict, Any, Set

from parlai.agents.ir_baseline.ir_baseline import score_match, MaxPriorityQueue
from parlai.core.dict import DictionaryAgent
Expand Down Expand Up @@ -332,8 +332,8 @@ def _build_query_representation(
return rep

@staticmethod
def get_available_memory(
observation: Message, dictionary: DictionaryAgent
def maybe_reduce_memories(
text: str, memories: List[str], dictionary: DictionaryAgent
) -> List[str]:
"""
TFIDF-Match memories with the textual input to reduce num memories.
Expand All @@ -347,17 +347,40 @@ def get_available_memory(
return - potentially shortened - list of memories
"""
new_memories = []
mems = observation['memories']
if not mems or len(mems) < 32: # 512 / 16, assuming 16 tokens max per memory
return mems
if (
not memories or len(memories) < 32
): # 512 / 16, assuming 16 tokens max per memory
return memories
mpq = MaxPriorityQueue(1000)
query = MemoryUtils._build_query_representation(observation['text'], dictionary)
for m in mems:
query = MemoryUtils._build_query_representation(text, dictionary)
for m in memories:
score = score_match(query, m, 0, dictionary)
mpq.add(m, score)
new_memories = list(reversed(mpq))[:32]
return new_memories

@staticmethod
def get_available_memories(
memories: List[str],
in_session_memories: Set[str],
ignore_in_session_memories: bool,
) -> List[str]:
"""
Return available memories.
:param memories:
list of all memories
:param in_session_memories:
set of memories generated within the current conversation session
:param ignore_in_session_memories:
whether to ignore memories generated within the session
"""
return [
m
for m in memories
if m not in in_session_memories or not ignore_in_session_memories
]


#################
# OPT API UTILS #
Expand Down
74 changes: 70 additions & 4 deletions tests/nightly/gpu/test_bb3.py
Expand Up @@ -411,15 +411,81 @@ def test_memory_tfidf(self):
agent = create_agent(self.opt)
dictionary = agent.dictionary
memories = self.memories * 100
new_memories = MemoryUtils.get_available_memory(
{'text': 'I wish I could see my cats again!', 'memories': memories},
new_memories = MemoryUtils.maybe_reduce_memories(
'I wish I could see my cats again!',
memories,
dictionary,
)
assert "cats" in new_memories[0]
assert len(new_memories) <= 32
new_memories = MemoryUtils.get_available_memory(
{'text': 'I hope the horses are faster today!', 'memories': memories},
new_memories = MemoryUtils.maybe_reduce_memories(
'I hope the horses are faster today!',
memories,
dictionary,
)
assert "horses" in new_memories[0]
assert len(new_memories) <= 32


class TestIgnoreInSessionMemories(TestOptFtBase):
def test_in_session_memories(self):
opt = copy.deepcopy(self.opt)
opt['knowledge_conditioning'] = 'separate'
opt['override']['knowledge_conditioning'] = 'separate'
agent = create_agent(opt)
opt2 = copy.deepcopy(self.opt)
opt2['knowledge_conditioning'] = 'separate'
opt2['override']['knowledge_conditioning'] = 'separate'
opt2['ignore_in_session_memories_mkm'] = True
opt2['override']['ignore_in_session_memories_mkm'] = True
agent2 = create_agent(opt2)

# first, check with normal messages
agent1_acts = []
agent2_acts = []
for _ in range(5):
agent.observe(self.message)
agent2.observe(self.message)
agent1_acts.append(agent.act())
agent2_acts.append(agent2.act())

# ignore first message for agent1 since there aren't any memories
assert all(a[Module.MEMORY_DIALOGUE.message_name()] for a in agent1_acts[1:])
assert all(not a[Module.MEMORY_DIALOGUE.message_name()] for a in agent2_acts)

# Check that in session memories is strict subset of memories
# when using opening message
agent.reset()
original_memories = copy.deepcopy(self.memories)
agent.observe(self.opening_message)
agent.act()
assert all(m in agent.memories for m in agent.in_session_memories)
assert not any(m in agent.in_session_memories for m in agent.memories)

# set ignore in session memories to True; ensure that final returned memories
# still have all the memories, but that we don't use the memory module
agent.in_session_memories = set()
agent.ignore_in_session_memories_mkm = True
message = copy.deepcopy(self.message)
agent.observe(message)
act = agent.act()
assert all(
m in act['memories']
for m in original_memories + list(agent.in_session_memories)
)
assert len(agent.in_session_memories) == (
len(act['memories']) - len(original_memories)
)

def test_memory_utils(self):
new_memories = ['in session memory 1', 'in session memory 2']
memories = self.memories + new_memories
in_session_memories = set(new_memories)
available_memories = MemoryUtils.get_available_memories(
memories, in_session_memories, ignore_in_session_memories=False
)
assert available_memories == memories
available_memories = MemoryUtils.get_available_memories(
memories, in_session_memories, ignore_in_session_memories=True
)
assert available_memories == self.memories

0 comments on commit e3eff5b

Please sign in to comment.