Skip to content

Commit

Permalink
Merge pull request #853 from opentensor/BIT-523-synapse-callback-mess…
Browse files Browse the repository at this point in the history
…age-return

[BIT-523] Add message to synapse_callback return outputs
  • Loading branch information
opentaco committed Jul 27, 2022
2 parents df06000 + 21fc454 commit eba5bcd
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 58 deletions.
6 changes: 3 additions & 3 deletions bittensor/_axon/axon_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,10 @@ def default_forward_callback(self, inputs_x:torch.FloatTensor, synapses=[], hotk
synapse_check = self.synapse_checks(synapse, hotkey)

if synapse.synapse_type in self.synapse_callbacks and self.synapse_callbacks[synapse.synapse_type] != None and synapse_check:
model_output, response_tensor = self.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse, model_output)
message, model_output, response_tensor = self.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse, model_output)
response_tensors.append(response_tensor)
response_codes.append(bittensor.proto.ReturnCode.Success)
response_messages.append('Success')
response_messages.append('Success' if message is None else message)

elif not synapse_check:
response_tensors.append(None)
Expand Down Expand Up @@ -652,7 +652,7 @@ def default_backward_callback(self, inputs_x:torch.FloatTensor, grads_dy:torch.F
for index, synapse in enumerate(synapses):
try:
if synapse.synapse_type in self.synapse_callbacks and self.synapse_callbacks[synapse.synapse_type] != None:
model_output, response_tensor = self.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse)
message, model_output, response_tensor = self.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse)
torch.autograd.backward (
tensors = [ response_tensor ],
grad_tensors = [ grads_dy[index] ],
Expand Down
16 changes: 10 additions & 6 deletions bittensor/_neuron/text/core_server/nucleus_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def forward(self, inputs, tokenizer=None):
Decoded predictions of the next token in the sentence.
"""
decoded_targets = self.local_forward(inputs, tokenizer)[1]
message, model_output, decoded_targets = self.local_forward(inputs, tokenizer)[1]

shift_logits = decoded_targets[..., :-1, :].contiguous()
shift_labels = inputs[..., 1:].contiguous()
Expand Down Expand Up @@ -276,7 +276,7 @@ def local_forward(self, token_batch, tokenizer=None, encode_len=bittensor.__netw
attention_mask=tokens['attention_mask'],
output_hidden_states=True)

return model_output, model_output.logits
return None, model_output, model_output.logits

def encode_forward(self,inputs,tokenizer=None, model_output = None):
r""" Forward pass through the pretrained model and possible mappings between hidden units.
Expand Down Expand Up @@ -327,7 +327,7 @@ def encode_forward(self,inputs,tokenizer=None, model_output = None):
else:
encoded_hidden = self.mapping(down)

return model_output, encoded_hidden
return None, model_output, encoded_hidden

def encode_forward_causallm(self, token_batch, tokenizer=None, encode_len=bittensor.__network_dim__, model_output=None):
r""" Forward pass through the pretrained model and possible mappings between hidden units.
Expand Down Expand Up @@ -374,9 +374,10 @@ def _forward(_model_output=model_output):

original_loss = self.get_loss_fct(pre_logits, tokens['input_ids'])
translated_loss = self.get_loss_fct(logits_std, token_batch)
logger.info(f'TextCausalLM \t| Server loss: {original_loss: .2f} \t| Translated loss: {translated_loss: .2f}')
message = f'Loss: {original_loss:.2f}{translated_loss:.2f}'
# logger.info(f'TextCausalLM \t| Server loss: {original_loss: .2f} \t| Translated loss: {translated_loss: .2f}')

return _model_output, logits_std
return message, _model_output, logits_std

if self.config.neuron.remote_train:
return _forward() # track gradients for training
Expand Down Expand Up @@ -435,7 +436,10 @@ def _forward(_model_output=model_output):
compact_topk, _topk_tokens, _topk_probs, _floor_probs = result
# compact_topk: [sum_b(sum_k(len(phrase_k) + 1)_b)] Compacted 1-D tensor >= batch_size * (2 * topk + 1)

return _model_output, compact_topk
original_loss = self.get_loss_fct(_model_output.logits, tokens['input_ids'])
message = f'Loss: {original_loss:.2f}'

return message, _model_output, compact_topk

if self.config.neuron.remote_train:
return _forward() # track gradients for training
Expand Down
18 changes: 9 additions & 9 deletions bittensor/_neuron/text/core_server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,23 @@ def forward_generate( inputs_x:torch.FloatTensor, synapse, model_output = None):
raw_texts = [model.tokenizer.decode(out) for out in output]
tokens = [model.std_tokenizer.encode(raw_text, return_tensors="pt")[:,:synapse.num_to_generate].view(-1) for raw_text in raw_texts]
bittensor_output = pad_sequence(tokens, batch_first=True)
return model_output, bittensor_output
return None, model_output, bittensor_output

def forward_hidden_state(inputs_x:torch.FloatTensor, synapse, model_output = None):
model_output, hidden = model.encode_forward(inputs_x.to(model.device), model_output = model_output)
return model_output, hidden
message, model_output, hidden = model.encode_forward(inputs_x.to(model.device), model_output=model_output)
return message, model_output, hidden

def forward_casual_lm(inputs_x:torch.FloatTensor, synapse, model_output = None):
model_output, logits = model.encode_forward_causallm(inputs_x.to(model.device), model_output = model_output)
return model_output, logits
message, model_output, logits = model.encode_forward_causallm(inputs_x.to(model.device), model_output=model_output)
return message, model_output, logits

def forward_casual_lm_next(inputs_x: torch.FloatTensor, synapse, model_output=None):
model_output, topk_token_phrases = model.encode_forward_causallmnext(inputs_x.to(model.device),
topk=synapse.topk,
model_output=model_output)
message, model_output, topk_token_phrases = model.encode_forward_causallmnext(inputs_x.to(model.device),
topk=synapse.topk,
model_output=model_output)
# topk_token_phrases: [sum_b(sum_k(len(phrase_k) + 1)_b)] contains topk token phrases and probabilities
# Compacted 1-D tensor >= batch_size * (2 * topk + 1)
return model_output, topk_token_phrases
return message, model_output, topk_token_phrases

def optimizer_step():
optimizer.step()
Expand Down
14 changes: 7 additions & 7 deletions tests/integration_tests/test_dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,13 @@ def test_dend_del():
def test_successful_synapse():
wallet = bittensor.wallet()
def forward_generate( inputs_x, synapse, model_output = None):
return None, torch.rand(inputs_x.shape[0], synapse.num_to_generate)
return None, None, torch.rand(inputs_x.shape[0], synapse.num_to_generate)

def forward_hidden_state( inputs_x, synapse, model_output = None):
return None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__)
return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__)

def forward_casual_lm(inputs_x, synapse, model_output = None):
return None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__)
return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__)

axon = bittensor.axon (
port = 8096,
Expand Down Expand Up @@ -312,10 +312,10 @@ def faulty( inputs_x, synapse, model_output = None):
raise UnknownException

def forward_hidden_state( inputs_x, synapse, model_output = None):
return None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__)
return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__)

def forward_casual_lm(inputs_x, synapse, model_output = None):
return None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__)
return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__)

axon = bittensor.axon (
port = 8097,
Expand Down Expand Up @@ -357,10 +357,10 @@ def forward_casual_lm(inputs_x, synapse, model_output = None):
def test_missing_synapse():
wallet = bittensor.wallet()
def forward_hidden_state( inputs_x, synapse, model_output = None):
return None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__)
return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__)

def forward_casual_lm(inputs_x, synapse, model_output = None):
return None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__)
return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__)

axon = bittensor.axon (
port = 8098,
Expand Down
Loading

0 comments on commit eba5bcd

Please sign in to comment.