Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
handle distributed (#4023)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster committed Sep 16, 2021
1 parent cf5cb1a commit 6f16710
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions projects/blenderbot2/agents/blenderbot2.py
Expand Up @@ -47,6 +47,7 @@
)
from .sub_modules import RetrievalType, KnowledgeAccessMethod
from parlai.agents.fid.fid import SearchQuerySearchEngineFiDAgent
from parlai.utils.fsdp import is_fsdp


ZOO_QUERY_GENERATOR = 'zoo:blenderbot2/query_generator/model'
Expand Down Expand Up @@ -389,6 +390,13 @@ def rag_model_type(self, model: str):
self._rag_model_type = model
self._rag_model_interface = RAG_MODELS[model](self.opt, self.NULL_IDX)

@property
def model_api(self) -> BlenderBot2RagModel:
if hasattr(self.model, 'module') and not is_fsdp(self.model):
return self.model.module
else:
return self.model

def build_model(self) -> BlenderBot2RagModel:
"""
Build and return BlenderBot2RagModel.
Expand Down Expand Up @@ -507,7 +515,7 @@ def _set_query_vec(self, observation: Message) -> Message:
)
if self.add_person_tokens:
query_str = self._remove_person_tokens(query_str)
observation['query_vec'] = self.model.tokenize_query(query_str)
observation['query_vec'] = self.model_api.tokenize_query(query_str)
return observation

def _set_memory_vec(self, observation: Message) -> Message:
Expand Down Expand Up @@ -541,7 +549,7 @@ def _set_memory_vec(self, observation: Message) -> Message:
m for m in memories if self.opt['memory_extractor_phrase'] in m
]
if memories:
mem_vecs = [self.model.tokenize_memory(mem) for mem in memories]
mem_vecs = [self.model_api.tokenize_memory(mem) for mem in memories]

observation['memory_vec'] = mem_vecs
return observation
Expand All @@ -565,7 +573,7 @@ def _set_query_generator_vec(self, observation: Message) -> Message:
KnowledgeAccessMethod.CLASSIFY,
KnowledgeAccessMethod.SEARCH_ONLY,
]
and self.model.has_query_generator()
and self.model_api.has_query_generator()
):
query_generator_input = observation[self.opt['query_generator_key']]
if self.opt['query_generator_ignore_phrase']:
Expand All @@ -578,7 +586,7 @@ def _set_query_generator_vec(self, observation: Message) -> Message:
query_generator_input = self._remove_person_tokens(
query_generator_input
)
query_generator_vec = self.model.tokenize_query_generator_input(
query_generator_vec = self.model_api.tokenize_query_generator_input(
query_generator_input
)

Expand Down Expand Up @@ -661,7 +669,7 @@ def _set_memory_decoder_vec(self, observation: Message) -> Message:
KnowledgeAccessMethod.CLASSIFY,
KnowledgeAccessMethod.MEMORY_ONLY,
]
and self.model.has_memory_decoder()
and self.model_api.has_memory_decoder()
):
memory_decoder_input = observation[self.opt['memory_decoder_key']]
if self.opt['memory_decoder_ignore_phrase']:
Expand All @@ -678,7 +686,7 @@ def _set_memory_decoder_vec(self, observation: Message) -> Message:
for t in tt.split('\n')
]
memory_decoder_vec = [
self.model.tokenize_memory_decoder_input(i) for i in conv_lines
self.model_api.tokenize_memory_decoder_input(i) for i in conv_lines
]

observation['memory_decoder_vec'] = memory_decoder_vec
Expand Down Expand Up @@ -782,10 +790,10 @@ def eval_step(self, batch):
output = super().eval_step(batch)
if output is None or not hasattr(self.model, 'retriever'):
return output
if hasattr(self.model.retriever, 'top_docs'):
output.top_docs = self.model.retriever.top_docs
if hasattr(self.model.retriever, 'search_queries'):
output.search_queries = self.model.retriever.search_queries
if hasattr(self.model_api.retriever, 'top_docs'):
output.top_docs = self.model_api.retriever.top_docs
if hasattr(self.model_api.retriever, 'search_queries'):
output.search_queries = self.model_api.retriever.search_queries
return output

def _model_input(
Expand Down Expand Up @@ -835,11 +843,11 @@ def compute_loss(
if (
KnowledgeAccessMethod(self.opt['knowledge_access_method'])
is KnowledgeAccessMethod.CLASSIFY
and self.model.has_query_generator()
and self.model_api.has_query_generator()
):
_scores, _preds, enc_state, *_ = output
_, _, input_turns_cnt, _, _ = enc_state
retrieval_type = self.model.get_retrieval_type()
retrieval_type = self.model_api.get_retrieval_type()
assert isinstance(retrieval_type, torch.Tensor)
if input_turns_cnt is not None:
new_ret_type = torch.zeros(input_turns_cnt.size(0))
Expand Down

0 comments on commit 6f16710

Please sign in to comment.