Skip to content

Commit

Permalink
UI updates (#796)
Browse files Browse the repository at this point in the history
* UI updates

* core server fixes

* turning off blacklisting for now

* generation fix

* UI and check updates

* small bug fixes

* constant import
  • Loading branch information
Eugene-hu committed Jun 28, 2022
1 parent b92618e commit 63ef589
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 41 deletions.
3 changes: 3 additions & 0 deletions bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def turn_console_off():
# Substrate ss58_format
__ss58_format__ = 42

# Wallet ss58 address length
__ss58_address_length__ = 48

__networks__ = [ 'local', 'nobunaga', 'nakamoto']

__datasets__ = ['ArXiv', 'BookCorpus2', 'Books3', 'DMMathematics', 'EnronEmails', 'EuroParl', 'Gutenberg_PG', 'HackerNews', 'NIHExPorter', 'OpenSubtitles', 'PhilPapers', 'UbuntuIRC', 'YoutubeSubtitles']
Expand Down
23 changes: 19 additions & 4 deletions bittensor/_axon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __new__(
synapse_last_hidden: 'Callable' = None,
synapse_causal_lm: 'Callable' = None,
synapse_seq_2_seq: 'Callable' = None,
synapse_checks: 'Callable' = None,
thread_pool: 'futures.ThreadPoolExecutor' = None,
server: 'grpc._Server' = None,
port: int = None,
Expand Down Expand Up @@ -81,7 +82,9 @@ def __new__(
synapse_causal_lm (:obj:`callable`, `optional`):
function which is called by the causal lm synapse
synapse_seq_2_seq (:obj:`callable`, `optional`):
function which is called by the seq2seq synapse
function which is called by the seq2seq synapse
synapse_checks (:obj:`callable`, 'optional'):
function which is called before each synapse to check for stake
thread_pool (:obj:`ThreadPoolExecutor`, `optional`):
Threadpool used for processing server queries.
server (:obj:`grpc._Server`, `required`):
Expand Down Expand Up @@ -141,6 +144,8 @@ def __new__(
synapses[bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM] = synapse_causal_lm
synapses[bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ] = synapse_seq_2_seq

synapse_check_function = synapse_checks if synapse_checks != None else axon.default_synapse_check

if priority != None:
priority_threadpool = bittensor.prioritythreadpool(config=config)
else:
Expand All @@ -154,6 +159,7 @@ def __new__(
forward = forward_text,
backward = backward_text,
synapses = synapses,
synapse_checks = synapse_check_function,
priority = priority,
priority_threadpool = priority_threadpool,
forward_timeout = config.axon.forward_timeout,
Expand Down Expand Up @@ -235,6 +241,15 @@ def check_config(cls, config: 'bittensor.Config' ):
assert config.axon.port > 1024 and config.axon.port < 65535, 'port must be in range [1024, 65535]'
bittensor.wallet.check_config( config )

@classmethod
def default_synapse_check(cls, synapse, hotkey ):
""" default synapse check function
"""
if len(hotkey) == bittensor.__ss58_address_length__:
return True

return False

@staticmethod
def check_backward_callback( backward_callback:Callable, pubkey:str = '_' ):
""" Check and test axon backward callback function
Expand All @@ -255,13 +270,13 @@ def check_forward_callback( forward_callback:Callable, synapses:list = []):
"""
if not inspect.ismethod(forward_callback) and not inspect.isfunction(forward_callback):
raise ValueError('The axon forward callback must be a function with signature Callable[inputs_x: torch.Tensor] -> torch.FloatTensor:, got {}'.format(forward_callback))
if len( inspect.signature(forward_callback).parameters) != 2:
raise ValueError('The axon forward callback must have signature Callable[ inputs_x: torch.Tensor, synapses] -> torch.FloatTensor:, got {}'.format(inspect.signature(forward_callback)))
if len( inspect.signature(forward_callback).parameters) != 3:
raise ValueError('The axon forward callback must have signature Callable[ inputs_x: torch.Tensor, synapses, hotkey] -> torch.FloatTensor:, got {}'.format(inspect.signature(forward_callback)))
if 'inputs_x' not in inspect.signature(forward_callback).parameters:
raise ValueError('The axon forward callback must have signature Callable[ inputs_x: torch.Tensor] -> torch.FloatTensor:, got {}'.format(inspect.signature(forward_callback)))

sample_input = torch.randint(0,1,(3, 3))
forward_callback([sample_input], synapses)
forward_callback([sample_input], synapses, hotkey='')

class AuthInterceptor(grpc.ServerInterceptor):
""" Creates a new server interceptor that authenticates incoming messages from passed arguments.
Expand Down
26 changes: 19 additions & 7 deletions bittensor/_axon/axon_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
forward: 'Callable',
backward: 'Callable',
synapses: dict,
synapse_checks: 'Callable',
priority: 'Callable' = None,
priority_threadpool: 'bittensor.prioritythreadpool' = None,
forward_timeout: int = None,
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
self.forward_timeout = forward_timeout
self.backward_timeout = backward_timeout
self.synapse_callbacks = synapses
self.synapse_checks = synapse_checks
self.stats = self._init_stats()
self.started = None
self.optimizer_step = None
Expand Down Expand Up @@ -277,24 +279,24 @@ def finalize_codes_stats_and_logs():
self.forward_callback,
inputs_x = deserialized_forward_tensors,
synapses = synapses,
priority = priority
priority = priority,
hotkey= request.hotkey
)
forward_response_tensors, forward_codes, forward_messages = future.result( timeout= self.forward_timeout )
else:

forward_response_tensors, forward_codes, forward_messages = self.forward_callback(
inputs_x = deserialized_forward_tensors,
synapses = synapses,
hotkey= request.hotkey
)

synapse_is_response = [ True for _ in synapses ]
# ========================================
# ==== Fill codes from forward calls ====
# ========================================
for index, synapse in enumerate(synapses):
synapse_codes [ index ] = forward_codes [ index ]
synapse_messages [index] = forward_messages [ index ]

# ========================================
# ==== Catch forward request timeouts ====
# ========================================
Expand Down Expand Up @@ -351,7 +353,6 @@ def finalize_codes_stats_and_logs():
finalize_codes_stats_and_logs()
return [], synapse_codes[0], request.synapses


# =========================================================
# ==== Set return times for successfull forward ===========
# =========================================================
Expand Down Expand Up @@ -559,7 +560,7 @@ def finalize_codes_stats_and_logs():
finalize_codes_stats_and_logs()
return [], bittensor.proto.ReturnCode.Success, request.synapses

def default_forward_callback(self, inputs_x:torch.FloatTensor, synapses=[] ):
def default_forward_callback(self, inputs_x:torch.FloatTensor, synapses=[], hotkey = None):
"""
The default forward callback when no callback is attached: Is used to call specific synapse functions
Expand All @@ -570,6 +571,9 @@ def default_forward_callback(self, inputs_x:torch.FloatTensor, synapses=[] ):
synapses (:obj: list of bittensor.proto.SynapseArgs, 'Optional')
The proto message that contains additional args for individual synapse functions
hotkey (:obj: str of the hotkey, 'Optional')
The hotkey of the validator who sent the request
Returns:
response_tensors: (:obj: list of bittensor.proto.Tensor, `required`):
serialized tensor response from the nucleus call or None.
Expand All @@ -587,17 +591,25 @@ def default_forward_callback(self, inputs_x:torch.FloatTensor, synapses=[] ):
# --- calling attached synapses ---
for index, synapse in enumerate(synapses):
try:
if synapse.synapse_type in self.synapse_callbacks and self.synapse_callbacks[synapse.synapse_type] != None:
synapse_check = self.synapse_checks(synapse, hotkey)

if synapse.synapse_type in self.synapse_callbacks and self.synapse_callbacks[synapse.synapse_type] != None and synapse_check:
model_output, response_tensor = self.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse, model_output)
response_tensors.append(response_tensor)
response_codes.append(bittensor.proto.ReturnCode.Success)
response_messages.append('Success')

elif not synapse_check:
response_tensors.append(None)
response_codes.append(bittensor.proto.ReturnCode.UnknownException)
response_messages.append('Synapse Check Failed')

else:
response_tensors.append(None)
response_codes.append(bittensor.proto.ReturnCode.NotImplemented)
response_messages.append('Not Implemented')

except Exception as e:
except Exception as e:
# --- Exception Hit in Synapse ---
response_tensors.append(None)
response_codes.append(bittensor.proto.ReturnCode.UnknownException)
Expand Down
13 changes: 13 additions & 0 deletions bittensor/_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ def config() -> 'bittensor.config':
default='None',
help='''Miners available through bittensor.neurons'''
)

run_parser.add_argument(
'--synapse',
type=str,
choices= list(bittensor.synapse.__synapses_types__),
default='None',
help='''Synapses available through bittensor.synapse'''
)

bittensor.subtensor.add_args( run_parser )
bittensor.wallet.add_args( run_parser )

Expand Down Expand Up @@ -802,6 +811,10 @@ def check_run_config( config: 'bittensor.Config' ):
if config.model == 'None' and not config.no_prompt:
model = Prompt.ask('Enter miner name', choices = list(bittensor.neurons.__text_neurons__.keys()), default = 'template_miner')
config.model = model

if 'server' in config.model and not config.no_prompt:
synapse = Prompt.ask('Enter synapse', choices = list(bittensor.synapse.__synapses_types__), default = 'All')
config.synapse = synapse

def check_help_config( config: 'bittensor.Config'):
if config.model == 'None':
Expand Down
12 changes: 11 additions & 1 deletion bittensor/_cli/cli_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,18 @@ def run_miner ( self ):
# Run miner.
if self.config.model == 'template_miner':
bittensor.neurons.template_miner.neuron().run()

elif self.config.model == 'core_server':
bittensor.neurons.core_server.neuron().run()

if self.config.synapse == 'TextLastHiddenState':
bittensor.neurons.core_server.neuron(lasthidden=True, causallm=False, seq2seq = False).run()
elif self.config.synapse == 'TextCausalLM':
bittensor.neurons.core_server.neuron(lasthidden=False, causallm=True, seq2seq = False).run()
elif self.config.synapse == 'TextSeq2Seq':
bittensor.neurons.core_server.neuron(lasthidden=False, causallm=False, seq2seq = True).run()
else:
bittensor.neurons.core_server.neuron().run()

elif self.config.model == 'core_validator':
bittensor.neurons.core_validator.neuron().run()
elif self.config.model == 'multitron_server':
Expand Down
31 changes: 31 additions & 0 deletions bittensor/_neuron/text/core_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ class neuron:
bittensor axon object
metagraph (:obj:bittensor.metagraph, `optional`):
bittensor metagraph object
lasthidden (:obj:bool, `optional`):
lasthidden synapse control
causallm (:obj:bool, `optional`):
causallm synapse control
seq2seq (:obj:bittensor.metagraph, `optional`):
seq2seq synapse control
synapse_list (:obj:list of int, `optional`):
Examples::
>>> subtensor = bittensor.subtensor(network='nakamoto')
Expand All @@ -58,10 +66,33 @@ def __init__(
wallet: 'bittensor.wallet' = None,
axon: 'bittensor.axon' = None,
metagraph: 'bittensor.metagraph' = None,
lasthidden = None,
causallm = None,
seq2seq = None,
synapse_list = None,

):
if config == None: config = server.config()
config = config;

if synapse_list != None:
config.neuron.lasthidden = False
config.neuron.causallm = False
config.neuron.seq2seq = False

if bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE in synapse_list:
config.neuron.lasthidden = True

if bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM in synapse_list:
config.neuron.causallm = True

if bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ in synapse_list:
config.neuron.seq2seq = True

config.neuron.lasthidden = lasthidden if lasthidden != None else config.neuron.lasthidden
config.neuron.causallm = causallm if causallm != None else config.neuron.causallm
config.neuron.seq2seq = seq2seq if seq2seq != None else config.neuron.seq2seq

self.check_config( config )
bittensor.logging (
config = config,
Expand Down
12 changes: 8 additions & 4 deletions bittensor/_neuron/text/core_server/nucleus_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def encode_forward_causallm(self, token_batch, tokenizer=None, encode_len=bitten
The nucleus's logit outputs as a torch tensor of shape [batch_size, sequence_len, __vocab_size__]
"""
tokens = self.remapping_token_causallm(token_batch, tokenizer) # remap to server tokenizer

if model_output == None:
if self.config.neuron.remote_train or (self.config.neuron.autocast and self.device[:4] == 'cuda'):
model_output = self.pre_model(input_ids=tokens['input_ids'],
Expand Down Expand Up @@ -386,13 +385,12 @@ def remapping_token_causallm(self, token_batch, std_tokenizer=None):
tokens['offset_mapping'] = pad_offsets(tokens['offset_mapping'], to_offsets_batch, pad_offsets_batch)

tokens['offset_mapping_std'] = std_tokens['offset_mapping'] # include std token info

for key in ['input_ids', 'attention_mask']: # form a torch batch tensor
padded_tokens= pad_sequence([torch.LongTensor(tensor) for tensor in tokens[key]], batch_first=True)
tokens[key] = torch.zeros(token_batch.shape, dtype = torch.long)
tokens[key] = torch.zeros(padded_tokens.shape, dtype = torch.long)
tokens[key][:, :padded_tokens.shape[1]] = padded_tokens
tokens[key] = torch.LongTensor(tokens[key])

return tokens

def get_loss_fct(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
Expand Down Expand Up @@ -476,6 +474,12 @@ def config ():
parser.add_argument('--neuron.blacklist_allow_non_registered', action='store_true', help='''If true, allow non-registered peers''', default=False)
parser.add_argument('--neuron.local_train', action='store_true', help='''If true, allow local training''', default=False)
parser.add_argument('--neuron.remote_train', action='store_true', help='''If true, allow remote training''', default=False)
parser.add_argument('--neuron.lasthidden', action='store_false', help='To turn off last hidden synapse', default=True)
parser.add_argument('--neuron.causallm', action='store_false', help='To turn off causallm synapse', default=True)
parser.add_argument('--neuron.seq2seq', action='store_false', help='To turn off seq2seq synapse', default=True)
parser.add_argument('--neuron.lasthidden_stake', type = float, help='the amount of stake to run last hidden synapse',default=0)
parser.add_argument('--neuron.causallm_stake', type = float, help='the amount of stake to run causallm synapse',default=0)
parser.add_argument('--neuron.seq2seq_stake', type = float, help='the amount of stake to run seq2seq synapse',default=0)
parser.add_argument('--neuron.finetune.all', action='store_true', help='Finetune your whole model instead of only on the last (few) layers', default=False)
parser.add_argument('--neuron.finetune.num_layers', type=int, help='The number of layers to finetune on your model.', default=1)
parser.add_argument('--neuron.finetune.layer_name', type=str, help='Specify since which layer to finetune. eg. encoder.layer.11', default=None)
Expand Down
Loading

0 comments on commit 63ef589

Please sign in to comment.