diff --git a/.circleci/config.yml b/.circleci/config.yml index 7855427ace..c95b9ce1c8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2.1 orbs: python: circleci/python@2.0.3 - coveralls: coveralls/coveralls@1.0.6 + # coveralls: coveralls/coveralls@1.0.6 jobs: build-and-test: @@ -76,21 +76,21 @@ jobs: - store_artifacts: path: test-results - - when: - condition: - equal: ["3.10.5", << parameters.python-version >> ] - steps: - - run: - name: Upload Coverage - command: | - . env/bin/activate && coveralls - env: - CI_NAME: circleci - CI_BUILD_NUMBER: $CIRCLE_BUILD_NUM - CI_BUILD_URL: $CIRCLE_BUILD_URL - CI_BRANCH: $CIRCLE_BRANCH - CI_JOB_ID: $CIRCLE_NODE_INDEX - COVERALLS_PARALLEL: true + #- when: + #condition: + #equal: ["3.10.5", << parameters.python-version >> ] + #steps: + #- run: + #name: Upload Coverage + #command: | + #. env/bin/activate && coveralls + #env: + #CI_NAME: circleci + #CI_BUILD_NUMBER: $CIRCLE_BUILD_NUM + #CI_BUILD_URL: $CIRCLE_BUILD_URL + #CI_BRANCH: $CIRCLE_BRANCH + #CI_JOB_ID: $CIRCLE_NODE_INDEX + #COVERALLS_PARALLEL: true unit-tests-all-python-versions: docker: @@ -120,6 +120,6 @@ workflows: - unit-tests-all-python-versions: requires: - build-and-test - - coveralls: - requires: - - build-and-test + #- coveralls: + #requires: + #- build-and-test diff --git a/README.md b/README.md index 385f143f75..4af0fe67e1 100644 --- a/README.md +++ b/README.md @@ -230,19 +230,19 @@ The template server follows a similar structure as the template miner. ```bash $ cd bittensor -$ python3 ./bittensor/_neuron/text/template_server/main.py --wallet.name --wallet.hotkey +$ python3 ./bittensor/_neuron/text/core_server/main.py --wallet.name --wallet.hotkey ``` or ```python3 >> import bittensor ->> bittensor.neurons.text.template_server.neuron().run() +>> bittensor.neurons.text.core_server.neuron().run() ``` For the full list of settings, please run ```bash $ cd bittensor -$ python3 ./bittensor/_neuron/text/template_server/main.py --help +$ python3 ./bittensor/_neuron/text/core_server/main.py --help ``` diff --git a/benchmarks/advanced_server.py b/benchmarks/advanced_server.py deleted file mode 100644 index 8a509463b4..0000000000 --- a/benchmarks/advanced_server.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" Benchmarking pytest fixture. - -Example: - $ python3 benchmarks/advanced_server.py --neuron.model_name albert-base-v1 - -""" -from benchmarks import QueryBenchmark -import multiprocessing -import bittensor - -class Benchmark ( QueryBenchmark ): - r""" Benchmark pytest class. - """ - - @staticmethod - def miner_name() -> str: - r""" Return miner name - """ - return 'advanced_server' - - @staticmethod - def run_neuron( config , subtensor, metagraph, wallet ): - r""" To be implemented in the subclass, runs the neuron. - Args: - config (bittensor.Config) - Run config - """ - bittensor.neurons.text.advanced_server.neuron( config,subtensor=subtensor, metagraph=metagraph,wallet=wallet).run() - - @staticmethod - def config() -> 'bittensor.Config': - r""" Return config - Returns: - config (bittensor.Config) - Run config. - """ - config = bittensor.neurons.text.advanced_server.neuron.config() - config.neuron.blacklist.stake.forward = 0 - config.neuron.blacklist.stake.backward = 0 - config.neuron.blacklist_allow_non_registered = True - config.neuron.blacklist.time = False - return config - - -if __name__ == '__main__': - benchmark = Benchmark() - benchmark.run() - diff --git a/benchmarks/template_server.py b/benchmarks/core_server.py similarity index 86% rename from benchmarks/template_server.py rename to benchmarks/core_server.py index 558602d6dd..0bebe6d3f6 100644 --- a/benchmarks/template_server.py +++ b/benchmarks/core_server.py @@ -18,7 +18,7 @@ """ Benchmarking pytest fixture. Example: - $ python3 benchmarks/template_server.py --neuron.model_name albert-base-v1 + $ python3 benchmarks/core_server.py --neuron.model_name albert-base-v1 """ from benchmarks import QueryBenchmark @@ -33,7 +33,7 @@ class Benchmark ( QueryBenchmark ): def miner_name() -> str: r""" Return miner name """ - return 'template_server' + return 'core_server' @staticmethod def run_neuron( config , subtensor, metagraph, wallet ): @@ -42,7 +42,7 @@ def run_neuron( config , subtensor, metagraph, wallet ): config (bittensor.Config) Run config """ - bittensor.neurons.text.template_server.neuron( config,subtensor=subtensor, metagraph=metagraph,wallet=wallet).run() + bittensor.neurons.text.core_server.neuron( config,subtensor=subtensor, metagraph=metagraph,wallet=wallet).run() @staticmethod def config() -> 'bittensor.Config': @@ -51,7 +51,7 @@ def config() -> 'bittensor.Config': config (bittensor.Config) Run config. """ - config = bittensor.neurons.text.template_server.neuron.config() + config = bittensor.neurons.text.core_server.neuron.config() return config diff --git a/bittensor/__init__.py b/bittensor/__init__.py index ab6c16a943..c92a69c810 100644 --- a/bittensor/__init__.py +++ b/bittensor/__init__.py @@ -18,7 +18,7 @@ from rich.console import Console # Bittensor code and protocol version. -__version__ = '2.0.4' +__version__ = '3.0.0' version_split = __version__.split(".") __version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) @@ -34,7 +34,7 @@ def turn_console_off(): # Vocabulary dimension. #__vocab_size__ = len( tokenizer ) + len( tokenizer.additional_special_tokens) + 100 # Plus 100 for eventual token size increase. -__vocab_size__ = 50378 +__vocab_size__ = 50258 # Tensor dimension. # NOTE (const): if/when this increases peers must be responsible for trimming or expanding output to this size. @@ -49,6 +49,9 @@ def turn_console_off(): # Substrate ss58_format __ss58_format__ = 42 +# Wallet ss58 address length +__ss58_address_length__ = 48 + __networks__ = [ 'local', 'nobunaga', 'nakamoto'] __datasets__ = ['ArXiv', 'BookCorpus2', 'Books3', 'DMMathematics', 'EnronEmails', 'EuroParl', 'Gutenberg_PG', 'HackerNews', 'NIHExPorter', 'OpenSubtitles', 'PhilPapers', 'UbuntuIRC', 'YoutubeSubtitles'] @@ -102,6 +105,7 @@ def turn_console_off(): from bittensor._subtensor import subtensor as subtensor from bittensor._tokenizer import tokenizer as tokenizer from bittensor._serializer import serializer as serializer +from bittensor._synapse import synapse as synapse from bittensor._dataset import dataset as dataset from bittensor._receptor import receptor_pool as receptor_pool from bittensor._wandb import wandb as wandb @@ -122,7 +126,12 @@ def turn_console_off(): from bittensor._dataset.dataset_impl import Dataset as Dataset from bittensor._receptor.receptor_pool_impl import ReceptorPool as ReceptorPool from bittensor._threadpool.priority_thread_pool_impl import PriorityThreadPoolExecutor as PriorityThreadPoolExecutor -from bittensor._ipfs.ipfs_impl import Ipfs +from bittensor._ipfs.ipfs_impl import Ipfs as Ipfs +from bittensor._synapse.synapse_impl import Synapse as Synapse +from bittensor._synapse.text_causallm_impl import TextCausalLM as TextCausalLM +from bittensor._synapse.text_causallmnext_impl import TextCausalLMNext as TextCausalLMNext +from bittensor._synapse.text_lasthiddenstate_impl import TextLastHiddenState as TextLastHiddenState +from bittensor._synapse.text_seq2seq_impl import TextSeq2Seq as TextSeq2Seq # DEFAULTS defaults = Config() diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index eabdcd0902..0018ed26c7 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -51,10 +51,11 @@ def __new__( wallet: 'bittensor.Wallet' = None, forward_text: 'Callable' = None, backward_text: 'Callable' = None, - forward_image: 'Callable' = None, - backward_image: 'Callable' = None, - forward_tensor: 'Callable' = None, - backward_tensor: 'Callable' = None, + synapse_last_hidden: 'Callable' = None, + synapse_causal_lm: 'Callable' = None, + synapse_causal_lm_next: 'Callable' = None, + synapse_seq_2_seq: 'Callable' = None, + synapse_checks: 'Callable' = None, thread_pool: 'futures.ThreadPoolExecutor' = None, server: 'grpc._Server' = None, port: int = None, @@ -77,14 +78,16 @@ def __new__( function which is called on forward text requests. backward_text (:obj:`callable`, `optional`): function which is called on backward text requests. - forward_image (:obj:`callable`, `optional`): - function which is called on forward image requests. - backward_image (:obj:`callable`, `optional`): - function which is called on backward image requests. - forward_tensor (:obj:`callable`, `optional`): - function which is called on forward tensor requests. - backward_tensor (:obj:`callable`, `optional`): - function which is called on backward tensor requests. + synapse_last_hidden (:obj:`callable`, `optional`): + function which is called by the last hidden synapse + synapse_causal_lm (:obj:`callable`, `optional`): + function which is called by the causal lm synapse + synapse_causal_lm_next (:obj:`callable`, `optional`): + function which is called by the TextCausalLMNext synapse + synapse_seq_2_seq (:obj:`callable`, `optional`): + function which is called by the seq2seq synapse + synapse_checks (:obj:`callable`, 'optional'): + function which is called before each synapse to check for stake thread_pool (:obj:`ThreadPoolExecutor`, `optional`): Threadpool used for processing server queries. server (:obj:`grpc._Server`, `required`): @@ -139,8 +142,13 @@ def __new__( ('grpc.keepalive_timeout_ms', 500000)] ) - forwards = [forward_text, forward_image, forward_tensor] - backwards = [backward_text, backward_image, backward_tensor] + synapses = {} + synapses[bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE] = synapse_last_hidden + synapses[bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM] = synapse_causal_lm + synapses[bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT] = synapse_causal_lm_next + synapses[bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ] = synapse_seq_2_seq + + synapse_check_function = synapse_checks if synapse_checks != None else axon.default_synapse_check if priority != None: priority_threadpool = bittensor.prioritythreadpool(config=config) @@ -152,8 +160,10 @@ def __new__( server = server, ip = config.axon.ip, port = config.axon.port, - forwards = forwards, - backwards = backwards, + forward = forward_text, + backward = backward_text, + synapses = synapses, + synapse_checks = synapse_check_function, priority = priority, priority_threadpool = priority_threadpool, forward_timeout = config.axon.forward_timeout, @@ -200,7 +210,7 @@ def add_args( cls, parser: argparse.ArgumentParser, prefix: str = None ): parser.add_argument('--' + prefix_str + 'axon.backward_timeout', type=int, help='Number of seconds to wait for backward axon request', default=2*bittensor.__blocktime__) parser.add_argument('--' + prefix_str + 'axon.forward_timeout', type=int, - help='Number of seconds to wait for forward axon request', default=bittensor.__blocktime__) + help='Number of seconds to wait for forward axon request', default=5*bittensor.__blocktime__) parser.add_argument('--' + prefix_str + 'axon.priority.max_workers', type = int, help='''maximum number of threads in thread pool''', default = bittensor.defaults.axon.priority.max_workers) parser.add_argument('--' + prefix_str + 'axon.priority.maxsize', type=int, @@ -217,13 +227,13 @@ def add_args( cls, parser: argparse.ArgumentParser, prefix: str = None ): def add_defaults(cls, defaults): """ Adds parser defaults to object from enviroment variables. """ - defaults.axon = bittensor.Config() + defaults.axon = bittensor.config() defaults.axon.port = os.getenv('BT_AXON_PORT') if os.getenv('BT_AXON_PORT') != None else 8091 defaults.axon.ip = os.getenv('BT_AXON_IP') if os.getenv('BT_AXON_IP') != None else '[::]' defaults.axon.max_workers = os.getenv('BT_AXON_MAX_WORERS') if os.getenv('BT_AXON_MAX_WORERS') != None else 10 defaults.axon.maximum_concurrent_rpcs = os.getenv('BT_AXON_MAXIMUM_CONCURRENT_RPCS') if os.getenv('BT_AXON_MAXIMUM_CONCURRENT_RPCS') != None else 400 - defaults.axon.priority = bittensor.Config() + defaults.axon.priority = bittensor.config() defaults.axon.priority.max_workers = os.getenv('BT_AXON_PRIORITY_MAX_WORKERS') if os.getenv('BT_AXON_PRIORITY_MAX_WORKERS') != None else 10 defaults.axon.priority.maxsize = os.getenv('BT_AXON_PRIORITY_MAXSIZE') if os.getenv('BT_AXON_PRIORITY_MAXSIZE') != None else -1 @@ -236,56 +246,42 @@ def check_config(cls, config: 'bittensor.Config' ): assert config.axon.port > 1024 and config.axon.port < 65535, 'port must be in range [1024, 65535]' bittensor.wallet.check_config( config ) + @classmethod + def default_synapse_check(cls, synapse, hotkey ): + """ default synapse check function + """ + if len(hotkey) == bittensor.__ss58_address_length__: + return True + + return False + @staticmethod - def check_backward_callback( backward_callback:Callable, modality:int, pubkey:str = '_' ): + def check_backward_callback( backward_callback:Callable, pubkey:str = '_' ): """ Check and test axon backward callback function """ if not inspect.ismethod(backward_callback) and not inspect.isfunction(backward_callback): raise ValueError('The axon backward callback must be a function with signature Callable[inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor ) -> torch.FloatTensor:, got {}'.format(backward_callback)) - if len( inspect.signature(backward_callback).parameters) != 2: - raise ValueError('The axon backward callback must have signature Callable[ inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor ) -> torch.FloatTensor:, got {}'.format(inspect.signature(backward_callback))) + if len( inspect.signature(backward_callback).parameters) != 3: + raise ValueError('The axon backward callback must have signature Callable[ inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, synapses ) -> torch.FloatTensor:, got {}'.format(inspect.signature(backward_callback))) if 'inputs_x' not in inspect.signature(backward_callback).parameters: raise ValueError('The axon backward callback must have signature Callable[inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor ) -> torch.FloatTensor:, got {}'.format(inspect.signature(backward_callback))) if 'grads_dy' not in inspect.signature(backward_callback).parameters: raise ValueError('The axon backward callback must have signature Callable[inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor ) -> torch.FloatTensor:, got {}'.format(inspect.signature(backward_callback))) - if modality == bittensor.proto.Modality.TEXT: - sample_input = torch.randint(0,1,(3, 3)) - grads_raw = torch.rand(3, 3, bittensor.__network_dim__) - backward_callback(sample_input,grads_raw) - - if modality == bittensor.proto.Modality.IMAGE: - sample_input = torch.rand(1,1,3,512,512) - grads_raw = torch.rand(512, 512, bittensor.__network_dim__) - backward_callback(sample_input,grads_raw) - - if modality == bittensor.proto.Modality.TENSOR: - sample_input = torch.rand(1,1,1) - grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - backward_callback(sample_input,grads_raw) @staticmethod - def check_forward_callback( forward_callback:Callable, modality:int, pubkey:str = '_'): + def check_forward_callback( forward_callback:Callable, synapses:list = []): """ Check and test axon forward callback function """ if not inspect.ismethod(forward_callback) and not inspect.isfunction(forward_callback): raise ValueError('The axon forward callback must be a function with signature Callable[inputs_x: torch.Tensor] -> torch.FloatTensor:, got {}'.format(forward_callback)) - if len( inspect.signature(forward_callback).parameters) != 1: - raise ValueError('The axon forward callback must have signature Callable[ inputs_x: torch.Tensor] -> torch.FloatTensor:, got {}'.format(inspect.signature(forward_callback))) + if len( inspect.signature(forward_callback).parameters) != 3: + raise ValueError('The axon forward callback must have signature Callable[ inputs_x: torch.Tensor, synapses, hotkey] -> torch.FloatTensor:, got {}'.format(inspect.signature(forward_callback))) if 'inputs_x' not in inspect.signature(forward_callback).parameters: raise ValueError('The axon forward callback must have signature Callable[ inputs_x: torch.Tensor] -> torch.FloatTensor:, got {}'.format(inspect.signature(forward_callback))) - if modality == bittensor.proto.Modality.TEXT: - sample_input = torch.randint(0,1,(3, 3)) - forward_callback(sample_input) - - if modality == bittensor.proto.Modality.IMAGE: - sample_input = torch.rand(1,1,3,512,512) - forward_callback(sample_input) - - if modality == bittensor.proto.Modality.TENSOR: - sample_input = torch.rand(1,1,1) - forward_callback(sample_input) + sample_input = torch.randint(0,1,(3, 3)) + forward_callback([sample_input], synapses, hotkey='') class AuthInterceptor(grpc.ServerInterceptor): """ Creates a new server interceptor that authenticates incoming messages from passed arguments. diff --git a/bittensor/_axon/axon_impl.py b/bittensor/_axon/axon_impl.py index b20f49a639..2aafa4f9b7 100644 --- a/bittensor/_axon/axon_impl.py +++ b/bittensor/_axon/axon_impl.py @@ -44,8 +44,10 @@ def __init__( ip: str, port: int, server: 'grpc._Server', - forwards: List = [], - backwards: List = [], + forward: 'Callable', + backward: 'Callable', + synapses: dict, + synapse_checks: 'Callable', priority: 'Callable' = None, priority_threadpool: 'bittensor.prioritythreadpool' = None, forward_timeout: int = None, @@ -73,13 +75,15 @@ def __init__( self.port = port self.wallet = wallet self.server = server - self.forward_callback = forwards - self.backward_callback = backwards + self.forward_callback = forward if forward != None else self.default_forward_callback + self.backward_callback = backward if backward != None else self.default_backward_callback self.forward_timeout = forward_timeout self.backward_timeout = backward_timeout - self.modality = self.find_modality() + self.synapse_callbacks = synapses + self.synapse_checks = synapse_checks self.stats = self._init_stats() self.started = None + self.optimizer_step = None # -- Priority self.priority = priority @@ -107,20 +111,18 @@ def Forward(self, request: bittensor.proto.TensorMessage, context: grpc.Servicer response (bittensor.proto.TensorMessage): proto response carring the nucleus forward output or None under failure. """ - tensor, code, time, message = self._forward( request ) + forward_response_tensors, code, synapses = self._forward( request ) response = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = self.wallet.hotkey.ss58_address, return_code = code, - message = message, - tensors = [tensor] if tensor is not None else [], - requires_grad = True, + tensors = forward_response_tensors if forward_response_tensors is not None else [], + requires_grad = request.requires_grad, + synapses = synapses, ) - # ---- Update stats for this request. - self.update_stats_for_request( request, response, time, code) return response - def Backward(self, request: bittensor.proto.TensorMessage, context: grpc.ServicerContext) -> bittensor.proto.TensorMessage: + def Backward( self, request: bittensor.proto.TensorMessage, context: grpc.ServicerContext ) -> bittensor.proto.TensorMessage: r""" The function called by remote GRPC Backward requests from other neurons. Backward is equivalent to a 'backward' gradient descent pass through a neural network. After checking request validity, passes the request to the nucleus for processing. @@ -136,147 +138,20 @@ def Backward(self, request: bittensor.proto.TensorMessage, context: grpc.Service response (:obj:`bittensor.proto.TensorMessage`): proto response carring the nucleus backward output or None under failure. """ - tensor, code, time, message = self._backward( request ) + backward_response_tensors, code, synapses = self._backward( request ) response = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = self.wallet.hotkey.ss58_address, return_code = code, - message = message, - tensors = [tensor] if tensor is not None else [], - requires_grad = True, + tensors = backward_response_tensors, + requires_grad = request.requires_grad, + synapses = synapses ) - self.update_stats_for_request( request, response, time, code ) return response - def _call_forward( - self, - public_key: str, - inputs_x: torch.Tensor, - modality: bittensor.proto.Modality - ) -> Tuple[ torch.FloatTensor, int, str ]: - r""" Calls the forward callback served by the nucleus. - - Args: - public_key (str, `required`): - public key of the sender - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - modality ( bittensor.proto.Modality, `required`): - modality of inputs. - - Returns: - response (:obj:`torch.FloatTensor, `required`): - Torch tensor response from miner processes. - code (:obj:`bittensor.proto.ReturnCode, `required`) - return code associated with forward call i.e. Success of Timeout. - message (str, `required`): - message associated with forward call, potentially error, or 'success'. - - """ - - # Check forward has been subscribed. - if self.forward_callback[modality] == None: - message = "Forward callback is not yet subscribed on this axon." - return None, bittensor.proto.ReturnCode.NotImplemented, message - - # Make forward call. - try: - if self.priority != None: - priority = self.priority(public_key,inputs_x=inputs_x, request_type = bittensor.proto.RequestType.FORWARD) - future = self.priority_threadpool.submit(self.forward_callback[modality],inputs_x=inputs_x,priority=priority) - - try: - response_tensor = future.result(timeout= self.forward_timeout) - except concurrent.futures.TimeoutError : - raise TimeoutError('TimeOutError') - except Exception as e: - logger.error('Error found: {}, with message {}'.format(repr(e), e)) - - else: - response_tensor = self.forward_callback[modality]( inputs_x= inputs_x) - - message = "Success" - code = bittensor.proto.ReturnCode.Success - return response_tensor, code, message - - except Exception as e: - response_tensor = None - message = "Error calling forward callback: {}".format(e) - if isinstance(e, TimeoutError): - code = bittensor.proto.ReturnCode.Timeout - else: - code = bittensor.proto.ReturnCode.UnknownException - return response_tensor, code, message - - def _call_backward( - self, - public_key: str, - inputs_x: torch.Tensor, - grads_dy: torch.FloatTensor, - modality: bittensor.proto.Modality - ) -> Tuple[ torch.FloatTensor, int, str ]: - r""" Calls the backward callback. - - Args: - public_key (str, `required`): - public key of the sender - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be backward processed. - grads_dy ( :obj:`torch.Tensor`, `required`): - torch gradient inputs to be backward processed with inputs. - modality ( bittensor.proto.Modality, `required`): - modality of inputs. - - Returns: - response (:obj:`torch.FloatTensor, `required`): - Torch tensor response from miner processes. - code (:obj:`bittensor.proto.ReturnCode, `required`) - return code associated with forward call i.e. Success of Timeout. - message (str, `required`): - message associated with forward call, potentially error, or 'success'. - """ - # Check backward has been subscribed. - if self.backward_callback[modality] == None: - message = "Backward callback is not yet subscribed on this axon." - return None, bittensor.proto.ReturnCode.NotImplemented, message - - if modality == bittensor.proto.Modality.TEXT: - if self.priority != None: - try: - priority = self.priority(public_key,inputs_x=inputs_x, request_type = bittensor.proto.RequestType.BACKWARD) - future = self.priority_threadpool.submit(self.backward_callback[modality],inputs_x=inputs_x,grads_dy=grads_dy,priority=priority) - except concurrent.futures.TimeoutError : - raise TimeoutError('TimeOutError') - except Exception as e: - logger.error('Error found: {}, with message {}'.format(repr(e), e)) - else: - self.backward_callback[modality](inputs_x, grads_dy) - - response_tensor = torch.ones(inputs_x.size()) - message = "Success" - code = bittensor.proto.ReturnCode.Success - return response_tensor, code, message - - # Make backward call. - try: - response_tensor = self.backward_callback[modality]( inputs_x, grads_dy) - message = "Success" - code = bittensor.proto.ReturnCode.Success - return response_tensor, code, message - - except Exception as e: - response_tensor = None - message = "Error calling backward callback: {}".format(e) - if isinstance(e, TimeoutError): - code = bittensor.proto.ReturnCode.Timeout - else: - code = bittensor.proto.ReturnCode.UnknownException - - return response_tensor, code, message - def _forward(self, request): r""" Performs validity checks on the grpc request before passing the tensors to the forward queue. - Returns the output, message and code from the backend forward call. + Returns the outputs and synapses from the backend forward call. Args: request (:obj:`bittensor.proto`, `required`): @@ -284,279 +159,553 @@ def _forward(self, request): Returns: response (:obj:`bittensor.proto.Tensor, `required`): serialized tensor response from the nucleus call or None. - code (:obj:`bittensor.proto.ReturnCode, `required`) - return code associated with forward call i.e. Success of Timeout. - time (:type:`float`, `required`): - Length of call in seconds. - message (str, `required`): - message associated with forward call, potentially error, or 'success'. + code (:obj:`bittensor.proto.ReturnCode`, `required`): + Code from the call. This specifies if the overall function call was a success. + This is separate from the synapse returns codes which relate to the individual synapse call. + synapses (:obj:`List[ 'bittensor.proto.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Synapse wire protos with return codes from forward request. """ + # =================================================================== + # ==== First deserialize synapse wire protos to instance objects ==== + # =================================================================== + synapses: List['bittensor.Synapse'] = [] + for synapse_wire_proto in request.synapses: + synapses.append( bittensor.synapse.deserialize( synapse_wire_proto ) ) + + + # =================================== + # ==== Init params from synapses ==== + # =================================== + # These items are filled through the call and the function returns + # when all codes are non-success or the function finishes completely. + synapse_messages = [ "Success" for _ in synapses ] + synapse_codes = [ bittensor.proto.ReturnCode.Success for _ in synapses ] + synapse_inputs = [ None for _ in synapses ] + synapse_responses = [ synapse.empty() for synapse in synapses ] # We fill nones for non success. + synapse_is_response = [ False for _ in synapses ] + synapse_call_times = [ 0 for _ in synapses ] start_time = clock.time() - try: - # ---- Check Empty request ---- - if len(request.tensors) == 0: - code = bittensor.proto.ReturnCode.EmptyRequest - message = "Forward request contains {} tensors, expected 1 tensor in the forward call".format(len(request.tensors)) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=None, outputs=None, message=message ) - return None, code, call_time, message - - # ---- Check deserialization ---- - tensor_inputs = request.tensors[0] - modality = tensor_inputs.modality - try: - deserializer = bittensor.serializer( serialzer_type = tensor_inputs.serializer ) - torch_inputs = deserializer.deserialize(tensor_inputs, to_type = bittensor.proto.TensorType.TORCH) - except Exception as e: - code = bittensor.proto.ReturnCode.RequestDeserializationException - message = "Request deserialization exception: {}".format(str(e)) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=None, outputs=None, message=message ) - return None, code, call_time, message - - # ---- Check shape and modality ---- - if list(torch_inputs.shape)[0] < 1: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward request batch dim exception with batch_size = {} ".format(list(torch_inputs.shape)[0]) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - - if list(torch_inputs.shape)[1] < 1: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward request sequence dim exception with sequence_dim = {} ".format(list(torch_inputs.shape)[1]) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - - if modality == bittensor.proto.Modality.TEXT: - if len(list(torch_inputs.shape)) != 2: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward text input shape exception with len(request.shape) = {} must have rank 2.".format(len(list(torch_inputs.shape))) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - - if modality == bittensor.proto.Modality.IMAGE: - if len(list(torch_inputs.shape)) != 5: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward image input shape exception for len(shape) = {} must have rank 5".format(len(list(torch_inputs.shape))) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - - if modality == bittensor.proto.Modality.TENSOR: - if len(list(torch_inputs.shape)) != 3: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward message tensor input shape exception len(shape) = {} must have rank 3".format(len(list(torch_inputs.shape))) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - except Exception as e: - code = bittensor.proto.ReturnCode.UnknownException - message = 'exception in preprocessing forward call with error: {}'.format(e) + # ================================================================== + # ==== Function which returns true if all codes are non success ==== + # ================================================================== + def check_if_should_return() -> bool: + for code in synapse_codes: + if code == bittensor.proto.ReturnCode.Success: + return False + return True + + + # ============================================================== + # ==== Function which prints all log statements per synapse ==== + # ============================================================== + def finalize_codes_stats_and_logs(): + for index, synapse in enumerate( synapses ): + request.synapses [ index ].return_code = synapse_codes[ index ] # Set synapse wire proto codes. + request.synapses [ index ].message = synapse_messages[ index ] # Set synapse wire proto message + bittensor.logging.rpc_log ( + axon = True, + forward = True, + is_response = synapse_is_response [index], + code = synapse_codes[ index ], + call_time = synapse_call_times[ index ], + pubkey = request.hotkey, + inputs = synapse_inputs [index] , + outputs = None if synapse_responses[index] == None else list( synapse_responses[index].shape ), + message = synapse_messages[ index ], + synapse = synapse.synapse_type + ) + + # ====================================== + # ==== Check Empty request ==== + # ====================================== + if len(request.tensors) == 0: + code = bittensor.proto.ReturnCode.EmptyRequest + message = "Forward request contains {} tensors, expected 1 tensor in the forward call".format(len(request.tensors)) call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - - # Post process. - try: + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], code, request.synapses - # ---- Make nucleus forward call. ---- - code = bittensor.proto.ReturnCode.Success - message = None + + # ====================================== + # ==== Check request length ==== + # ====================================== + if len( request.tensors ) != len( synapses ): + # Not enough responses per request. + code = bittensor.proto.ReturnCode.RequestShapeException call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - outputs, code, message = self._call_forward( - public_key = request.hotkey, - inputs_x = torch_inputs, - modality = modality - ) - if code != bittensor.proto.ReturnCode.Success: - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - - # ---- Catch empty ---- - if outputs == None: - code = bittensor.proto.ReturnCode.EmptyResponse - message = None - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message - - # ---- Serialize response ---- + message = "Request length doesn't match synape length." + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], bittensor.proto.ReturnCode.RequestShapeException, request.synapses + + + # =================================== + # ==== Deserialize/Check inputs ==== + # =================================== + deserialized_forward_tensors = [ None for _ in synapses] + for index, synapse in enumerate( synapses ): try: - serializer = bittensor.serializer ( bittensor.proto.Serializer.MSGPACK ) - outputs_serialized = serializer.serialize ( outputs, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH ) - except Exception as e: - code = bittensor.proto.ReturnCode.ResponseDeserializationException - message = e - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message + deserialized_forward_tensors [index] = synapse.deserialize_forward_request_tensor ( request.tensors [index] ) + except ValueError as e: + synapse_codes [index] = bittensor.proto.ReturnCode.RequestShapeException + synapse_call_times [index] = clock.time() - start_time + synapse_messages [index] = 'Input shape exception with error:{}'.format(str(e)) + + except Exception as e: + synapse_codes [index] = bittensor.proto.ReturnCode.RequestDeserializationException + synapse_call_times [index] = clock.time() - start_time + synapse_messages [index] = 'Input deserialization exception with error:{}'.format(str(e)) + # Check if the call can stop here. + if check_if_should_return(): + finalize_codes_stats_and_logs() + return [], synapse_codes[0] , request.synapses + + + # =================================== + # ==== Make forward calls. ========= + # =================================== + try: + finalize_codes_stats_and_logs() + if self.priority != None: + priority = self.priority( request.hotkey, inputs_x = deserialized_forward_tensors, request_type = bittensor.proto.RequestType.FORWARD ) + future = self.priority_threadpool.submit ( + self.forward_callback, + inputs_x = deserialized_forward_tensors, + synapses = synapses, + priority = priority, + hotkey= request.hotkey + ) + forward_response_tensors, forward_codes, forward_messages = future.result( timeout= self.forward_timeout ) + else: + + forward_response_tensors, forward_codes, forward_messages = self.forward_callback( + inputs_x = deserialized_forward_tensors, + synapses = synapses, + hotkey= request.hotkey + ) + synapse_is_response = [ True for _ in synapses ] + # ======================================== + # ==== Fill codes from forward calls ==== + # ======================================== + for index, synapse in enumerate(synapses): + synapse_codes [ index ] = forward_codes [ index ] + synapse_messages [index] = forward_messages [ index ] + # ======================================== + # ==== Catch forward request timeouts ==== + # ======================================== + except concurrent.futures.TimeoutError: + if self.priority != None: + future.cancel() + code = bittensor.proto.ReturnCode.Timeout + call_time = clock.time() - start_time + message = "Request reached timeout" + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], bittensor.proto.ReturnCode.Timeout, request.synapses + + # ================================== + # ==== Catch unknown exceptions ==== + # ================================== except Exception as e: code = bittensor.proto.ReturnCode.UnknownException - message = 'exception in processing forward call: {}'.format(e) call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(torch_inputs.shape), outputs=None, message=message ) - return None, code, call_time, message + message = str ( e ) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], bittensor.proto.ReturnCode.UnknownException, request.synapses + + # ================================================= + # ==== Encode/serialize responses and synapses ==== + # ================================================== + response_synapses = [] + for index, synapse in enumerate( synapses ): + try: + if synapse_codes[index] == bittensor.proto.ReturnCode.Success: + synapse_responses [ index ] = synapse.serialize_forward_response_tensor( deserialized_forward_tensors[ index ], forward_response_tensors [ index ] ) + else: + synapse_responses [ index ] = synapse.empty() + + except ValueError as e: + if str(e) == 'Empty Response': + synapse_codes [ index ]= bittensor.proto.ReturnCode.EmptyResponse + else: + synapse_codes [ index ]= bittensor.proto.ReturnCode.ResponseShapeException + + synapse_call_times [ index ] = clock.time() - start_time + synapse_messages [index] = "Synapse response shape exception with error: {}".format( str( e ) ) + synapse_responses [ index ] = synapse.empty() + + except Exception as e: + synapse_codes [ index ]= bittensor.proto.ReturnCode.ResponseSerializationException + synapse_call_times [ index ] = clock.time() - start_time + synapse_messages [index] = "Synapse response serialization exception with error: {}".format( str( e ) ) + synapse_responses [ index ] = synapse.empty() + + response_synapses.append(synapse.serialize_to_wire_proto(code = synapse_codes[index], message= synapse_messages[index] )) - # ---- Return successful response ---- - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=True, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(list(torch_inputs.shape)), outputs=outputs_serialized.shape, message=None ) - return outputs_serialized, code, call_time, message + + # Check if the call can stop here. + if check_if_should_return(): + finalize_codes_stats_and_logs() + return [], synapse_codes[0], request.synapses + + # ========================================================= + # ==== Set return times for successfull forward =========== + # ========================================================= + for index, _ in enumerate( synapses ): + if synapse_codes[index] == bittensor.proto.ReturnCode.Success: + synapse_call_times[index] = clock.time() - start_time + + finalize_codes_stats_and_logs() + return synapse_responses, bittensor.proto.ReturnCode.Success, response_synapses def _backward(self, request): - r""" Performs validity checks on the grpc request before piping the request to backend queue. - Returns the output, message and code from the call. + r""" Performs validity checks on the grpc request before piping the request to the backend queue. + Returns the outputs and synapses (with codes and messages from the backward call.) Args: request (:obj:`bittensor.proto`, `required`): Tensor request proto. Returns: response: (:obj:`bittensor.proto.Tensor, `required`): - serialized tensor response from the nucleus call or None. - code: (:obj:`bittensor.proto.ReturnCode, `required`) - return code associated with backward call i.e. Success of Timeout. - time (:type:`float`, `required`): - Length of call in seconds. - message: (str, `required`): - message associated with backward call, potentially error, or 'success'. + serialized tensor gradient responses. This is always an empty vector until gradients are allowed. + code (:obj:`bittensor.proto.ReturnCode`, `required`): + Code from the call. This specifies if the overall function call was a success. + This is separate from the synapse returns codes which relate to the individual synapse call. + synapses (:obj:`List[ 'bittensor.proto.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Synapse wire protos with return codes from forward request. """ + + # =================================================================== + # ==== First deserialize synapse wire protos to instance objects ==== + # =================================================================== + synapses: List['bittensor.Synapse'] = [] + for synapse_wire_proto in request.synapses: + synapses.append( bittensor.synapse.deserialize( synapse_wire_proto ) ) + + # =================================== + # ==== Init params from synapses ==== + # =================================== + # These items are filled through the call and the function returns + # when all codes are non-success or the function finishes completely. + synapse_messages = [ "Success" for _ in synapses ] + synapse_codes = [ bittensor.proto.ReturnCode.Success for _ in synapses ] + deserialized_forward_tensors = [ None for _ in synapses ] + deserialized_forward_gradients = [ None for _ in synapses ] + synapse_is_response = [ False for _ in synapses ] + synapse_call_times = [ 0 for _ in synapses ] start_time = clock.time() - # ---- Check request inputs ----. - if len(request.tensors) == 2: - inputs_x = request.tensors[0] - grads_dy = request.tensors[1] - modality_x = inputs_x.modality - else: - code = bittensor.proto.ReturnCode.InvalidRequest - message = "During backward: There are {} tensors in the request, expected 2.".format(len(request.tensors)) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=code, call_time = call_time, pubkey = request.hotkey, inputs=None, outputs=None, message = message ) - return None, code, call_time, message - # ---- Deserialize request --- - try: - serializer = bittensor.serializer( inputs_x.serializer ) - inputs_x = serializer.deserialize( inputs_x, to_type = bittensor.proto.TensorType.TORCH ) - grads_dy = serializer.deserialize( grads_dy, to_type = bittensor.proto.TensorType.TORCH ) - except Exception as e: - code = bittensor.proto.ReturnCode.RequestDeserializationException - message = "Request serialization exception with error: {}".format(str(e)) + # ================================================================== + # ==== Function which returns true if all codes are non success ==== + # ================================================================== + def check_if_should_return() -> bool: + for code in synapse_codes: + if code == bittensor.proto.ReturnCode.Success: + return False + return True + + # ============================================================== + # ==== Function which prints all log statements per synapse ==== + # ============================================================== + def finalize_codes_stats_and_logs(): + for index, synapse in enumerate( synapses ): + request.synapses [ index ].return_code = synapse_codes[ index ] # Set synapse wire proto codes. + request.synapses [ index ].message = synapse_messages[ index ] # Set synapse wire proto message + bittensor.logging.rpc_log ( + axon = True, + forward = False, + is_response = synapse_is_response [index], + code = synapse_codes[ index ], + call_time = synapse_call_times[ index ], + pubkey = request.hotkey, + inputs = None if deserialized_forward_gradients[index] == None else deserialized_forward_gradients[index].shape , + outputs = None, # we never return from backward. + message = synapse_messages[ index ], + synapse = synapse.synapse_type + ) + + # ====================================== + # ==== Check Empty request ==== + # ====================================== + if len(request.tensors) == 0: + code = bittensor.proto.ReturnCode.EmptyRequest + message = "Empty Request" call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=None, outputs=None, message=message ) - return None, code, call_time, message - - # ---- Check shapes ---- - if modality_x == bittensor.proto.Modality.TEXT: - if len(inputs_x.shape) != 2: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward text input shape exception with len(request.shape) = {} must have rank 2.".format(len(inputs_x.shape)) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message - - if modality_x == bittensor.proto.Modality.IMAGE: - if len(inputs_x.shape) != 5: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward image input shape exception for len(shape) = {} must have rank 5".format(len(inputs_x.shape)) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message - - if modality_x == bittensor.proto.Modality.TENSOR: - if len(inputs_x.shape) != 3: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Forward message tensor input shape exception len(shape) = {} must have rank 3".format(len(inputs_x.shape)) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message - - if len(grads_dy.shape) != 3: - code = bittensor.proto.ReturnCode.RequestShapeException - message = "Passed gradients must have rank 3 but got {}".format(len(grads_dy.shape)) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], code, request.synapses + + # ====================================== + # ==== Check Invalid request ==== + # ====================================== + if len(request.tensors) < 2: + code = bittensor.proto.ReturnCode.InvalidRequest + message = "Backward request contains {} tensors, expected atleast 2 tensor in the backward call".format(len(request.tensors)) call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message - - if grads_dy.shape[0] != inputs_x.shape[0] or grads_dy.shape[1] != inputs_x.shape[1]: + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], code, request.synapses + + # ====================================== + # ==== Check request length ==== + # ====================================== + if len( request.tensors ) != 2 * len( synapses ): # 2 per synapse (1 input + 1 grad). + # Not enough responses per request. code = bittensor.proto.ReturnCode.RequestShapeException - message = "Passed gradients must same first and second dimension as passed inputs got shapes {} and {}".format(grads_dy.shape, inputs_x.shape) - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message - - # ---- Make nucleus backward call. ---- - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=False, code=bittensor.proto.ReturnCode.Success, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=None ) - outputs, code, message = self._call_backward( - public_key = request.hotkey, - inputs_x = inputs_x, - grads_dy = grads_dy, - modality = modality_x - ) - if code != bittensor.proto.ReturnCode.Success: call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message + message = "Request length doesn't match synape length." + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], code, request.synapses + + # =================================== + # ==== Deserialize/Decode inputs ==== + # =================================== + for index, synapse in enumerate( synapses ): + try: + deserialized_forward_tensors [index] = synapse.deserialize_forward_request_tensor ( request.tensors [index] ) + deserialized_forward_gradients [index] = synapse.deserialize_backward_request_gradient ( deserialized_forward_tensors [index], request.tensors [ len( synapses ) + index ] ) - # ---- Catch empty ---- - if outputs == None: - code = bittensor.proto.ReturnCode.EmptyResponse - message = None - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message + except ValueError as e: + synapse_codes [index] = bittensor.proto.ReturnCode.RequestShapeException + synapse_call_times [index] = clock.time() - start_time + synapse_messages [index] = 'Input shape exception with error:{}'.format(str(e)) - # ---- Deserialize response ---- + except Exception as e: + synapse_codes [index] = bittensor.proto.ReturnCode.RequestDeserializationException + synapse_call_times [index] = clock.time() - start_time + synapse_messages [index] = 'Input deserialization exception with error:{}'.format(str(e)) + # Check if the call can stop here. + if check_if_should_return(): + finalize_codes_stats_and_logs() + return [], synapse_codes[0], request.synapses + + + # =================================== + # ==== Make backward calls. ========= + # =================================== try: - serializer = bittensor.serializer( bittensor.proto.Serializer.MSGPACK ) - outputs_serialized = serializer.serialize( outputs, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH ) + finalize_codes_stats_and_logs() + synapse_is_response = [ True for _ in synapses ] + if self.priority != None: + # No wait on backward calls. + priority = self.priority( request.hotkey, inputs_x = deserialized_forward_tensors, request_type = bittensor.proto.RequestType.BACKWARD ) + self.priority_threadpool.submit( + self.backward_callback, + inputs_x = deserialized_forward_tensors, + grads_dy = deserialized_forward_gradients, + synapses = synapses, + priority = priority + ) + + else: + # Calling default + backward_response_tensors, backward_codes, backward_messages = self.backward_callback ( deserialized_forward_tensors, deserialized_forward_gradients, synapses = synapses ) + + # ======================================== + # ==== Fill codes from forward calls ==== + # ======================================== + for index, synapse in enumerate(synapses): + synapse_codes [ index ] = backward_codes [ index ] + synapse_messages [index] = backward_messages [ index ] + + # ======================================== + # ==== Catch backward request timeouts ==== + # ======================================== + except concurrent.futures.TimeoutError: + code = bittensor.proto.ReturnCode.Timeout + call_time = clock.time() - start_time + message = "Request reached timeout" + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], bittensor.proto.ReturnCode.Timeout, request.synapses + + # ================================== + # ==== Catch unknown exceptions ==== + # ================================== except Exception as e: - code = bittensor.proto.ReturnCode.ResponseSerializationException - message = "Backward request serialization failed with error {} and inputs {}".format(e, outputs) + code = bittensor.proto.ReturnCode.UnknownException call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=None, message=message ) - return None, code, call_time, message + message = str ( e ) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_codes_stats_and_logs() + return [], bittensor.proto.ReturnCode.UnknownException, request.synapses + + # Check if the call can stop here. + if check_if_should_return(): + finalize_codes_stats_and_logs() + return [], synapse_codes[0], request.synapses + + # ============================== + # ==== Finalize call times ===== + # ============================== + for index, _ in enumerate( synapses ): + if synapse_codes[index] == bittensor.proto.ReturnCode.Success: + synapse_call_times[index] = clock.time() - start_time + + finalize_codes_stats_and_logs() + return [], bittensor.proto.ReturnCode.Success, request.synapses + + def default_forward_callback(self, inputs_x:torch.FloatTensor, synapses=[], hotkey = None): + """ + The default forward callback when no callback is attached: Is used to call specific synapse functions + + Args: + inputs_x (:obj:`torch.FloatTensor`, `required`): + The inputs that will be passed to the synapse functions + + synapses (:obj: list of bittensor.proto.SynapseArgs, 'Optional') + The proto message that contains additional args for individual synapse functions + + hotkey (:obj: str of the hotkey, 'Optional') + The hotkey of the validator who sent the request + + Returns: + response_tensors: (:obj: list of bittensor.proto.Tensor, `required`): + serialized tensor response from the nucleus call or None. + response_codes: (:obj: list of bittensor.proto.ReturnCode, `required`) + return code associated with forward call i.e. Success of Timeout. + response_messages: (:obj: list of strings, `required`) + return message associated with synapse call + """ + # --- initialize response variables --- + response_tensors = [] + response_codes = [] + response_messages = [] + model_output = None + + # --- calling attached synapses --- + for index, synapse in enumerate(synapses): + try: + synapse_check = self.synapse_checks(synapse, hotkey) - # ---- Finaly return ---- - call_time = clock.time() - start_time - bittensor.logging.rpc_log( axon=True, forward=False, is_response=True, code=code, call_time = call_time, pubkey=request.hotkey, inputs=list(grads_dy.shape), outputs=list(outputs_serialized.shape), message=None ) - return outputs_serialized, code, call_time, message + if synapse.synapse_type in self.synapse_callbacks and self.synapse_callbacks[synapse.synapse_type] != None and synapse_check: + message, model_output, response_tensor = self.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse, model_output) + response_tensors.append(response_tensor) + response_codes.append(bittensor.proto.ReturnCode.Success) + response_messages.append('Success' if message is None else message) + + elif not synapse_check: + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.UnknownException) + response_messages.append('Synapse Check Failed') + + else: + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.NotImplemented) + response_messages.append('Not Implemented') + + except Exception as e: + # --- Exception Hit in Synapse --- + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.UnknownException) + response_messages.append(str(e)) + + return response_tensors, response_codes, response_messages - def attach( self, servicer:object, modality:int): + def default_backward_callback(self, inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, synapses=[] ): """ - Attaches the forward and backward callbacks to the passed object. + The default backward callback when no callback is attached: Is used to call specific synapse functions + + Args: + inputs_x (:obj:`torch.FloatTensor`, `required`): + The inputs that will be passed to the synapse functions + grads_dy (:obj:`torch.FloatTensor`, `required`): + The gradients that will be passed to the synapse functions + synapses (:obj: list of bittensor.proto.SynapseArgs, 'Optional') + The proto message that contains additional args for individual synapse functions Returns: - servicer (:object:`object`, `required`): - object with callbacks servicer.forward and servicer.backward + response_tensors: (:obj: list of bittensor.proto.Tensor, `required`): + serialized tensor response from the nucleus call or None. + response_codes: (:obj: list of bittensor.proto.ReturnCode, `required`) + return code associated with forward call i.e. Success of Timeout. + response_messages: (:obj: list of strings, `required`) + return message associated with synapse call """ - self.attach_forward_callback( servicer.forward , modality) - self.attach_backward_callback( servicer.backward , modality) + # --- initialize response variables --- + response_tensors = [] + response_codes = [] + response_messages = [] + + # --- calling attached synapses --- + with torch.enable_grad() and torch.autograd.set_detect_anomaly(True): + for index, synapse in enumerate(synapses): + try: + if synapse.synapse_type in self.synapse_callbacks and self.synapse_callbacks[synapse.synapse_type] != None: + message, model_output, response_tensor = self.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse) + torch.autograd.backward ( + tensors = [ response_tensor ], + grad_tensors = [ grads_dy[index] ], + retain_graph=True + ) + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.Success) + response_messages.append('Success') + else: + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.NotImplemented) + response_messages.append('Not Implemented') + except Exception as e: + # --- Exception Hit in Synapse --- + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.UnknownException) + response_messages.append(str(e)) + + if self.optimizer_step != None: + self.optimizer_step() + + return response_tensors, response_codes, response_messages - def attach_forward_callback(self, forward_callback: Callable[ [str, torch.Tensor, int], torch.Tensor ] , modality: int): + def attach_forward_callback(self, forward_callback: Callable[ [str, torch.Tensor, int], torch.Tensor ]): """ Assigns the forward_callback. Returns: forward_callback (:callabl:`Callable[ [str, torch.Tensor, int], torch.Tensor `, `required`): Forward function called on recieving a forward request. """ - bittensor.axon.check_forward_callback(forward_callback,modality) - self.forward_callback[modality] = forward_callback + bittensor.axon.check_forward_callback(forward_callback) + self.forward_callback = forward_callback - def attach_backward_callback(self, backward_callback: Callable[ [str, torch.Tensor, torch.Tensor, int], torch.Tensor ], modality: int ): + def attach_synapse_callback(self, synapse_callback: Callable[[str, torch.Tensor, int],torch.Tensor], synapse_type ): + """ Assigns the callback to a specific synapse. + + Args: + synapse_callback (:callabl:`Callable[ [str, torch.Tensor, int], torch.Tensor `, `required`): + function called for a specific synapse. + """ + self.synapse_callbacks[synapse_type] = synapse_callback + + def attach_backward_callback(self, backward_callback: Callable[ [str, torch.Tensor, torch.Tensor, int], torch.Tensor ] ): """ Assigns the backward_callback call to this neuron. Returns: backward_callback (:callabl:`Callable[ [torch.Tensor, torch.Tensor], torch.Tensor `, `required`): Backward callback called on recieving a backward request. """ - bittensor.axon.check_backward_callback(backward_callback,modality) - self.backward_callback[modality] = backward_callback + bittensor.axon.check_backward_callback(backward_callback) + self.backward_callback = backward_callback def __del__(self): r""" Called when this axon is deleted, ensures background threads shut down properly. @@ -592,7 +741,6 @@ def serve( raise RuntimeError('Failed to serve neuron.') return self - def start(self) -> 'Axon': r""" Starts the standalone axon GRPC server thread. """ @@ -613,35 +761,16 @@ def stop(self) -> 'Axon': logger.success("Axon Stopped:".ljust(20) + "{}", self.ip + ':' + str(self.port)) self.started = False return self - - def find_modality(self): - r""" Detects modality from forward callbacks - """ - modality_list= [index for index, v in enumerate(self.forward_callback) if v != None] - - if len(modality_list) > 1: - raise NotImplementedError('Multiple modality detected. We do not currently support multi-modality miners.') - elif len(modality_list) == 1: - if modality_list[0] == 0: - return bittensor.proto.Modality.TEXT - if modality_list[0] == 1: - return bittensor.proto.Modality.IMAGE - if modality_list[0] == 2: - return bittensor.proto.Modality.TENSOR - elif len(modality_list) == 0: - logger.warning('No modality detected. Defaulting to the text modality') - return bittensor.proto.Modality.TEXT def check(self): r""" Checks axon's forward and backward callbacks """ pubkey = self.wallet.hotkey.ss58_address - for index,forward in enumerate(self.forward_callback): - if forward != None: - bittensor.axon.check_forward_callback(forward,index,pubkey) - for index, backward in enumerate(self.backward_callback): - if backward != None: - bittensor.axon.check_backward_callback(backward,index,pubkey) + if self.forward_callback != None: + bittensor.axon.check_forward_callback(self.forward_callback,index,pubkey) + + if self.backward_callback != None: + bittensor.axon.check_backward_callback(backward,index,pubkey) return self def _init_stats(self): @@ -674,6 +803,7 @@ def _init_stats(self): avg_out_bytes_per_pubkey = {} ) + #TODO: Replace/update axon and dendrite stats def update_stats_for_request(self, request, response, time, code): r""" Updates statistics for this request and response. Args: diff --git a/bittensor/_cli/__init__.py b/bittensor/_cli/__init__.py index 3bb2b131b3..e74946b58e 100644 --- a/bittensor/_cli/__init__.py +++ b/bittensor/_cli/__init__.py @@ -63,7 +63,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) overview_parser.add_argument( @@ -100,7 +100,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) run_parser.add_argument( @@ -110,6 +110,15 @@ def config() -> 'bittensor.config': default='None', help='''Miners available through bittensor.neurons''' ) + + run_parser.add_argument( + '--synapse', + type=str, + choices= list(bittensor.synapse.__synapses_types__), + default='None', + help='''Synapses available through bittensor.synapse''' + ) + bittensor.subtensor.add_args( run_parser ) bittensor.wallet.add_args( run_parser ) @@ -121,7 +130,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.subtensor.add_args( metagraph_parser ) @@ -160,7 +169,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( inspect_parser ) @@ -182,7 +191,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( query_parser ) @@ -198,7 +207,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( weights_parser ) @@ -212,7 +221,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) set_weights_parser.add_argument ("--uids", type=int, required=False, nargs='*', action='store', help="Uids to set.") @@ -228,7 +237,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( list_parser ) @@ -253,6 +262,10 @@ def config() -> 'bittensor.config': 'regen_coldkey', help='''Regenerates a coldkey from a passed value''' ) + regen_coldkeypub_parser = cmd_parsers.add_parser( + 'regen_coldkeypub', + help='''Regenerates a coldkeypub from the public part of the coldkey.''' + ) regen_hotkey_parser = cmd_parsers.add_parser( 'regen_hotkey', help='''Regenerates a hotkey from a passed mnemonic''' @@ -296,7 +309,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) regen_coldkey_parser.add_argument( @@ -308,6 +321,41 @@ def config() -> 'bittensor.config': bittensor.wallet.add_args( regen_coldkey_parser ) + regen_coldkeypub_parser.add_argument( + "--public_key", + "--pubkey", + dest="public_key_hex", + required=False, + default=None, + type=str, + help='The public key (in hex) of the coldkey to regen e.g. 0x1234 ...' + ) + regen_coldkeypub_parser.add_argument( + "--ss58_address", + "--addr", + "--ss58", + dest="ss58_address", + required=False, + default=None, + type=str, + help='The ss58 address of the coldkey to regen e.g. 5ABCD ...' + ) + regen_coldkeypub_parser.add_argument( + '--no_prompt', + dest='no_prompt', + action='store_true', + help='''Set true to avoid prompting the user.''', + default=False, + ) + regen_coldkeypub_parser.add_argument( + '--overwrite_coldkeypub', + default=False, + action='store_true', + help='''Overwrite the old coldkeypub file with the newly generated coldkeypub''' + ) + bittensor.wallet.add_args( regen_coldkeypub_parser ) + + # Fill arguments for the regen hotkey command. regen_hotkey_parser.add_argument( "--mnemonic", @@ -332,7 +380,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) regen_hotkey_parser.add_argument( @@ -370,7 +418,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) new_coldkey_parser.add_argument( @@ -408,7 +456,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) new_hotkey_parser.add_argument( @@ -446,7 +494,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( unstake_parser ) @@ -484,7 +532,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( stake_parser ) @@ -508,7 +556,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( transfer_parser ) @@ -520,7 +568,7 @@ def config() -> 'bittensor.config': '--no_prompt', dest='no_prompt', action='store_true', - help='''Set true to protect the generated bittensor key with a password.''', + help='''Set true to avoid prompting the user.''', default=False, ) bittensor.wallet.add_args( register_parser ) @@ -530,7 +578,7 @@ def config() -> 'bittensor.config': run_parser.add_argument( '--path', dest="path", - default=os.path.expanduser('miners/text/template_miner.py'), + default=os.path.expanduser('miners/text/core_server.py'), type=str, required=False ) @@ -559,6 +607,8 @@ def check_config (config: 'bittensor.Config'): cli.check_new_hotkey_config( config ) elif config.command == "regen_coldkey": cli.check_regen_coldkey_config( config ) + elif config.command == "regen_coldkeypub": + cli.check_regen_coldkeypub_config( config ) elif config.command == "regen_hotkey": cli.check_regen_hotkey_config( config ) elif config.command == "metagraph": @@ -609,7 +659,7 @@ def check_transfer_config( config: 'bittensor.Config'): # Get destination. if not config.dest: dest = Prompt.ask("Enter destination public key: (ss58 or ed2519)") - if not bittensor.utils.is_valid_destination_address( dest ): + if not bittensor.utils.is_valid_bittensor_address_or_public_key( dest ): sys.exit() else: config.dest = str(dest) @@ -789,6 +839,19 @@ def check_regen_coldkey_config( config: 'bittensor.Config' ): else: config.mnemonic = prompt_answer + def check_regen_coldkeypub_config( config: 'bittensor.Config' ): + if config.wallet.get('name') == bittensor.defaults.wallet.name and not config.no_prompt: + wallet_name = Prompt.ask("Enter wallet name", default = bittensor.defaults.wallet.name) + config.wallet.name = str(wallet_name) + if config.ss58_address == None and config.public_key_hex == None: + prompt_answer = Prompt.ask("Enter the ss58_address or the public key in hex") + if prompt_answer.startswith("0x"): + config.public_key_hex = prompt_answer + else: + config.ss58_address = prompt_answer + if not bittensor.utils.is_valid_bittensor_address_or_public_key(address = config.ss58_address if config.ss58_address else config.public_key_hex): + sys.exit(1) + def check_run_config( config: 'bittensor.Config' ): # Check network. @@ -806,12 +869,16 @@ def check_run_config( config: 'bittensor.Config' ): # Check Miner if config.model == 'None' and not config.no_prompt: - model = Prompt.ask('Enter miner name', choices = list(bittensor.neurons.__text_neurons__.keys()), default = 'template_miner') + model = Prompt.ask('Enter miner name', choices = list(bittensor.neurons.__text_neurons__.keys()), default = 'core_server') config.model = model + + if 'server' in config.model and not config.no_prompt: + synapse = Prompt.ask('Enter synapse', choices = list(bittensor.synapse.__synapses_types__), default = 'All') + config.synapse = synapse def check_help_config( config: 'bittensor.Config'): if config.model == 'None': - model = Prompt.ask('Enter miner name', choices = list(bittensor.neurons.__text_neurons__.keys()), default = 'template_miner') + model = Prompt.ask('Enter miner name', choices = list(bittensor.neurons.__text_neurons__.keys()), default = 'core_server') config.model = model def check_update_config( config: 'bittensor.Config'): diff --git a/bittensor/_cli/cli_impl.py b/bittensor/_cli/cli_impl.py index 04bfe6c0a4..cec4734fd4 100644 --- a/bittensor/_cli/cli_impl.py +++ b/bittensor/_cli/cli_impl.py @@ -67,6 +67,8 @@ def run ( self ): self.create_new_hotkey() elif self.config.command == "regen_coldkey": self.regen_coldkey() + elif self.config.command == "regen_coldkeypub": + self.regen_coldkeypub() elif self.config.command == "regen_hotkey": self.regen_hotkey() elif self.config.command == "metagraph": @@ -102,6 +104,12 @@ def regen_coldkey ( self ): wallet = bittensor.wallet(config = self.config) wallet.regenerate_coldkey( mnemonic = self.config.mnemonic, seed = self.config.seed, use_password = self.config.use_password, overwrite = self.config.overwrite_coldkey ) + def regen_coldkeypub ( self ): + r""" Creates a new coldkeypub under this wallet. + """ + wallet = bittensor.wallet(config = self.config) + wallet.regenerate_coldkeypub( ss58_address=self.config.get('ss58_address'), public_key=self.config.get('public_key_hex'), overwrite = self.config.overwrite_coldkeypub ) + def regen_hotkey ( self ): r""" Creates a new coldkey under this wallet. """ @@ -154,8 +162,9 @@ def inspect ( self ): registered = '[bold white]Yes[/bold white]' stake = bittensor.Balance.from_tao( neuron.stake ) emission = bittensor.Balance.from_rao( neuron.emission * 1000000000 ) - _, c, t = dendrite.forward_text( endpoints = endpoint, inputs = 'hello world') - latency = "{}".format(t.tolist()[0]) if c.tolist()[0] == 1 else 'N/A' + synapses = [bittensor.synapse.TextLastHiddenState()] + _, c, t = dendrite.text( endpoints = endpoint, inputs = 'hello world', synapses=synapses) + latency = "{}".format((t[0]).tolist()[0]) if (c[0]).tolist()[0] == 1 else 'N/A' cold_balance = wallet.get_balance( subtensor = subtensor ) bittensor.__console__.print("\n[bold white]{}[/bold white]:\n [bold grey]{}[bold white]{}[/bold white]\n {}[bold white]{}[/bold white]\n {}{}\n {}{}\n {}{}\n {}{}\n {}{}[/bold grey]".format( wallet, "coldkey:".ljust(15), wallet.coldkeypub.ss58_address, "hotkey:".ljust(15), wallet.hotkey.ss58_address, "registered:".ljust(15), registered, "balance:".ljust(15), cold_balance.__rich__(), "stake:".ljust(15), stake.__rich__(), "emission:".ljust(15), emission.__rich_rao__(), "latency:".ljust(15), latency ), highlight=True) @@ -190,14 +199,19 @@ def run_miner ( self ): self.register() # Run miner. - if self.config.model == 'template_miner': - bittensor.neurons.template_miner.neuron().run() - elif self.config.model == 'template_server': - bittensor.neurons.template_server.neuron().run() + if self.config.model == 'core_server': + + if self.config.synapse == 'TextLastHiddenState': + bittensor.neurons.core_server.neuron(lasthidden=True, causallm=False, seq2seq = False).run() + elif self.config.synapse == 'TextCausalLM': + bittensor.neurons.core_server.neuron(lasthidden=False, causallm=True, seq2seq = False).run() + elif self.config.synapse == 'TextSeq2Seq': + bittensor.neurons.core_server.neuron(lasthidden=False, causallm=False, seq2seq = True).run() + else: + bittensor.neurons.core_server.neuron().run() + elif self.config.model == 'core_validator': bittensor.neurons.core_validator.neuron().run() - elif self.config.model == 'advanced_server': - bittensor.neurons.advanced_server.neuron().run() elif self.config.model == 'multitron_server': bittensor.neurons.multitron_server.neuron().run() @@ -207,14 +221,10 @@ def help ( self ): sys.argv = [sys.argv[0], '--help'] # Run miner. - if self.config.model == 'template_miner': - bittensor.neurons.template_miner.neuron().run() - elif self.config.model == 'template_server': - bittensor.neurons.template_server.neuron().run() + if self.config.model == 'core_server': + bittensor.neurons.core_server.neuron().run() elif self.config.model == 'core_validator': bittensor.neurons.core_validator.neuron().run() - elif self.config.model == 'advanced_server': - bittensor.neurons.advanced_server.neuron().run() elif self.config.model == 'multitron_server': bittensor.neurons.multitron_server.neuron().run() diff --git a/bittensor/_dataset/dataset_impl.py b/bittensor/_dataset/dataset_impl.py index 44996eaf73..f86f0c7aff 100644 --- a/bittensor/_dataset/dataset_impl.py +++ b/bittensor/_dataset/dataset_impl.py @@ -459,10 +459,10 @@ def construct_text_corpus(self, min_data_len = 0): Contents of the text data. """ self.IPFS_fails = 0 + data_corpus = [] try: # --- Get directories from a random dataset_hash directories = list(self.get_hashes_from_dataset()) - data_corpus = [] # --- Generate a random order of the directories random.shuffle(directories) diff --git a/bittensor/_dataset/dataset_mock.py b/bittensor/_dataset/dataset_mock.py index d245375063..1cf2d0cf6d 100644 --- a/bittensor/_dataset/dataset_mock.py +++ b/bittensor/_dataset/dataset_mock.py @@ -136,7 +136,7 @@ def __getitem__(self, idx): start_idx = (idx * self.block_size) % len(self.data) end_idx = start_idx + self.block_size if self.no_tokenizer == False: - tokenized_text = torch.tensor(self.tokenizer(" ".join(self.data[start_idx:end_idx]), padding=True, truncation=True)['input_ids'], dtype=torch.long) + tokenized_text = torch.tensor(self.tokenizer(" ".join(self.data[start_idx:end_idx]), truncation=True)['input_ids'], dtype=torch.long) elif self.no_tokenizer == True: tokenized_text = " ".join(self.data[start_idx:end_idx]) diff --git a/bittensor/_dataset/thread_queue.py b/bittensor/_dataset/thread_queue.py index 9b3e0101b8..750412f965 100644 --- a/bittensor/_dataset/thread_queue.py +++ b/bittensor/_dataset/thread_queue.py @@ -52,7 +52,7 @@ def run(self): if not self.queue.full(): item = self.target(*self.arg, self.queue.qsize()+1 ) self.queue.put(item) - time.sleep(10) + time.sleep(2) return def stop(self): diff --git a/bittensor/_dendrite/dendrite_impl.py b/bittensor/_dendrite/dendrite_impl.py index 393dd6986a..0b05b931c4 100644 --- a/bittensor/_dendrite/dendrite_impl.py +++ b/bittensor/_dendrite/dendrite_impl.py @@ -28,9 +28,12 @@ from torch.autograd.function import once_differentiable from loguru import logger from transformers.utils.logging import enable_explicit_format +from yaml import serialize import bittensor from bittensor._endpoint.endpoint_impl import Endpoint +from bittensor._serializer import serializer, serializer_impl +from bittensor._synapse import TextCausalLM, synapse import bittensor.utils.stats as stat_utils import bittensor.utils.codes as codes @@ -42,15 +45,6 @@ DUMMY = torch.empty(0, requires_grad=True) -# Helper function for filling nill (zero) responses on failures. -def nill_response_for(inputs): - """ Get zero matrix with the same size as inputs - """ - if torch.numel(inputs) == 0: - return torch.tensor([]) - return torch.zeros((inputs.size(0), inputs.size(1), bittensor.__network_dim__), dtype=torch.float32) - - class Dendrite(torch.autograd.Function): r""" This is the implementation class for a bittensor.dendrite(). The dendrite class operates as a normal torch autograd friendly operation which accepts a list of bittensor.endpoints and a list of torch tensors. The passed endpoints are queried with the passed inputs and either return @@ -108,7 +102,7 @@ def forward( dendrite: 'bittensor.Dendrite', dummy: torch.Tensor, endpoints: List['bittensor.Endpoint'], - modality: bittensor.proto.Modality, + synapses: List[ 'bittensor.Synapse' ], timeout: int, requires_grad: bool, *inputs: torch.Tensor @@ -129,11 +123,9 @@ def forward( endpoints (:obj:`List[bittensor.Endpoint']` of shape :obj:`(n_endpoints)`, `required`): List of endpoints which match length of inputs. Inputs are sent forward to these endpoints. - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality or type ENUM [TEXT, IMAGE, TENSOR] - - inputs (:obj:`List[torch.Tensor]` of shape :obj:`(n_endpoints)`, `required`): - List of torch tensors to be sent to the associated endpoints. + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. timeout (int): request timeout. @@ -141,6 +133,9 @@ def forward( requires_grad (int, default = dendrite.requires_grad, `optional`): If true, the backward pass triggers passing gradients on the wire. + inputs (:obj:`List[torch.Tensor]` of shape :obj:`(n_endpoints)`, `required`): + List of torch tensors to be sent to the associated endpoints. + Returns: codes (:obj:`torch.LongTensor` of shape :obj:`(n_endpoints)` `required`): Return code associated with forward call. @@ -148,23 +143,38 @@ def forward( times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): times per call. - outputs (:obj:`List[torch.FloatTensor]` of shape :obj:`n_endpoints * (batch_size, sequence_len, bittensor.__network_dim__)`, `required`): - Output encodings of inputs produced by the remote endpoints. Non-responses are zeroes of common shape. + outputs (:obj:`List[torch.FloatTensor` of shape :obj:`num_synapses * n_endpoints * (-1, -1, -1) `, `required`): + List of outputs from each synapses and each endpoint unfolded into a single list. Non-responses are zeroes of expected shape. """ ctx.receptor_pool = dendrite.receptor_pool - ctx.endpoints, ctx.inputs, ctx.modality, ctx.timeout, ctx.does_requires_grad = endpoints, inputs, modality, timeout, requires_grad - inputs = [x.cpu().clone().detach() for x in inputs] - forward_outputs, forward_codes, forward_times = ctx.receptor_pool.forward( - endpoints=endpoints, - inputs=inputs, - modality=modality, - timeout=timeout + ctx.endpoints, ctx.synapses, ctx.inputs, ctx.timeout, ctx.does_requires_grad = endpoints, synapses, inputs, timeout, requires_grad + inputs:List[torch.Tensor] = [x.cpu().clone().detach() for x in inputs] + + # Ouputs are list of lists where the outer list corresponds to the endpoints and the + # inner list corresponds to the synapses. + forward_outputs, forward_codes, forward_times = ctx.receptor_pool.forward ( + endpoints = endpoints, + synapses = synapses, + inputs = inputs, + timeout = timeout, ) ctx.forward_codes = forward_codes - forward_times = [-1 if t is None else t for t in forward_times] - return (torch.tensor(forward_codes, dtype=torch.int64), - torch.tensor(forward_times, dtype=torch.float32), - *forward_outputs) + + # We need to flatten the outputs across the synapse dimension. + def flatten(t): + return [item for sublist in t for item in sublist] + # flattened items now have length num_endpoints * num_synapses + # where endpoint i's jth outputs is at position (num_synapses * i ) + j + flattened_forward_codes: List[ bittensor.proto.ReturnCode ] = flatten( forward_codes ) + flattened_forward_times: List[float] = flatten( forward_times ) + flattened_forward_outputs: List[torch.Tensor] = flatten( forward_outputs ) + + # We will pack all the codes and times into a single tensor + flattened_torch_codes: torch.LongTensor = torch.tensor(flattened_forward_codes, dtype=torch.int64) + flattened_torch_times: torch.FloatTensor = torch.tensor(flattened_forward_times, dtype=torch.float32) + + # Return all outputs as a tuple of torch tensors of length 2 + (num_endpoints * num_synapses) + return (flattened_torch_codes, flattened_torch_times, *flattened_forward_outputs) @staticmethod @once_differentiable @@ -188,6 +198,7 @@ def backward( grads (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): Gradients of this function's outputs computed during the loss.backward() call. + This is a list item of size num_endpoints * num_synapses. Returns: DUMMY, None, None, None, @@ -195,278 +206,318 @@ def backward( Gradient results for each input. """ + # output_grads is a list of gradients per synapse. They need to be packed (unflattened) + # into a list of lists. + packed_grads: List[ List [ torch.FloatTensor ] ] = [ output_grads[ s : s + len(ctx.synapses) ] for s in range (0, len(output_grads), len( ctx.synapses )) ] if ctx.does_requires_grad: - grads_cpu = [x.cpu().clone().detach() for x in output_grads] input_grads, _, _ = ctx.receptor_pool.backward( - endpoints=ctx.endpoints, - inputs_x=ctx.inputs, - grads_dy=grads_cpu, - modality=ctx.modality, - timeout=ctx.timeout, + endpoints = ctx.endpoints, + inputs = ctx.inputs, + synapses = ctx.synapses, + grads = packed_grads, + timeout = ctx.timeout, ) - return (None, None, None, None, None, None, *input_grads) + # Input grads is a list of lists + # We need to flatten the outputs across the synapse dimension. + def flatten(t): + return [item for sublist in t for item in sublist] + flattened_input_grads: List[torch.FloatTensor] = flatten( input_grads ) + return (None, None, None, None, None, None, *flattened_input_grads) else: - input_grads = [nill_response_for(inp) for inp in ctx.inputs] + # Create nill responses for each input and each synapse. + input_grads = [ syn.nill_backward_response_tensor ( inp ) for inp in ctx.inputs for syn in ctx.synapses ] return (None, None, None, None, None, None, *input_grads) def _forward( self, - endpoints: List['bittensor.Endpoint'], - inputs: List[torch.Tensor], - modality: bittensor.proto.Modality, - timeout: int = None, - requires_grad: bool = None - ) -> Tuple[List[torch.Tensor], torch.LongTensor, torch.FloatTensor]: + endpoints: List [ 'bittensor.Endpoint' ], + synapses: List[ 'bittensor.Synapse' ], + inputs: List [ torch.Tensor ], + timeout: Optional [ int ] = None, + requires_grad: Optional [ bool ] = None, + ) -> Tuple [ List[ torch.Tensor ], List[ torch.LongTensor ], List [ torch.FloatTensor ]]: r""" Internal Forward tensor inputs to a list of neuron endpoints. Args: endpoints (:obj:`List[bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): List of remote endpoints which match length of inputs. Tensors from inputs are sent forward to these endpoints. + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. + inputs (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): List of tensors to send to corresponding endpoints. Tensors are of arbitrary type and shape depending on the - modality. - - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] + synapse. - timeout (int, default = dendrite.timeout, `required`): + timeout (int, default = dendrite.timeout, `optional`): request timeout. requires_grad (int, default = dendrite.requires_grad, `optional`): If true, the backward pass triggers passing gradients on the wire. Returns: - responses (:obj:`List[torch.FloatTensor]` of shape :obj:`(batch_size, sequence_len, bittensor.__network_dim__)`, `required`): + outputs (:obj:`List[torch.FloatTensor]` of shape :obj:`(batch_size, sequence_len, bittensor.__network_dim__)`, `required`): Output encodings of inputs produced by the remote endpoints. Non-responses are zeroes of common shape. codes (:obj:`List[torch.LongTensor]` of shape :obj:`[num_endpoints]`, `required`): - dendrite call return codes. + Return codes per endpoint per synapse. times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): - times per call. + Call times per endpoint per synapse. """ - timeout = timeout if timeout is not None else self.config.dendrite.timeout - requires_grad = requires_grad if requires_grad is not None else self.config.dendrite.requires_grad - forward_response = Dendrite.apply( + timeout:int = timeout if timeout is not None else self.config.dendrite.timeout + requires_grad:bool = requires_grad if requires_grad is not None else self.config.dendrite.requires_grad + + # The forwarnd response is a tuple with shape (flattened_torch_codes, flattened_torch_times, *flattened_forward_outputs) + # packed with torch tensors of length 2 + (num_endpoints * num_synapses). The first two tensors are codes and times + # the last (num_endpoints * num_synapses) tensors are per endpoint per synapse output tensors. + forward_response: List[torch.Tensor] = Dendrite.apply ( self, DUMMY, endpoints, - modality, + synapses, timeout, requires_grad, *inputs ) - codes = forward_response[0] - times = forward_response[1] - responses = forward_response[2:] - return responses, codes, times - def forward_image( - self, - endpoints: Union[List['bittensor.Endpoint'], 'bittensor.Endpoint'], - inputs: List[torch.FloatTensor], - timeout: int = None, - requires_grad: bool = None - ) -> Tuple[Union[List[torch.FloatTensor], torch.FloatTensor], torch.LongTensor, torch.FloatTensor]: - r""" Forward image inputs to endpoints. + # Split codes into num_synapse lists of codes + # split_codes is a list of tensors codes each with length num_synapses + codes: torch.LongTensor = forward_response[0] + packed_codes: List[torch.LongTensor] = torch.split( codes, len( synapses ) ) + + # Split times into num_synapse lists of codes + # split_times is a list of tensors times each with length num_synapses + times: torch.FloatTensor = forward_response[1] + packed_times: List[torch.FloatTensor] = torch.split( times, len( synapses ) ) + + # Output responses is a list with length num_endpoints num_synapses + # we need to pack the responses into a list of lists corresponding to + # each endpoint. + outputs: List[torch.Tensor] = forward_response[2:] + packed_outputs: List[ List[torch.Tensor] ] = [ outputs[ s : s + len(synapses) ] for s in range (0, len(outputs), len( synapses )) ] + + return packed_outputs, packed_codes, packed_times + + def text ( + self, + endpoints: Union[ torch.LongTensor, List[torch.LongTensor], List['bittensor.Endpoint'], 'bittensor.Endpoint' ], + synapses: List[ 'bittensor.Synapse' ], + inputs: Union[str, List[str], List[torch.LongTensor], torch.LongTensor], + timeout: int = None, + requires_grad: bool = None, + ) -> Tuple[ Union[List[torch.FloatTensor], torch.FloatTensor], torch.LongTensor, torch.FloatTensor]: + r""" Forward text inputs to a list of neuron endpoints and returns logit encodings or timeout. - Args: - endpoints (:obj:`Union[List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): - List or single of endpoints which match the length of inputs. Inputs are sent forward to these endpoints. + Args: + endpoints (:obj:`Union[torch.LongTensor, List[torch.LongTensor], List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): + Endpoints to send inputs to. Endpoint can be one of the following types: + - a single endpoint tensor shape [250] + - a set of endpoint tensors shape [n, 250] + - a list of endpoints tensors each of shape [250] + - a single endpoint object. Inputs will be sent to this endpoint alone. + - a list of endpoint objects. All inputs will be sent to these endpoints. - inputs (:obj:`Union[List[torch.FloatTensor], torch.FloatTensor]` of shape :obj:`(num_endpoints * [ batch_size, sequence_len, channels, rows, cols ])`, `required`): - List or single of image-tensors to send to corresponding endpoints. Tensors are images encoded using the - torch.toTensor() or other encoding which produces the shape [batch_size, channels, rows, cols]. + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. - timeout (int, default = dendrite.timeout `optional`): - Request timeout. + inputs (:obj:`Union[str, List[str], List[torch.LongTensor], torch.LongTensor]` of shape :obj:`(num_endpoints * [batch_size, sequence_len])`, `required`): + Tokenized sentences to send on the wire. Inputs can be one of the following types: + - a single string: the string will be tokenized using the bittensor tokenizer. + - a list of strings: the strings will be tokenized using the bittensor tokenizer. + - a tensor with shape [batch_size, sequence_len], assumed to be the output of bittensor tokenizer. + - a tensor with shape [n, batch_size, sequence_len], the operation will unbind the tensor and pass inputs to endpoints. + - a list of tensors of type long each representing a tokenized sentence to be sent to each endpoint. + If inputs are tensors they will be cast to int64 format before sending on the wire. - requires_grad (int, default = dendrite.requires_grad, `optional`): - If true, the backward pass triggers passing gradients on the wire. + timeout (:type:`int`, default = dendrite.timeout `optional`): + Request timeout. Queries that do not respond will be replaced by zeros. - Returns: - responses (:obj:`Union[ List[torch.FloatTensor], torch.FloatTensor] ` of shape :obj:`(batch_size, sequence_len, bittensor.__network_dim__)`, `required`): - Output encodings of inputs produced by remote endpoints. Non-responses are zeroes of input shape plus output dimension. + requires_grad (:type:`int`, default = dendrite.requires_grad, `optional`): + If true, the backward pass triggers passing gradients on the wire. - codes (:obj:`torch.LongTensor` of shape :obj:`[ num_endpoints ]`, `required`): - dendrite call return ops. + Returns: + outputs (:obj:`List[ List[ torch.FloatTensor ] ]` of shape :obj:`num_synapses * ( num_endpoints * ( -1, -1, -1 ) )`, `required`): + List of outputs from synapses, each a list of size num_endpoints of tensors with relevant size. Non-responses are zeroes of relevant + synapse shape. + + codes (:obj:`List [ torch.LongTensor ]` of shape :obj:`[ num_endpoints ]`, `required`): + Return code per call per synapse. + + times (:obj:`List [ torch.FloatTensor ]` of shape :obj:`[ num_endpoints ]`, `required`): + Times per call per synapse. + """ + formatted_endpoints, formatted_inputs = self.format_text_inputs ( + endpoints = endpoints, + inputs = inputs + ) + outputs, codes, times = self._forward ( + endpoints = formatted_endpoints, + synapses = synapses, + inputs = formatted_inputs, + timeout = timeout, + requires_grad = requires_grad, + ) + # Return. + self.update_stats( formatted_endpoints, synapses, formatted_inputs, outputs, codes, times ) + return outputs, codes, times + + def text_causal_lm ( + self, + endpoints: Union [ torch.LongTensor, List [ torch.LongTensor ], List[ 'bittensor.Endpoint' ], 'bittensor.Endpoint' ], + inputs: Union [ str, List[ str ], List [ torch.LongTensor ], torch.LongTensor], + synapse: Optional[ 'bittensor.synapse.TextCausalLM' ] = synapse.TextCausalLM(), + timeout: Optional [ int ] = None, + requires_grad: Optional [ bool ] = None, + ) -> Tuple[Union[List[torch.FloatTensor], torch.FloatTensor], torch.LongTensor, torch.FloatTensor]: + r""" Forward text inputs to a list of neuron endpoints and returns logit encodings or timeout. - times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): - times per call. - """ - # Check types. - if not isinstance(endpoints, list) and not isinstance(endpoints, Endpoint): - raise ValueError('endpoints must be of type list or bittensor.Endpoint. Got {}'.format(type(endpoints))) - - if not isinstance(inputs, list) and not isinstance(inputs, torch.FloatTensor): - raise ValueError( - 'inputs must be of type list[torch.FloatTensor] or torch.FloatTensor. Got {}'.format(type(inputs))) - - # Format to list. - non_list_inputs = False - if not isinstance(inputs, list): - non_list_inputs = True - inputs = [inputs] - - # Format to list. - if not isinstance(endpoints, list): - endpoints = [endpoints] - - # Catch inputs != List and endpoints == List - elif non_list_inputs and isinstance(endpoints, list): - raise ValueError( - 'endpoints and inputs must be of same type. Got endpoints {} and inputs {} '.format(type(endpoints), - type(inputs[0]))) - - # Check length. - if len(inputs) < 1: - raise ValueError('inputs list must have at least one element. Got len {}'.format(len(inputs))) - if len(endpoints) < 1: - raise ValueError('endpoints list must have at least one item. Got len {}'.format(len(endpoints))) - if len(inputs) != len(endpoints): - error_msg = 'List of tensor inputs should have the same length as passed destination endpoints, got {} and {}'.format( - len(inputs), len(endpoints)) - raise ValueError(error_msg) + Args: + endpoints (:obj:`Union[torch.LongTensor, List[torch.LongTensor], List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): + Endpoints to send inputs to. Endpoint can be one of the following types: + - a single endpoint tensor shape [250] + - a set of endpoint tensors shape [n, 250] + - a list of endpoints tensors each of shape [250] + - a single endpoint object. Inputs will be sent to this endpoint alone. + - a list of endpoint objects. All inputs will be sent to these endpoints. - # Check list types. - if not isinstance(inputs[0], torch.FloatTensor): - raise ValueError('inputs must be of type torch.FloatTensor. Got {}'.format(type(inputs[0]))) - if not isinstance(endpoints[0], Endpoint): - raise ValueError('endpoints must be of type bittensor.Endpoint. Got {}'.format(type(endpoints))) - # Check shape. - if len(inputs[0].shape) != 5: - error_msg = 'Image inputs should be rank 5 with semantic shape: [batch_size, sequence_len, channels, rows, cols], got {}'.format( - inputs[0].shape) - raise ValueError(error_msg) + inputs (:obj:`Union[str, List[str], List[torch.LongTensor], torch.LongTensor]` of shape :obj:`(num_endpoints * [batch_size, sequence_len])`, `required`): + Tokenized sentences to send on the wire. Inputs can be one of the following types: + - a single string: the string will be tokenized using the bittensor tokenizer. + - a list of strings: the strings will be tokenized using the bittensor tokenizer. + - a tensor with shape [batch_size, sequence_len], assumed to be the output of bittensor tokenizer. + - a tensor with shape [n, batch_size, sequence_len], the operation will unbind the tensor and pass inputs to endpoints. + - a list of tensors of type long each representing a tokenized sentence to be sent to each endpoint. + If inputs are tensors they will be cast to int64 format before sending on the wire. - # Make calls. - responses, codes, times = self._forward( - endpoints=endpoints, - inputs=inputs, - modality=bittensor.proto.Modality.IMAGE, - timeout=timeout, - requires_grad=requires_grad - ) + synapse (:type:`'bittensor.synapse.TextCausalLM'`, default = bittensor.synapse.TextCausalLM(), `optional`): + Synapse axon function call which defaults to bittensor.synapse.TextCausalLM(). + + timeout (:type:`int`, default = dendrite.timeout `optional`): + Request timeout. Queries that do not respond will be replaced by zeros. - # Format to singletons. - if non_list_inputs: - responses = responses[0] + requires_grad (:type:`int`, default = dendrite.requires_grad, `optional`): + If true, the backward pass triggers passing gradients on the wire. + + Returns: + outputs (:obj:`List[ torch.FloatTensor ]` of shape :obj:`num_endpoints * (batch_size, sequence_len, bittensor.__vocab_size__ )`, `required`): + List of output logit encodings of inputs produced by each remote endpoints. Non-responses are zeroes of input shape plus output dimension. + The first dimension will match the number of endpoints queried. + codes (:obj:`torch.LongTensor` of shape :obj:`[ num_endpoints ]`, `required`): + dendrite call return ops. + + times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): + times per call. + """ + if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TextCausalLM: + raise ValueError( "Passed synapse must have type: {} got {} instead".formate( bittensor.proto.Synapse.SynapseType.TextCausalLM, synapses.synapse_type ) ) + + # Format inputs. + formatted_endpoints, formatted_inputs = self.format_text_inputs ( + endpoints = endpoints, + inputs = inputs + ) + # Optionally convert synapses and set typing info. + synapses = [ synapse ] + # Make calls. + outputs, codes, times = self._forward ( + endpoints = formatted_endpoints, + synapses = synapses, + inputs = formatted_inputs, + timeout = timeout, + requires_grad = requires_grad, + ) # Return. - self.update_stats( endpoints, inputs, responses, codes, times ) - return responses, codes, times + self.update_stats( formatted_endpoints, synapses, formatted_inputs, outputs, codes, times ) + return outputs[0], codes[0], times[0] - def forward_tensor( + def text_causal_lm_next( self, - endpoints: Union[List['bittensor.Endpoint'], 'bittensor.Endpoint'], - inputs: List[torch.FloatTensor], - timeout: int = None, - requires_grad: bool = None + endpoints: Union[torch.LongTensor, List[torch.LongTensor], List['bittensor.Endpoint'], 'bittensor.Endpoint'], + inputs: Union[str, List[str], List[torch.LongTensor], torch.LongTensor], + synapse: Optional['bittensor.synapse.TextCausalLMNext'] = synapse.TextCausalLMNext(), + timeout: Optional[int] = None, + requires_grad: Optional[bool] = None, ) -> Tuple[Union[List[torch.FloatTensor], torch.FloatTensor], torch.LongTensor, torch.FloatTensor]: - r""" Forward tensor inputs to endpoints. + r""" Forward text inputs to a list of neuron endpoints and returns logit encodings or timeout. - Args: - endpoints (:obj:`Union[List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): - List or single of endpoints which match the length of inputs. Inputs are sent forward to these endpoints. + Args: + endpoints (:obj:`Union[torch.LongTensor, List[torch.LongTensor], List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): + Endpoints to send inputs to. Endpoint can be one of the following types: + - a single endpoint tensor shape [250] + - a set of endpoint tensors shape [n, 250] + - a list of endpoints tensors each of shape [250] + - a single endpoint object. Inputs will be sent to this endpoint alone. + - a list of endpoint objects. All inputs will be sent to these endpoints. - inputs (:obj:`Union[List[torch.LongTensor], torch.LongTensor]` of shape :obj:`(num_endpoints * [batch_size, sequence_len])`, `required`): - List or single tensors to send to corresponding endpoints. Tensors are of float type and - with shape [batch_size, sequence_len, bittensor.__network_dim__]. - timeout (int, default = dendrite.timeout `optional`): - Request timeout. + inputs (:obj:`Union[str, List[str], List[torch.LongTensor], torch.LongTensor]` of shape :obj:`(num_endpoints * [batch_size, sequence_len])`, `required`): + Tokenized sentences to send on the wire. Inputs can be one of the following types: + - a single string: the string will be tokenized using the bittensor tokenizer. + - a list of strings: the strings will be tokenized using the bittensor tokenizer. + - a tensor with shape [batch_size, sequence_len], assumed to be the output of bittensor tokenizer. + - a tensor with shape [n, batch_size, sequence_len], the operation will unbind the tensor and pass inputs to endpoints. + - a list of tensors of type long each representing a tokenized sentence to be sent to each endpoint. + If inputs are tensors they will be cast to int64 format before sending on the wire. - requires_grad (int, default = dendrite.requires_grad, `optional`): - If true, the backward pass triggers passing gradients on the wire. + synapse (:type:`'bittensor.synapse.TextCausalLMNext'`, default = bittensor.synapse.TextCausalLMNext(), `optional`): + Synapse axon function call which defaults to bittensor.synapse.TextCausalLMNext(). - Returns: - responses (:obj:`Union[ List[torch.FloatTensor], torch.FloatTensor] ` of shape :obj:`(batch_size, sequence_len, bittensor.__network_dim__)`, `required`): - Output encodings of inputs produced by remote endpoints. Non-responses are zeroes of input shape plus output dimension. + timeout (:type:`int`, default = dendrite.timeout `optional`): + Request timeout. Queries that do not respond will be replaced by zeros. - codes (:obj:`torch.LongTensor` of shape :obj:`[ num_endpoints ]`, `required`): - dendrite call return ops. + requires_grad (:type:`int`, default = dendrite.requires_grad, `optional`): + If true, the backward pass triggers passing gradients on the wire. - times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): - times per call. - """ - # Check types. - if not isinstance(endpoints, list) and not isinstance(endpoints, Endpoint): - raise ValueError('endpoints must be of type list or bittensor.Endpoint. Got {}'.format(type(endpoints))) - - if not isinstance(inputs, list) and not isinstance(inputs, torch.FloatTensor): - raise ValueError( - 'inputs must be of type list[torch.FloatTensor] or torch.FloatTensor. Got {}'.format(type(inputs))) - - # Format to list. - non_list_inputs = False - if not isinstance(inputs, list): - non_list_inputs = True - inputs = [inputs] - - # Format to list. - if not isinstance(endpoints, list): - endpoints = [endpoints] - - # Catch inputs != List and endpoints == List - elif non_list_inputs and isinstance(endpoints, list): - raise ValueError( - 'endpoints and inputs must be of same type. Got endpoints {} and inputs {} '.format(type(endpoints), - type(inputs[0]))) - - # Check length. - if len(inputs) < 1: - raise ValueError('inputs list must have at least one element. Got len {}'.format(len(inputs))) - if len(endpoints) < 1: - raise ValueError('endpoints list must have at least one item. Got len {}'.format(len(endpoints))) - if len(inputs) != len(endpoints): - error_msg = 'List of tensor inputs should have the same length as passed destination endpoints, got {} and {}'.format( - len(inputs), len(endpoints)) - raise ValueError(error_msg) + Returns: + outputs (:obj:`List[ torch.FloatTensor ]` of shape :obj:`num_endpoints * ( >= batch_size * (2 * topk + 1) )`, `required`): + List of output topk phrases encodings of inputs produced by each remote endpoints. + Non-responses are zeroes of input shape plus output dimension. + The first dimension will match the number of endpoints queried. - # Check list types. - if not isinstance(inputs[0], torch.FloatTensor): - raise ValueError('inputs must be of type torch.FloatTensor. Got {}'.format(type(inputs[0]))) - if not isinstance(endpoints[0], Endpoint): - raise ValueError('endpoints must be of type bittensor.Endpoint. Got {}'.format(type(endpoints))) + codes (:obj:`torch.LongTensor` of shape :obj:`[ num_endpoints ]`, `required`): + dendrite call return ops. - # Check shape. - if len(inputs[0].shape) != 3: - error_msg = 'Tensor inputs should be rank 3 with semantic shape: [batch_size, sequence_len, bittensor.__network_dim__]' - raise ValueError(error_msg) - if inputs[0].shape[2] != bittensor.__network_dim__: - error_msg = 'Passed tensor must have last dimension {} got {}'.format(bittensor.__network_dim__, - inputs[0].shape[2]) - raise ValueError(error_msg) + times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): + times per call. + """ + if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TextCausalLMNext: + raise ValueError(f"Passed synapse must have type: {bittensor.proto.Synapse.SynapseType.TextCausalLMNext} " + f"got {synapse.synapse_type} instead") - # Make calls. - responses, codes, times = self._forward( + # Format inputs. + formatted_endpoints, formatted_inputs = self.format_text_inputs( endpoints=endpoints, - inputs=inputs, - modality=bittensor.proto.Modality.TENSOR, + inputs=inputs + ) + # Optionally convert synapses and set typing info. + synapses = [synapse] + # Make calls. + outputs, codes, times = self._forward( + endpoints=formatted_endpoints, + synapses=synapses, + inputs=formatted_inputs, timeout=timeout, - requires_grad=requires_grad + requires_grad=requires_grad, ) - - # Format to singletons. - if non_list_inputs: - responses = responses[0] - # Return. - self.update_stats( endpoints, inputs, responses, codes, times ) - return responses, codes, times + self.update_stats(formatted_endpoints, synapses, formatted_inputs, outputs, codes, times) + return outputs[0], codes[0], times[0] - def forward_text( + def text_last_hidden_state( self, - endpoints: Union[ - torch.LongTensor, List[torch.LongTensor], List['bittensor.Endpoint'], 'bittensor.Endpoint'], + endpoints: Union[ torch.LongTensor, List[torch.LongTensor], List['bittensor.Endpoint'], 'bittensor.Endpoint' ], inputs: Union[str, List[str], List[torch.LongTensor], torch.LongTensor], + synapse: Optional[ 'bittensor.synapse.TextLastHiddenState' ] = synapse.TextLastHiddenState(), timeout: int = None, - requires_grad: bool = None + requires_grad: bool = None, ) -> Tuple[Union[List[torch.FloatTensor], torch.FloatTensor], torch.LongTensor, torch.FloatTensor]: - r""" Forward text inputs to a list of neuron endpoints and block until responses or timeout. + r""" Forward text inputs to a list of neuron endpoints and block until last hidden state responses or timeout. Args: endpoints (:obj:`Union[torch.LongTensor, List[torch.LongTensor], List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): @@ -485,6 +536,9 @@ def forward_text( - a tensor with shape [n, batch_size, sequence_len], the operation will unbind the tensor and pass inputs to endpoints. If inputs are tensors they will be cast to int64 format before sending on the wire. + synapse (:type:`'bittensor.synapse.TextLastHiddenState'`, default = bittensor.synapse.TextLastHiddenState(), `optional`): + Synapse axon function call which defaults to bittensor.synapse.TextLastHiddenState(). + timeout (:type:`int`, default = dendrite.timeout `optional`): Request timeout. Queries that do not respond will be replaced by zeros. @@ -492,8 +546,8 @@ def forward_text( If true, the backward pass triggers passing gradients on the wire. Returns: - responses (:obj:`torch.FloatTensor` of shape :obj:`(n, batch_size, sequence_len, bittensor.__network_dim__)`, `required`): - Output encodings of inputs produced by remote endpoints. Non-responses are zeroes of input shape plus output dimension. + outputs (:obj:`List [ torch.FloatTensor ]` of shape :obj:` num_endpoints * ( -1, sequence_len, bittensor.__network_dim__ )`, `required`): + List of output last hidden state encodings of inputs produced by remote endpoints. Non-responses are zeroes of input shape plus output dimension. The first dimension will match the number of endpoints queried. codes (:obj:`torch.LongTensor` of shape :obj:`[ num_endpoints ]`, `required`): @@ -502,7 +556,57 @@ def forward_text( times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): times per call. """ + if synapse.synapse_type != bittensor.proto.Synapse.SynapseType.TextLastHiddenState: + raise ValueError( "Passed synapse must have type:{} got:{} instead".formate( bittensor.proto.Synapse.SynapseType.TextLastHiddenState, synapses.synapse_type ) ) + # Format inputs. + formatted_endpoints, formatted_inputs = self.format_text_inputs ( + endpoints = endpoints, + inputs = inputs + ) + synapses = [ synapse ] + # Make calls. + outputs, codes, times = self._forward ( + endpoints = formatted_endpoints, + synapses = synapses, + inputs = formatted_inputs, + timeout = timeout, + requires_grad = requires_grad, + ) + # Return. + self.update_stats( formatted_endpoints, synapses, formatted_inputs, outputs, codes, times ) + return outputs[0], codes[0], times[0] + + def format_text_inputs ( + self, + endpoints: Union[ torch.LongTensor, List[torch.LongTensor], List['bittensor.Endpoint'], 'bittensor.Endpoint' ], + inputs: Union[str, List[str], List[torch.LongTensor], torch.LongTensor], + ) -> Tuple[ 'bittensor.Endpoint', List[torch.LongTensor] ]: + r""" Formats endpoint and inputs args to a common format. + Args: + endpoints (:obj:`Union[torch.LongTensor, List[torch.LongTensor], List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): + Endpoints to send inputs to. Endpoint can be one of the following types: + - a single endpoint tensor shape [250] + - a set of endpoint tensors shape [n, 250] + - a list of endpoints tensors each of shape [250] + - a single endpoint object. Inputs will be sent to this endpoint alone. + - a list of endpoint objects. All inputs will be sent to these endpoints. + + inputs (:obj:`Union[str, List[str], List[torch.LongTensor], torch.LongTensor]` of shape :obj:`(num_endpoints * [batch_size, sequence_len])`, `required`): + Tokenized sentences to send on the wire. Inputs can be one of the following types: + - a single string: the string will be tokenized using the bittensor tokenizer. + - a list of strings: the strings will be tokenized using the bittensor tokenizer. + - a tensor with shape [batch_size, sequence_len], assumed to be the output of bittensor tokenizer. + - a tensor with shape [n, batch_size, sequence_len], the operation will unbind the tensor and pass inputs to endpoints. + If inputs are tensors they will be cast to int64 format before sending on the wire. + + Returns: + formatted_endpoints (:obj:`Union[torch.LongTensor, List[torch.LongTensor], List[bittensor.Endpoint], bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): + A list of endpoint objects. All inputs will be sent to these endpoints. + + formatted_inputs (:obj:`Union[str, List[str], List[torch.LongTensor], torch.LongTensor]` of shape :obj:`(num_endpoints * [batch_size, sequence_len])`, `required`): + A list of tensor of type long each representing a tokenized sentence to be sent to each endpoint. + """ # To be filled. Inputs and endpoint must be list with the same number of elements. formatted_inputs = [] formatted_endpoints = [] @@ -528,8 +632,7 @@ def cast_and_check_tensor_input(tensor_input) -> torch.LongTensor: raise ValueError(error_msg) return tensor_input - # ---- Endpoints is singular. - + # ---- Endpoints is singular. if isinstance(endpoints, bittensor.Endpoint): formatted_endpoints = [endpoints] @@ -581,7 +684,7 @@ def cast_and_check_tensor_input(tensor_input) -> torch.LongTensor: elif isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], str): # Encode to tensors. tokenizer = bittensor.tokenizer() - tokenized_sentences = tokenizer(inputs, padding=True, truncation=True)['input_ids'] + tokenized_sentences = tokenizer(inputs, padding = True, truncation=True)['input_ids'] tokenizer_tensor = cast_and_check_tensor_input(torch.tensor(tokenized_sentences, dtype=torch.int64)) formatted_inputs = [tokenizer_tensor for _ in formatted_endpoints] @@ -616,18 +719,7 @@ def cast_and_check_tensor_input(tensor_input) -> torch.LongTensor: len(inputs), len(endpoints)) raise ValueError(error_msg) - # Make calls. - responses, codes, times = self._forward( - endpoints=formatted_endpoints, - inputs=formatted_inputs, - modality=bittensor.proto.Modality.TEXT, - timeout=timeout, - requires_grad=requires_grad, - ) - - # Return. - self.update_stats( formatted_endpoints, formatted_inputs, responses, codes, times ) - return responses, codes, times + return formatted_endpoints, formatted_inputs def _init_stats(self): return SimpleNamespace( @@ -654,31 +746,43 @@ def _init_stats(self): qps_per_pubkey = {}, ) - def update_stats(self, endpoints, requests, responses, return_ops, query_times): + def update_stats( + self, + endpoints: List[ 'bittensor.Endpoint'], + synapses: List[ 'bittensor.proto.Synapse' ], + inputs: List[torch.Tensor], + outputs: List[ List[ torch.Tensor ] ], + codes: List [ List[ torch.LongTensor ] ], + times: List [ List[ torch.FloatTensor ] ] + ): r""" Update dendrite stat according to the response we get from peers. Updates were saved to self.stats. Args: endpoints (:obj:`List[bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): The set of endpoints that dendrite sent request to. - requests (List[torch.Tensor] of shape :obj:`[ num_endpoints ]`, `required`): - Requests from the call. + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. - responses (List[torch.FloatTensor] of shape :obj:`[ num_endpoints ]`, `required`): - Responses from the call. + inputs (:obj:`List[torch.Tensor]` of shape :obj:`(n_endpoints)`, `required`): + List of torch tensors to be sent to the associated endpoints. + + outputs (:obj:`List[ List[ torch.FloatTensor ] ]` of shape :obj:`num_synapses * ( num_endpoints * ( -1, -1, -1 ) )`, `required`): + List of outputs from synapses, each a list of size num_endpoints of tensors with relevant size. Non-responses are zeroes of relevant + synapse shape. - return_ops (:obj:`torch.LongTensor` of shape :obj:`[ num_endpoints ]`, `required`): - Dendrite call return ops. + codes (:obj:`List [ torch.LongTensor ]` of shape :obj:`[ num_endpoints ]`, `required`): + Return code per call per synapse. - query_times (:obj:`torch.FloatTensor` of shape :obj:`[ num_endpoints ]`, `required`): - Times per call. + times (:obj:`List [ torch.FloatTensor ]` of shape :obj:`[ num_endpoints ]`, `required`): + Times per call per synapse. """ self.stats.qps.event() self.stats.total_requests += 1 total_in_bytes_per_second = 0 - self.stats.avg_out_bytes_per_second.event( float(sys.getsizeof(requests)) ) - for (e_i, req_i, resp_i, code_i, time_i) in list(zip(endpoints, requests, responses, return_ops.tolist(), query_times.tolist())): - pubkey = e_i.hotkey - + self.stats.avg_out_bytes_per_second.event( float(sys.getsizeof(inputs)) ) + for (end_i, syn_i, inps_i, outs_i, codes_i, times_i) in list( zip ( endpoints, synapses, inputs, outputs, codes, times ) ): + pubkey = end_i.hotkey # First time for this pubkey we create a new entry. if pubkey not in self.stats.requests_per_pubkey: self.stats.requests_per_pubkey[pubkey] = 0 @@ -690,15 +794,16 @@ def update_stats(self, endpoints, requests, responses, return_ops, query_times): self.stats.qps_per_pubkey[pubkey] = stat_utils.EventsPerSecondRollingAverage( 0, 0.01 ) self.stats.requests_per_pubkey[pubkey] += 1 - self.stats.successes_per_pubkey[pubkey] += 1 if code_i == 1 else 0 - self.stats.query_times_per_pubkey[pubkey].event( float(time_i) ) - self.stats.avg_in_bytes_per_pubkey[pubkey].event( float(sys.getsizeof(resp_i)) ) - self.stats.avg_out_bytes_per_pubkey[pubkey].event( float(sys.getsizeof(req_i)) ) + self.stats.successes_per_pubkey[pubkey] += (codes_i == 1).sum().int() + self.stats.query_times_per_pubkey[pubkey].event( float( times_i.max() ) ) + self.stats.avg_in_bytes_per_pubkey[pubkey].event( float(sys.getsizeof( outs_i )) ) + self.stats.avg_out_bytes_per_pubkey[pubkey].event( float(sys.getsizeof( inps_i )) ) self.stats.qps_per_pubkey[pubkey].event() - total_in_bytes_per_second += sys.getsizeof(resp_i) if code_i == 1 else 0 + total_in_bytes_per_second += sys.getsizeof(outs_i) if (codes_i == 1).sum().int() == len( synapses ) else 0 try: - if bittensor.proto.ReturnCode.Name(code_i) in self.stats.codes_per_pubkey[pubkey].keys(): - self.stats.codes_per_pubkey[pubkey][bittensor.proto.ReturnCode.Name(code_i)] += 1 + for code_i_s in codes_i: + if bittensor.proto.ReturnCode.Name(code_i_s) in self.stats.codes_per_pubkey[pubkey].keys(): + self.stats.codes_per_pubkey[pubkey][bittensor.proto.ReturnCode.Name(code_i_s)] += 1 except: # Code may be faulty. pass @@ -755,3 +860,4 @@ def to_wandb( self ): except Exception as e: bittensor.logging.error( prefix='failed dendrite.to_wandb()', sufix = str(e)) return {} + diff --git a/bittensor/_dendrite/dendrite_mock.py b/bittensor/_dendrite/dendrite_mock.py index 021813f518..27e7a91574 100644 --- a/bittensor/_dendrite/dendrite_mock.py +++ b/bittensor/_dendrite/dendrite_mock.py @@ -37,14 +37,6 @@ # dummy tensor that triggers autograd DUMMY = torch.empty(0, requires_grad=True) -# Helper function for filling nill (zero) responses on failures. -def nill_response_for(inputs): - """ Get zero matrix with the same size as inputs - """ - if torch.numel(inputs) == 0: - return torch.tensor([]) - return torch.zeros((inputs.size(0), inputs.size(1), bittensor.__network_dim__), dtype=torch.float32) - class DendriteMock(torch.autograd.Function): @@ -556,7 +548,7 @@ def cast_and_check_tensor_input(tensor_input) -> torch.LongTensor: elif isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], str): # Encode to tensors. tokenizer = bittensor.tokenizer() - tokenized_sentences = tokenizer(inputs, padding=True, truncation=True)['input_ids'] + tokenized_sentences = tokenizer(inputs, truncation=True)['input_ids'] tokenizer_tensor = cast_and_check_tensor_input(torch.tensor(tokenized_sentences, dtype=torch.int64)) formatted_inputs = [tokenizer_tensor for _ in formatted_endpoints] diff --git a/bittensor/_logging/__init__.py b/bittensor/_logging/__init__.py index d1a5b9b241..f541b97ff8 100644 --- a/bittensor/_logging/__init__.py +++ b/bittensor/_logging/__init__.py @@ -213,7 +213,7 @@ def log_formatter(cls, record): """ extra = record['extra'] if 'rpc' in extra: - log_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + extra['code_str'] + " | {extra[prefix]} | {extra[direction]} | {extra[arrow]} | {extra[uid_str]} | {extra[inputs]} | {extra[call_time]} | {extra[key_str]} | {extra[rpc_message]} \n" + log_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + extra['code_str'] + " | {extra[prefix]} | {extra[direction]} | {extra[arrow]} | {extra[uid_str]} | {extra[inputs]} | {extra[call_time]} | {extra[key_str]} | {extra[rpc_message]} | {extra[synapse]} \n" return log_format elif 'receptor' in extra: log_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + extra['action'] + " | uid:{extra[uid]} | {extra[ip_str]} | hotkey:{extra[hotkey]} | coldkey:{extra[coldkey]} \n" @@ -242,7 +242,20 @@ def log_save_formatter(cls, record): @classmethod - def rpc_log( cls, axon: bool, forward: bool, is_response: bool, code:int, call_time: float, pubkey: str, uid: int = None, inputs:list = None, outputs:list = None, message:str = ''): + def rpc_log( + cls, + axon: bool, + forward: bool, + is_response: bool, + code:int, + call_time: float, + pubkey: str, + uid: int = None, + inputs:list = None, + outputs:list = None, + message:str = '', + synapse:'bittensor.Synapse' = None + ): """ Debug logging for the communication between endpoints with axon/dendrite """ @@ -282,8 +295,24 @@ def rpc_log( cls, axon: bool, forward: bool, is_response: bool, code:int, call_t inputs = str(list(inputs)) if inputs != None else '[x]' inputs = inputs.center(15) + if synapse != None: + synapse = codes.code_to_synapse(synapse) + rpc_message = message if message != None else 'None' - logger.debug( 'rpc', rpc=True, prefix=prefix, direction=direction, arrow=arrow, call_time = call_time_str, uid_str=uid_str, key_str=key_str, code_str=code_str, inputs = inputs, rpc_message = rpc_message) + logger.debug( + 'rpc', + rpc=True, + prefix=prefix, + direction=direction, + arrow=arrow, + call_time = call_time_str, + uid_str=uid_str, + key_str=key_str, + code_str=code_str, + inputs = inputs, + rpc_message = rpc_message, + synapse = synapse + ) @classmethod diff --git a/bittensor/_neuron/__init__.py b/bittensor/_neuron/__init__.py index eaad26a974..aa56a0ce6f 100644 --- a/bittensor/_neuron/__init__.py +++ b/bittensor/_neuron/__init__.py @@ -20,18 +20,12 @@ version_split = __version__.split(".") __version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) -from .text import template_miner, template_server, advanced_server, core_validator, multitron_server - - -__all_neurons__ = { 'text_template_miner': template_miner.neuron, +from .text import core_validator, core_server +__all_neurons__ = { 'text_core_validator': core_validator.neuron, - 'text_template_server':template_server.neuron, - 'text_advanced_server':advanced_server.neuron, - 'multitron_server': multitron_server.neuron} + 'core_server': core_server.neuron} -__text_neurons__ = { 'template_miner': template_miner.neuron, +__text_neurons__ = { 'core_validator': core_validator.neuron, - 'template_server':template_server.neuron, - 'advanced_server':advanced_server.neuron, - 'multitron_server': multitron_server, + 'core_server': core_server.neuron } diff --git a/bittensor/_neuron/text/advanced_server/__init__.py b/bittensor/_neuron/text/advanced_server/__init__.py deleted file mode 100644 index f018d8b0e0..0000000000 --- a/bittensor/_neuron/text/advanced_server/__init__.py +++ /dev/null @@ -1,112 +0,0 @@ -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -""" Advanced server neurons - -Example: - $ import neurons - $ neurons.text.advanced_server().run() - -""" - -import bittensor -import threading -import os - -from .nucleus_impl import server -from .run import serve - -class neuron: - r""" - Creates a bittensor neuron that specializes in serving the bittensor network. The advanced server - trains itself while accepting requests from the bittensor network. This is done by accumulating - gradients over the wire and applying them in a single step. Blacklist features are enabled in - advanced servers to determine who can apply gradients. - - Args: - config (:obj:`bittensor.Config`, `optional`): - bittensor.server.config() - subtensor (:obj:bittensor.subtensor , `optional`): - bittensor subtensor connection - dataset (:obj:bittensor.dataset , `optional`): - bittensor dataset - wallet (:obj:bittensor.wallet, `optional`): - bittensor wallet object - axon (:obj:bittensor.axon, `optional`): - bittensor axon object - metagraph (:obj:bittensor.metagraph, `optional`): - bittensor metagraph object - - Examples:: - >>> subtensor = bittensor.subtensor(network='nakamoto') - >>> server = bittensor.neuron.text.advanced_server.neuron(subtensor=subtensor) - >>> server.run() - """ - def __init__( - self, - config: 'bittensor.config' = None, - subtensor: 'bittensor.subtensor' = None, - dataset: 'bittensor.dataset' = None, - wallet: 'bittensor.wallet' = None, - axon: 'bittensor.axon' = None, - metagraph: 'bittensor.metagraph' = None, - ): - if config == None: config = server.config() - config = config; - self.check_config( config ) - bittensor.logging ( - config = config, - logging_dir = config.neuron.full_path, - ) - - self.model = server( config = config ) - self.config = config - - self.subtensor = subtensor - self.dataset = dataset - self.wallet = wallet - self.axon = axon - self.metagraph = metagraph - - def run(self): - serve( - self.config, - self.model, - subtensor=self.subtensor, - wallet = self.wallet, - metagraph=self.metagraph, - axon= self.axon) - - @classmethod - def config(cls): - return server.config() - - @classmethod - def check_config( cls, config: 'bittensor.Config' ): - r""" Checks/validates the config namespace object. - """ - bittensor.logging.check_config( config ) - bittensor.wallet.check_config( config ) - bittensor.subtensor.check_config( config ) - bittensor.metagraph.check_config( config ) - bittensor.dataset.check_config( config ) - bittensor.axon.check_config( config ) - bittensor.wandb.check_config( config ) - full_path = os.path.expanduser('{}/{}/{}/{}'.format( config.logging.logging_dir, config.wallet.get('name', bittensor.defaults.wallet.name), config.wallet.get('hotkey', bittensor.defaults.wallet.hotkey), config.neuron.name )) - config.neuron.full_path = os.path.expanduser(full_path) - if not os.path.exists(config.neuron.full_path): - os.makedirs(config.neuron.full_path) diff --git a/bittensor/_neuron/text/advanced_server/main.py b/bittensor/_neuron/text/advanced_server/main.py deleted file mode 100644 index d3db580ed7..0000000000 --- a/bittensor/_neuron/text/advanced_server/main.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" The bittensor advanced server class - -Example: - $ python3 miners/text/advanced_server.py --logging.debug - -""" - -import bittensor -if __name__ == "__main__": - bittensor.utils.version_checking() - template = bittensor.neurons.advanced_server.neuron().run() \ No newline at end of file diff --git a/bittensor/_neuron/text/advanced_server/nucleus_impl.py b/bittensor/_neuron/text/advanced_server/nucleus_impl.py deleted file mode 100644 index 9b0c99ddf6..0000000000 --- a/bittensor/_neuron/text/advanced_server/nucleus_impl.py +++ /dev/null @@ -1,305 +0,0 @@ -import argparse -import bittensor -import torch -import torch.nn.functional as F -from torch import nn -from transformers import AutoModel,AutoTokenizer,AutoConfig -from torch.nn.utils.rnn import pad_sequence -from loguru import logger; logger = logger.opt(colors=True) -from typing import Tuple, Optional - -class server(torch.nn.Module): - def __init__(self, - config: 'bittensor.config' = None, - pretrained: bool = None, - model_name: str = None, - padding: bool =None, - interpolate: bool =None, - inter_degree: str = None, - model = None, - tokenizer = None, - mapping_function = None, - token_remap = None, - checking= None): - r"""" Creates a server that serves up a pretrained miner on the bittensor network - Args: - config (:obj:`bittensor.Config`, `required`): - bittensor.server.config() - pretrained (:obj:bool , `optional`): - if the model should pretrained or not - model_name (:obj:string , `optional`): - name of the pretrained model from huggingface to use - padding (:obj:bool, `optional`): - If the server should pad out to match the hidden units that the bittensor network is using - If set to False, it will instead create a mapping layer to do the same thing. - interpolate (:obj:bool, `optional`): - If the server should interpolate between sequence length differences. - If set to false, there should be a mapping function that takes care of the differnces - inter_degree (:obj:str, `optional`): - The Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area) - model (:obj:torch.module, `optional`): - Overrides the huggingface pretrained model with your own pretrained model - tokenizer (:obj:huggingface.tokenizer, `optional`): - Overrides the huggingface tokenizer with your tokenizer - mapping_function (:obj:Callable, `optional`): - Custom mapping function that maps between sequence length differences between tokenizers - token_remap (:obj:Callable, `optional`): - Custom function that maps between tokenizers (defaults to self.remapping_token) - """ - super(server, self).__init__() - if config == None: config = server.config() - self.config = config;print(config) - - #setting up pretrained model - self.model_name = model_name if model_name != None else config.neuron.model_name - self.pretrained = pretrained if pretrained != None else config.neuron.pretrained - if self.pretrained == True: - self.pre_model = model if model != None else AutoModel.from_pretrained(self.model_name) - self.tokenizer = tokenizer if tokenizer != None else AutoTokenizer.from_pretrained(self.model_name) - elif self.pretrained == False: - model_config = AutoConfig.from_pretrained(self.model_name) - model_config.vocab_size= bittensor.__vocab_size__ - self.pre_model = model if model != None else AutoModel.from_config(model_config) - self.tokenizer = bittensor.tokenizer() - - #parameters of the models - self.final_dim = bittensor.__network_dim__ - self.pre_dimension = self.pre_model.config.hidden_size - self.device = config.neuron.device - self.padding = padding if padding != None else config.neuron.padding - self.interpolate = interpolate if interpolate != None else config.neuron.interpolate - self.inter_degree = inter_degree if inter_degree != None else config.neuron.inter_degree - self.checking = checking if checking != None else config.neuron.checking - self.mapping_function= mapping_function - self.token_remap = token_remap if token_remap != None else self.remapping_token - - if self.padding == False: - self.mapping = torch.nn.Linear( self.pre_dimension, self.final_dim) - - self.decoder = torch.nn.Linear( self.final_dim, bittensor.__vocab_size__ , bias=False) - self.loss_fct = torch.nn.CrossEntropyLoss() - - self.outputs_cache = None - self.gradients_cache = None - - #checking if the parameters of the server makes sense - if self.checking and pretrained == True: - self.check() - - # -- keeps track of gradients applied - self.backward_gradients = 0 - self.set_fine_tuning_params() - - def set_fine_tuning_params(self) -> Tuple[bool, str]: - r''' Set to tune only the parameter of the last layer - Returns: - reached_last_layer (:type:`bool`): - If we have set partial of the model to requires grad. - - last_layer_name (:type:`string`): - The name of the last layer that user specified or we found. - None if the user did not specify and we couldnt find it. - ''' - def find_last_layer(model: torch.nn.Module) -> Optional[str]: - r''' Recursively find the last layer in a nn.ModuleList - Args: - model (:obj:`torch.module`): - The model (or sub-model) to fine the last layer from. - Returns: - name (:type:`str`): - The name (or sub-name) of the last layer. - None if not found - ''' - reverted_child_list = [(name, child) for name, child in model.named_children()] - reverted_child_list.reverse() - - for name, child in reverted_child_list: - if isinstance(child, nn.ModuleList): - if self.config.neuron.finetune.num_layers > len(child): - logger.warning(f'Number of finetune layers was set higher then the layers avaliable {len(child)}') - return None - return (name + '.' +str(len(child) - self.config.neuron.finetune.num_layers)) - - for name, child in reverted_child_list: - name_ = find_last_layer(child) - if name_ != None: - return (name+'.'+ name_) - - return None - - if self.config.neuron.finetune.layer_name == None: - last_layer_name = find_last_layer(self.pre_model) - else: - last_layer_name = self.config.neuron.finetune.layer_name - - reached_last_layer = False - - # set the non-last layer parameters not to require grads - if (self.config.neuron.finetune.all) or (last_layer_name == None): - return False, last_layer_name - - logger.success(f'Set to finetune layer {last_layer_name} and onwards') - - for name, param in self.pre_model.named_parameters(): - if last_layer_name in name or reached_last_layer == True: - param.requires_grad = True - reached_last_layer = True - else: - param.requires_grad = False - - if reached_last_layer == False: - if self.config.neuron.finetune.all: - logger.warning('Set to finetune the whole model, this will significantly increase the memory usage.') - else: - logger.warning(f'Cannot identify the last layer of the model with name {last_layer_name}, setting to finetune on all of the parameters.') - - return reached_last_layer, last_layer_name - - def forward(self, inputs,tokenizer=None): - """ - Forward pass through the whole server model. Returns the loss and decoded predictions. - - Args: - inputs ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - tokenizer (:obj:'huggingface.tokenizer', optional): - The tokenizer which was used to tokenize the inputs - Returns: - loss (:obj:`torch.FloatTensor`): - MLM loss from the inputs - decoded_targets (:obj:`torch.FloatTensor`): - Decoded predictions of the next token in the sentence. - - """ - decoded_targets = self.decoder(self.encode_forward(inputs,tokenizer)) - - shift_logits = decoded_targets[..., :-1, :].contiguous() - shift_labels = inputs[..., 1:].contiguous() - loss = self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) - - return loss, decoded_targets - - def encode_forward(self,inputs,tokenizer=None): - r""" Forward pass through the pretrained model and possible mappings between hidden units. - The response tensor should be the hidden units computed using the local context and with shape: [batch_size, sequence_len, __network_dim__]. - - Args: - inputs ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - tokenizer ( huggingface.tokenizer, `optional`): - The tokenizer which was used to tokenize the inputs - - Returns: - outputs (:obj:`torch.FloatTensor`): - The nucleus's outputs as a torch tensor of shape [batch_size, sequence_len, __network_dim__] - """ - sen_len = inputs.size() - inputs = self.token_remap(inputs,tokenizer).to(self.device) - pre_hidden = self.pre_model(inputs).last_hidden_state - - if self.interpolate: - down= F.interpolate(pre_hidden.unsqueeze(1),size=[sen_len[1],pre_hidden.size()[2]],mode=self.inter_degree).squeeze(1) - elif self.mapping_function: - down = self.mapping_function(pre_hidden) - else: - raise Exception('interpolation off but no mapping function found. Please attach a mapping function') - - if self.padding: - padding_l = (self.final_dim-self.pre_dimension)//2 - padding_r = (self.final_dim-self.pre_dimension) - padding_l - encoded_hidden = F.pad(down, (padding_l, padding_r), "constant", 0) - else: - encoded_hidden = self.mapping(down) - return encoded_hidden - - def remapping_token(self,input, old_tokenizer=None): - r""" Default remapping of tokenizers; decodes the message and then remaps the message using a new tokenizer - Args: - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - old_tokenizer ( huggingface.tokenizer, `required`): - The tokenizer which was used to tokenize the input (defaults to bittensor tokenizer if none is given) - """ - if old_tokenizer == None: - old_tokenizer = bittensor.tokenizer() - new_data = [] - for i in range(input.shape[0]): - decoded = old_tokenizer.decode(input[i]) - hugging = self.tokenizer(decoded) - new_data += [torch.LongTensor(hugging.input_ids)] - new_data = pad_sequence(new_data,batch_first=True) - return new_data - - def check(self): - r"""Checks the server settings - """ - assert self.tokenizer.name_or_path == self.pre_model.name_or_path, 'incorrect model ({}) and tokenizer ({})'.format(self.pre_model.name_or_path,self.tokenizer.name_or_path) - if self.interpolate == False: - assert self.mapping_function != None, 'Incorrect Settings; needs atleast one mapping function for sequence length changes' - - def save(self, path): - try: - state_dict = { - 'model': self.pretrained, - 'pretrained_model': self.pre_model.state_dict(), - 'decoder': self.decoder.state_dict() - } - if self.padding == False: - state_dict['mapping'] = self.mapping.state_dict() - torch.save( state_dict, "{}/model.torch".format( path) ) - bittensor.logging.success(prefix='Saved model', sufix='{}/model.torch'.format( path ) ) - except Exception as e: - logger.exception('Failed to save model with error:{}', e) - - def load(self, path): - try: - state_dict= torch.load("{}/model.torch".format( path )) - if self.pretrained == state_dict['model']: - self.pre_model.load_state_dict(state_dict['pretrained_model'], strict=False) - self.decoder.load_state_dict(state_dict['decoder']) - if self.padding == False: - self.mapping.load_state_dict(state_dict['mapping']) - - bittensor.logging.success( prefix = 'Reloaded model', sufix = '{}/model.torch'.format( path )) - - - except Exception as e: - logger.warning('No saved model found with error: {}', e) - - @staticmethod - def config (): - parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='If set, defaults are overridden by passed file.') - parser.add_argument('--neuron.learning_rate', type=float, help='Training initial learning rate.', default=0.01) - parser.add_argument('--neuron.momentum', type=float, help='optimizer momentum.', default=0.8) - parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0) - parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu")) - parser.add_argument('--neuron.model_name', type=str, help='pretrained model from hugging face',default='gpt2') - parser.add_argument('--neuron.pretrained', action='store_false', help='if the model should be pretrained',default=True) - parser.add_argument('--neuron.padding', action='store_false', help='To pad out final dimensions',default=True) - parser.add_argument('--neuron.interpolate', action='store_false', help='To interpolate between sentence length',default=True) - parser.add_argument('--neuron.inter_degree', type=str, help='Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area)', default='nearest') - parser.add_argument('--neuron.name', type=str, help='Trials for this miner go in miner.root / (wallet_cold - wallet_hot) / miner.name ', default='advanced_server') - parser.add_argument('--neuron.checking', action='store_false', help='To check if server settings are correct',default=True) - parser.add_argument('--neuron.restart', action='store_true', help='If set, train the neuron from the beginning', default=False) - parser.add_argument('--neuron.blacklist.stake.forward', type=float, help='Amount of stake (tao) in order not to get blacklisted for forward requests', default=10) - parser.add_argument('--neuron.blacklist.stake.backward', type=float, help='Amount of stake (tao) in order not to get blacklisted for backward requests', default=100) - parser.add_argument('--neuron.blacklist_allow_non_registered', action='store_true', help='''If true, allow non-registered peers''', default=False) - parser.add_argument('--neuron.metagraph_sync', type=float, help='how often to sync the metagraph', default=100000) - parser.add_argument('--neuron.blocks_per_set_weights', type=float, help='how often to sync set weights', default=100) - parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch', default=2) - parser.add_argument('--neuron.blacklist.time', type=int, help='how often a peer can query you (seconds) ', default=0) - parser.add_argument('--neuron.finetune.all', action='store_true', help='Finetune your whole model instead of only on the last (few) layers', default=False) - parser.add_argument('--neuron.finetune.num_layers', type=int, help='The number of layers to finetune on your model.', default=1) - parser.add_argument('--neuron.finetune.layer_name', type=str, help='Specify since which layer to finetune. eg. encoder.layer.11', default=None) - - bittensor.wallet.add_args( parser ) - bittensor.axon.add_args( parser ) - bittensor.subtensor.add_args( parser ) - bittensor.logging.add_args( parser ) - bittensor.wandb.add_args(parser) - bittensor.prioritythreadpool.add_args( parser ) - bittensor.dataset.add_args( parser ) - bittensor.metagraph.add_args( parser ) - return bittensor.config( parser ) - diff --git a/bittensor/_neuron/text/advanced_server/run.py b/bittensor/_neuron/text/advanced_server/run.py deleted file mode 100644 index 79847afa9e..0000000000 --- a/bittensor/_neuron/text/advanced_server/run.py +++ /dev/null @@ -1,380 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" Advanced server neuron. - -Example: - $ python miners/text/advanced_server/main.py - -""" -from time import time -import bittensor -import torch -import wandb -import pandas -import datetime -import traceback -import sys -import os - -from loguru import logger; logger = logger.opt(colors=True) -from torch.nn.utils import clip_grad_norm_ -from datetime import datetime,timedelta -from threading import Lock -os.environ['TOKENIZERS_PARALLELISM'] = 'false' - -def serve( - config, - gp_server= None, - subtensor = None, - wallet = None, - metagraph = None, - axon = None -): - config.to_defaults() - - # Create Subtensor connection - subtensor = bittensor.subtensor(config = config) if subtensor == None else subtensor - - # Load/Create our bittensor wallet. - if wallet == None: - wallet = bittensor.wallet( config = config ).create().register(subtensor=subtensor) - else: - wallet.register(subtensor=subtensor) - - # Load/Sync/Save our metagraph. - if metagraph == None: - metagraph = bittensor.metagraph ( - subtensor = subtensor - ).load().sync().save() - else: - metagraph.load().sync().save() - - # Instantiate the model we are going to serve on the network. - # Creating a threading lock for updates to the model - mutex = Lock() - gp_server = gp_server.to(gp_server.device) - - # Create our optimizer. - optimizer = torch.optim.SGD( - [ {"params": gp_server.parameters()} ], - lr = config.neuron.learning_rate, - momentum = config.neuron.momentum, - ) - bittensor.tokenizer() - timecheck = {} - - n_topk_peer_weights = subtensor.min_allowed_weights - # Define our forward function. - def forward_text ( inputs_x ): - r""" Forward function that is called when the axon recieves a forward request from other peers - Args: - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - - Returns: - outputs (:obj:`torch.FloatTensor`): - The nucleus's outputs as a torch tensor of shape [batch_size, sequence_len, __network_dim__] - """ - return gp_server.encode_forward( inputs_x.to(gp_server.device) ) - - # Define our backward function. - def backward_text (inputs_x, grads_dy ): - r"""Backwards function that is called when the axon recieves a backwards request from other peers. - Updates the server parameters with gradients through the chain. - - Args: - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs from previous forward call. - grads_dy ( :obj:`torch.Tensor`, `required`): - torch grads of forward output. - - """ - # -- normalized grads -- - grads_dy = grads_dy/(grads_dy.sum() + 0.00001) - - with mutex: - outputs_y = gp_server.encode_forward( inputs_x.to(gp_server.device) ) - with torch.autograd.set_detect_anomaly(True): - torch.autograd.backward ( - tensors = [ outputs_y ], - grad_tensors = [ grads_dy.to(gp_server.device) ], - retain_graph=True - ) - logger.info('Backwards axon gradient applied') - - gp_server.backward_gradients += inputs_x.size(0) - - def priority(pubkey:str, request_type:bittensor.proto.RequestType, inputs_x) -> float: - r"""Calculates the priority on requests based on stake and size of input - - Args: - pubkey ( str, `required`): - The public key of the caller. - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - request_type ( bittensor.proto.RequestType, `required`): - the request type ('FORWARD' or 'BACKWARD'). - """ - try: - uid = metagraph.hotkeys.index(pubkey) - priority = metagraph.S[uid].item()/ sys.getsizeof(inputs_x) - - except: - # zero priority for those who are not registered. - priority = 0 - - return priority - - def blacklist(pubkey:str, request_type:bittensor.proto.RequestType) -> bool: - r"""Axon security blacklisting, used to blacklist message from low stake members - Args: - pubkey ( str, `required`): - The public key of the caller. - request_type ( bittensor.proto.RequestType, `required`): - the request type ('FORWARD' or 'BACKWARD'). - """ - - def registration_check(): - # If we allow non-registered requests return False = not blacklisted. - is_registered = pubkey in metagraph.hotkeys - if not is_registered: - if config.neuron.blacklist_allow_non_registered: - return False - raise Exception('Registration blacklist') - - # Check for stake - def stake_check() -> bool: - - # Check stake. - uid = metagraph.hotkeys.index(pubkey) - if request_type == bittensor.proto.RequestType.FORWARD: - if metagraph.S[uid].item() < config.neuron.blacklist.stake.forward: - raise Exception('Stake blacklist') - - return False - - elif request_type == bittensor.proto.RequestType.BACKWARD: - if metagraph.S[uid].item() < config.neuron.blacklist.stake.backward: - raise Exception('Stake blacklist') - - return False - - def validator_check(): - - uid = metagraph.hotkeys.index(pubkey) - if (metagraph.W[uid] >0).sum() >= n_topk_peer_weights: - return False - raise Exception('Validator blacklist') - - - # Check for time - def time_check(): - current_time = datetime.now() - if pubkey in timecheck.keys(): - prev_time = timecheck[pubkey] - if current_time - prev_time >= timedelta(seconds=config.neuron.blacklist.time): - timecheck[pubkey] = current_time - return False - else: - timecheck[pubkey] = current_time - raise Exception('Time blacklist') - else: - timecheck[pubkey] = current_time - return False - - # Blacklist checks - try: - registration_check() - - stake_check() - - time_check() - - validator_check() - - return False - - #blacklisted - except Exception as e: - return True - - if axon == None: - # Create our axon server - axon = bittensor.axon ( - config = config, - wallet = wallet, - forward_text = forward_text, - backward_text = backward_text, - blacklist = blacklist, - priority = priority - ) - - # Training Data - dataset = bittensor.dataset(config=config) - - # load our old model - if not config.neuron.restart : - gp_server.load(config.neuron.full_path) - - if config.wandb.api_key != 'default': - # --- Init Wandb. - bittensor.wandb( - config = config, - cold_pubkey = wallet.coldkeypub.ss58_address, - hot_pubkey = wallet.hotkey.ss58_address, - root_dir = config.neuron.full_path - ) - - nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address) - - # --- last sync block - last_sync_block = subtensor.get_current_block() - last_set_block = last_sync_block - - # -- Main Training loop -- - try: - # -- download files from the mountain - data = next(dataset) - - # --- creating our chain weights - # --- query the chain for the most current number of peers on the network - chain_weights = torch.zeros(subtensor.n) - uid = nn.uid - chain_weights[uid] = 1 - - # -- serve axon to the network. - axon.start().serve(subtensor = subtensor) - - while True: - - # --- Check registration and optionally re-register - nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address) - if not wallet.is_registered( subtensor = subtensor ): - wallet.register( subtensor = subtensor ) - axon.serve( subtensor = subtensor ) # Re-serve the erased axon data. - nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address) - - - # --- Run - current_block = subtensor.get_current_block() - end_block = current_block + config.neuron.blocks_per_epoch - interation = 0 - - # --- Training step. - while end_block >= current_block: - if current_block != subtensor.get_current_block(): - loss, _ = gp_server( next( dataset ).to(gp_server.device) ) - if interation > 0 : - losses += loss - else: - losses = loss - interation += 1 - current_block = subtensor.get_current_block() - - - #Custom learning rate - if gp_server.backward_gradients > 0: - optimizer.param_groups[0]['lr'] = 1/(gp_server.backward_gradients) - else: - optimizer.param_groups[0]['lr'] = 0.1 - - # --- Update parameters - if interation != 0 or gp_server.backward_gradients != 0: - with mutex: - logger.info('Backpropagation Started') - if interation != 0: - losses.backward() - clip_grad_norm_(gp_server.parameters(), 1.0) - - optimizer.step() - optimizer.zero_grad() - logger.info('Backpropagation Successful: Model updated') - - nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address) - - gp_server.backward_gradients = 0 - # --- logging data - wandb_data = { - 'block': end_block, - 'loss': losses.cpu().item()/interation, - 'stake': nn.stake, - 'rank': nn.rank, - 'incentive': nn.incentive, - 'trust': nn.trust, - 'consensus': nn.consensus, - 'incentive': nn.incentive, - 'dividends': nn.dividends, - 'emission': nn.emission, - } - bittensor.__console__.print('[green]Current Status:[/green]', wandb_data) - - # Add additional wandb data for axon, metagraph etc. - if config.wandb.api_key != 'default': - if uid in metagraph.W: - df = pandas.concat( [ - bittensor.utils.indexed_values_to_dataframe( prefix = 'w_i_{}'.format(nn.uid), index = metagraph.uids, values = metagraph.W[:, uid] ), - bittensor.utils.indexed_values_to_dataframe( prefix = 's_i'.format(nn.uid), index = metagraph.uids, values = metagraph.S ), - axon.to_dataframe( metagraph = metagraph ), - ], axis = 1) - df['uid'] = df.index - stats_data_table = wandb.Table( dataframe = df ) - wandb_info_axon = axon.to_wandb() - wandb.log( { **wandb_data, **wandb_info_axon }, step = current_block ) - wandb.log( { 'stats': stats_data_table }, step = current_block ) - wandb.log( { 'axon_query_times': wandb.plot.scatter( stats_data_table, "uid", "axon_query_time", title="Axon Query time by UID") } ) - wandb.log( { 'in_weights': wandb.plot.scatter( stats_data_table, "uid", 'w_i_{}'.format(nn.uid), title="Inward weights by UID") } ) - wandb.log( { 'stake': wandb.plot.scatter( stats_data_table, "uid", 's_i', title="Stake by UID") } ) - - # Save the model - gp_server.save(config.neuron.full_path) - - if current_block - last_set_block > config.neuron.blocks_per_set_weights: - - # --- Setting weights - try: - last_set_block = current_block - # Set self weights to maintain activity. - did_set = subtensor.set_weights( - uids=torch.arange(0,subtensor.n), - weights = chain_weights, - wait_for_inclusion = False, - wallet = wallet, - ) - - if did_set: - logger.success('Successfully set weights on the chain') - else: - logger.error('Failed to set weights on chain. (Timeout)') - except Exception as e: - logger.error('Failure setting weights on chain with error: {}', e) - - - if current_block - last_sync_block > config.neuron.metagraph_sync: - metagraph.sync() - last_sync_block = current_block - - - except KeyboardInterrupt: - # --- User ended session ---- - axon.stop() - dataset.close() - - except Exception as e: - # --- Unknown error ---- - logger.exception('Unknown exception: {} with traceback {}', e, traceback.format_exc()) - diff --git a/bittensor/_neuron/text/template_server/__init__.py b/bittensor/_neuron/text/core_server/__init__.py similarity index 68% rename from bittensor/_neuron/text/template_server/__init__.py rename to bittensor/_neuron/text/core_server/__init__.py index 1a6507eea7..2fefb88748 100644 --- a/bittensor/_neuron/text/template_server/__init__.py +++ b/bittensor/_neuron/text/core_server/__init__.py @@ -19,7 +19,7 @@ Example: $ import neurons - $ neurons.text.template_server.neuron().run() + $ neurons.text.core_server.neuron().run() """ import bittensor @@ -45,10 +45,20 @@ class neuron: bittensor axon object metagraph (:obj:bittensor.metagraph, `optional`): bittensor metagraph object + lasthidden (:obj:bool, `optional`): + lasthidden synapse control + causallm (:obj:bool, `optional`): + causallm synapse control + causallmnext (:obj:bool, `optional`): + causallmnext synapse control + seq2seq (:obj:bittensor.metagraph, `optional`): + seq2seq synapse control + synapse_list (:obj:list of int, `optional`): + Examples:: >>> subtensor = bittensor.subtensor(network='nakamoto') - >>> server = bittensor.neuron.text.template_server.neuron(subtensor=subtensor) + >>> server = bittensor.neuron.text.core_server.neuron(subtensor=subtensor) >>> server.run() """ def __init__( @@ -58,10 +68,39 @@ def __init__( wallet: 'bittensor.wallet' = None, axon: 'bittensor.axon' = None, metagraph: 'bittensor.metagraph' = None, + lasthidden = None, + causallm = None, + causallmnext = None, + seq2seq = None, + synapse_list = None, ): if config == None: config = server.config() config = config; + + if synapse_list != None: + config.neuron.lasthidden = False + config.neuron.causallm = False + config.neuron.causallmnext = False + config.neuron.seq2seq = False + + if bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE in synapse_list: + config.neuron.lasthidden = True + + if bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM in synapse_list: + config.neuron.causallm = True + + if bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT in synapse_list: + config.neuron.causallmnext = True + + if bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ in synapse_list: + config.neuron.seq2seq = True + + config.neuron.lasthidden = lasthidden if lasthidden != None else config.neuron.lasthidden + config.neuron.causallm = causallm if causallm != None else config.neuron.causallm + config.neuron.causallmnext = causallmnext if causallmnext is not None else config.neuron.causallmnext + config.neuron.seq2seq = seq2seq if seq2seq != None else config.neuron.seq2seq + self.check_config( config ) bittensor.logging ( config = config, diff --git a/bittensor/_neuron/text/template_server/main.py b/bittensor/_neuron/text/core_server/main.py similarity index 56% rename from bittensor/_neuron/text/template_server/main.py rename to bittensor/_neuron/text/core_server/main.py index f0775932ba..0b65b25578 100644 --- a/bittensor/_neuron/text/template_server/main.py +++ b/bittensor/_neuron/text/core_server/main.py @@ -1,4 +1,4 @@ import bittensor if __name__ == "__main__": bittensor.utils.version_checking() - template = bittensor.neurons.template_server.neuron().run() + template = bittensor.neurons.core_server.neuron().run() diff --git a/bittensor/_neuron/text/core_server/nucleus_impl.py b/bittensor/_neuron/text/core_server/nucleus_impl.py new file mode 100644 index 0000000000..661e0cf38e --- /dev/null +++ b/bittensor/_neuron/text/core_server/nucleus_impl.py @@ -0,0 +1,567 @@ +import argparse +import math +import bittensor +import torch +from torch import nn +import torch.nn.functional as F +from types import SimpleNamespace +from typing import Tuple, Optional + +from transformers import AutoModel,AutoTokenizer,AutoConfig, AutoModelForCausalLM +from torch.nn.utils.rnn import pad_sequence +from bittensor.utils.tokenizer_utils import prep_tokenizer, get_translation_map, translate_logits_to_probs_std, \ + translate_special_token_text, pad_offsets, topk_token_phrases, compact_topk_token_phrases + +from loguru import logger; logger = logger.opt(colors=True) + +class server(torch.nn.Module): + def __init__(self, + config: 'bittensor.config' = None, + pretrained: bool = None, + model_name: str = None, + padding: bool =None, + interpolate: bool =None, + inter_degree: str = None, + model = None, + tokenizer = None, + mapping_function = None, + token_remap = None, + checking= None): + r"""" Creates a server that serves up a pretrained miner on the bittensor network + Args: + config (:obj:`bittensor.Config`, `required`): + bittensor.server.config() + pretrained (:obj:bool , `optional`): + if the model should pretrained or not + model_name (:obj:string , `optional`): + name of the pretrained model from huggingface to use + padding (:obj:bool, `optional`): + If the server should pad out to match the hidden units that the bittensor network is using + If set to False, it will instead create a mapping layer to do the same thing. + interpolate (:obj:bool, `optional`): + If the server should interpolate between sequence length differences. + If set to false, there should be a mapping function that takes care of the differnces + inter_degree (:obj:str, `optional`): + The Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area) + model (:obj:torch.module, `optional`): + Overrides the huggingface pretrained model with your own pretrained model + tokenizer (:obj:huggingface.tokenizer, `optional`): + Overrides the huggingface tokenizer with your tokenizer + mapping_function (:obj:Callable, `optional`): + Custom mapping function that maps between sequence length differences between tokenizers + token_remap (:obj:Callable, `optional`): + Custom function that maps between tokenizers (defaults to self.remapping_token) + """ + super(server, self).__init__() + if config == None: config = server.config() + self.config = config;print(config) + self.std_tokenizer = bittensor.tokenizer() + self.device = config.neuron.device + + #setting up pretrained model + self.model_name = model_name if model_name != None else config.neuron.model_name + self.pretrained = pretrained if pretrained != None else config.neuron.pretrained + if self.pretrained == True: + self.pre_model = model if model != None else AutoModelForCausalLM.from_pretrained(self.model_name) + self.tokenizer = tokenizer + if tokenizer is None: + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + except ValueError: # when fast not available as in https://github.com/huggingface/tokenizers/pull/1005 + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) + + elif self.pretrained == False: + model_config = AutoConfig.from_pretrained(self.model_name) + model_config.vocab_size= bittensor.__vocab_size__ + self.pre_model = model if model != None else AutoModel.from_config(model_config) + self.tokenizer = bittensor.tokenizer() + + # Define PAD Token = EOS Token (GPT2 generate convention, when PAD Token is None) + # https://github.com/huggingface/transformers/blob/49c8c67fb815a277405f84dea4a66353e19fb347/tests/models/gpt2/test_modeling_gpt2.py#L532 + if self.pre_model.config.pad_token_id is None and self.pre_model.config.eos_token_id is not None: + self.pre_model.config.pad_token_id = self.pre_model.config.eos_token_id + + self.tokenizer = prep_tokenizer(self.tokenizer, self.std_tokenizer) + self.to_translation_map = get_translation_map(self.tokenizer, self.std_tokenizer) + self.from_translation_map = get_translation_map(self.std_tokenizer, self.tokenizer) + self.split_map_cache = {} + + if self.config.neuron.local_train or self.config.neuron.remote_train: + self.pre_model.train() + self.set_fine_tuning_params() + + else: + self.pre_model.eval() + + if self.config.neuron.autocast and self.device[:4] == 'cuda': + self.pre_model.half() + + #parameters of the models + self.final_dim = bittensor.__network_dim__ + self.pre_dimension = self.pre_model.config.hidden_size + self.padding = padding if padding != None else config.neuron.padding + self.interpolate = interpolate if interpolate != None else config.neuron.interpolate + self.inter_degree = inter_degree if inter_degree != None else config.neuron.inter_degree + self.checking = checking if checking != None else config.neuron.checking + self.mapping_function= mapping_function + self.token_remap = token_remap if token_remap is not None else self.remapping_token + + if self.config.neuron.padding == False: + self.mapping = torch.nn.Linear( self.pre_dimension, self.final_dim) + + self.decoder = torch.nn.Linear( self.final_dim, bittensor.__vocab_size__ , bias=False) + self.loss_fct = torch.nn.CrossEntropyLoss() + + self.outputs_cache = None + self.gradients_cache = None + self.best_loss = math.inf + + #checking if the parameters of the server makes sense + if self.checking and pretrained == True: + self.check() + + # -- keeps track of gradients applied + self.backward_gradients_count = 0 + + + def set_fine_tuning_params(self) -> Tuple[bool, str]: + r''' Set to tune only the parameter of the last layer + Returns: + reached_last_layer (:type:`bool`): + If we have set partial of the model to requires grad. + + last_layer_name (:type:`string`): + The name of the last layer that user specified or we found. + None if the user did not specify and we couldnt find it. + ''' + def find_last_layer(model: torch.nn.Module) -> Optional[str]: + r''' Recursively find the last layer in a nn.ModuleList + Args: + model (:obj:`torch.module`): + The model (or sub-model) to fine the last layer from. + Returns: + name (:type:`str`): + The name (or sub-name) of the last layer. + None if not found + ''' + reverted_child_list = [(name, child) for name, child in model.named_children()] + reverted_child_list.reverse() + + for name, child in reverted_child_list: + if isinstance(child, nn.ModuleList): + if self.config.neuron.finetune.num_layers > len(child): + logger.warning(f'Number of finetune layers was set higher then the layers avaliable {len(child)}') + return None + return (name + '.' +str(len(child) - self.config.neuron.finetune.num_layers)) + + for name, child in reverted_child_list: + name_ = find_last_layer(child) + if name_ != None: + return (name+'.'+ name_) + + return None + + if self.config.neuron.finetune.layer_name == None: + last_layer_name = find_last_layer(self.pre_model) + else: + last_layer_name = self.config.neuron.finetune.layer_name + + reached_last_layer = False + + # set the non-last layer parameters not to require grads + if (self.config.neuron.finetune.all) or (last_layer_name == None): + return False, last_layer_name + + logger.success(f'Set to finetune layer {last_layer_name} and onwards') + + for name, param in self.pre_model.named_parameters(): + if last_layer_name in name or reached_last_layer == True: + param.requires_grad = True + reached_last_layer = True + else: + param.requires_grad = False + + if reached_last_layer == False: + if self.config.neuron.finetune.all: + logger.warning('Set to finetune the whole model, this will significantly increase the memory usage.') + else: + logger.warning(f'Cannot identify the last layer of the model with name {last_layer_name}, setting to finetune on all of the parameters.') + + return reached_last_layer, last_layer_name + + def remapping_token(self, token_batch, std_tokenizer=None, return_offsets_mapping=False): + r""" Tokenizer remapping; decodes the message and then remaps the message using a new tokenizer + Args: + token_batch ( :obj:`torch.LongTensor`, `required`): + token_batch to be retokenized, [batch_size, sequence_len] + std_tokenizer ( :obj:`transformers.Tokenizer`, `optional`): + The standard tokenizer which was used to tokenize the input. + return_offsets_mapping ( :obj:`bool`, `required`): + Return offsets_mapping in tokenization to delineate token segment positions. + """ + if std_tokenizer is None: + std_tokenizer = self.std_tokenizer + + text_batch = std_tokenizer.batch_decode(token_batch) # decode tokens to original text + result = translate_special_token_text(text_batch, std_tokenizer, self.tokenizer) # translate special tokens + to_text_batch, from_offsets_batch, to_offsets_batch, pad_offsets_batch = result + + tokens = self.tokenizer(to_text_batch, padding=True, truncation=True, return_tensors='pt', + add_special_tokens=False).to(self.device) # assume tokenizer.padding_side = 'left' + + if return_offsets_mapping: # get offsets_mapping in tokenization to delineate token segment positions + server_tokens = self.tokenizer(to_text_batch, return_offsets_mapping=True, add_special_tokens=False) + std_tokens = std_tokenizer(text_batch, return_offsets_mapping=True) # encode again to get offsets mapping + + # pad offsets so that special token offset widths match for continued correct alignment + tokens['offset_mapping'] = pad_offsets(server_tokens['offset_mapping'], to_offsets_batch, pad_offsets_batch) + tokens['offset_mapping_std'] = pad_offsets(std_tokens['offset_mapping'], from_offsets_batch, + pad_offsets_batch) + return tokens + + def forward(self, inputs, tokenizer=None): + """ + Forward pass through the whole server model. Returns the loss and decoded predictions. + + Args: + inputs ( :obj:`torch.Tensor`, `required`): + torch inputs to be forward processed. + tokenizer (:obj:'huggingface.tokenizer', optional): + The tokenizer which was used to tokenize the inputs + Returns: + loss (:obj:`torch.FloatTensor`): + MLM loss from the inputs + decoded_targets (:obj:`torch.FloatTensor`): + Decoded predictions of the next token in the sentence. + + """ + message, model_output, decoded_targets = self.local_forward(inputs, tokenizer)[1] + + shift_logits = decoded_targets[..., :-1, :].contiguous() + shift_labels = inputs[..., 1:].contiguous() + loss = self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) + + return loss, decoded_targets + + def local_forward(self, token_batch, tokenizer=None, encode_len=bittensor.__network_dim__, model_output = None): + r""" Forward pass through the pretrained model and possible mappings between hidden units. + The response tensor should be the hidden units computed using the local context and + with shape: [batch_size, sequence_len, __vocab_size__]. + + Args: + token_batch ( :obj:`torch.LongTensor`, `required`): + torch inputs to be forward processed, [batch_size, sequence_len] + tokenizer ( huggingface.tokenizer, `optional`): + The tokenizer which was used to tokenize the inputs + encode_len ( :obj:`int`, `optional`): + logit encoding length, default bittensor.__network_dim__ length + model_output (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `optional`): + The output of huggingface auto model. + + Returns: + model_outputs (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `required`): + The output of huggingface auto model. + + logits (:obj:`torch.FloatTensor`): + The nucleus's logit outputs as a torch tensor of shape [batch_size, sequence_len, __vocab_size__] + """ + tokens = self.token_remap(token_batch, std_tokenizer=tokenizer) # remap to server tokenizer + + if model_output == None: + if self.config.neuron.local_train: + model_output = self.pre_model(input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask'], + output_hidden_states=True) + else: + with torch.no_grad(): + model_output = self.pre_model(input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask'], + output_hidden_states=True) + + return None, model_output, model_output.logits + + def encode_forward(self,inputs,tokenizer=None, model_output = None): + r""" Forward pass through the pretrained model and possible mappings between hidden units. + The response tensor should be the hidden units computed using the local context and with shape: [batch_size, sequence_len, __network_dim__]. + + Args: + inputs ( :obj:`torch.Tensor`, `required`): + torch inputs to be forward processed. + tokenizer ( huggingface.tokenizer, `optional`): + The tokenizer which was used to tokenize the inputs + model_outputs (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `optional`): + The output of huggingface auto model. + + Returns: + model_outputs (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `required`): + The output of huggingface auto model. + + encoded_hidden (:type:`torch.Tensor`, `required`) + The hidden layer output as a torch tensor of shape [batch_size, sequence_len, __network_dim__ ] + """ + sen_len = inputs.size() + tokens = self.token_remap(inputs, tokenizer) # remap to server tokenizer + + if model_output == None: + if self.config.neuron.remote_train: + model_output = self.pre_model(input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask'], + output_hidden_states=True) + else: + with torch.no_grad(): + model_output = self.pre_model(input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask'], + output_hidden_states=True) + + pre_hidden = model_output.hidden_states[-1] + + if self.interpolate and sen_len[1] != pre_hidden.size()[1]: + down= F.interpolate(pre_hidden.unsqueeze(1),size=[sen_len[1],pre_hidden.size()[2]],mode=self.inter_degree).squeeze(1) + elif self.mapping_function: + down = self.mapping_function(pre_hidden) + else: + down = pre_hidden + + if self.padding: + padding_l = (self.final_dim-self.pre_dimension)//2 + padding_r = (self.final_dim-self.pre_dimension) - padding_l + encoded_hidden = F.pad(down, (padding_l, padding_r), "constant", 0) + else: + encoded_hidden = self.mapping(down) + + return None, model_output, encoded_hidden + + def encode_forward_causallm(self, token_batch, tokenizer=None, encode_len=bittensor.__network_dim__, model_output=None): + r""" Forward pass through the pretrained model and possible mappings between hidden units. + The response tensor should be the hidden units computed using the local context and + with shape: [batch_size, sequence_len, __vocab_size__]. + + Args: + token_batch ( :obj:`torch.LongTensor`, `required`): + torch inputs to be forward processed, [batch_size, sequence_len] + tokenizer ( huggingface.tokenizer, `optional`): + The tokenizer which was used to tokenize the inputs + encode_len ( :obj:`int`, `optional`): + logit encoding length, default bittensor.__network_dim__ length + model_output (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `optional`): + The output of huggingface auto model. + + Returns: + model_outputs (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `required`): + The output of huggingface auto model. + + logits_std (:obj:`torch.FloatTensor`): + The nucleus's logit outputs as a torch tensor of shape [batch_size, sequence_len, __vocab_size__] + """ + tokens = self.token_remap(token_batch, std_tokenizer=tokenizer, return_offsets_mapping=True) # remap to server tokenizer + + def _forward(_model_output=model_output): + if _model_output is None: + # transformer models like gerpt2 typically perform worse with left-side attention mask, so turning it off + _model_output = self.pre_model(input_ids=tokens['input_ids'], + #attention_mask=tokens['attention_mask'], + output_hidden_states=True) + pre_logits = _model_output.logits # [batch_size, sequence_len, self.tokenizer.vocab_len] + + probs_std = translate_logits_to_probs_std(pre_logits, + tokens['offset_mapping'], tokens['offset_mapping_std'], + self.tokenizer, self.std_tokenizer, + self.split_map_cache, + self.to_translation_map, self.from_translation_map, + tokens['input_ids'], token_batch) + probs_std = probs_std.to(self.device) + logits_std = torch.log(probs_std + 1e-40) + + #removing the loss calculation for stablity testing + original_loss = self.get_loss_fct(pre_logits, tokens['input_ids']).item() + translated_loss = self.get_loss_fct(logits_std, token_batch).item() + #message = 'Success' + message = f'Loss: {original_loss:.2f} → {translated_loss:.2f}' + # logger.info(f'TextCausalLM \t| Server loss: {original_loss: .2f} \t| Translated loss: {translated_loss: .2f}') + + return message, _model_output, logits_std + + if self.config.neuron.remote_train: + return _forward() # track gradients for training + + with torch.no_grad(): + return _forward() # no gradients + + def encode_forward_causallmnext(self, token_batch, std_tokenizer=None, topk: int = 4096, model_output=None): + r""" + Forward pass through the pretrained model and select topk tokenizer logits and retokenize with std_tokenizer, + then compact new token phrases and probabilities + into 1-D tensor [ >= batch_size * (2 * topk + 1)] prob + at least 1 token per phrase + floor_prob. + The floor probability is the mean probability of token phrases not captured in topk, required since + the server tokenizer vocab_size may not be known to the receiver/validator. + + Args: + token_batch ( :obj:`torch.LongTensor`, `required`): + torch inputs to be forward processed, [batch_size, std_sequence_len]. + std_tokenizer ( :obj:`PreTrainedTokenizerBase`, `optional`): + The standard tokenizer which was used to tokenize the inputs. + topk ( :obj:`int`, `optional`): + Amount of std_tokenized server phrases with highest probability to produce. + model_output (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `optional`): + The output of transformers AutoModel. + + Returns: + model_outputs (:obj:`transformers.modeling_outputs.BaseModelOutputWithCrossAttentions`, `required`): + The output of transformers AutoModel. + topk_tensor (:obj:`torch.Tensor`, `required`): + [batch_size, (topk + 1), max_len] tensor includes topk token probabilities (prob_k) + floor_prob + in first column with gradients attached, with std_tokens in remaining columns with ignore_index padding. + Content structure: + [[[prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., ignore_index?], + [prob_k=1_b=0, tok_0_k=1_b=0, tok_1_k=1_b=0, ..., ignore_index?], + [...], + [prob_floor_b=0, ignore_index, ..., ignore_index]], + [[prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., ignore_index?], + [prob_k=1_b=1, tok_0_k=1_b=1, tok_1_k=1_b=1, ..., ignore_index?], + [...], + [prob_floor_b=1, ignore_index, ..., ignore_index]], + [...]] + """ + if std_tokenizer is None: + std_tokenizer = self.std_tokenizer + + # remap to server tokenizer, expect right-aligned sequences so that last position keeps continuation prediction + tokens = self.token_remap(token_batch, std_tokenizer) + + def _forward(_model_output=model_output): + if _model_output is None: + _model_output = self.pre_model(input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask'], + output_hidden_states=True) + + # model_output.logits: [batch_size, sequence_len, server_vocab_size] + last_logits = _model_output.logits[:, -1, :] # [batch_size] server prediction of continuation, right-aligned + + # Select topk tokenizer logits and retokenize with std_tokenizer, + # then compact new token phrases and probabilities into 1-D tensor + topk_tensor = topk_token_phrases(last_logits, self.tokenizer, topk=topk) # [batch_size, (topk + 1), max_len] + + original_loss = self.get_loss_fct(_model_output.logits, tokens['input_ids']).item() + message = f'Loss: {original_loss:.2f}' + #message = 'Success' + + return message, _model_output, topk_tensor + + if self.config.neuron.remote_train: + return _forward() # track gradients for training + + with torch.no_grad(): + return _forward() # no gradients + + def get_loss_fct(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor: + """ + Calculate loss_fct, CausalLM loss, next-token prediction loss. + Args: + logits (:obj:`torch.FloatTensor`, `required`): + [batch_size, sequence_len, bittensor.__network_dim__] + labels (:obj:`torch.LongTensor`, `required`): + [batch_size, sequence_len] + + Returns: + loss (:obj:`torch.FloatTensor`): + scalar + """ + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = self.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + return loss + + def check(self): + r"""Checks the server settings + """ + assert self.tokenizer.name_or_path == self.pre_model.name_or_path, 'incorrect model ({}) and tokenizer ({})'.format(self.pre_model.name_or_path,self.tokenizer.name_or_path) + if self.interpolate == False: + assert self.mapping_function != None, 'Incorrect Settings; needs atleast one mapping function for sequence length changes' + + def save(self, path): + try: + state_dict = { + 'model': self.pretrained, + 'pretrained_model': self.pre_model.state_dict(), + 'decoder': self.decoder.state_dict(), + 'best_loss': self.best_loss, + } + if self.padding == False: + state_dict['mapping'] = self.mapping.state_dict() + torch.save( state_dict, "{}/model.torch".format( path) ) + bittensor.logging.success(prefix='Saved model', sufix='{}/model.torch'.format( path ) ) + except Exception as e: + logger.exception('Failed to save model with error:{}', e) + + def load(self, path): + try: + state_dict= torch.load("{}/model.torch".format( path )) + if self.pretrained == state_dict['model']: + self.pre_model.load_state_dict(state_dict['pretrained_model'], strict=False) + self.decoder.load_state_dict(state_dict['decoder']) + if self.padding == False: + self.mapping.load_state_dict(state_dict['mapping']) + self.best_loss = state_dict['best_loss'] + bittensor.logging.success( prefix = 'Reloaded model', sufix = '{}/model.torch'.format( path )) + + + except Exception as e: + logger.warning('No saved model found with error: {}', e) + + @staticmethod + def config (): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, help='If set, defaults are overridden by passed file.') + + # ML model arguements + parser.add_argument('--neuron.learning_rate', type=float, help='Training initial learning rate.', default=0.01) + parser.add_argument('--neuron.momentum', type=float, help='optimizer momentum.', default=0.8) + parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0) + parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu")) + parser.add_argument('--neuron.model_name', type=str, help='pretrained model from hugging face',default='gpt2') + parser.add_argument('--neuron.pretrained', action='store_false', help='if the model should be pretrained',default=True) + parser.add_argument('--neuron.padding', action='store_false', help='To pad out final dimensions',default=True) + parser.add_argument('--neuron.interpolate', action='store_false', help='To interpolate between sentence length',default=True) + parser.add_argument('--neuron.inter_degree', type=str, help='Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area)', default='nearest') + parser.add_argument('--neuron.autocast', action='store_true', help='(experimental) autocasts the model to float16. Must require cuda', default=False) + parser.add_argument('--neuron.local_train', action='store_true', help='''If true, allow local training''', default=False) + parser.add_argument('--neuron.remote_train', action='store_true', help='''If true, allow remote training''', default=False) + parser.add_argument('--neuron.finetune.all', action='store_true', help='Finetune your whole model instead of only on the last (few) layers', default=False) + parser.add_argument('--neuron.finetune.num_layers', type=int, help='The number of layers to finetune on your model.', default=1) + parser.add_argument('--neuron.finetune.layer_name', type=str, help='Specify since which layer to finetune. eg. encoder.layer.11', default=None) + + # Miner arguements + parser.add_argument('--neuron.name', type=str, help='Trials for this miner go in miner.root / (wallet_cold - wallet_hot) / miner.name ', default='core_server') + parser.add_argument('--neuron.checking', action='store_false', help='To check if server settings are correct',default=True) + parser.add_argument('--neuron.restart', action='store_true', help='If True, train the neuron from the beginning', default=False) + parser.add_argument('--neuron.blacklist.stake', type=float, help='Amount of stake (tao) in order not to get blacklisted', default=10) + parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch', default=10) + parser.add_argument('--neuron.blacklist.time', type=int, help='how often a peer can query you (seconds) ', default=1) + parser.add_argument('--neuron.blocks_per_set_weights', type=float, help='how often to set weights', default=100) + parser.add_argument('--neuron.metagraph_sync', type=float, help='how often to sync the metagraph', default=100000) + parser.add_argument('--neuron.blacklist_allow_non_registered', action='store_true', help='''If true, allow non-registered peers''', default=False) + parser.add_argument('--neuron.disable_blacklist', action='store_true', help='Turns off blacklisting', default=False) + parser.add_argument('--neuron.disable_priority', action='store_true', help='Turns off priority threadpool', default=False) + + # Synapse Arguements + parser.add_argument('--neuron.lasthidden', action='store_false', help='To turn off last hidden synapse', default=True) + parser.add_argument('--neuron.causallm', action='store_false', help='To turn off causallm synapse', default=True) + parser.add_argument('--neuron.causallmnext', action='store_false', help='To turn off causallmnext synapse', default=True) + parser.add_argument('--neuron.seq2seq', action='store_false', help='To turn off seq2seq synapse', default=True) + parser.add_argument('--neuron.lasthidden_stake', type = float, help='the amount of stake to run last hidden synapse',default=0) + parser.add_argument('--neuron.causallm_stake', type = float, help='the amount of stake to run causallm synapse',default=0) + parser.add_argument('--neuron.causallmnext_stake', type=float, help='the amount of stake to run causallmnext synapse', default=0) + parser.add_argument('--neuron.seq2seq_stake', type = float, help='the amount of stake to run seq2seq synapse',default=0) + + + bittensor.wallet.add_args( parser ) + bittensor.axon.add_args( parser ) + bittensor.subtensor.add_args( parser ) + bittensor.logging.add_args( parser ) + bittensor.wandb.add_args(parser) + bittensor.prioritythreadpool.add_args( parser ) + bittensor.dataset.add_args( parser ) + bittensor.metagraph.add_args( parser ) + return bittensor.config( parser ) + diff --git a/bittensor/_neuron/text/core_server/run.py b/bittensor/_neuron/text/core_server/run.py new file mode 100644 index 0000000000..acc095bef3 --- /dev/null +++ b/bittensor/_neuron/text/core_server/run.py @@ -0,0 +1,429 @@ +#!/bin/python3 +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +""" The Exodus base client. + +Example: + $ python miners/text/template_client.py + +""" +import bittensor +import sys +import time +import datetime +from threading import Lock +from datetime import datetime,timedelta +from loguru import logger; logger = logger.opt(colors=True) +from torch.nn.utils.rnn import pad_sequence + +import wandb +import pandas +import torch +import torch.nn.functional as F +from torch.nn.utils import clip_grad_norm_ + +def serve( + config, + model, + subtensor = None, + wallet = None, + axon= None, + metagraph = None, + ): + config.to_defaults() + model= model.to(model.device) + + # Create Subtensor connection + subtensor = bittensor.subtensor(config = config) if subtensor == None else subtensor + + # Load/Create our bittensor wallet. + if wallet == None: + wallet = bittensor.wallet( config = config ).create().reregister(subtensor=subtensor) + else: + wallet.reregister(subtensor=subtensor) + + + # Load/Sync/Save our metagraph. + if metagraph == None: + metagraph = bittensor.metagraph ( + subtensor = subtensor + ) + + metagraph.load().sync().save() + + # Create our optimizer. + optimizer = torch.optim.SGD( + [ {"params": model.parameters()} ], + lr = config.neuron.learning_rate, + momentum = config.neuron.momentum, + ) + mutex = Lock() + + timecheck_dicts = {bittensor.proto.RequestType.FORWARD:{}, bittensor.proto.RequestType.BACKWARD:{}} + n_topk_peer_weights = subtensor.min_allowed_weights + + def priority(pubkey:str, request_type:bittensor.proto.RequestType, inputs_x) -> float: + r"""Calculates the priority on requests based on stake and size of input + Args: + pubkey ( str, `required`): + The public key of the caller. + inputs_x ( :obj:`torch.Tensor`, `required`): + torch inputs to be forward processed. + request_type ( bittensor.proto.RequestType, `required`): + the request type ('FORWARD' or 'BACKWARD'). + """ + try: + uid = metagraph.hotkeys.index(pubkey) + priority = metagraph.S[uid].item() + + except: + # zero priority for those who are not registered. + priority = 0 + + return priority + + def forward_generate( inputs_x:torch.FloatTensor, synapse, model_output = None): + tokens = model.token_remap(inputs_x.to(model.device)) + output = model.pre_model.generate( + input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask'], + max_length=max(tokens['input_ids'].shape[1] + 1, synapse.num_to_generate), + num_beams=synapse.num_beams, + no_repeat_ngram_size=synapse.no_repeat_ngram_size, + early_stopping = synapse.early_stopping, + do_sample=synapse.do_sample, + top_p=synapse.top_p, + num_return_sequences=synapse.num_return_sequences, + temperature = synapse.temperature, + repetition_penalty = synapse.repetition_penalty, + length_penalty = synapse.length_penalty, + max_time = synapse.max_time, + num_beam_groups = synapse.num_beam_groups, + ) + raw_texts = [model.tokenizer.decode(out) for out in output] + tokens = [model.std_tokenizer.encode(raw_text, return_tensors="pt")[:,:synapse.num_to_generate].view(-1) for raw_text in raw_texts] + bittensor_output = pad_sequence(tokens, batch_first=True) + return None, model_output, bittensor_output + + def forward_hidden_state(inputs_x:torch.FloatTensor, synapse, model_output = None): + with mutex: + message, model_output, hidden = model.encode_forward(inputs_x.to(model.device), model_output=model_output) + return message, model_output, hidden + + def forward_casual_lm(inputs_x:torch.FloatTensor, synapse, model_output = None): + with mutex: + message, model_output, logits = model.encode_forward_causallm(inputs_x.to(model.device), model_output=model_output) + return message, model_output, logits + + def forward_casual_lm_next(inputs_x: torch.FloatTensor, synapse, model_output=None): + with mutex: + message, model_output, topk_token_phrases = model.encode_forward_causallmnext(inputs_x.to(model.device), + topk=synapse.topk, + model_output=model_output) + # topk_token_phrases: [sum_b(sum_k(len(phrase_k) + 1)_b)] contains topk token phrases and probabilities + # Compacted 1-D tensor >= batch_size * (2 * topk + 1) + return message, model_output, topk_token_phrases + + def optimizer_step(): + optimizer.step() + optimizer.zero_grad() + + def blacklist(pubkey:str, request_type:bittensor.proto.RequestType) -> bool: + r"""Axon security blacklisting, used to blacklist message from low stake members + Args: + pubkey ( str, `required`): + The public key of the caller. + request_type ( bittensor.proto.RequestType, `required`): + the request type ('FORWARD' or 'BACKWARD'). + """ + # Check for registrations + + def registration_check(): + # If we allow non-registered requests return False = not blacklisted. + is_registered = pubkey in metagraph.hotkeys + if not is_registered: + if config.neuron.blacklist_allow_non_registered: + + return False + raise Exception('Registration blacklist') + + # Check for stake + def stake_check() -> bool: + # Check stake. + uid = metagraph.hotkeys.index(pubkey) + if metagraph.S[uid].item() < config.neuron.blacklist.stake: + raise Exception('Stake blacklist') + return False + + # Check for time + def time_check(): + current_time = datetime.now() + # Only check if the request are forward requests + timecheck = timecheck_dicts[request_type] + if pubkey in timecheck.keys(): + prev_time = timecheck[pubkey] + if current_time - prev_time >= timedelta(seconds=config.neuron.blacklist.time): + timecheck[pubkey] = current_time + else: + timecheck[pubkey] = current_time + raise Exception('Time blacklist') + else: + timecheck[pubkey] = current_time + + return False + + + # Black list or not + try: + registration_check() + + time_check() + + #stake_check() + + return False + + except Exception as e: + return True + + def synapse_check(synapse, hotkey): + """ + Custom synapse function to protect certain synapse functions depending on the stake and weight. + Certain synapses require more compute than others. For instance, TEXT_SEQ_2_SEQ requires a significantly + more commitment by the server than a requeset for TEXT_CAUSAL_LM_NEXT. + + Args: + synapse (:obj:`bittensor.proto.SynapseArgs`, `required`): + The proto message that contains additional args for individual synapse functions + hotkey (:obj:`torch.FloatTensor`, `required`): + The hotkey that sent the request + + """ + ## Uid that sent the request + incoming_uid = metagraph.hotkeys.index(hotkey) + if synapse.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE: + + if metagraph.S[incoming_uid] < config.neuron.lasthidden_stake: + return False + + elif synapse.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM: + + if metagraph.S[incoming_uid] < config.neuron.causallm_stake: + return False + + elif synapse.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT: + + if metagraph.S[incoming_uid] < config.neuron.causallmnext_stake: + return False + + elif synapse.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ: + + if (metagraph.S[incoming_uid] < config.neuron.seq2seq_stake) and (metagraph.S[incoming_uid, uid]): + return False + else: + return False + + return True + + def backward_callback(inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, synapses=[] ): + """ + The default backward callback when no callback is attached: Is used to call specific synapse functions + + Args: + inputs_x (:obj:`torch.FloatTensor`, `required`): + The inputs that will be passed to the synapse functions + grads_dy (:obj:`torch.FloatTensor`, `required`): + The gradients that will be passed to the synapse functions + synapses (:obj: list of bittensor.proto.SynapseArgs, 'Optional') + The proto message that contains additional args for individual synapse functions + + Returns: + response_tensors: (:obj: list of bittensor.proto.Tensor, `required`): + serialized tensor response from the nucleus call or None. + response_codes: (:obj: list of bittensor.proto.ReturnCode, `required`) + return code associated with forward call i.e. Success of Timeout. + response_messages: (:obj: list of strings, `required`) + return message associated with synapse call + """ + # --- initialize response variables --- + response_tensors = [] + response_codes = [] + response_messages = [] + + if not config.neuron.remote_train: + return response_tensors, response_codes, response_messages + + # --- calling attached synapses --- + with mutex and torch.enable_grad() and torch.autograd.set_detect_anomaly(True): + for index, synapse in enumerate(synapses): + try: + if synapse.synapse_type in axon.synapse_callbacks and axon.synapse_callbacks[synapse.synapse_type] != None: + model_output, response_tensor = axon.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse) + grads_dy_norm = grads_dy[index]/(grads_dy[index].sum() + 0.00001) + torch.autograd.backward ( + tensors = [ response_tensor ], + grad_tensors = [ grads_dy_norm ], + retain_graph=True + ) + model.backward_gradients_count += inputs_x[index].size(0) + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.Success) + response_messages.append('Success') + else: + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.NotImplemented) + response_messages.append('Not Implemented') + except Exception as e: + # --- Exception Hit in Synapse --- + response_tensors.append(None) + response_codes.append(bittensor.proto.ReturnCode.UnknownException) + response_messages.append(str(e)) + + return response_tensors, response_codes, response_messages + + # Create our axon server and subscribe it to the network. + if axon == None: + axon = bittensor.axon( + config = config, + wallet = wallet, + synapse_checks=synapse_check, + synapse_last_hidden = forward_hidden_state if model.config.neuron.lasthidden else None, + synapse_causal_lm = forward_casual_lm if model.config.neuron.causallm else None, + synapse_causal_lm_next = forward_casual_lm_next if model.config.neuron.causallmnext else None, + synapse_seq_2_seq = forward_generate if model.config.neuron.seq2seq else None , + blacklist = blacklist if not model.config.neuron.disable_blacklist else None, + priority = priority if not model.config.neuron.disable_priority else None, + ).start().serve(subtensor=subtensor) + + axon.optimizer_step = optimizer_step + axon.attach_backward_callback(backward_callback) + # Training Data + if config.neuron.local_train: + dataset = bittensor.dataset(config=config) + dataset.set_data_size(10, 64) + data = next(dataset) + + # load our old model + if not config.neuron.restart : + model.load(config.neuron.full_path) + + if config.wandb.api_key != 'default': + # --- Init Wandb. + bittensor.wandb( + config = config, + cold_pubkey = wallet.coldkeypub.ss58_address, + hot_pubkey = wallet.hotkey.ss58_address, + root_dir = config.neuron.full_path + ) + + last_set_block = subtensor.get_current_block() + + + # --- Run Forever. + while True: + + iteration = 0 + local_data = {} + nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address) + uid = metagraph.hotkeys.index( wallet.hotkey.ss58_address ) + current_block = subtensor.get_current_block() + end_block = current_block + config.neuron.blocks_per_epoch + if config.neuron.local_train: + # --- Training step. + while end_block >= current_block: + if current_block != subtensor.get_current_block(): + loss, _ = model( next( dataset ).to(model.device) ) + if iteration > 0 : + losses += loss + else: + losses = loss + iteration += 1 + current_block = subtensor.get_current_block() + logger.info(f'local training\titeration: {iteration}\tloss: {loss}') + + if iteration != 0: + (losses/iteration).backward() + + else: + while end_block >= current_block: + time.sleep(12) + current_block = subtensor.get_current_block() + + + # --- Update parameters + if (config.neuron.local_train and iteration > 0) or (config.neuron.remote_train and model.backward_gradients_count > 0): + # Custom learning rate + if model.backward_gradients_count > 0: + optimizer.param_groups[0]['lr'] = 0.1/(model.backward_gradients_count) + else: + optimizer.param_groups[0]['lr'] = 0.1 + + logger.info('Backpropagation Started') + clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + model.backward_gradients = 0 + logger.info('Backpropagation Successful: Model updated') + local_data = {'local/loss': losses.detach().item() / iteration} + + if local_data['local/loss'] < model.best_loss: + model.best_loss = local_data['local/loss'] + model.save(config.neuron.full_path) + + wandb_data = { + 'stake': nn.stake, + 'rank': nn.rank, + 'trust': nn.trust, + 'consensus': nn.consensus, + 'incentive': nn.incentive, + 'emission': nn.emission, + } + + if config.wandb.api_key != 'default': + + df = pandas.concat( [ + bittensor.utils.indexed_values_to_dataframe( prefix = 'w_i_{}'.format(nn.uid), index = metagraph.uids, values = metagraph.W[:, uid] ), + axon.to_dataframe( metagraph = metagraph ), + ], axis = 1) + df['uid'] = df.index + wandb_info_axon = axon.to_wandb() + wandb.log( { **wandb_data, **wandb_info_axon, **local_data }, step = current_block ) + wandb.log( { 'stats': wandb.Table( dataframe = df ) }, step = current_block ) + + if current_block - last_set_block > config.neuron.blocks_per_set_weights: + try: + bittensor.__console__.print('[green]Current Status:[/green]', {**wandb_data, **local_data}) + + last_set_block = current_block + # Set self weights to maintain activity. + # --- query the chain for the most current number of peers on the network + chain_weights = torch.zeros(subtensor.n) + chain_weights [ uid ] = 1 + did_set = subtensor.set_weights( + uids=torch.arange(0,subtensor.n), + weights = chain_weights, + wait_for_inclusion = False, + wallet = wallet, + ) + + metagraph.sync() + if did_set: + logger.success('Successfully set weights on the chain') + else: + logger.error('Failed to set weights on chain. (Timeout)') + except Exception as e: + logger.error('Failure setting weights on chain with error: {}', e) diff --git a/bittensor/_neuron/text/core_validator/__init__.py b/bittensor/_neuron/text/core_validator/__init__.py index 7f24525598..2b85205f64 100644 --- a/bittensor/_neuron/text/core_validator/__init__.py +++ b/bittensor/_neuron/text/core_validator/__init__.py @@ -21,10 +21,9 @@ $ python3 miners/text/core_validator.py --logging.debug """ -import sys import argparse import time -from types import SimpleNamespace +import datetime import bittensor import torch import os @@ -34,21 +33,59 @@ import traceback from rich import print from rich.console import Console +from rich.style import Style +from rich.table import Table from rich.traceback import install -from ..neuron_utilities import joining_context, partial_contexts, ThreadQueue -import torch.nn as nn -import random +from typing import List, Tuple, Callable, Dict, Any, Union + +from ..neuron_utilities import ThreadQueue, PositionalEncoding, calc_loss_fct +from bittensor.utils.tokenizer_utils import phrase_cross_entropy + from torch.nn.utils import clip_grad_norm_ -import torch.nn.functional as F from torch.nn import TransformerEncoder, TransformerEncoderLayer from loguru import logger -import cProfile from threading import Lock logger = logger.opt( colors=True ) console = Console() install(show_locals=True) +# Neuron stats recorded by validator neuron/nucleus +# [Column_name, key_name, format_string, rich_style] # description +neuron_stats_columns = [ + ['UID', 'uid', '{:.0f}', 'cyan'], # neuron UID + ['Upd!', 'updates!', '{}', 'bright_yellow'], # number of exponential moving average updates with zeroing on + ['nUpd', 'updates_shapley_values_nxt', '{}', 'bright_yellow'], # number of exponential moving average updates to nShap + ['mUpd', 'updates_shapley_values_min', '{}', 'bright_yellow'], # number of exponential moving average updates to mShap + ['nTime', 'response_time_nxt', '{:.2f}', 'yellow'], # response time to TextCausalLMNext forward requests [TextCausalLMNext] + ['sTime', 'response_time', '{:.2f}', 'yellow'], # response time to TextCausalLM forward requests + ['Route', 'routing_score', '{:.3f}', 'grey30'], # validator routing score (higher preferred) + ['Weight', 'weight', '{:.5f}', 'green'], # weight set on substrate (each epoch) + ['nShap!', 'shapley_values_nxt!', '{:.0f}', 'magenta'], # Shapley value (=vBase+vSyn) for phrase validation (zeroing) [TextCausalLMNext] + ['nShap', 'shapley_values_nxt', '{:.0f}', 'magenta'], # Shapley value (=vBase+vSyn) for phrase validation [TextCausalLMNext] + ['mShap!', 'shapley_values_min!', '{:.0f}', 'bright_magenta'], # min(Shap, vShap) of sequence and validation Shapley (zeroing) + ['mShap', 'shapley_values_min', '{:.0f}', 'bright_magenta'], # min(Shap, vShap) of sequence and validation Shapley + ['sLoss', 'loss', '{:.2f}', 'bright_cyan'], # next token prediction loss average over sequence + ['vLoss', 'loss_val', '{:.2f}', 'bright_cyan'], # next token prediction loss for validation task + ['nvLoss', 'loss_val_nxt', '{:.2f}', 'bright_cyan'], # next token prediction loss for validation task [TextCausalLMNext] + ['nLoss', 'loss_nxt', '{:.2f}', 'bright_cyan'], # next token phrase prediction loss for phrase validation task [TextCausalLMNext] + ['RLoss', 'routing_loss', '{:.3f}', 'grey30'], # MSE between routing_score and conditioned loss + ['nRLoss', 'routing_loss_nxt', '{:.3f}', 'grey30'], # MSE between routing_score_nxt and conditioned loss [TextCausalLMNext] + ['sShap', 'shapley_values', '{:.0f}', 'magenta'], # Shapley value (=Base+Syn) over sequence + ['vShap', 'shapley_values_val', '{:.0f}', 'magenta'], # Shapley value (=vBase+vSyn) for validation + ['sBase', 'base_params', '{:.0f}', ''], # parameter count estimate via adjusted scaling law + ['vBase', 'base_params_val', '{:.0f}', ''], # square root parameter count estimate for validation task + ['nBase', 'base_params_nxt', '{:.0f}', ''], # square root parameter count estimate for phrase validation task [TextCausalLMNext] + ['nParam~', 'est_params_nxt', '{:.2g}', 'magenta'], # parameter count estimate for phrase validation task [TextCausalLMNext] + ['sSyn', 'synergy', '{:.0f}', 'white'], # Shapley pairwise synergy over sequence loss (parameter count estimate) + ['vSyn', 'synergy_val', '{:.0f}', 'white'], # Shapley pairwise synergy over validation loss (count estimate) + ['nSyn', 'synergy_nxt', '{:.0f}', 'white'], # Shapley pairwise synergy over phrase validation loss (count estimate) [TextCausalLMNext] + ['sSynD', 'synergy_loss_diff', '{:.2f}', 'bright_blue'], # Shapley pairwise synergy over sequence loss (loss difference) + ['vSynD', 'synergy_loss_diff_val', '{:.2f}', 'bright_blue'], # Shapley pairwise synergy over validation loss (loss difference) + ['nSynD', 'synergy_loss_diff_nxt', '{:.2f}', 'bright_blue'], # Shapley pairwise synergy over phrase validation loss (loss difference) [TextCausalLMNext] +] + + class neuron: r""" Creates a bittensor neuron that specializes validating other peers. The core validator @@ -109,13 +146,29 @@ def __init__( self.dendrite = bittensor.dendrite ( config = self.config, wallet = self.wallet ) if dendrite == None else dendrite self.device = torch.device ( device = self.config.neuron.device ) self.nucleus = nucleus ( config = self.config, device = self.device, subtensor = self.subtensor ).to( self.device ) - self.dataset = bittensor.dataset ( config = self.config, batch_size = self.subtensor.validator_batch_size, block_size = self.subtensor.validator_sequence_length ) if dataset == None else dataset - + self.dataset = (bittensor.dataset(config=self.config, batch_size=self.subtensor.validator_batch_size, + block_size=self.subtensor.validator_sequence_length + self.config.neuron.validation_len) + if dataset is None else dataset) + self.optimizer = torch.optim.SGD( + self.nucleus.parameters(), lr=self.config.neuron.learning_rate, momentum=self.config.neuron.momentum + ) + # === Create thread queue === - self.forward_thread_queue = ThreadQueue(num_jobs = self.config.neuron.forward_num, target = self.forward) self.loss = None self.loss_agg_mutex = Lock() - self.moving_avg_scores = None + + # === Neuron statistics variables === + self.neuron_stats = {} + self.alpha = 0.05 # EMA coefficient in [0, 1], higher alpha discounts older observations faster + + if self.config.neuron.validation_synapse == 'TextCausalLMNext': + self.weight_key = 'shapley_values_nxt' # stat key + ! to calculate neuron weights with + # stat keys to duplicate (['key']->['key!']) and push zero to its EMA if neuron non-responsive + self.synapse_keys = ['shapley_values_nxt'] + else: + self.weight_key = 'shapley_values_min' # stat key + ! to calculate neuron weights with + # stat keys to duplicate (['key']->['key!']) and push zero to its EMA if neuron non-responsive + self.synapse_keys = ['shapley_values_min'] @classmethod def check_config( cls, config: 'bittensor.Config' ): @@ -142,12 +195,14 @@ def add_args( cls, parser ): parser.add_argument('--neuron.momentum', type=float, help='optimizer momentum.', default=0.8 ) parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch, -1 value means we use the chain value.', default = -1 ) parser.add_argument('--neuron.epochs_until_reset', type=int, help='Number of epochs before weights are reset.', default = -1 ) + parser.add_argument('--neuron.validation_len', type=int, help='Number of tokens to holdout for phrase validation beyond sequence context.', default=8) parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu")) parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0 ) parser.add_argument('--neuron.restart_on_failure', action='store_true', help='''Restart neuron on unknown error.''', default=True ) parser.add_argument('--neuron._mock', action='store_true', help='To turn on neuron mocking for testing purposes.', default=False ) parser.add_argument('--neuron.wait_for_finalization', action='store_true', help='''when setting weights the miner waits for trnasaction finalization.''', default=False) parser.add_argument('--neuron.forward_num', type=int, help='''How much forward request before a backward call.''', default=3) + parser.add_argument('--neuron.validation_synapse', type=str, help='''Synapse used for validation.''', default='TextCausalLMNext', choices = ['TextCausalLMNext', 'TextCausalLM']) @classmethod def config ( cls ): @@ -162,7 +217,15 @@ def config ( cls ): bittensor.dataset.add_args( parser ) bittensor.wandb.add_args(parser) return bittensor.config( parser ) - + + def __repr__(self) -> str: + return self.__str__() + + def __str__(self) -> str: + return (f'[bold]UID {self.uid}[/bold] \[{self.dendrite.receptor_pool.external_ip}] ' + f'({self.wallet.name}:[bold]{self.wallet.coldkeypub.ss58_address[:7]}[/bold]/' + f'{self.config.wallet.hotkey}:[bold]{self.wallet.hotkey.ss58_address[:7]}[/bold])') + def __del__(self): self.__exit__() @@ -172,16 +235,16 @@ def __exit__ ( self, exc_type, exc_value, exc_traceback ): print(exc_type, exc_value, exc_traceback) self.dataset.close() self.dendrite.__del__() - self.forward_thread_queue.stop() - self.forward_thread_queue.join() def __enter__(self): r""" Sanity checks and begin validator. """ # === Wallet === - # Connects wallett to network. + # Connects wallet to network. + self.wallet.create() # NOTE: This registration step should likely be solved offline first. - self.wallet.create().register( subtensor = self.subtensor ) + self.wallet.reregister( subtensor = self.subtensor ) + # === UID === # Get our uid from the chain. @@ -198,17 +261,6 @@ def __enter__(self): root_dir = self.config.neuron.full_path ) - def forward(self): - r""" Run the nucleus forward request - This function is supposed to be ran multi-threaded. - """ - result = self.nucleus( next(self.dataset) , self.metagraph, self.dendrite ) - - # === Backward === - # Backwards gradients through model to train gating and remote endpoints. - (result.loss / self.config.neuron.forward_num).backward() - return result - def run ( self ): r""" Run the validator and terminate on Keyboard interrupt. """ @@ -218,7 +270,6 @@ def run ( self ): # === Start forward requests === self.metagraph_sync() - self.forward_thread_queue.start() # === Run === # Iterates through epochs. @@ -252,18 +303,23 @@ def run_epoch( self ): current_block = self.subtensor.block batch_size = self.subtensor.validator_batch_size sequence_length = self.subtensor.validator_sequence_length + validation_len = self.config.neuron.validation_len # Number of tokens to holdout for phrase validation beyond sequence context n_topk_peer_weights = self.subtensor.min_allowed_weights max_allowed_ratio = self.subtensor.max_allowed_min_max_ratio blocks_per_epoch = self.subtensor.validator_epoch_length if self.config.neuron.blocks_per_epoch == -1 else self.config.neuron.blocks_per_epoch epochs_until_reset = self.subtensor.validator_epochs_per_reset if self.config.neuron.epochs_until_reset == -1 else self.config.neuron.epochs_until_reset + + # === Update dataset size === + if (batch_size != self.dataset.batch_size) or (sequence_length + validation_len != self.dataset.block_size): + self.dataset.set_data_size(batch_size, sequence_length + validation_len) + # === Logs === - print ( '\nEra:', '\n\t batch_size:', batch_size, '\n\t sequence_length:', sequence_length, '\n\t n_topk_peer_weights:', n_topk_peer_weights, - '\n\t max_allowed_ratio:', max_allowed_ratio, '\n\t blocks_per_epoch:', blocks_per_epoch, '\n\t epochs_until_reset:', epochs_until_reset, - '\n\t until_reset:', self.epoch % epochs_until_reset, '\n\t current_block:', current_block, '\n') if self.config.using_wandb: - wandb.log( { 'era/batch_size': batch_size, 'era/sequence_length': sequence_length, 'era/n_topk_peer_weights': n_topk_peer_weights, - 'era/max_allowed_ratio': max_allowed_ratio, 'era/blocks_per_epoch': blocks_per_epoch, 'era/epochs_until_reset': epochs_until_reset, - }, step = current_block ) + wandb.log({'era/batch_size': batch_size, 'era/sequence_length': sequence_length, + 'era/validation_len': validation_len, + 'era/n_topk_peer_weights': n_topk_peer_weights, 'era/max_allowed_ratio': max_allowed_ratio, + 'era/blocks_per_epoch': blocks_per_epoch, 'era/epochs_until_reset': epochs_until_reset}, + step=current_block) # === Run Epoch === # Each block length lasts blocks_per_epoch blocks. @@ -272,27 +328,6 @@ def run_epoch( self ): self.metagraph_sync() # Reset metagraph. epoch_steps = 0 - # === Reset Epochs with new params. === - # Pulls new default validator training parameters and resets - # the model and dataset for the following epoch. - if self.epoch % epochs_until_reset == 0: - print ('\n\n=== Reset ===\n\n') - # === Resetting model + dataset === - if (batch_size != self.dataset.batch_size) or (sequence_length != self.dataset.block_size): - self.dataset.set_data_size(batch_size, sequence_length) - - self.nucleus = nucleus ( config = self.config, device = self.device, subtensor = self.subtensor ).to( self.device ) - self.optimizer = torch.optim.SGD ( - self.nucleus.parameters(), lr = self.config.neuron.learning_rate, momentum = self.config.neuron.momentum - ) - - # === Reset Scores === - self.moving_avg_scores = torch.ones_like( self.metagraph.S ) * -1 - - # Checks if moving avg has been initiated - if self.moving_avg_scores == None: - self.moving_avg_scores = torch.ones_like( self.metagraph.S ) * -1 - start_block = self.subtensor.block while self.subtensor.block < start_block + blocks_per_epoch: start_time = time.time() @@ -300,56 +335,110 @@ def run_epoch( self ): # === Forward === # Forwards inputs through the network and returns the loss # and endpoint scores using shapely approximation of salience. - forward_results = self.forward_thread_queue.get() - print(f'Run\t| Got forward result in {round(time.time() - start_time, 3)}') - loss, scores, uids = self.nucleus.compute_shapely_scores(forward_results) - # === Scoring === + loss, stats = self.nucleus( next(self.dataset) , self.metagraph, self.dendrite ) + + # === Backward === + # Backwards gradients through model to train gating and remote endpoints. + if hasattr(loss, 'grad_fn') and loss.grad_fn is not None: + logger.info(f'Backward (loss: {loss:.3f})') + start_time = time.time() + (loss / self.config.neuron.forward_num).backward() + logger.info(f'Backward [{time.time() - start_time:.3g}s]') + + # === Stats update === # Updates moving averages and history. - self.moving_avg_scores[uids] = self.moving_avg_scores[uids]*(0.99) + scores*(0.01) - + responsive_uids, queried_uids = self.neuron_stats_update(stats) + # === State update === # Prints step logs to screen. epoch_steps += 1 self.global_step += 1 current_block = self.subtensor.block step_time = time.time() - start_time + + if epoch_steps % 25 == 1: + # validator identifier status console message (every 25 validation steps) + print(f"[white not bold]{datetime.datetime.now():%Y-%m-%d %H:%M:%S}[/white not bold]{' ' * 4} | " + f"{f'[bright_white]core_validator[/bright_white]'.center(16 + len('[bright_white][/bright_white]'))} | " + f"UID [cyan]{self.uid}[/cyan] " + f"[dim white not bold][{self.dendrite.receptor_pool.external_ip}][/dim white not bold] " + f"[white not bold]cold:[bold]{self.wallet.name}[/bold]:" + f"[bright_white not bold]{self.wallet.coldkeypub.ss58_address}[/bright_white not bold] " + f"[dim white]/[/dim white] " + f"hot:[bold]{self.config.wallet.hotkey}[/bold]:" + f"[bright_white not bold]{self.wallet.hotkey.ss58_address}[/bright_white not bold][/white not bold]") + + # validator update status console message + print(f"[white not bold]{datetime.datetime.now():%Y-%m-%d %H:%M:%S}[/white not bold]{' ' * 4} | " + f"{f'UID [bright_cyan]{self.uid}[/bright_cyan]'.center(16 + len('[bright_cyan][/bright_cyan]'))} | " + f'Updated [yellow]{current_block - self.metagraph.last_update[self.uid]}[/yellow] [dim]blocks ago[/dim] | ' + f'Dividends [green not bold]{self.metagraph.dividends[self.uid]:.5f}[/green not bold] | ' + f'Stake \u03C4[magenta not bold]{self.metagraph.stake[self.uid]:.5f}[/magenta not bold] ' + f'[dim](retrieved [yellow]{current_block - start_block}[/yellow] blocks ago from {self.subtensor.network})[/dim]') + + # step update console message (every validation step) + print(f"[white not bold]{datetime.datetime.now():%Y-%m-%d %H:%M:%S}[/white not bold]{' ' * 4} | " + f"{f'[magenta dim not bold]#{current_block}[/magenta dim not bold]'.center(16 + len('[magenta dim not bold][/magenta dim not bold]'))} | " + f'[green not bold]{current_block - start_block}[/green not bold]/' + f'[white not bold]{blocks_per_epoch}[/white not bold] [dim]blocks/epoch[/dim] | ' + f'[white not bold]Step {epoch_steps}[white not bold] ' + f'[dim] Epoch {self.epoch}[/dim] | ' + f'[bright_green not bold]{len(responsive_uids)}[/bright_green not bold]/' + f'[white]{len(queried_uids)}[/white] ' + f'[dim white not bold][green]responsive[/green]/queried[/dim white not bold] ' + f'[[yellow]{step_time:.3g}[/yellow]s]') + + if self.config.logging.debug or self.config.logging.trace: + # === Print stats update (table) === + # Prints exponential moving average statistics of valid neurons from latest validator forward + stats_table({uid: self.neuron_stats[uid] + for uid, stat in stats.items() if len(set(stat.keys()) & set(self.synapse_keys))}, + self.weight_key, self.config.get('width', None), + f'[white] Stats update [/white] | ' + str(self), # title + f'#{current_block}: ' + f'[bold]{current_block - start_block}[/bold]/{blocks_per_epoch} (blocks/epoch) | ' + f'Epoch {self.epoch} | ' + f'[white] Step {epoch_steps} ({self.global_step} global) \[{step_time:.3g}s] [/white]') # caption + + # === Calculate neuron weights === + topk_uids, topk_weights = self.calculate_weights() + self.weights_table(topk_uids, topk_weights, + include_uids=list(stats.keys()), num_rows=2 * len(stats)) # print weights table # === Logs === - print( '\nStep:', '\n\t epoch:', self.epoch, '\n\t epoch_steps:', epoch_steps, '\n\t global_steps:', self.global_step, '\n\t step_time:', step_time, '\n\t loss:', loss.item(), - '\n\t current_block', current_block, '\n\t blocks remaining:', current_block - start_block, '/', blocks_per_epoch, '\n') if self.config.using_wandb: - wandb.log( { 'epoch/epoch': self.epoch, 'epoch/epoch_steps': epoch_steps, 'epoch/global_steps': self.global_step, 'epoch/loss': loss.item(), 'epoch/time': step_time }, step = current_block ) - step_topk_scores, step_topk_uids = bittensor.unbiased_topk( self.moving_avg_scores, k = n_topk_peer_weights ) - step_topk_normalized = bittensor.utils.weight_utils.normalize_max_multiple( x = step_topk_scores, multiple = max_allowed_ratio ) - for i, w in list(zip(step_topk_uids.tolist(), step_topk_normalized.tolist()) ): - wandb.log( {'weights/w_{}'.format( i ): w }, step = current_block ) + for uid, vals in self.neuron_stats.items(): + for key in vals: # detailed neuron evaluation fields, e.g. loss, shapley_values, synergy + wandb.log({f'stats/{key}_{uid}': vals[key]}, step=current_block, commit=False) + + wandb.log({'epoch/epoch': self.epoch, 'epoch/epoch_steps': epoch_steps, + 'epoch/global_steps': self.global_step, 'epoch/loss': loss.item(), + 'epoch/time': step_time}, step=current_block, commit=True) # Do the backward request after the a queue of forward requests got finished. - if self.forward_thread_queue.paused() and self.forward_thread_queue.is_empty(): - print('Run\t| Model update') + if epoch_steps % self.config.neuron.forward_num == 1: + start_time = time.time() + logger.info('Model update \t| Optimizer step') # === Apply gradients === # Applies local gradients to parameters. clip_grad_norm_(self.nucleus.parameters(), self.config.neuron.clip_gradients) self.optimizer.step() - self.optimizer.zero_grad() + self.optimizer.zero_grad() + logger.info(f'Model update \t| Optimizer step [{time.time() - start_time:.3g}s]') - # === Get another round of forward requests === - self.forward_thread_queue.resume() - # Iterate epochs. self.epoch += 1 - # === Set weights === - # Find the n_topk_peer_weights peers to set weights to. - # We use the mean of the epoch weights. - topk_scores, topk_uids = bittensor.unbiased_topk(self.moving_avg_scores, k = n_topk_peer_weights ) - topk_scores = bittensor.utils.weight_utils.normalize_max_multiple( x = topk_scores, multiple = max_allowed_ratio ) - print( '\nScores:', '\n\t weights:', topk_scores.sort()[0].tolist(), '\n\t sum:', topk_scores.sum().item(), - '\n\t min:', topk_scores.min().item(), '\n\t max:', topk_scores.max().item(), '\n\t max/min:', (topk_scores.max()/topk_scores.min()).item() ) + # === Calculate neuron weights === + topk_uids, topk_weights = self.calculate_weights() + + if self.config.logging.debug or self.config.logging.trace: + self.weights_table(topk_uids, topk_weights) # print weights table + self.subtensor.set_weights( uids = topk_uids.detach().to('cpu'), - weights = topk_scores.detach().to('cpu'), + weights = topk_weights.detach().to('cpu'), wallet = self.wallet, wait_for_finalization = self.config.neuron.wait_for_finalization, ) @@ -359,65 +448,131 @@ def run_epoch( self ): if self.config.using_wandb: # Logging history to wandb. df = pandas.concat( [ - bittensor.utils.indexed_values_to_dataframe( prefix = 'weights', index = topk_uids, values = torch.zeros( self.metagraph.n ).scatter( dim = 0, src = topk_scores, index = topk_uids ) ), + bittensor.utils.indexed_values_to_dataframe( prefix = 'weights', index = topk_uids, values = torch.zeros( self.metagraph.n ).scatter( dim = 0, src = topk_weights, index = topk_uids ) ), self.dendrite.to_dataframe( metagraph = self.metagraph ) ], axis = 1); df['uid'] = df.index wandb_data_dend = self.dendrite.to_wandb() + wandb_weight = {f'stats/weight_{uid}': weight for uid, weight in zip (topk_uids, topk_weights)} wandb_data = { 'stake': self.metagraph.S[ self.uid ].item(), 'dividends': self.metagraph.D[ self.uid ].item() } - wandb.log( { 'stats': wandb.Table( dataframe = df ) }, step = current_block ) - wandb.log( { **wandb_data, **wandb_data_dend }, step = current_block ) - + wandb.log( { 'stats': wandb.Table( dataframe = df ) }, step = current_block, commit=False) + wandb.log( { **wandb_data, **wandb_data_dend, **wandb_weight }, step = current_block, commit=True) + def metagraph_sync(self): r""" Syncing metagraph together with other metagraph-size related objects """ old_hotkeys = self.metagraph.hotkeys self.metagraph.sync() - - # === Create if None - if self.moving_avg_scores == None: - self.moving_avg_scores = torch.ones_like( self.metagraph.S ) * -1 - # === Match size for the moving average score - if self.metagraph.n > len(self.moving_avg_scores): - size_incerease = self.metagraph.n - len(self.moving_avg_scores) - self.moving_avg_scores = torch.concat([self.moving_avg_scores, torch.ones(size_incerease) * -1]) - - # === Reset moving average score if uid got replaced + # === Reset neuron stats if uid got replaced for uid, old_hotkey in enumerate(old_hotkeys): if old_hotkey != self.metagraph.hotkeys[uid]: - self.moving_avg_scores[uid] = -1 -class PositionalEncoding(nn.Module): - r""" Positional Encoder which adds information based on the relative position of each token - - """ - def __init__(self, d_model: int, dropout: float, max_len: int = 5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - - # === Create position matrix === - # Creates a positional matrix with alternating frequencies - # pe: (torch.FloatTensor) positional encoding matrix - # pe.shape: [1, max_len, network_dim] - pe = torch.zeros(1, max_len, d_model) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, : , 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - - def forward(self, x: torch.tensor) -> torch.tensor: + if uid in self.neuron_stats: + del self.neuron_stats[uid] + + def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]): + r""" Updates self.neuron_stats with new individual dictionaries per uid. """ - Args: - x: Tensor, shape [batch_size, seq_len, embedding_dim] + responsive_uids = [] + for _uid, _stats in neuron_stats.items(): + stats = self.neuron_stats.setdefault(_uid, {}) + + # === EMA zeroing update === + # Push zero into EMA for synapse_keys to exponentially decay weighting keys if neuron non-responsive + if 'updates!' in stats: + stats['updates!'] += 1 # increment number of EMA zeroing updates + else: + stats.setdefault('updates!', 1) # number of EMA zeroing updates init to zero + + for key in self.synapse_keys: + zkey = key + '!' # zeroing key + stats.setdefault(zkey, 0.) # initialize zkey val to zero to gradually increase with observations + if key in _stats and not math.isnan(_stats[key]): + stats[zkey] = (1 - self.alpha) * stats[zkey] + self.alpha * _stats[key] + else: + stats[zkey] = (1 - self.alpha) * stats[zkey] # + self.alpha * 0 + + # === EMA normal update === + # If synapse responsive push available values into EMA for normal update. + # Normal EMA values provide a view on neuron performance if fully responsive. + for key in self.synapse_keys: + if key in _stats: + updates = 'updates_' + key + if updates in stats: + stats[updates] += 1 # increment number of normal EMA updates made + responsive_uids += [_uid] + else: + stats.setdefault(updates, 1) # add updates fields for new uid entries + + for key in _stats: # detailed neuron evaluation fields, e.g. loss, shapley_values, synergy + if math.isnan(_stats[key]): + continue + if key in stats: + stats[key] = (1 - self.alpha) * stats[key] + self.alpha * _stats[key] # update EMA + else: + stats.setdefault(key, _stats[key]) + + return responsive_uids, list(neuron_stats.keys()) # responsive_uids, queried_uids + + def calculate_weights(self): + r""" Calculates neuron set-weights from weight_key mapped values. Defines weight_key as the neuron stats key + used to obtain the mapped stat value (typically a Shapley value) that the final set-weights are calculated from. """ - # === Positional Encoding === - # Inject some information of the relative position of the token in the sequence. - # Finally, Dropout is applied to tokens - # x: (torch.FloatTensor) input sequence tokens with position information injected - # x.shape: [batch_size, seq_len, network_dim] - x = x + self.pe[0, :x.size(1)] - return self.dropout(x) + + weight_key = self.weight_key + '!' # use zeroing key to penalize non-responsive neurons + n_topk_peer_weights = self.subtensor.min_allowed_weights + max_allowed_ratio = self.subtensor.max_allowed_min_max_ratio + + # === Calculate neuron weights === + neuron_weights = torch.zeros_like(self.metagraph.S) # allow unevaluated UIDs to be selected to meet min topk + + for uid in self.neuron_stats: + if weight_key in self.neuron_stats[uid]: + neuron_weights[uid] = torch.tensor([self.neuron_stats[uid][weight_key]]) + + # Find the n_topk_peer_weights peers to set weights to. + topk_weights, topk_uids = bittensor.unbiased_topk(neuron_weights, k=n_topk_peer_weights) + topk_weights = bittensor.utils.weight_utils.normalize_max_multiple(x=topk_weights, + multiple=max_allowed_ratio) + return topk_uids, topk_weights + + def weights_table(self, topk_uids, topk_weights, include_uids=None, num_rows: int = None): + r""" Prints weights table given topk_uids and topk_weights. + """ + n_topk_peer_weights = self.subtensor.min_allowed_weights + max_allowed_ratio = self.subtensor.max_allowed_min_max_ratio + + # === Weight table === + # Prints exponential moving average statistics of valid neurons and latest weights + _neuron_stats = {} + unvalidated = [] + for uid, weight in zip(topk_uids.tolist(), topk_weights.tolist()): + if uid in self.neuron_stats: + _neuron_stats[uid] = {k: v for k, v in self.neuron_stats[uid].items()} + _neuron_stats[uid]['weight'] = weight + else: + unvalidated += [uid] + + avail_include_uids = None + if include_uids is not None and num_rows is not None: + avail_include_uids = list(set(_neuron_stats.keys()) & set(include_uids)) # exclude include_uids with no stats + if len(_neuron_stats) > num_rows: # limit table to included_uids and remaining topk up to num_rows + remaining_uids = set(_neuron_stats.keys()) - set(include_uids) # find topk remaining, loses topk ordering + remaining_uids = [uid for uid in _neuron_stats if uid in remaining_uids] # recover topk ordering + limited_uids = avail_include_uids + remaining_uids[:num_rows - len(include_uids)] + _neuron_stats = {uid: stats for uid, stats in _neuron_stats.items() if uid in limited_uids} + + print() + stats_table(_neuron_stats, 'weight', self.config.get('width', None), + f'[white] Neuron weights [/white] | ' + str(self), # title + f'Validated {n_topk_peer_weights}/' + f'[bold]{len(self.neuron_stats)}[/bold]/{self.metagraph.n} (min/[bold]valid[/bold]/total) | ' + f'sum:{topk_weights.sum().item():.2g} ' + f'[white] max:[bold]{topk_weights.max().item():.4g}[/bold] / ' + f'min:[bold]{topk_weights.min().item():.4g}[/bold] [/white] ' + f'\[{topk_weights.max().item() / topk_weights.min().item():.1f}:1] ' + f'({max_allowed_ratio} allowed)', # caption + mark_uids=avail_include_uids) + class nucleus( torch.nn.Module ): """ Nucleus class which holds the validator model. @@ -426,7 +581,10 @@ def __init__( self, config, device, subtensor ): super(nucleus, self).__init__() self.config = config self.device = device - self.max_n = subtensor.max_n + self.max_n = subtensor.max_n + + tokenizer = bittensor.tokenizer() + self.pad_token = tokenizer(tokenizer.pad_token)['input_ids'][0] # Token embeddings project int64 tokens onto representations. self.token_embedding = torch.nn.Embedding( bittensor.__vocab_size__, bittensor.__network_dim__ ) @@ -450,6 +608,9 @@ def __init__( self, config, device, subtensor ): # SGMOE Gates: Instantiating the gates per expert. self.gates = torch.nn.Linear( bittensor.__network_dim__, self.max_n, bias=True ).to( self.device ) + + self.sigmoid = torch.nn.Sigmoid() + self.reset_weights() @classmethod @@ -461,6 +622,8 @@ def add_args( cls, parser ): parser.add_argument('--nucleus.dropout', type=float, help='the dropout value', default=0.2) parser.add_argument('--nucleus.importance', type=float, help='hyperparameter for the importance loss', default=3) parser.add_argument('--nucleus.noise_multiplier', type=float, help='Standard deviation multipler on weights', default=2 ) + parser.add_argument('--nucleus.dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False ) + parser.add_argument('--nucleus.scaling_law_power', type=float, help='Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.', default=0.5) @classmethod def config ( cls ): @@ -487,31 +650,15 @@ def init_xavier( component ): self.encoder.apply( init_xavier ) torch.nn.init.xavier_uniform_( self.gates.weight ) - # === Compute loss given joined responses === - # This function computes target loss for next token prediction given - # the joined responses as a hidden unit input. - # target_loss: (torch.float64): loss after decoding responses to targets. - # target_loss.shape = [ 1 ] - def get_target_loss ( self, hidden, targets ): - # hidden: (torch.float64): [ batch_size, sequence_len, __network_dim__ ] - # Hidden units which are encoded and decoded onto targets for loss computation. - # targets: (torch.float64): [n] - # Token targets, - src_mask = torch.triu(torch.ones(hidden.size(1), hidden.size(1)) * float('-inf'), diagonal=1) - src_mask = src_mask.to(self.config.neuron.device) - encoded_hidden = self.encoder( hidden, mask = src_mask ) - decoded_targets = self.decoder( encoded_hidden ) - shift_logits = decoded_targets[..., :-1, :].contiguous() - shift_labels = targets[..., 1:].contiguous() - return self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) - - def forward ( - self, - inputs: torch.FloatTensor, - metagraph: 'bittensor.Metagraph', - dendrite: 'bittensor.Dendrite', + def forward( + self, + inputs: torch.FloatTensor, + metagraph: 'bittensor.Metagraph', + dendrite: 'bittensor.Dendrite', ): - r""" Forward validator pass. Selects peer to query, joins results and computes scoring. + r""" + Forward validator pass. Selects endpoints to query and validate, calculates routing_score and Shapley values + for validated synapses. Args: inputs (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, *-1*)`, `required`): Tensor inputs to distribute to neurons using query context. @@ -520,29 +667,35 @@ def forward ( dendrite (bittensor.Dendrite): Dendrite RPC client used to make network queries. Returns: - global_loss (torch.FloatTensor, [1] ): - Loss for training validator nucleus. - scores (torch.FloatTensor, [ metagraph.n ]): - Scores per endpoint for this batch. - """ + loss (:obj:`torch.FloatTensor`): + Loss for training validator nucleus and dendrite backward to endpoints. + neuron_stats (:obj:`Dict`, `required`): + Statistics per endpoint for this batch. + """ + start_time = time.time() + + val_len = self.config.neuron.validation_len # Number of tokens to holdout for phrase validation beyond sequence context + inputs = inputs.to(self.device) + inputs_seq = inputs[..., :-val_len] # input sequence without last validation tokens [batch_size, sequence_len] + # === Create the local context used to select endpoints === # The context tensor returns a hidden unit representation for the text inputs # this context can be used as input to the gates in the next step. # embedding: retrieve learned representation vectors for input vocabulary tokens. # inputs.shape = [batch_size, sequence_len] # embedding.shape = [batch_size, sequence_len, bittensor.__network_dim__] - embedding = self.token_embedding( inputs )* math.sqrt( bittensor.__network_dim__ ) - + embedding = self.token_embedding(inputs_seq) * math.sqrt(bittensor.__network_dim__) + # === Create an attention mask === # The attention mask will mask out parts of the context # This prevents cheating and forward-looking when predicting each token in the sequence. # src_mask: (torch.FloatTensor) attention mask adds -inf to positions not allowed to attend # src_mask.shape = [sequence_len, sequence_len] src_mask = torch.triu(torch.ones(embedding.size(1), embedding.size(1)) * float('-inf'), diagonal=1) - src_mask = src_mask.to(self.config.neuron.device) + src_mask = src_mask.to(self.device) # === Apply the positional encoding to help select endpoints === - # The positional encoder provides information based on the relative postion of each token + # The positional encoder provides information based on the relative postion of each token # embedding.shape = [batch_size, sequence_len, bittensor.__network_dim__] # pos_embedding: (torch.FloatTensor) positional encoded embedding. # pos_embedding.shape = [batch_size, sequence_len, bittensor.__network_dim__] @@ -550,120 +703,560 @@ def forward ( # routing_context: (torch.FloatTensor): context tensor which is used to select endpoints. # routing_context.shape = [ batch size, __network_dim__ ] - routing_context = self.routing_encoder( pos_embedding, mask = src_mask ) + routing_context = self.routing_encoder(pos_embedding, mask=src_mask) - # === Get weights for uids. === - # We iterate over each of the network uids and compute a querying score for each + # === Get gate values for UIDs. === + # We iterate over each of the network UIDs and compute a querying score for each # using the gating function. This returns a score per endpoint per example. - # routing_weights: (torch.FloatTensor): score per example, per endpoint. - # routing_weights.shape = [ batch size, __network_n__ ] + # routing_score: (torch.FloatTensor): score per example, per endpoint. + # routing_score.shape = [metagraph.n] # The gates act over the last embedding of the routing_context. - routing_weights = self.gates( routing_context[:,-1,:] ) - - # === Normalize routing_weights across batch dimension and add noise. === - # We are summing across the batch dimension to create a per-batch score per endpoint. - # The resulting routing_weights tensor is a score per expert. - # routing_weights: (torch.FloatTensor): normalized weights across batch dimension with noise. - # routing_weights.shape = [ n_filtered ] - batchwise_routing_weights = torch.mean(routing_weights, axis = 0)[:metagraph.n] - noisy_routing_weights = torch.normal( 0, torch.std(batchwise_routing_weights).item(), size=( batchwise_routing_weights.size())).to( self.config.neuron.device ) - noisy_routing_weights = batchwise_routing_weights + noisy_routing_weights * self.config.nucleus.noise_multiplier - + routing_score = torch.mean(self.sigmoid(self.gates(routing_context[:, -1, :])), dim=0) + + # Ensure number of queried neurons does not exceed metagraph.n + num_endpoints = min([self.config.nucleus.topk, metagraph.n]) - # === Get indices and values for uids with highest scores === - # We are taking the topk routing weights and returning their uids. - # First we ensure topk is smaller than the network size then use the torch.topk. - # topk_routing_weights: (torch.float64): scores of uids with highest scores. - # topk_routing_weights.shape = [ self.config.nucleus.topk ] - # topk_routing_uids: (torch.LongTensor): uids with highest scores. - # topk_routing_uids.shape = [ self.config.nucleus.topk ] - top_k_routing_weights, routing_uids = torch.topk( noisy_routing_weights, self.config.nucleus.topk, dim=0) + logger.info(f'Forward \t| Routing forward [{time.time() - start_time:.3g}s]') + logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}') + request_start_time = time.time() - # === Get endpoint information for the highest scoring uids === + # === Randomly select num_endpoints UIDs === + random_uids = torch.randperm(metagraph.n)[:num_endpoints] + + # === Get endpoint information for the selected UIDs === # We index into the metagraph's endpoints and return a list of the filtered set of endpoints we wish to query. - # routing_endpoints: List[bittensor.endpoints]: endpoint information for filtered uids. + # random_endpoints: List[bittensor.endpoints]: endpoint information for filtered uids. # len(neurons) == self.config.nucleus.topk - routing_endpoints = [ metagraph.endpoints[ uid ] for uid in routing_uids ] + random_endpoints = [metagraph.endpoints[uid] for uid in random_uids] + + # === Define which synapse we want to use === + # The synapse defines the task we are sending to the neurons + # synapses: List[bittensor.synapse]: synapse information + # TODO: WORK IN PROGRESS, prototype + if self.config.neuron.validation_synapse == 'TextCausalLMNext': + synapses = [(bittensor.synapse.TextCausalLMNext(), textcausallmnext)] + else: + synapses = [(bittensor.synapse.TextCausalLM(), textcausallm)] # === Query the endpoints === - # Makes the dendrite call into the network returning the representations + # Makes the dendrite call into the network returning the representations # for each of the endpoints. The return ops can be used to filter weights and outputs. # query_responses: (List[torch.float64]): responses from each endpoint. - # query_responses.shape = self.config.nucleus.topk * [ batch_size, sequence_len, __network_dim__ ] + # query_responses.shape = self.config.nucleus.topk * num_synapses * [batch_size, sequence_len, synapse_dim] # return_ops: (torch.int64): Return ops. - # return_ops.shape = [ self.config.nucleus.topk ] - query_responses, return_ops, times = dendrite.forward_text ( - endpoints = routing_endpoints, - inputs = inputs + # return_ops.shape = self.config.nucleus.topk * [num_synapses] + query_responses, return_ops, times = dendrite.text( + endpoints=random_endpoints, + inputs=inputs_seq, + synapses=[syn for syn, _ in synapses], + timeout=bittensor.__blocktime__ ) + + if not self.config.nucleus.dendrite_backward: + query_responses = [[syn.detach().to(self.device) for syn in res] for res in query_responses] + return_ops = [ops.detach().to(self.device) for ops in return_ops] + times = [t.detach().to(self.device) for t in times] + # Send responses to device. This is required to ensure we move the responses # Onto the correct device. - for response in query_responses: - response.to( self.device ) - - # === Compute global loss === - # Computes the global training loss for the nucleus by decoding all the responses - # onto the targets. - # target_loss: (torch.float64): loss after decoding all responses and a variance loss. - # target_loss.shape = [ 1 ] - responses_hidden, _ = joining_context( return_ops, batchwise_routing_weights[routing_uids], query_responses) - target_loss = self.get_target_loss ( responses_hidden, inputs ) - print ('Loss\t|\t{}'.format( target_loss.item() )) - - # === Compute Importance loss === - # Computes the importance loss based on the stardard error of batchwise_routing_weights - # This ensures that gates do not converge onto a few experts - # importance_loss: (torch.float64) the importance loss based on the stardard error - # target_loss: (torch.float64): the total loss (global training loss + importance loss) - # target_loss.shape = [ 1 ] - importance_loss = self.config.nucleus.importance * (torch.std(batchwise_routing_weights)/torch.mean(batchwise_routing_weights))**2 - loss = target_loss + importance_loss - - state_dict = SimpleNamespace( - inputs = inputs, - batchwise_routing_weights = batchwise_routing_weights, - routing_uids = routing_uids, - query_responses = query_responses, - return_ops = return_ops, - responses_hidden = responses_hidden, - loss = loss, - n = metagraph.n.item() - ) - - return state_dict + for responses in query_responses: + for response in responses: + response.to(self.device) - def compute_shapely_scores(self, state_dict): - - # === Compute shapely scores === - # Computes shapely scores for each endpoint by masking the response and - # computing the change in loss induced. - # shapely_scores: (torch.float32): shapely scores per query_response - # shapely_scores.shape = [ metagraph.n ] - masked_contexts = partial_contexts( - state_dict.return_ops, - state_dict.routing_uids, - state_dict.batchwise_routing_weights[state_dict.routing_uids], - state_dict.query_responses - ) - # Turn off gradient computation for shapely scores. - # shapely_scores.shape = [ nucleus.topk ] - # This sets non queried peers as if non-responsive - shapely_scores = torch.zeros( state_dict.routing_uids.size()) - # Turn off gradient computation for shapely scores. - with torch.no_grad(): - self.eval() - - unmasked_loss = self.get_target_loss(state_dict.responses_hidden, state_dict.inputs) - # Iterate over all responses creating a masked context. - for i, uid in enumerate(masked_contexts): - # Create mask by zeroing out the response at index. - masked_loss = self.get_target_loss ( masked_contexts[uid], state_dict.inputs ) - shapely_score = unmasked_loss - masked_loss - print ('Shapely\t|\tuid: {}\tweight: {}\tscore: {}\tcode: {}\tsum: {}'.format( uid, state_dict.batchwise_routing_weights[state_dict.routing_uids][i], -shapely_score.item(), state_dict.return_ops[i], state_dict.query_responses[i].sum())) - shapely_scores[ i ] = -shapely_score if not torch.abs(1 - state_dict.query_responses[i].std()).item() < 0.05 else -1 - - # Ensures that the nonresponsive peers are not rewarded - shapely_scores[state_dict.return_ops != 1 ] = -1 - - # === Done === - return state_dict.loss, shapely_scores, state_dict.routing_uids + logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)} ' + f'[{time.time() - request_start_time:.3g}s]') + + # === Prepare validation parameter set === + console_width = self.config.get('width', None) # console width for rich table displays of synapse measures + validation_params = (random_uids, query_responses, return_ops, times, routing_score, + inputs, val_len, self.loss_fct, self.config.nucleus.scaling_law_power, console_width, + self.config.logging.debug or self.config.logging.trace) + + loss = torch.tensor(0.).to(self.device) # to accumulate neuron_loss and routing_loss over synapses + neuron_stats = {} # to gather neuron synapse validation measures and statistics + + # === Validate synapse responses === + # Iterate over all queried synapses and validate responses + for i, (synapse, validate_func) in enumerate(synapses): + _loss, stats = validate_func(*validation_params, synapse=synapse, index_s=i) # validate individual synapse + loss += _loss # add neuron_loss and routing_loss + + for _uid, _stats in stats.items(): + neuron_stats.setdefault(_uid, {}) + neuron_stats[_uid].update(_stats) # gather neuron synapse validation measures and statistics + + return loss, neuron_stats + + +def scaling_law_loss_to_params(loss): + r""" (OpenAI scaling laws) Kaplan, Jared, et al. "Scaling laws for neural language models." arXiv:2001.08361 (2020) + """ + num_params = torch.exp(torch.log(torch.tensor(8.8e13).to(loss.device)) - + torch.log(torch.clamp(loss, 1.69)) / 0.076) # loss lower bound 1.69 is entropy of natural text + return num_params + + +def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor], + times: List[torch.FloatTensor], routing_score: torch.FloatTensor, + inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float, + console_width: int, logging, synapse: 'bittensor.TextCausalLM' = None, index_s: int = 0 + ) -> Tuple[torch.FloatTensor, Dict]: + r""" + Calculate Shapley values and neuron response validation measure statistics, given TextCausalLM synapse responses. + Args: + uids (:obj:`torch.Tensor`, `required`): [num_neurons] + Neuron UIDs. + query_responses (:obj:`List[List[torch.FloatTensor]]`, `required`): + List of outputs from synapses, each a list of size num_endpoints of tensors with relevant size. Non-responses are zeroes of relevant + synapse shape. Shape num_synapses * ( num_endpoints * ( -1, -1, -1 ) ) + return_ops (:obj:`List[torch.LongTensor]` of shape :obj:`[num_endpoints]`, `required`): + Return code per call per synapse. + times (:obj:`List [torch.FloatTensor]` of shape :obj:`[num_endpoints]`, `required`): + Times per call per synapse. + routing_score (:obj:`torch.FloatTensor`, `required`): + [metagraph.n] Predictive routing score per endpoint in the metagraph, mean over the batch. + inputs (:obj:`torch.FloatTensor`, `required`): + [batch_size, sequence_len + validation_len] Token batch of original inputs with validation tokens. + validation_len (:obj:`int`, `required`): + Number of held-out phrase token batch for extended validation, not sent to neurons. + loss_fct (:obj:`Callable`, `required`): + CrossEntropy loss function to use. + scaling_law_power (:obj:`float`, `required`): + Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. + console_width (:obj:`int`, `required`): + Config console width for table print. + logging (:obj:`bool`, `required`): + Log tables to console. + synapse (:obj:`bittensor.TextCausalLM`, `optional`): + TextCausalLM synapse object. + index_s (:obj:`int`, `optional`): + Index of synapse to extract responses. + + Returns: + loss (:obj:`torch.FloatTensor`): + Loss for training validator nucleus and dendrite backward to endpoints. + stats (:obj:`Dict`, `required`): + Statistics per endpoint for this batch. + """ + + inputs_seq = inputs[..., :-validation_len] # input sequence without last token [batch_size, sequence_len] + inputs_val = inputs[..., -validation_len] # input validation with next token [batch_size] + + def _base_params(_stats, query_response): + _stats.update({'logits': query_response[:, :-1, :], + 'logits_val': query_response[:, -1:, :]}) + + for target, _ext in [(inputs_seq[:, 1:], ''), (inputs_val, '_val')]: + _loss = calc_loss_fct(loss_fct, _stats['logits' + _ext], target) # CausalLM loss + if _loss.isnan() or _loss.isinf(): + _loss = 20 # assign large loss + + # estimate the effective number of model parameters, modified with the scaling_law_power + _num_params = scaling_law_loss_to_params(_loss) + + # powered down number of params, e.g. dynamic range 3 → 6 nats for scaling_law_power=0.5 + _pow_num_params = torch.pow(_num_params, scaling_law_power) + + _stats.update({'loss' + _ext: _loss, + 'est_params' + _ext: _num_params, 'base_params' + _ext: _pow_num_params, + 'synergy' + _ext: 0, 'synergy_loss_diff' + _ext: 0}) + + def _synergy(first, second, target, _ext): + # Combined logits: log of average probabilities per token between responses + combined_logits = torch.log((torch.softmax(first['logits' + _ext], dim=-1) + + torch.softmax(second['logits' + _ext], dim=-1)) / 2 + 1e-40) + measured_loss = calc_loss_fct(loss_fct, combined_logits, target) # actual measured loss + + return measured_loss + + shapley_start_time = time.time() + + loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score, + _base_params, index_s, ext='') + + logger.info(f'{str(synapse)} \t| Shapley base values [{time.time() - shapley_start_time:.3g}s]') + + synergy_start_time = time.time() + + syn_loss_diff = shapley_synergy(stats, _synergy, ext='', target=inputs_seq[:, 1:], + scaling_law_power=scaling_law_power) + syn_loss_diff_val = shapley_synergy(stats, _synergy, ext='_val', target=inputs_val, + scaling_law_power=scaling_law_power) + + # === Shapley value combination === + # Combine base values with synergy approximation to get final Shapley values. + for s in stats.values(): + for ext in ['', '_val']: + if 'base_params' + ext in s and 'synergy' + ext in s: + s['shapley_values' + ext] = (s['base_params' + ext] + s['synergy' + ext]) + + if 'logits' + ext in s: + del s['logits' + ext] # remove logits - not needed for stats anymore + + if 'shapley_values' in s and 'shapley_values_val' in s: + s['shapley_values_min'] = torch.min(s['shapley_values'], s['shapley_values_val']) + + for key in s: + if hasattr(s[key], 'item'): + s[key] = s[key].item() + + logger.info(f'{str(synapse)} \t| Shapley synergy values [{time.time() - synergy_start_time:.3g}s]') + + if logging: + # === Synergy table === + # Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal) + synergy_table(stats, syn_loss_diff, 'shapley_values_min', console_width=console_width) + + # === Neuron responses (table) === + # Prints the evaluation of the neuron responses to the validator request + synapse_table(str(synapse), stats, 'shapley_values_min', console_width, shapley_start_time) + + # === Unsuccessful responses === + # Prints the return codes and response times of unsuccessful responses + unsuccess(str(synapse), unsuccessful) + + return loss, stats + + +def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor], + times: List[torch.FloatTensor], routing_score: torch.FloatTensor, + inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float, + console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0 + ) -> Tuple[torch.FloatTensor, Dict]: + r""" + Calculate Shapley values and neuron response validation measure statistics, given TextCausalLMNext synapse responses. + Args: + uids (:obj:`torch.Tensor`, `required`): [num_neurons] + Neuron UIDs. + query_responses (:obj:`List[List[torch.FloatTensor]]`, `required`): + List of outputs from synapses, each a list of size num_endpoints of tensors with relevant size. Non-responses are zeroes of relevant + synapse shape. Shape num_synapses * ( num_endpoints * ( -1, -1, -1 ) ) + return_ops (:obj:`List[torch.LongTensor]` of shape :obj:`[num_endpoints]`, `required`): + Return code per call per synapse. + times (:obj:`List [torch.FloatTensor]` of shape :obj:`[num_endpoints]`, `required`): + Times per call per synapse. + routing_score (:obj:`torch.FloatTensor`, `required`): + [metagraph.n] Predictive routing score per endpoint in the metagraph, mean over the batch. + inputs (:obj:`torch.FloatTensor`, `required`): + [batch_size, sequence_len + validation_len] Token batch of original inputs with validation tokens. + validation_len (:obj:`int`, `required`): + Number of held-out phrase token batch for extended validation, not sent to neurons. + loss_fct (:obj:`Callable`, `required`): + CrossEntropy loss function to use. + scaling_law_power (:obj:`float`, `required`): + Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. + console_width (:obj:`int`, `required`): + Config console width for table print. + logging (:obj:`bool`, `required`): + Log tables to console. + synapse (:obj:`bittensor.TextCausalLMNext`, `optional`): + TextCausalLMNext Synapse object. + index_s (:obj:`int`, `optional`): + Index of synapse to extract responses. + + Returns: + loss (:obj:`torch.FloatTensor`): + Loss for training validator nucleus and dendrite backward to endpoints. + stats (:obj:`Dict`, `required`): + Statistics per endpoint for this batch. + """ + + inputs_nxt = inputs[..., -validation_len:] # input validation with next token target phrase [batch_size, val_len] + + def _base_params(_stats, query_response): + # topk_tensor = unravel_topk_token_phrases(query_response, topk=synapse.topk) # [batch_size, topk + 1, max_len] + _losses_val, _losses = phrase_cross_entropy(inputs_nxt, query_response, reduce=False) + _losses_val[_losses_val.isnan()] = 20 # assign large loss + _losses[_losses.isnan()] = 20 # assign large loss + _loss_val = _losses_val.mean() + _loss = _losses.mean() + + # estimate the effective number of model parameters, modified with the scaling_law_power + _num_params = scaling_law_loss_to_params(_loss) + + # powered down number of params, e.g. dynamic range 3 → 6 nats for scaling_law_power=0.5 + _pow_num_params = torch.pow(_num_params, scaling_law_power) + + _stats.update({'loss_val_nxt': _loss_val, 'losses_nxt': _losses, 'loss_nxt': _loss, + 'est_params_nxt': _num_params, 'base_params_nxt': _pow_num_params, + 'synergy_nxt': 0, 'synergy_loss_diff_nxt': 0}) + + def _synergy(first, second, target, ext): + # average first + second probabilities per batch item, convert to loss + measured_loss = -torch.log((torch.exp(-first['losses_nxt']) + + torch.exp(-second['losses_nxt'])) / 2 + 1e-40).mean() + + return measured_loss + + shapley_start_time = time.time() + + loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score, + _base_params, index_s, ext='_nxt') + + logger.info(f'{str(synapse)} \t| Shapley base values [{time.time() - shapley_start_time:.3g}s]') + + synergy_start_time = time.time() + + syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=scaling_law_power) + + # === Shapley value combination === + # Combine base values with synergy approximation to get final Shapley values. + for s in stats.values(): + if 'base_params_nxt' in s and 'synergy_nxt' in s: + s['shapley_values_nxt'] = s['base_params_nxt'] + s['synergy_nxt'] + + if 'losses_nxt' in s: + del s['losses_nxt'] # remove batch losses - not needed for stats anymore + + for key in s: + if hasattr(s[key], 'item'): + s[key] = s[key].item() + + logger.info(f'{str(synapse)} \t| Shapley synergy values [{time.time() - synergy_start_time:.3g}s]') + + if logging: + # === Synergy table === + # Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal) + synergy_table(stats, syn_loss_diff, 'shapley_values_nxt', console_width) + + # === Neuron responses (table) === + # Prints the evaluation of the neuron responses to the validator request + synapse_table(str(synapse), stats, 'shapley_values_nxt', console_width, shapley_start_time) + + # === Unsuccessful responses === + # Prints the return codes and response times of unsuccessful responses + unsuccess(str(synapse), unsuccessful) + + return loss, stats + + +def shapley_base(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor], + times: List[torch.FloatTensor], routing_score: torch.FloatTensor, + base_params: Callable, index_s: int = 0, ext: str = None) -> Tuple[Union[float, torch.FloatTensor], + Dict, + List]: + r""" + Calculate Shapley base values and neuron response validation measure statistics, given responses from a synapse. + Args: + uids (:obj:`torch.Tensor`, `required`): [num_neurons] + Neuron UIDs. + query_responses (:obj:`List[List[torch.FloatTensor]]`, `required`): + List of outputs from synapses, each a list of size num_endpoints of tensors with relevant size. Non-responses are zeroes of relevant + synapse shape. Shape num_synapses * ( num_endpoints * ( -1, -1, -1 ) ) + return_ops (:obj:`List[torch.LongTensor]` of shape :obj:`[num_endpoints]`, `required`): + Return code per call per synapse. + times (:obj:`List [torch.FloatTensor]` of shape :obj:`[num_endpoints]`, `required`): + Times per call per synapse. + routing_score (:obj:`torch.FloatTensor`, `required`): + [metagraph.n] Predictive routing score per endpoint in the metagraph, mean over the batch. + base_params (:obj:`Callable`, `required`): + CrossEntropy loss function to use. + index_s (:obj:`int`, `optional`): + Index of synapse to extract responses. + ext (:obj:`str`, `optional`): + Extension to parameter string for stats key. + + Returns: + loss (:obj:`torch.FloatTensor`): + Loss for training validator nucleus and dendrite backward to endpoints. + stats (:obj:`Dict`, `required`): + Statistics per endpoint for this batch. + unsuccessful (:obj:`List`, `required`): + Unsuccessful endpoints [(uid, return_op, time)]. + """ + stats = {} + unsuccessful = [] + neuron_loss = 0. # neuron losses to accumulate to then backward() via dendrite + routing_loss = 0. # validator routing loss for local model update + + # === Base parameter estimation === + # Shapley values - base level - coalition size 1 + # Collect successful neuron responses, calculate base Shapley values. + # Measured in effective number of model parameters, according to OpenAI scaling laws. + for index, _uid in enumerate(uids.tolist()): + if return_ops[index][index_s] == bittensor.proto.ReturnCode.Success: + _stats = {'uid': _uid, + 'response_time' + ext: times[index][index_s], + 'routing_score': routing_score[_uid]} + + try: + base_params(_stats, query_responses[index][index_s]) + + neuron_loss += _stats['loss' + ext] # add sequence loss to be backward() to neuron + + # === Add routing loss === + # MSE loss between predicted routing score and ideal target routing score. + # The Bayes risk approx. 1.69, i.e. the minimal loss achievable for next-token + # prediction on the full distribution 𝑃, a.k.a the "entropy of natural text" + # Hoffmann, Jordan, et al. "Training Compute-Optimal Large Language Models." arXiv:2203.15556 (2022). + routing_score_target = torch.exp(-torch.clamp(_stats['loss' + ext].detach() - 1.69, 0)) + _routing_loss = (routing_score[_uid] - routing_score_target) ** 2 # MSE loss + routing_loss += _routing_loss + _stats.update({'routing_score_target' + ext: routing_score_target, 'routing_loss' + ext: _routing_loss}) + + stats[_uid] = _stats + except Exception as e: + logger.warning(f'Synapse {index_s} error (shapley_base)\t| ' + f'UID {_uid} [{times[index][index_s]:.2f}s]: {e}') + stats[_uid] = _stats + unsuccessful += [(_uid, return_ops[index][index_s], times[index][index_s])] + else: + stats[_uid] = {'uid': _uid, + 'response_time' + ext: times[index][index_s], + 'routing_score': routing_score[_uid]} + unsuccessful += [(_uid, return_ops[index][index_s], times[index][index_s])] + + return neuron_loss + routing_loss, stats, unsuccessful + + +def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tensor = None, scaling_law_power: float = 0.5): + r""" + Calculates Shapley synergy for coalition size 2, measured performance above expected performance. + Measured in effective number of model parameters, just like base Shapley values. + Args: + stats (:obj:`Dict`, `required`): + Statistics per endpoint for this batch. + synergy (:obj:`Callable`, `required`) + Function to calculate measured loss. + ext (:obj:`str`, `optional`): + Extension to parameter string for stats key. + target (:obj:`torch.Tensor`, `optional`): + Target to measure loss against. + scaling_law_power (:obj:`float`, `optional`): + Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. + + Returns: + syn_loss_diff (:obj:`Dict`, `required`): + Dictionary table of pairwise synergies as loss reductions, with direct loss on diagonal. + """ + # === Shapley synergy approximation === + # Shapley values - second level - coalition size 2 + # Synergy = measured performance above expected performance + # Measured in effective number of model parameters, just like base Shapley values. + syn_loss_diff = {} # expected_loss - measured_loss (where > 0) + for _first, first in stats.items(): + if 'loss' + ext not in first: + continue + first_diff = syn_loss_diff.setdefault(_first, {}) + first_diff[_first] = first['loss' + ext] # diagonal keeps direct loss + + for _second, second in stats.items(): + if 'loss' + ext not in second or _second <= _first: + continue + second_diff = syn_loss_diff.setdefault(_second, {}) + + with torch.no_grad(): + expected_loss = torch.min(first['loss' + ext], second['loss' + ext]) # expecting min loss + + measured_loss = synergy(first, second, target, ext) + + loss_diff_share = torch.clamp(expected_loss - measured_loss, 0) / 2 # record direct loss diff + first['synergy_loss_diff' + ext] += loss_diff_share + second['synergy_loss_diff' + ext] += loss_diff_share + + # pairwise loss reduction of expected to measured loss due to synergy between first and second + first_diff[_second] = loss_diff_share + second_diff[_first] = loss_diff_share + + measured_params = scaling_law_loss_to_params(measured_loss) + expected_params = scaling_law_loss_to_params(expected_loss) + + # powered down number of params, e.g. dynamic range 3 → 6 nats for scaling_law_power=0.5 + pow_measured_params = torch.pow(measured_params, scaling_law_power) + pow_expected_params = torch.pow(expected_params, scaling_law_power) + + synergy_share = torch.clamp(pow_measured_params - pow_expected_params, 0) / 2 + first['synergy' + ext] += synergy_share # share synergy amongst coalition members + second['synergy' + ext] += synergy_share + + return syn_loss_diff + + +def synergy_table(stats, syn_loss_diff, sort_col, console_width): + r""" Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal) + """ + sort = sorted([(uid, s[sort_col]) for uid, s in stats.items() if sort_col in s], reverse=True, key=lambda _row: _row[1]) + uid_col = neuron_stats_columns[0] # [Column_name, key_name, format_string, rich_style] + columns = [uid_col] + [[f'{s[0]}', '', '{:.2f}', ''] for s in sort] + rows = [[uid_col[2].format(s[0])] + + [('[bright_cyan]{:.2f}[/bright_cyan]' if t == s else + '[magenta]{:.2f}[/magenta]' if syn_loss_diff[s[0]][t[0]] > 0 else + '[dim]{:.0f}[/dim]').format(syn_loss_diff[s[0]][t[0]]) for t in sort] for s in sort] + + # === Synergy table === + table = Table(width=console_width, box=None) + table.title = f'[white] Synergy [/white]' + table.caption = f'loss decrease' + + for col, _, _, stl in columns: # [Column_name, key_name, format_string, rich_style] + table.add_column(col, style=stl, justify='right') + for row in rows: + table.add_row(*row) + + if len(rows): + print(table) + print() + + +def stats_table(stats, sort_col, console_width, title, caption, mark_uids=None): + r""" Gathers data and constructs neuron statistics table and prints it + """ + # === Gather columns and rows === + if mark_uids is None: + mark_uids = list() + stats_keys = [set(k for k in stat) + for stat in stats.values() if sort_col in stat] # all available stats keys with sort_col + + if len(stats_keys) == 0: + return # nothing to print + + stats_keys = set.union(*stats_keys) + columns = [c[:] for c in neuron_stats_columns if c[1] in stats_keys] # available columns intersecting with stats_keys + rows = [[('', 0) if key not in stat + else (('* ' if key == 'uid' and mark_uids and uid in mark_uids else '') + txt.format(stat[key]), stat[key]) + for _, key, txt, _ in columns] + for uid, stat in stats.items() if sort_col in stat] # only keep rows with at least one non-empty cell + + if len(columns) == 0 or len(rows) == 0: + return # nothing to print + + # === Sort rows === + col_keys = [c[1] for c in columns] + if sort_col in col_keys: + sort_idx = col_keys.index(sort_col) # sort column with key of sort_col + columns[sort_idx][0] += '\u2193' # ↓ downwards arrow (sort) + rows = sorted(rows, reverse=True, key=lambda _row: _row[sort_idx][1]) # sort according to sortcol + + # === Instantiate stats table === + table = Table(width=console_width, box=None, row_styles=[Style(bgcolor='grey15'), ""]) + table.title = title + table.caption = caption + + for col, _, _, stl in columns: # [Column_name, key_name, format_string, rich_style] + table.add_column(col, style=stl, justify='right') + for row in rows: + table.add_row(*[txt for txt, val in row]) + + # === Print table === + print(table) + + +def synapse_table(name, stats, sort_col, console_width, start_time): + r""" Prints the evaluation of the neuron responses to the validator request + """ + + stats_table(stats, sort_col, console_width, + f'[white] \[{name}] responses [/white] | Validator forward', # title + f'[bold]{len([s for s in stats.values() if len(s)])}[/bold]/{len(stats)} (respond/topk) | ' + f'[bold]Synapse[/bold] | [white]\[{time.time() - start_time:.3g}s][/white]' # caption + ) + + +def unsuccess(_name, _unsuccessful): + r""" Prints the return codes and response times of unsuccessful responses + """ + # === Unsuccessful responses === + unsuccess_txt = f'{_name} \t| Unsuccessful UID[return_op time]: ' + for _uid, _return_op, _time in _unsuccessful: + unsuccess_txt += f'{_uid}[{_return_op} {_time:.2f}] ' + logger.info(unsuccess_txt) diff --git a/bittensor/_neuron/text/multitron_server/.nucleus_impl.py.swp b/bittensor/_neuron/text/multitron_server/.nucleus_impl.py.swp deleted file mode 100644 index 4d985593ee..0000000000 Binary files a/bittensor/_neuron/text/multitron_server/.nucleus_impl.py.swp and /dev/null differ diff --git a/bittensor/_neuron/text/multitron_server/__init__.py b/bittensor/_neuron/text/multitron_server/__init__.py deleted file mode 100644 index 27099457ed..0000000000 --- a/bittensor/_neuron/text/multitron_server/__init__.py +++ /dev/null @@ -1,69 +0,0 @@ -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -""" Advanced server neurons - -Example: - $ import neurons - $ neurons.text.multitron_server().run() - -""" - -import bittensor -import os - -from .nucleus_impl import server -from .ddp_run import Server - -class neuron: - - def __new__( - self, - config: 'bittensor.config' = None - ): - if config == None: config = neuron.config() - config = config; - self.check_config( config ) - bittensor.logging ( - config = config, - logging_dir = config.neuron.full_path, - ) - - self.model = server(config=config) - self.config = config - return Server(self.config, self.model) - - @staticmethod - def config (): - return server.config() - - @staticmethod - def check_config( config: 'bittensor.Config' ): - r""" Checks/validates the config namespace object. - """ - bittensor.logging.check_config( config ) - bittensor.wallet.check_config( config ) - bittensor.subtensor.check_config( config ) - bittensor.metagraph.check_config( config ) - bittensor.dataset.check_config( config ) - bittensor.axon.check_config( config ) - bittensor.wandb.check_config( config ) - full_path = os.path.expanduser('{}/{}/{}/{}'.format( config.logging.logging_dir, config.wallet.get('name', bittensor.defaults.wallet.name), config.wallet.get('hotkey', bittensor.defaults.wallet.hotkey), config.neuron.name )) - config.neuron.full_path = os.path.expanduser(full_path) - assert config.neuron.device != 'cpu', "multitron_server must be ran on cuda device. Please consider mining with template_server or advanced_server instead." - if not os.path.exists(config.neuron.full_path): - os.makedirs(config.neuron.full_path) diff --git a/bittensor/_neuron/text/multitron_server/ddp_run.py b/bittensor/_neuron/text/multitron_server/ddp_run.py deleted file mode 100644 index 26e1bf5ba4..0000000000 --- a/bittensor/_neuron/text/multitron_server/ddp_run.py +++ /dev/null @@ -1,407 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" Advanced server neuron. - -Example: - $ python miners/text/multitron_server/main.py - -""" -from re import I - -import bittensor -import torch -import pandas -import datetime -import traceback -import sys -import os - -from loguru import logger; logger = logger.opt(colors=True) -from torch.nn.utils import clip_grad_norm_ -from datetime import datetime,timedelta -from threading import Lock -from torch.nn.parallel import DistributedDataParallel as DDP -import torch.distributed as dist -import torch.multiprocessing as mp -import time -from multiprocessing import Process, Manager, Event -import threading - -os.environ['TOKENIZERS_PARALLELISM'] = 'false' - -torch.autograd.set_detect_anomaly(True) - -class DDPPipe(): - def __init__( self, config: 'bittensor.config', gp_server, wallet: 'bittensor.wallet', forward_q, events, outputs): - r""" Initializes the neuron with the passed config. - """ - torch.autograd.set_detect_anomaly(True) - self.config = config - self.config.to_defaults() - self.gp_server = gp_server# .to(gp_server.device) - self.wallet = wallet - self.world_size = config.neuron.world_size - self.forward_q = forward_q - self.events = events - self.outputs = outputs - - - - def init_process(self, rank): - r""" For each process, anchor them to the process group - so that they know how to communication with each other. - - Args: - rank (int): - rank (id) of the process. - """ - os.environ['MASTER_ADDR'] = self.config.neuron.address - os.environ['MASTER_PORT'] = self.config.neuron.port - if 'cuda' in self.config.neuron.device: - backend = 'nccl' - else: - backend = 'gloo' - - dist.init_process_group( - backend, - rank=rank, - world_size=self.world_size, - ) - - def init_bit(self, rank = 0): - r""" Init bittensor modules after spawning process. - - Args: - rank (int): - rank (id) of the process. - """ - self.device = torch.device( device = f'cuda:{rank}' ) - self.gp_server.device = self.device - self.gp_server = self.gp_server.to(self.device) - self.subtensor = bittensor.subtensor ( config = self.config ) - self.metagraph = bittensor.metagraph ( config = self.config, subtensor = self.subtensor ) - self.metagraph.sync() - self.optimizer = torch.optim.SGD( - [ {'params': self.gp_server.parameters() } ], - lr = self.config.neuron.learning_rate, - momentum = self.config.neuron.momentum, - ) - - if rank == 0 : - logger.success( self.subtensor ) - self.subtensor.register( self.wallet ) - - bittensor.tokenizer() - - def cleanup(self): - r""" Kill the process. - """ - dist.destroy_process_group() - - def run_parallel( self, ready = None): - r""" Spawn multiple processes. - """ - self.process_ctx = mp.spawn(self.run, - args=(self.world_size, ready), - nprocs=self.world_size, - join = True - ) - - def run(self, rank = 0, world_size = 0, ready= None): - self.init_bit(rank) - if self.config.neuron.restart == False: - self.gp_server.load(self.config.neuron.full_path) - - self.gp_server = self.gp_server.to(self.device) - - nn = self.subtensor.neuron_for_pubkey(self.wallet.hotkey.ss58_address) - uid = nn.uid - - # --- last sync block - last_sync_block = self.subtensor.get_current_block() - last_set_block = last_sync_block - last_log_block = last_sync_block - last_log_time = time.time() - # -- Main Training loop -- - if ready != None and rank == 0 : - ready.set() - - try: - torch.cuda.empty_cache() - while True: - try: - request_id, inputs_x = self.forward_q.get(timeout = self.config.neuron.console_log_time) - if inputs_x != None: - inputs_x = inputs_x.to(self.device) - output = self.gp_server.encode_forward(inputs_x) - output_clone = output.detach().clone().to(device = 'cpu') - self.outputs[request_id] = output_clone - self.events[request_id].set() - del output - del output_clone - del inputs_x - torch.cuda.empty_cache() - except Exception as e: - logger.warning(e) - if 'out of memory' in str(e): - for p in self.gp_server.pre_model.parameters(): - if p.grad is not None: - del p.grad - if inputs_x != None: - del inputs_x - torch.cuda.empty_cache() - bittensor.logging.success('cleaned memory', sufix = f'rank: {rank}, {e}') - - # log if a certain time period had passed - # checking with time instead of block here to avoid frequent syncing from subtensor in a while loop - if time.time() - last_log_time > self.config.neuron.console_log_time: - last_log_time = time.time() - - # ---- syncing metagraph for all rank - current_block = self.subtensor.get_current_block() - if current_block - last_sync_block > self.config.neuron.metagraph_sync: - self.metagraph.sync() - last_sync_block = current_block - - # ---- console logging - if rank == 0: - # ---- data - data = { - 'block': current_block, - 'stake': nn.stake, - 'rank': nn.rank, - 'incentive': nn.incentive, - 'trust': nn.trust, - 'consensus': nn.consensus, - 'incentive': nn.incentive, - 'dividends': nn.dividends, - 'emission': nn.emission, - } - - # ---- console logging - bittensor.__console__.print('[green]Current Status:[/green]', data) - - except Exception as e: - # --- Unknown error ---- - logger.exception('Unknown exception: {} with traceback {}', e, traceback.format_exc()) - -class Server: - def __init__( self, config: 'bittensor.config', gp_server): - r""" Initializes the neuron with the passed config. - """ - self.config = config - self.wallet = bittensor.wallet( config = config ).create().register() - self.subtensor = bittensor.subtensor ( config = self.config ) - logger.success( self.subtensor ) - - ctx = mp.get_context('spawn') - self.forward_q = ctx.Queue() - - self.manager = Manager() - self.events = self.manager.dict() - self.outputs = self.manager.dict() - - self.axon = bittensor.axon ( - config = self.config, - wallet = self.wallet, - forward_text = self.forward_text, - backward_text = lambda x : None, - blacklist = self.blacklist, - priority = self.priority - ) - - self.axon_pipe = DDPPipe(config, gp_server, self.wallet, self.forward_q, self.events, self.outputs ) - self.timecheck = {} - self.subtensor = bittensor.subtensor ( config = self.config ) - self.metagraph = bittensor.metagraph ( config = self.config, subtensor = self.subtensor ) - self.futures = {} - self.last_sync_block = None - self.last_set_weight_block = None - - # Instantiate the model we are going to serve on the network. - # Creating a threading lock for updates to the model - # Define our forward function. - def forward_text ( self, inputs_x): - r""" Forward function that is called when the axon recieves a forward request from other peers - Args: - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - - Returns: - outputs (:obj:`torch.FloatTensor`): - The nucleus's outputs as a torch tensor of shape [batch_size, sequence_len, __network_dim__] - """ - result = None - request_id = id(inputs_x) - self.forward_q.put( (request_id, inputs_x) ) - self.events[request_id] = self.manager.Event() - - if self.events[request_id].wait(12): - result = self.outputs[request_id] - - del self.events[request_id] - del self.outputs[request_id] - - return result - - def priority(self, pubkey:str, request_type:bittensor.proto.RequestType, inputs_x) -> float: - r"""Calculates the priority on requests based on stake and size of input - - Args: - pubkey ( str, `required`): - The public key of the caller. - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - request_type ( bittensor.proto.RequestType, `required`): - the request type ('FORWARD' or 'BACKWARD'). - """ - uid = self.metagraph.hotkeys.index(pubkey) - priority = self.metagraph.S[uid].item()/ sys.getsizeof(inputs_x) - - return priority - - def blacklist(self, pubkey:str, request_type:bittensor.proto.RequestType) -> bool: - r"""Axon security blacklisting, used to blacklist message from low stake members - Args: - pubkey ( str, `required`): - The public key of the caller. - request_type ( bittensor.proto.RequestType, `required`): - the request type ('FORWARD' or 'BACKWARD'). - """ - - # Check for stake - def stake_check() -> bool: - # If we allow non-registered requests return False = not blacklisted. - is_registered = pubkey in self.metagraph.hotkeys - if not is_registered: - if self.config.neuron.blacklist_allow_non_registered: - return False - else: - return True - - # Check stake. - uid = self.metagraph.hotkeys.index(pubkey) - if request_type == bittensor.proto.RequestType.FORWARD: - if self.metagraph.S[uid].item() < self.config.neuron.blacklist.stake.forward: - return True - else: - return False - - elif request_type == bittensor.proto.RequestType.BACKWARD: - if self.metagraph.S[uid].item() < self.config.neuron.blacklist.stake.backward: - return True - else: - return False - - # Check for time - def time_check(): - current_time = datetime.now() - if pubkey in self.timecheck.keys(): - prev_time = self.timecheck[pubkey] - if current_time - prev_time >= timedelta(seconds=self.config.neuron.blacklist.time): - self.timecheck[pubkey] = current_time - return False - else: - self.timecheck[pubkey] = current_time - return True - else: - self.timecheck[pubkey] = current_time - return False - - # Black list or not - if stake_check() or time_check(): - return True - else: - return False - - def run(self): - def serve_when_ready(serve_kwargs, pipe_ready): - r""" Start to serve Axon when DDP have started - Args: - serve_kwargs(map): - Arguments for serving axon. - pipe_ready(manager.Event): - The Event when the DDP is ready - """ - if pipe_ready.wait(): - self.axon.start().serve(**serve_kwargs) - - return - - def sync(keyboard_interupt): - r""" Sync with metagraph and set weight to chain. - Args: - keyboard_interupt(manager.Event): - Whether we have tried to stop the program with keyboard_interupt. - """ - while not keyboard_interupt.is_set(): - current_block = self.subtensor.get_current_block() - if (self.last_sync_block == None) or (current_block - self.last_sync_block > self.config.neuron.metagraph_sync): - self.last_sync_block = current_block - self.metagraph.sync() - bittensor.logging.success('Metagraph synced', sufix = f'{self.last_sync_block} --> {current_block}') - - if (self.last_set_weight_block == None) or (current_block - self.last_set_weight_block > self.config.neuron.blocks_per_set_weights): - self.last_set_weight_block = current_block - chain_weights = torch.zeros(self.metagraph.n) - chain_weights [ self.uid ] = 1 - did_set = self.subtensor.set_weights( - uids=self.metagraph.uids, - weights = chain_weights, - wait_for_inclusion = False, - wallet = self.wallet, - ) - - if did_set: - logger.success('Successfully set weights on the chain') - else: - logger.error('Failed to set weights on chain. (Timeout)') - - time.sleep(self.config.neuron.check_sync_time) - - try: - self.wallet.create() - self.subtensor.register( self.wallet ) - self.metagraph.sync() - neuron = self.subtensor.neuron_for_pubkey(self.wallet.hotkey.ss58_address) - self.uid = neuron.uid - - pipe_ready = self.manager.Event() - keyboard_interupt = self.manager.Event() - axon_start_thread = threading.Thread( target = serve_when_ready, args = ({'subtensor': self.subtensor}, pipe_ready) ) - sync_thread = threading.Thread( target = sync, args = (keyboard_interupt, )) - axon_start_thread.start() - sync_thread.start() - self.axon_pipe.run_parallel(ready = pipe_ready) - - # Just to keep this run function alive. - while True: - time.sleep(20) - - except KeyboardInterrupt: - keyboard_interupt.set() - logger.success('Keyboard Interuped') - self.axon.stop() - axon_start_thread.join() - sync_thread.join() - except Exception as e: - # --- Unknown error ---- - logger.exception('Unknown exception: {} with traceback {}', e, traceback.format_exc()) - - - diff --git a/bittensor/_neuron/text/multitron_server/main.py b/bittensor/_neuron/text/multitron_server/main.py deleted file mode 100644 index 4a5d1d823b..0000000000 --- a/bittensor/_neuron/text/multitron_server/main.py +++ /dev/null @@ -1,3 +0,0 @@ -import bittensor -if __name__ == "__main__": - template = bittensor.neurons.multitron_server.neuron().run() \ No newline at end of file diff --git a/bittensor/_neuron/text/multitron_server/nucleus_impl.py b/bittensor/_neuron/text/multitron_server/nucleus_impl.py deleted file mode 100644 index 6703ed4ca6..0000000000 --- a/bittensor/_neuron/text/multitron_server/nucleus_impl.py +++ /dev/null @@ -1,247 +0,0 @@ -import argparse -import bittensor -import torch -import torch.nn.functional as F - -from transformers import AutoModel,AutoTokenizer,AutoConfig -from torch.nn.utils.rnn import pad_sequence -# from loguru import logger; logger = logger.opt(colors=True) - -class server(torch.nn.Module): - def __init__(self, - config: 'bittensor.config' = None, - pretrained: bool = None, - model_name: str = None, - padding: bool =None, - interpolate: bool =None, - inter_degree: str = None, - model = None, - tokenizer = None, - mapping_function = None, - token_remap = None, - checking= None): - r"""" Creates a server that serves up a pretrained miner on the bittensor network - Args: - config (:obj:`bittensor.Config`, `required`): - bittensor.server.config() - pretrained (:obj:bool , `optional`): - if the model should pretrained or not - model_name (:obj:string , `optional`): - name of the pretrained model from huggingface to use - padding (:obj:bool, `optional`): - If the server should pad out to match the hidden units that the bittensor network is using - If set to False, it will instead create a mapping layer to do the same thing. - interpolate (:obj:bool, `optional`): - If the server should interpolate between sequence length differences. - If set to false, there should be a mapping function that takes care of the differnces - inter_degree (:obj:str, `optional`): - The Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area) - model (:obj:torch.module, `optional`): - Overrides the huggingface pretrained model with your own pretrained model - tokenizer (:obj:huggingface.tokenizer, `optional`): - Overrides the huggingface tokenizer with your tokenizer - mapping_function (:obj:Callable, `optional`): - Custom mapping function that maps between sequence length differences between tokenizers - token_remap (:obj:Callable, `optional`): - Custom function that maps between tokenizers (defaults to self.remapping_token) - """ - super(server, self).__init__() - if config == None: config = server.config() - self.config = config;print(config) - - #setting up pretrained model - self.model_name = model_name if model_name != None else config.neuron.model_name - self.pretrained = pretrained if pretrained != None else config.neuron.pretrained - if self.pretrained == True: - self.pre_model = model if model != None else AutoModel.from_pretrained(self.model_name) - self.tokenizer = tokenizer if tokenizer != None else AutoTokenizer.from_pretrained(self.model_name) - elif self.pretrained == False: - model_config = AutoConfig.from_pretrained(self.model_name) - model_config.vocab_size= bittensor.__vocab_size__ - self.pre_model = model if model != None else AutoModel.from_config(model_config) - self.tokenizer = bittensor.tokenizer() - - self.pre_model.eval() - - #parameters of the models - self.final_dim = bittensor.__network_dim__ - self.pre_dimension = self.pre_model.config.hidden_size - self.padding = padding if padding != None else config.neuron.padding - self.interpolate = interpolate if interpolate != None else config.neuron.interpolate - self.inter_degree = inter_degree if inter_degree != None else config.neuron.inter_degree - self.checking = checking if checking != None else config.neuron.checking - self.mapping_function= mapping_function - self.token_remap = token_remap if token_remap != None else self.remapping_token - - if self.padding == False: - self.mapping = torch.nn.Linear( self.pre_dimension, self.final_dim) - - self.decoder = torch.nn.Linear( self.final_dim, bittensor.__vocab_size__ , bias=False) - self.loss_fct = torch.nn.CrossEntropyLoss() - - self.outputs_cache = None - self.gradients_cache = None - - #checking if the parameters of the server makes sense - if self.checking and pretrained == True: - self.check() - - # -- keeps track of gradients applied - self.backward_gradients = 0 - - def forward(self, inputs,tokenizer=None): - """ - Forward pass through the whole server model. Returns the loss and decoded predictions. - - Args: - inputs ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - tokenizer (:obj:'huggingface.tokenizer', optional): - The tokenizer which was used to tokenize the inputs - Returns: - loss (:obj:`torch.FloatTensor`): - MLM loss from the inputs - decoded_targets (:obj:`torch.FloatTensor`): - Decoded predictions of the next token in the sentence. - - """ - decoded_targets = self.decoder(self.encode_forward(inputs,tokenizer)) - - shift_logits = decoded_targets[..., :-1, :].contiguous() - shift_labels = inputs[..., 1:].contiguous() - loss = self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) - - return loss, decoded_targets - - def encode_forward(self,inputs,tokenizer=None): - r""" Forward pass through the pretrained model and possible mappings between hidden units. - The response tensor should be the hidden units computed using the local context and with shape: [batch_size, sequence_len, __network_dim__]. - - Args: - inputs ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - tokenizer ( huggingface.tokenizer, `optional`): - The tokenizer which was used to tokenize the inputs - - Returns: - outputs (:obj:`torch.FloatTensor`): - The nucleus's outputs as a torch tensor of shape [batch_size, sequence_len, __network_dim__] - """ - sen_len = inputs.size() - - inputs = self.token_remap(inputs,tokenizer).to(self.device) - - with torch.no_grad(): - pre_hidden = self.pre_model(inputs).last_hidden_state - - if self.interpolate: - down= F.interpolate(pre_hidden.unsqueeze(1),size=[sen_len[1],pre_hidden.size()[2]],mode=self.inter_degree).squeeze(1) - elif self.mapping_function: - down = self.mapping_function(pre_hidden) - else: - raise Exception('interpolation off but no mapping function found. Please attach a mapping function') - - if self.padding: - padding_l = (self.final_dim-self.pre_dimension)//2 - padding_r = (self.final_dim-self.pre_dimension) - padding_l - encoded_hidden = F.pad(down, (padding_l, padding_r), "constant", 0) - else: - encoded_hidden = self.mapping(down) - - return encoded_hidden - - def remapping_token(self,input, old_tokenizer=None): - r""" Default remapping of tokenizers; decodes the message and then remaps the message using a new tokenizer - Args: - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - old_tokenizer ( huggingface.tokenizer, `required`): - The tokenizer which was used to tokenize the input (defaults to bittensor tokenizer if none is given) - """ - if old_tokenizer == None: - old_tokenizer = bittensor.tokenizer() - new_data = [] - for i in range(input.shape[0]): - decoded = old_tokenizer.decode(input[i]) - hugging = self.tokenizer(decoded) - new_data += [torch.LongTensor(hugging.input_ids)] - new_data = pad_sequence(new_data,batch_first=True) - return new_data - - def check(self): - r"""Checks the server settings - """ - assert self.tokenizer.name_or_path == self.pre_model.name_or_path, 'incorrect model ({}) and tokenizer ({})'.format(self.pre_model.name_or_path,self.tokenizer.name_or_path) - if self.interpolate == False: - assert self.mapping_function != None, 'Incorrect Settings; needs atleast one mapping function for sequence length changes' - - def save(self, path): - try: - state_dict = { - 'model': self.pretrained, - 'pretrained_model': self.pre_model.state_dict(), - 'decoder': self.decoder.state_dict() - } - if self.padding == False: - state_dict['mapping'] = self.mapping.state_dict() - torch.save( state_dict, "{}/model.torch".format( path) ) - # bittensor.logging.success(prefix='Saved model', sufix='{}/model.torch'.format( path ) ) - print(f'Saved model: {path}/model.torch' ) - except Exception as e: - print('Failed to save model with error:{}', e) - - def load(self, path): - try: - state_dict= torch.load("{}/model.torch".format( path )) - if self.pretrained == state_dict['model']: - self.pre_model.load_state_dict(state_dict['pretrained_model'], strict=False) - self.decoder.load_state_dict(state_dict['decoder']) - if self.padding == False: - self.mapping.load_state_dict(state_dict['mapping']) - - # bittensor.logging.success( prefix = 'Reloaded model', sufix = '{}/model.torch'.format( path )) - print( f'Reloaded model {path}/model.torch') - - - except Exception as e: - print('No saved model found with error: {}', e) - - @staticmethod - def config (): - parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='If set, defaults are overridden by passed file.') - parser.add_argument('--neuron.learning_rate', type=float, help='Training initial learning rate.', default=0.01) - parser.add_argument('--neuron.momentum', type=float, help='optimizer momentum.', default=0.8) - parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0) - parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu")) - parser.add_argument('--neuron.model_name', type=str, help='pretrained model from hugging face',default='gpt2') - parser.add_argument('--neuron.pretrained', action='store_false', help='if the model should be pretrained',default=True) - parser.add_argument('--neuron.padding', action='store_false', help='To pad out final dimensions',default=True) - parser.add_argument('--neuron.interpolate', action='store_false', help='To interpolate between sentence length',default=True) - parser.add_argument('--neuron.inter_degree', type=str, help='Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area)', default='nearest') - parser.add_argument('--neuron.name', type=str, help='Trials for this miner go in miner.root / (wallet_cold - wallet_hot) / miner.name ', default='multitron_server') - parser.add_argument('--neuron.checking', action='store_false', help='To check if server settings are correct',default=True) - parser.add_argument('--neuron.restart', action='store_true', help='If set, train the neuron from the beginning.', default=False) - parser.add_argument('--neuron.blacklist.stake.forward', type=float, help='Amount of stake (tao) in order not to get blacklisted for forward requests', default=10) - parser.add_argument('--neuron.blacklist.stake.backward', type=float, help='Amount of stake (tao) in order not to get blacklisted for backward requests', default=100) - parser.add_argument('--neuron.blacklist_allow_non_registered', action='store_true', help='''If true, black lists non-registered peers''', default=True) - parser.add_argument('--neuron.metagraph_sync', type=float, help='how often to sync the metagraph', default=100000) - parser.add_argument('--neuron.blocks_per_set_weights', type=float, help='how often to sync set weights', default=100) - parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch', default=2) - parser.add_argument('--neuron.blacklist.time', type=int, help='how often a peer can query you (seconds) ', default=2) - parser.add_argument('--neuron.world_size', type=int, help='The number of processes for ddp.', default=1) - parser.add_argument('--neuron.address', type=str, help='The address for multiprocess communication', default='localhost') - parser.add_argument('--neuron.port', type=str, help='The port for multiprocess communication', default='8865') - parser.add_argument('--neuron.console_log_time', type=int, help='How often to log in console, in second.', default=30) - parser.add_argument('--neuron.check_sync_time', type=int, help='How often to check metagraph sync, in second.', default=180 ) - - bittensor.wallet.add_args( parser ) - bittensor.axon.add_args( parser ) - bittensor.subtensor.add_args( parser ) - bittensor.logging.add_args( parser ) - bittensor.wandb.add_args(parser) - bittensor.prioritythreadpool.add_args( parser ) - bittensor.dataset.add_args( parser ) - bittensor.metagraph.add_args( parser ) - return bittensor.config( parser ) - diff --git a/bittensor/_neuron/text/neuron_utilities.py b/bittensor/_neuron/text/neuron_utilities.py index fd769c4a94..18eca8b906 100644 --- a/bittensor/_neuron/text/neuron_utilities.py +++ b/bittensor/_neuron/text/neuron_utilities.py @@ -2,6 +2,7 @@ import bittensor import threading import time +import math import torch import torch.nn as nn import torch.nn.functional as F @@ -9,6 +10,17 @@ import queue from threading import Thread from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy + + +def calc_loss_fct(loss_fct, logits, labels): + r""" Calculates self.loss_fct with logits and labels that are expected to be aligned already. + """ + _logits = logits.contiguous() + _labels = labels.contiguous() + loss = loss_fct(_logits.view(-1, _logits.size(-1)), _labels.view(-1)) + return loss + def update_metagraph_peerweight(metagraph, nucleus, device): r""" @@ -95,7 +107,7 @@ def fisher_score_approximation(loss, peer_weights, ): validator_scores = second_order + first_order return validator_scores -def joining_context(return_ops, topk_weights, responses): +def joining_context(return_ops, topk_weights, responses, synapses): """ Joins response embbedings depending on the return codes Args: @@ -113,14 +125,22 @@ def joining_context(return_ops, topk_weights, responses): The uids used to create output """ - joining_uids= torch.where( return_ops == bittensor.proto.ReturnCode.Success )[0] - joining_weights = F.softmax( topk_weights[(return_ops == bittensor.proto.ReturnCode.Success)], dim = 0 ) - output = torch.zeros( (responses[0].shape[0], responses[0].shape[1], bittensor.__network_dim__)) - for index, joining_weight in enumerate( joining_weights ): - output += responses[joining_uids[index]]* joining_weight - return output, joining_uids - -def partial_contexts(return_ops, topk_uids, topk_weights, responses): + # TODO : Test for different modalities (currently works for casuallm) + codes = torch.stack(return_ops) + outputs = [] + for index_s, synapse in enumerate(synapses): + joining_uids= torch.where( codes[:,index_s] == bittensor.proto.ReturnCode.Success )[0] + joining_weights = F.softmax( topk_weights[(codes[:,index_s] == bittensor.proto.ReturnCode.Success)], dim = 0 ) + if len(joining_uids) != 0: + output = torch.zeros_like(responses[joining_uids[0]][index_s] ) + for index, joining_weight in enumerate( joining_weights ): + output += responses[joining_uids[index]][index_s]* joining_weight + outputs.append(output) + else: + outputs.append([]) + return outputs, joining_uids + +def partial_contexts(return_ops, topk_uids, topk_weights, responses, synapses): """ Creates the partial contexts which are used to calculate the shapley scores @@ -139,16 +159,15 @@ def partial_contexts(return_ops, topk_uids, topk_weights, responses): A dict containing all of joinned contexts with a single peer masked out """ + # TODO : Test for different modalities (currently works for casuallm) partial_context = {} with torch.no_grad(): for i, uid in enumerate(topk_uids): - partial_return_ops = return_ops.clone() + partial_return_ops = deepcopy(return_ops) # --- Only mask peers that successfully - if partial_return_ops[i] != bittensor.proto.ReturnCode.Success: - pass - else: - partial_return_ops[i] = bittensor.proto.ReturnCode.NoReturn - partial_context[uid.item()], _ = joining_context(partial_return_ops, topk_weights, responses) + partial_return_ops[i][ partial_return_ops[i] == bittensor.proto.ReturnCode.Success ] = bittensor.proto.ReturnCode.NoReturn + + partial_context[uid.item()], _ = joining_context(partial_return_ops, topk_weights, responses, synapses) return partial_context class ThreadQueue(threading.Thread): @@ -213,4 +232,39 @@ def is_empty(self): return self.queue.empty() def get(self): - return self.queue.get() \ No newline at end of file + return self.queue.get() + + +class PositionalEncoding(nn.Module): + r""" Positional Encoder which adds information based on the relative position of each token + + """ + + def __init__(self, d_model: int, dropout: float, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + + # === Create position matrix === + # Creates a positional matrix with alternating frequencies + # pe: (torch.FloatTensor) positional encoding matrix + # pe.shape: [1, max_len, network_dim] + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x: torch.tensor) -> torch.tensor: + """ + Args: + x: Tensor, shape [batch_size, seq_len, embedding_dim] + """ + # === Positional Encoding === + # Inject some information of the relative position of the token in the sequence. + # Finally, Dropout is applied to tokens + # x: (torch.FloatTensor) input sequence tokens with position information injected + # x.shape: [batch_size, seq_len, network_dim] + x = x + self.pe[0, :x.size(1)] + return self.dropout(x) diff --git a/bittensor/_neuron/text/template_miner/__init__.py b/bittensor/_neuron/text/template_miner/__init__.py deleted file mode 100644 index 5fe109f846..0000000000 --- a/bittensor/_neuron/text/template_miner/__init__.py +++ /dev/null @@ -1,709 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" The bittensor core miner - -Example: - $ python miners/core_miner.py --logging.debug - -""" -import sys -import argparse -import time -import bittensor -import torch -import os -import wandb -import math -import pandas -import traceback -from rich import print -from rich.console import Console -from rich.traceback import install -from datetime import datetime,timedelta -from ..neuron_utilities import joining_context, partial_contexts -import torch.nn as nn - -from torch.nn.utils import clip_grad_norm_ -import torch.nn.functional as F -from torch.nn import TransformerEncoder, TransformerEncoderLayer -from loguru import logger -logger = logger.opt( colors=True ) -console = Console() -install(show_locals=True) - -class neuron: - """ Neuron class which drives the training of the validator. - """ - def __init__( self, config: 'bittensor.Config' = None ): - - # === Set up Config === - if config == None: config = neuron.config() - self.config = config - neuron.check_config( self.config ) - self.config.to_defaults() - if self.config.neuron._mock == True: - self.config.subtensor._mock = True - self.config.wallet._mock = True - self.config.dataset._mock = True - self.config.dendrite._mock = True - self.config.metagraph._mock = True - self.config.subtensor._mock = True - print ( self.config ) - - # === Create Bittensor objects === - bittensor.logging( config = self.config, logging_dir = self.config.neuron.full_path ) - self.wallet = bittensor.wallet ( config = self.config ) - self.subtensor = bittensor.subtensor ( config = self.config ) - self.metagraph = bittensor.metagraph ( config = config, subtensor = self.subtensor ) - self.dendrite = bittensor.dendrite ( config = self.config, wallet = self.wallet ) - self.device = torch.device ( device = self.config.neuron.device ) - self.axon = bittensor.axon ( - config = self.config, - wallet = self.wallet, - priority = self.priority, - forward_text = self.forward_text, - blacklist = self.blacklist, - ) - self.dataset = bittensor.dataset ( config = self.config, batch_size = self.subtensor.validator_batch_size, block_size = self.subtensor.validator_sequence_length ) - self.nucleus = nucleus ( config = self.config, device = self.device, subtensor = self.subtensor ).to( self.device ) - self.optimizer = torch.optim.SGD ( self.nucleus.parameters(), lr = self.config.neuron.learning_rate, momentum = self.config.neuron.momentum ) - self.timecheck = {} - - self.moving_avg_scores = None - - @classmethod - def add_args( cls, parser ): - parser.add_argument('--neuron.name', type=str, help='Trials for this miner go in miner.root / (wallet_cold - wallet_hot) / miner.name ', default='core_validator') - parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu")) - parser.add_argument('--neuron.learning_rate', type=float, help='Training initial learning rate.', default=0.1 ) - parser.add_argument('--neuron.momentum', type=float, help='optimizer momentum.', default=0.8 ) - parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch, -1 value means we use the chain value.', default = -1 ) - parser.add_argument('--neuron.epochs_until_reset', type=int, help='Number of epochs before weights are reset.', default = -1 ) - parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0 ) - parser.add_argument('--neuron.restart_on_failure', action='store_true', help='''Restart neuron on unknown error.''', default=True ) - parser.add_argument('--neuron._mock', action='store_true', help='To turn on neuron mocking for testing purposes.', default=False ) - parser.add_argument('--neuron.blacklist.stake.forward', type=float, help='Amount of stake (tao) in order not to get blacklisted for forward requests', default=10) - parser.add_argument('--neuron.blacklist.stake.backward', type=float, help='Amount of stake (tao) in order not to get blacklisted for backward requests', default=100) - parser.add_argument('--neuron.blacklist_allow_non_registered', action='store_true', help='''If true, black lists non-registered peers''', default=True) - parser.add_argument('--neuron.use_upnpc', action='store_true', help='''neuron attempts to port forward axon using upnpc.''', default=False) - parser.add_argument('--neuron.wait_for_finalization', action='store_true', help='''when setting weights the miner waits for trnasaction finalization.''', default=False) - - @classmethod - def config ( cls ): - parser = argparse.ArgumentParser() - cls.add_args( parser ) - nucleus.add_args( parser ) - bittensor.wallet.add_args( parser ) - bittensor.dendrite.add_args( parser ) - bittensor.subtensor.add_args( parser ) - bittensor.metagraph.add_args( parser ) - bittensor.logging.add_args( parser ) - bittensor.dataset.add_args( parser ) - bittensor.wandb.add_args( parser ) - bittensor.axon.add_args( parser ) - return bittensor.config( parser ) - - @classmethod - def check_config( cls, config: 'bittensor.Config' ): - r""" Checks/validates the config namespace object. - """ - nucleus.check_config( config ) - bittensor.logging.check_config( config ) - bittensor.wallet.check_config( config ) - bittensor.subtensor.check_config( config ) - bittensor.metagraph.check_config( config ) - bittensor.dataset.check_config( config ) - bittensor.dendrite.check_config( config ) - bittensor.wandb.check_config( config ) - bittensor.axon.check_config( config ) - full_path = os.path.expanduser('{}/{}/{}/{}'.format( config.logging.logging_dir, config.wallet.name, config.wallet.hotkey, config.neuron.name )) - config.neuron.full_path = os.path.expanduser(full_path) - config.using_wandb = config.wandb.api_key != 'default' - if not os.path.exists(config.neuron.full_path): - os.makedirs(config.neuron.full_path) - - # ---- Axon Forward call ---- - def forward_text ( self, inputs_x: torch.FloatTensor) -> torch.FloatTensor: - r""" Subscribed to an axon servicing endpoint: processes forward messages from the wire. - The arguments reflect an RPC request from another miner in the network, the response tensor - should be the hidden units computed using the local context and with shape: [batch_size, sequence_len, __network_dim__]. - - Args: - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - - Returns: - outputs (:obj:`torch.FloatTensor`): - The nucleus's outputs as a torch tensor of shape [batch_size, sequence_len, __network_dim__] - """ - outputs_y = self.nucleus ( - inputs_x.to( self.device ), - self.metagraph, - self.dendrite, - is_inference = True - ).to( self.device ) - return outputs_y - - - def priority(self, pubkey:str, request_type:bittensor.proto.RequestType, inputs_x: torch.FloatTensor) -> float: - r"""Return the request priority based on stake and size of input. - Used by the Axon to order requests. - Args: - pubkey ( str, `required`): - The public ss58 address of the caller. - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - request_type ( bittensor.proto.RequestType, `required`): - the request type ('FORWARD' or 'BACKWARD'). - """ - try: - # Priority = stake / request_size - priority = self.metagraph.S[ self.metagraph.hotkeys.index(pubkey) ] / sys.getsizeof(inputs_x) - except: - return 0 - - - return priority - - def blacklist(self, pubkey:str, request_type:bittensor.proto.RequestType) -> bool: - r"""Axon security blacklisting, used to blacklist message from low stake members - Args: - pubkey ( str, `required`): - The public key of the caller. - request_type ( bittensor.proto.RequestType, `required`): - the request type ('FORWARD' or 'BACKWARD'). - """ - - # Check for stake - def stake_check() -> bool: - # If we allow non-registered requests return False = not blacklisted. - is_registered = pubkey in self.metagraph.hotkeys - if not is_registered: - return not self.config.neuron.blacklist_allow_non_registered - - # Check stake. - uid = self.metagraph.hotkeys.index(pubkey) - if request_type == bittensor.proto.RequestType.FORWARD: - return self.metagraph.S[uid].item() < self.config.neuron.blacklist.stake.forward - - elif request_type == bittensor.proto.RequestType.BACKWARD: - return self.metagraph.S[uid].item() < self.config.neuron.blacklist.stake.backward - - # Check for time - def time_check(): - current_time = datetime.now() - self.timecheck[pubkey] = current_time - if pubkey in self.timecheck.keys(): - prev_time = self.timecheck[pubkey] - return not current_time - prev_time >= timedelta(seconds=self.config.neuron.blacklist.time) - else: - return False - - # Black list or not - return stake_check() or time_check() - - def __exit__ ( self, exc_type, exc_value, exc_traceback ): - r""" Close down neuron. - """ - print(exc_type, exc_value, exc_traceback) - self.dataset.close() - self.dendrite.__del__() - - def __enter__(self): - r""" Sanity checks and begin miner-validator. - """ - # === Wallet === - # Connects wallett to network. - # NOTE: This registration step should likely be solved offline first. - self.wallet.register( subtensor = self.subtensor ) - - # === UID === - # Get our uid from the chain. - # At this point we should have a uid because we are already registered. - self.uid = self.wallet.get_uid( subtensor = self.subtensor ) - - # === Start the axon server ==== - self.axon.start().serve ( - use_upnpc = self.config.neuron.use_upnpc, - subtensor = self.subtensor - ) - - # === Monitoring === - # Optionally set up wandb logging. - if self.config.using_wandb: - bittensor.wandb( - config = self.config, - cold_pubkey = self.wallet.coldkeypub.ss58_address, - hot_pubkey = self.wallet.hotkey.ss58_address, - root_dir = self.config.neuron.full_path - ) - - def run ( self ): - r""" Run the miner-validator and terminate on Keyboard interrupt. - """ - - # === Setup === - # Checks wallet and starts monitoring with wandb. - with self: - - # === Run === - # Iterates through epochs. - self.epoch = 0 - self.global_step = 0 - while True: - try: - # === Epoch === - # Each epoch runs for blocks_per_epoch and resets - # the model every epochs_until_reset. - self.run_epoch() - # === Stops on interrupt otherwise restarts === - except KeyboardInterrupt: - break - except Exception as e: - console.print_exception(show_locals=False) - print( traceback.format_exc() ) - print( 'Unknown exception: {}', e ) - if not self.config.neuron.restart_on_failure: - break - - - def run_epoch( self ): - r""" Runs a miner-validator epoch. We apply batches until the epoch length is exhausted. - Occasionally the validator nucleus is completely reset to ensure we dont converge to far. - At the end of the epoch we set weights on the chain and optionally log to wandb. - """ - # === Get params for epoch === - # Pulling the latest chain parameters. - current_block = self.subtensor.block - batch_size = self.subtensor.validator_batch_size - sequence_length = self.subtensor.validator_sequence_length - n_topk_peer_weights = self.subtensor.min_allowed_weights - max_allowed_ratio = self.subtensor.max_allowed_min_max_ratio - blocks_per_epoch = self.subtensor.validator_epoch_length if self.config.neuron.blocks_per_epoch == -1 else self.config.neuron.blocks_per_epoch - epochs_until_reset = self.subtensor.validator_epochs_per_reset if self.config.neuron.epochs_until_reset == -1 else self.config.neuron.epochs_until_reset - # === Logs === - print ( '\nEra:', '\n\t batch_size:', batch_size, '\n\t sequence_length:', sequence_length, '\n\t n_topk_peer_weights:', n_topk_peer_weights, - '\n\t max_allowed_ratio:', max_allowed_ratio, '\n\t blocks_per_epoch:', blocks_per_epoch, '\n\t epochs_until_reset:', epochs_until_reset, - '\n\t until_reset:', self.epoch % epochs_until_reset, '\n\t current_block:', current_block, '\n') - if self.config.using_wandb: - wandb.log( { 'era/batch_size': batch_size, 'era/sequence_length': sequence_length, 'era/n_topk_peer_weights': n_topk_peer_weights, - 'era/max_allowed_ratio': max_allowed_ratio, 'era/blocks_per_epoch': blocks_per_epoch, 'era/epochs_until_reset': epochs_until_reset, - }, step = current_block ) - - # === Reset Epochs with new params. === - # Pulls new default validator training parameters and resets - # the model and dataset for the following epoch. - if self.epoch % epochs_until_reset == 0: - print ('\n\n=== Reset ===\n\n') - # === Resetting model + dataset === - if (batch_size != self.dataset.batch_size) or (sequence_length != self.dataset.block_size): - self.dataset.set_data_size(batch_size, sequence_length) - - self.nucleus = nucleus ( config = self.config, device = self.device, subtensor = self.subtensor ).to( self.device ) - self.optimizer = torch.optim.SGD ( - self.nucleus.parameters(), lr = self.config.neuron.learning_rate, momentum = self.config.neuron.momentum - ) - # === Reset Scores === - self.moving_avg_scores = torch.ones_like( self.metagraph.S ) * -1 - - # Checks if moving avg has been initiated - if self.moving_avg_scores == None: - self.moving_avg_scores = torch.ones_like( self.metagraph.S ) * -1 - - - # === Run Epoch === - # Each block length lasts blocks_per_epoch blocks. - # This gives us a consistent network wide timer. - # Here we run until blocks_per_epochs have progressed. - self.metagraph_sync() - epoch_steps = 0 - score_history = [] - start_block = self.subtensor.block - current_block = start_block - while self.subtensor.block < start_block + blocks_per_epoch: - start_time = time.time() - - # === Forward === - # Forwards inputs through the network and returns the loss - # and endpoint scores using shapely approximation of salience. - loss, scores,uids = self.nucleus( next( self.dataset ), self.metagraph, self.dendrite, is_inference = False ) - - # === Backward === - # Backwards gradients through model to train gating and remote endpoints. - loss.backward() - - # === Apply gradients === - # Applies local gradients to parameters. - clip_grad_norm_(self.nucleus.parameters(), self.config.neuron.clip_gradients) - self.optimizer.step() - self.optimizer.zero_grad() - - # === Scoring === - # Updates moving averages and history. - self.moving_avg_scores[uids] = self.moving_avg_scores[uids]*(0.99) + scores*(0.01) - - # === State update === - # Prints step logs to screen. - epoch_steps += 1 - self.global_step += 1 - current_block = self.subtensor.block - step_time = time.time() - start_time - - # === Logs === - print( '\nStep:', '\n\t epoch:', self.epoch, '\n\t epoch_steps:', epoch_steps, '\n\t step:', self.global_step, '\n\t step_time:', step_time, '\n\t loss:', loss.item(), - '\n\t current_block', current_block, '\n\t blocks remaining:', current_block - start_block, '/', blocks_per_epoch, '\n') - if self.config.using_wandb: - wandb.log( { 'epoch/epoch': self.epoch, 'epoch/epoch_steps': epoch_steps, 'epoch/global_step': self.global_step, 'epoch/loss': loss.item(), 'epoch/time': step_time }, step = current_block ) - step_topk_scores, step_topk_uids = bittensor.unbiased_topk( self.moving_avg_scores, k = n_topk_peer_weights ) - step_topk_normalized = bittensor.utils.weight_utils.normalize_max_multiple( x = step_topk_scores, multiple = max_allowed_ratio ) - for i, w in list(zip(step_topk_uids.tolist(), step_topk_normalized.tolist()) ): - wandb.log( {'w_{}'.format( i ): w }, step = current_block ) - - # Iterate epochs. - self.epoch += 1 - - # === Set weights === - # Find the n_topk_peer_weights peers to set weights to. - # We use the mean of the epoch weights. - topk_scores, topk_uids = bittensor.unbiased_topk( self.moving_avg_scores, k = n_topk_peer_weights ) - topk_scores = bittensor.utils.weight_utils.normalize_max_multiple( x = topk_scores, multiple = max_allowed_ratio ) - print( '\nScores:', '\n\t weights:', topk_scores.sort()[0].tolist(), '\n\t sum:', topk_scores.sum().item(), - '\n\t min:', topk_scores.min().item(), '\n\t max:', topk_scores.max().item(), '\n\t max/min:', (topk_scores.max()/topk_scores.min()).item() ) - self.subtensor.set_weights( - uids = topk_uids.detach().to('cpu'), - weights = topk_scores.detach().to('cpu'), - wallet = self.wallet, - wait_for_finalization = self.config.neuron.wait_for_finalization, - ) - - # === Wandb Logs === - # Optionally send miner-validator logs to wandb. - if self.config.using_wandb: - # Logging history to wandb. - df = pandas.concat( [ - bittensor.utils.indexed_values_to_dataframe( prefix = 'weights', index = topk_uids, values = torch.zeros( self.metagraph.n ).scatter( dim = 0, src = topk_scores, index = topk_uids ) ), - self.dendrite.to_dataframe( metagraph = self.metagraph ), - self.axon.to_dataframe( metagraph = self.metagraph ), - ], axis = 1); df['uid'] = df.index - wandb_data = { 'stake': self.metagraph.S[ self.uid ].item(), 'dividends': self.metagraph.D[ self.uid ].item() } - wandb.log( { 'stats': wandb.Table( dataframe = df ) }, step = current_block ) - wandb.log( { **wandb_data, **self.dendrite.to_wandb(), **self.axon.to_wandb() }, step = current_block ) - - def metagraph_sync(self): - r""" Syncing metagraph together with other metagraph-size related objects - """ - self.metagraph.sync() - - if self.moving_avg_scores == None: - self.moving_avg_scores = torch.ones_like( self.metagraph.S ) * -1 - - if self.metagraph.n > len(self.moving_avg_scores): - size_increase = self.metagraph.n - len(self.moving_avg_scores) - self.moving_avg_scores = torch.concat([self.moving_avg_scores, torch.ones(size_increase) * -1]) - -class PositionalEncoding(nn.Module): - r""" Positional Encoder which adds information based on the relative position of each token - - """ - def __init__(self, d_model: int, dropout: float, max_len: int = 5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - - # === Create position matrix === - # Creates a positional matrix with alternating frequencies - # pe: (torch.FloatTensor) positional encoding matrix - # pe.shape: [1, max_len, network_dim] - pe = torch.zeros(1, max_len, d_model) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, : , 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - - def forward(self, x: torch.tensor) -> torch.tensor: - """ - Args: - x: Tensor, shape [batch_size, seq_len, embedding_dim] - """ - # === Positional Encoding === - # Inject some information of the relative position of the token in the sequence. - # Finally, Dropout is applied to tokens - # x: (torch.FloatTensor) input sequence tokens with position information injected - # x.shape: [batch_size, seq_len, network_dim] - x = x + self.pe[0, :x.size(1)] - return self.dropout(x) - -class nucleus( torch.nn.Module ): - """ Nucleus class which holds the miner-validator model. - """ - def __init__( self, config, device, subtensor ): - super(nucleus, self).__init__() - self.config = config - self.device = device - self.max_n = subtensor.max_n - - # Token embeddings project int64 tokens onto representations. - self.token_embedding = torch.nn.Embedding( bittensor.__vocab_size__, bittensor.__network_dim__ ) - - # Routing encoder, projects token embeddings onto context for routing inputs. - self.routing_encoder_layers = TransformerEncoderLayer( bittensor.__network_dim__, config.nucleus.nhead, config.nucleus.nhid, config.nucleus.dropout, batch_first=True) - self.routing_encoder = TransformerEncoder( self.routing_encoder_layers, 1 ) - - # Encoder projects response representations onto hidden units. - self.encoder_layers = TransformerEncoderLayer( bittensor.__network_dim__, config.nucleus.nhead, config.nucleus.nhid, config.nucleus.dropout, batch_first=True) - self.encoder = TransformerEncoder( self.encoder_layers, config.nucleus.nlayers ) - - # Student distillation layer learns to emulate the network and project it back. - self.student_layers = TransformerEncoderLayer( bittensor.__network_dim__, config.nucleus.student_nhead, config.nucleus.student_nhid, config.nucleus.student_dropout, batch_first=True) - self.student = TransformerEncoder( self.student_layers, config.nucleus.student_nlayers ) - - # Decoder which projects hidden unit representations on to the token dimension. - self.decoder = torch.nn.Linear( bittensor.__network_dim__, bittensor.__vocab_size__ , bias=False) - - # Positional Encoding - self.local_pos_encoder = PositionalEncoding( bittensor.__network_dim__, self.config.nucleus.dropout ) - - # Crosss entropy loss for NTP. - self.loss_fct = torch.nn.CrossEntropyLoss() - - # SGMOE Gates: Instantiating the gates per expert. - self.gates = torch.nn.Linear( bittensor.__network_dim__, self.max_n, bias=True ).to( self.device ) - self.reset_weights() - - @classmethod - def add_args( cls, parser ): - parser.add_argument('--nucleus.topk', type=int, help='the number of peers queried during each remote forward call', default = 50 ) - parser.add_argument('--nucleus.nhid', type=int, help='Encoder: the dimension of the feedforward network model in nn.TransformerEncoder', default=200 ) - parser.add_argument('--nucleus.nhead', type=int, help='Encoder:the number of heads in the multiheadattention models', default = 2 ) - parser.add_argument('--nucleus.nlayers', type=int, help='Encoder:the number of nn.TransformerEncoderLayer in nn.TransformerEncoder', default=2 ) - parser.add_argument('--nucleus.dropout', type=float, help='Encoder:the dropout value', default=0.2) - parser.add_argument('--nucleus.student_nhid', type=int, help='Student: the dimension of the feedforward network model in nn.TransformerEncoder', default=200 ) - parser.add_argument('--nucleus.student_nhead', type=int, help='Student: the number of heads in the multiheadattention models', default = 2 ) - parser.add_argument('--nucleus.student_nlayers', type=int, help='Student: the number of nn.TransformerEncoderLayer in nn.TransformerEncoder', default=2 ) - parser.add_argument('--nucleus.student_dropout', type=float, help='Student: the dropout value', default=0.2) - parser.add_argument('--nucleus.importance', type=float, help='hyperparameter for the importance loss', default=0.1) - - @classmethod - def config ( cls ): - parser = argparse.ArgumentParser() - cls.add_args( parser ) - return bittensor.config( parser ) - - @classmethod - def check_config( cls, config: 'bittensor.Config' ): - pass - - def reset_weights ( self ): - r""" Resets the validator weights. - """ - # === Resets all the weights using xavier initialization. === - torch.nn.init.xavier_uniform_ ( self.token_embedding.weight ) - torch.nn.init.xavier_uniform_ ( self.decoder.weight ) - torch.nn.init.xavier_uniform_( self.gates.weight ) - def init_xavier( component ): - try: - torch.nn.init.xavier_uniform_( component.weight ) - except: pass - self.routing_encoder.apply( init_xavier ) - self.encoder.apply( init_xavier ) - torch.nn.init.xavier_uniform_( self.gates.weight ) - - def forward ( - self, - inputs: torch.FloatTensor, - metagraph: 'bittensor.Metagraph', - dendrite: 'bittensor.Dendrite', - is_inference: bool = False - ): - r""" Forward miner-validator pass. Selects peer to query, joins results and computes scoring. - Args: - inputs (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, *-1*)`, `required`): - Tensor inputs to distribute to neurons using query context. - metagraph (bittensor.Metagraph): - Metagraph object used to query network information. - dendrite (bittensor.Dendrite): - Dendrite RPC client used to make network queries. - is_inference (bool): - If true, we only compute the student model and dont query the network. - Returns: - if not: is_inference - loss (torch.FloatTensor, [1] ): - Loss for training miner-validator nucleus. - scores (torch.FloatTensor, [ metagraph.n ]): - Scores per endpoint for this batch. - else: - student_embedding (torch.FloatTensor, [sequence_len, batch_size, bittensor.__network_dim__]): - Student embeddings from learning to emulate the network. - - """ - # === Create the local context used to select endpoints === - # The context tensor returns a hidden unit representation for the text inputs - # this context can be used as input to the gates in the next step. - # embedding: retrieve learned representation vectors for input vocabulary tokens. - # inputs.shape = [batch_size, sequence_len] - # embedding.shape = [batch_size, sequence_len, bittensor.__network_dim__] - embedding = self.token_embedding( inputs )* math.sqrt( bittensor.__network_dim__ ) - - # === Create an attention mask === - # The attention mask will mask out parts of the context - # This prevents cheating and forward-looking when predicting each token in the sequence. - # src_mask: (torch.FloatTensor) attention mask adds -inf to positions not allowed to attend - # src_mask.shape = [sequence_len, sequence_len] - src_mask = torch.triu(torch.ones(embedding.size(1), embedding.size(1)) * float('-inf'), diagonal=1) - src_mask = src_mask.to(self.config.neuron.device) - - # === Apply the positional encoding to help select endpoints === - # The positional encoder provides information based on the relative postion of each token - # embedding.shape = [batch_size, sequence_len, bittensor.__network_dim__] - # pos_embedding: (torch.FloatTensor) positional encoded embedding. - # pos_embedding.shape = [batch_size, sequence_len, bittensor.__network_dim__] - pos_embedding = self.local_pos_encoder(embedding) - - # routing_context: (torch.FloatTensor): context tensor which is used to select endpoints. - # routing_context.shape = [ batch size, __network_dim__ ] - routing_context = self.routing_encoder( pos_embedding, mask = src_mask ) - - # === Apply the distillation model which we will use in inference mode to query the miner === - # without any network calls. - # student_embedding: hidden layer encoding of sequence with local_context. - # student_embedding.shape = [sequence_len, batch_size, bittensor.__network_dim__] - student_embedding = self.student( routing_context.detach(), mask = src_mask ) * math.sqrt(bittensor.__network_dim__) - - # === Optionally return fast only the inference aspect of the model === - if is_inference: - return student_embedding - # Otherwise run the remainder of the network training, similar to the validator. - - # === Get weights for uids. === - # We iterate over each of the network uids and compute a querying score for each - # using the gating function. This returns a score per endpoint per example. - # routing_weights: (torch.FloatTensor): score per example, per endpoint. - # routing_weights.shape = [ batch size, __network_n__ ] - # The gates act over the last embedding of the routing_context. - routing_weights = self.gates( routing_context[:,-1,:] ) - - # === Normalize routing_weights across batch dimension and add noise. === - # We are summing across the batch dimension to create a per-batch score per endpoint. - # The resulting routing_weights tensor is a score per expert. - # routing_weights: (torch.FloatTensor): normalized weights across batch dimension with noise. - # routing_weights.shape = [ n_filtered ] - batchwise_routing_weights = torch.mean(routing_weights, axis = 0)[:metagraph.n] - noisy_routing_weights = torch.normal( 0, torch.std(batchwise_routing_weights).item(), size=( batchwise_routing_weights.size())).to( self.config.neuron.device ) - noisy_routing_weights = batchwise_routing_weights + noisy_routing_weights - - # === Compute Importance loss === - # Computes the importance loss based on the stardard error of batchwise_routing_weights - # This ensures that gates do not converge onto a few experts - # importance_loss: (torch.float64) the importance loss based on the stardard error - # target_loss: (torch.float64): the total loss (global training loss + importance loss) - # target_loss.shape = [ 1 ] - importance_loss = self.config.nucleus.importance * (torch.std(batchwise_routing_weights)/torch.mean(batchwise_routing_weights))**2 - - # === Get indices and values for uids with highest scores === - # We are taking the topk routing weights and returning their uids. - # First we ensure topk is smaller than the network size then use the torch.topk. - # topk_routing_weights: (torch.float64): scores of uids with highest scores. - # topk_routing_weights.shape = [ self.config.nucleus.topk ] - # topk_routing_uids: (torch.LongTensor): uids with highest scores. - # topk_routing_uids.shape = [ self.config.nucleus.topk ] - top_k_routing_weights, routing_uids = torch.topk( noisy_routing_weights, self.config.nucleus.topk, dim=0) - - # === Get endpoint information for the highest scoring uids === - # We index into the metagraph's endpoints and return a list of the filtered set of endpoints we wish to query. - # routing_endpoints: List[bittensor.endpoints]: endpoint information for filtered uids. - # len(neurons) == self.config.nucleus.topk - routing_endpoints = [ metagraph.endpoints[ uid ] for uid in routing_uids ] - - # === Query the endpoints === - # Makes the dendrite call into the network returning the representations - # for each of the endpoints. The return ops can be used to filter weights and outputs. - # query_responses: (List[torch.float64]): responses from each endpoint. - # query_responses.shape = self.config.nucleus.topk * [ batch_size, sequence_len, __network_dim__ ] - # return_ops: (torch.int64): Return ops. - # return_ops.shape = [ self.config.nucleus.topk ] - responses, return_ops, times = dendrite.forward_text ( - endpoints = routing_endpoints, - inputs = inputs - ) - # Send responses to device. This is required to ensure we move the responses - # Onto the correct device. - for response in responses: - response.to( self.device ) - - # === Compute global loss === - # Computes the global training loss for the nucleus by decoding all the responses - # onto the targets. - # target_loss: (torch.float64): loss after decoding all responses and a variance loss. - # target_loss.shape = [ 1 ] - responses_hidden, _ = joining_context( return_ops, batchwise_routing_weights[routing_uids], responses) - - # === Compute loss given joined responses === - # This function computes target loss for next token prediction given - # the joined responses as a hidden unit input. - # target_loss: (torch.float64): loss after decoding responses to targets. - # target_loss.shape = [ 1 ] - def get_target_loss ( hidden, targets ): - # hidden: (torch.float64): [ batch_size, sequence_len, __network_dim__ ] - # Hidden units which are encoded and decoded onto targets for loss computation. - # targets: (torch.float64): [n] - # Token targets, - src_mask = torch.triu(torch.ones(hidden.size(1), hidden.size(1)) * float('-inf'), diagonal=1) - src_mask = src_mask.to(self.config.neuron.device) - encoded_hidden = self.encoder( hidden, mask = src_mask ) - decoded_targets = self.decoder( encoded_hidden ) - shift_logits = decoded_targets[..., :-1, :].contiguous() - shift_labels = targets[..., 1:].contiguous() - return self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) - - - # === Compute shapely scores === - # Computes shapely scores for each endpoint by masking the response and - # computing the change in loss induced. - # shapely_scores: (torch.float32): shapely scores per query_response - # shapely_scores.shape = [ nucleus.topk ] - masked_contexts = partial_contexts(return_ops, routing_uids, batchwise_routing_weights[routing_uids], responses) - # This sets non queried peers as if non-responsive - shapely_scores = torch.zeros( routing_uids.size()) - # Turn off gradient computation for shapely scores. - with torch.no_grad(): - self.eval() - unmasked_loss = get_target_loss(responses_hidden, inputs) - - # Iterate over all responses creating a masked context. - for i, uid in enumerate(masked_contexts): - # Create mask by zeroing out the response at index. - masked_loss = get_target_loss ( masked_contexts[uid], inputs ) - shapely_score = unmasked_loss - masked_loss - print ('Shapely\t|\tuid: {}\tweight: {}\tscore: {}\tcode: {}\tsum: {}'.format( uid, batchwise_routing_weights[routing_uids][i], -shapely_score.item(), return_ops[i], responses[i].sum())) - shapely_scores[ i ] = -shapely_score - - # Ensures that the nonresponsive peers are not rewarded - shapely_scores[return_ops != 1 ] = -1 - - # distillation_loss : distillation loss between local_context and remote_context - # distillation_loss.shape = [1] - # This trains the local_context (student) to emulate the network context. - distillation_loss = F.mse_loss( student_embedding, responses_hidden.detach() ) - - # sum losses for training. - target_loss = get_target_loss ( responses_hidden, inputs ) - loss = importance_loss + target_loss + distillation_loss - print ('Loss\t|\t{}'.format( loss.item() )) - # === Done === - return loss, shapely_scores, routing_uids \ No newline at end of file diff --git a/bittensor/_neuron/text/template_miner/main.py b/bittensor/_neuron/text/template_miner/main.py deleted file mode 100644 index 1e1bc7e0d0..0000000000 --- a/bittensor/_neuron/text/template_miner/main.py +++ /dev/null @@ -1,27 +0,0 @@ - -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -""" Main Validator script. - -Example: - $ python3 bittensor/_neurons/text/template_miner/main.py ... -""" -import bittensor -if __name__ == "__main__": - bittensor.utils.version_checking() - bittensor.neurons.template_miner.neuron().run() diff --git a/bittensor/_neuron/text/template_server/nucleus_impl.py b/bittensor/_neuron/text/template_server/nucleus_impl.py deleted file mode 100644 index 29bb7bf2e7..0000000000 --- a/bittensor/_neuron/text/template_server/nucleus_impl.py +++ /dev/null @@ -1,252 +0,0 @@ -import argparse -import bittensor -import torch -import torch.nn.functional as F - -from transformers import AutoModel,AutoTokenizer,AutoConfig -from torch.nn.utils.rnn import pad_sequence - -from loguru import logger; logger = logger.opt(colors=True) - -class server(torch.nn.Module): - def __init__(self, - config: 'bittensor.config' = None, - pretrained: bool = None, - model_name: str = None, - padding: bool =None, - interpolate: bool =None, - inter_degree: str = None, - model = None, - tokenizer = None, - mapping_function = None, - token_remap = None, - checking= None): - r"""" Creates a server that serves up a pretrained miner on the bittensor network - Args: - config (:obj:`bittensor.Config`, `required`): - bittensor.server.config() - pretrained (:obj:bool , `optional`): - if the model should pretrained or not - model_name (:obj:string , `optional`): - name of the pretrained model from huggingface to use - padding (:obj:bool, `optional`): - If the server should pad out to match the hidden units that the bittensor network is using - If set to False, it will instead create a mapping layer to do the same thing. - interpolate (:obj:bool, `optional`): - If the server should interpolate between sequence length differences. - If set to false, there should be a mapping function that takes care of the differnces - inter_degree (:obj:str, `optional`): - The Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area) - model (:obj:torch.module, `optional`): - Overrides the huggingface pretrained model with your own pretrained model - tokenizer (:obj:huggingface.tokenizer, `optional`): - Overrides the huggingface tokenizer with your tokenizer - mapping_function (:obj:Callable, `optional`): - Custom mapping function that maps between sequence length differences between tokenizers - token_remap (:obj:Callable, `optional`): - Custom function that maps between tokenizers (defaults to self.remapping_token) - """ - super(server, self).__init__() - if config == None: config = server.config() - self.config = config;print(config) - - #setting up pretrained model - self.model_name = model_name if model_name != None else config.neuron.model_name - self.pretrained = pretrained if pretrained != None else config.neuron.pretrained - if self.pretrained == True: - self.pre_model = model if model != None else AutoModel.from_pretrained(self.model_name) - self.tokenizer = tokenizer if tokenizer != None else AutoTokenizer.from_pretrained(self.model_name) - elif self.pretrained == False: - model_config = AutoConfig.from_pretrained(self.model_name) - model_config.vocab_size= bittensor.__vocab_size__ - self.pre_model = model if model != None else AutoModel.from_config(model_config) - self.tokenizer = bittensor.tokenizer() - - if self.config.neuron.training: - self.pre_model.train() - elif self.config.neuron.autocast and self.device == 'cuda': - self.pre_model.half() - else: - self.pre_model.eval() - - #parameters of the models - self.final_dim = bittensor.__network_dim__ - self.pre_dimension = self.pre_model.config.hidden_size - self.device = config.neuron.device - self.padding = padding if padding != None else config.neuron.padding - self.interpolate = interpolate if interpolate != None else config.neuron.interpolate - self.inter_degree = inter_degree if inter_degree != None else config.neuron.inter_degree - self.checking = checking if checking != None else config.neuron.checking - self.mapping_function= mapping_function - self.token_remap = token_remap if token_remap != None else self.remapping_token - - if self.config.neuron.padding == False: - self.mapping = torch.nn.Linear( self.pre_dimension, self.final_dim) - - self.decoder = torch.nn.Linear( self.final_dim, bittensor.__vocab_size__ , bias=False) - self.loss_fct = torch.nn.CrossEntropyLoss() - - self.outputs_cache = None - self.gradients_cache = None - - #checking if the parameters of the server makes sense - if self.checking and pretrained == True: - self.check() - - # -- keeps track of gradients applied - self.backward_gradients = 0 - - def forward(self, inputs,tokenizer=None): - """ - Forward pass through the whole server model. Returns the loss and decoded predictions. - - Args: - inputs ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - tokenizer (:obj:'huggingface.tokenizer', optional): - The tokenizer which was used to tokenize the inputs - Returns: - loss (:obj:`torch.FloatTensor`): - MLM loss from the inputs - decoded_targets (:obj:`torch.FloatTensor`): - Decoded predictions of the next token in the sentence. - - """ - decoded_targets = self.decoder(self.encode_forward(inputs,tokenizer)) - - shift_logits = decoded_targets[..., :-1, :].contiguous() - shift_labels = inputs[..., 1:].contiguous() - loss = self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) - - return loss, decoded_targets - - def encode_forward(self,inputs,tokenizer=None): - r""" Forward pass through the pretrained model and possible mappings between hidden units. - The response tensor should be the hidden units computed using the local context and with shape: [batch_size, sequence_len, __network_dim__]. - - Args: - inputs ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - tokenizer ( huggingface.tokenizer, `optional`): - The tokenizer which was used to tokenize the inputs - - Returns: - outputs (:obj:`torch.FloatTensor`): - The nucleus's outputs as a torch tensor of shape [batch_size, sequence_len, __network_dim__] - """ - sen_len = inputs.size() - inputs = self.token_remap(inputs,tokenizer).to(self.device) - if self.config.neuron.training: - pre_hidden = self.pre_model(inputs).last_hidden_state - elif self.config.neuron.autocast and self.device == 'cuda': - pre_hidden = self.pre_model(inputs).last_hidden_state - else: - with torch.no_grad(): - pre_hidden = self.pre_model(inputs).last_hidden_state - - if self.interpolate and sen_len[1] != pre_hidden.size()[1]: - down= F.interpolate(pre_hidden.unsqueeze(1),size=[sen_len[1],pre_hidden.size()[2]],mode=self.inter_degree).squeeze(1) - elif self.mapping_function: - down = self.mapping_function(pre_hidden) - else: - down = pre_hidden - - - if self.padding: - padding_l = (self.final_dim-self.pre_dimension)//2 - padding_r = (self.final_dim-self.pre_dimension) - padding_l - encoded_hidden = F.pad(down, (padding_l, padding_r), "constant", 0) - else: - encoded_hidden = self.mapping(down) - return encoded_hidden - - def remapping_token(self,input, old_tokenizer=None): - r""" Default remapping of tokenizers; decodes the message and then remaps the message using a new tokenizer - Args: - inputs_x ( :obj:`torch.Tensor`, `required`): - torch inputs to be forward processed. - old_tokenizer ( huggingface.tokenizer, `required`): - The tokenizer which was used to tokenize the input (defaults to bittensor tokenizer if none is given) - """ - if old_tokenizer == None: - old_tokenizer = bittensor.tokenizer() - new_data = [] - for i in range(input.shape[0]): - decoded = old_tokenizer.decode(input[i]) - hugging = self.tokenizer(decoded) - new_data += [torch.LongTensor(hugging.input_ids)] - new_data = pad_sequence(new_data,batch_first=True) - return new_data - - def check(self): - r"""Checks the server settings - """ - assert self.tokenizer.name_or_path == self.pre_model.name_or_path, 'incorrect model ({}) and tokenizer ({})'.format(self.pre_model.name_or_path,self.tokenizer.name_or_path) - if self.interpolate == False: - assert self.mapping_function != None, 'Incorrect Settings; needs atleast one mapping function for sequence length changes' - - def save(self, path): - try: - state_dict = { - 'model': self.pretrained, - 'pretrained_model': self.pre_model.state_dict(), - 'decoder': self.decoder.state_dict() - } - if self.padding == False: - state_dict['mapping'] = self.mapping.state_dict() - torch.save( state_dict, "{}/model.torch".format( path) ) - bittensor.logging.success(prefix='Saved model', sufix='{}/model.torch'.format( path ) ) - except Exception as e: - logger.exception('Failed to save model with error:{}', e) - - def load(self, path): - try: - state_dict= torch.load("{}/model.torch".format( path )) - if self.pretrained == state_dict['model']: - self.pre_model.load_state_dict(state_dict['pretrained_model'], strict=False) - self.decoder.load_state_dict(state_dict['decoder']) - if self.padding == False: - self.mapping.load_state_dict(state_dict['mapping']) - - bittensor.logging.success( prefix = 'Reloaded model', sufix = '{}/model.torch'.format( path )) - - - except Exception as e: - logger.warning('No saved model found with error: {}', e) - - @staticmethod - def config (): - parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='If set, defaults are overridden by passed file.') - parser.add_argument('--neuron.learning_rate', type=float, help='Training initial learning rate.', default=0.01) - parser.add_argument('--neuron.momentum', type=float, help='optimizer momentum.', default=0.8) - parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0) - parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu")) - parser.add_argument('--neuron.model_name', type=str, help='pretrained model from hugging face',default='gpt2') - parser.add_argument('--neuron.pretrained', action='store_false', help='if the model should be pretrained',default=True) - parser.add_argument('--neuron.padding', action='store_false', help='To pad out final dimensions',default=True) - parser.add_argument('--neuron.interpolate', action='store_false', help='To interpolate between sentence length',default=True) - parser.add_argument('--neuron.inter_degree', type=str, help='Interpolate algorithm (nearest | linear | bilinear | bicubic | trilinear | area)', default='nearest') - parser.add_argument('--neuron.name', type=str, help='Trials for this miner go in miner.root / (wallet_cold - wallet_hot) / miner.name ', default='advanced_server') - parser.add_argument('--neuron.checking', action='store_false', help='To check if server settings are correct',default=True) - parser.add_argument('--neuron.restart', action='store_true', help='If True, train the neuron from the beginning', default=False) - parser.add_argument('--neuron.blacklist.stake', type=float, help='Amount of stake (tao) in order not to get blacklisted', default=10) - parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch', default=10) - parser.add_argument('--neuron.blacklist.time', type=int, help='how often a peer can query you (seconds) ', default=1) - parser.add_argument('--neuron.training', action='store_true', help='if the model should be training (increases memory load)', default=False) - parser.add_argument('--neuron.autocast', action='store_true', help='(experimental) autocasts the model to float16. Must require cuda', default=False) - parser.add_argument('--neuron.blocks_per_set_weights', type=float, help='how often to set weights', default=100) - parser.add_argument('--neuron.metagraph_sync', type=float, help='how often to sync the metagraph', default=100000) - parser.add_argument('--neuron.blacklist_allow_non_registered', action='store_true', help='''If true, allow non-registered peers''', default=False) - - - bittensor.wallet.add_args( parser ) - bittensor.axon.add_args( parser ) - bittensor.subtensor.add_args( parser ) - bittensor.logging.add_args( parser ) - bittensor.wandb.add_args(parser) - bittensor.prioritythreadpool.add_args( parser ) - bittensor.dataset.add_args( parser ) - bittensor.metagraph.add_args( parser ) - return bittensor.config( parser ) - diff --git a/bittensor/_neuron/text/template_server/run.py b/bittensor/_neuron/text/template_server/run.py deleted file mode 100644 index 2da4bddd76..0000000000 --- a/bittensor/_neuron/text/template_server/run.py +++ /dev/null @@ -1,236 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" The Exodus base client. - -Example: - $ python miners/text/template_client.py - -""" -import bittensor -import sys -import torch -import time -import wandb -import pandas -import datetime -from threading import Lock -from loguru import logger; logger = logger.opt(colors=True) -from datetime import datetime,timedelta - -def serve( - config, - model, - subtensor = None, - wallet = None, - axon= None, - metagraph = None, - ): - config.to_defaults() - model= model.to(model.device) - - # Create Subtensor connection - subtensor = bittensor.subtensor(config = config) if subtensor == None else subtensor - - # Load/Create our bittensor wallet. - if wallet == None: - wallet = bittensor.wallet( config = config ).create().register(subtensor=subtensor) - else: - wallet.register(subtensor=subtensor) - - - # Load/Sync/Save our metagraph. - if metagraph == None: - metagraph = bittensor.metagraph ( - subtensor = subtensor - ) - - metagraph.load().sync().save() - - # Create our optimizer. - optimizer = torch.optim.SGD( - [ {"params": model.parameters()} ], - lr = config.neuron.learning_rate, - momentum = config.neuron.momentum, - ) - mutex = Lock() - - timecheck = {} - n_topk_peer_weights = subtensor.min_allowed_weights - def forward_text ( inputs_x ): - r""" Single threaded version of the Forward function that is called when the axon recieves a forward request from other peers - """ - return model.encode_forward( inputs_x.to(model.device)) - - - def backward_text ( inputs_x, grads_dy ): - r"""Single threaded backwards function that is called when the axon recieves a backwards request from other peers. - Updates the server parameters with gradients through the chain. - """ - if config.neuron.training: - with mutex: - with torch.enable_grad(): - with torch.autograd.set_detect_anomaly(True): - outputs_y = model.encode_forward( inputs_x ) - torch.autograd.backward ( - tensors = [ outputs_y ], - grad_tensors = [ grads_dy ] - ) - optimizer.step() - optimizer.zero_grad() - - - def blacklist(pubkey:str, request_type:bittensor.proto.RequestType) -> bool: - r"""Axon security blacklisting, used to blacklist message from low stake members - Args: - pubkey ( str, `required`): - The public key of the caller. - request_type ( bittensor.proto.RequestType, `required`): - the request type ('FORWARD' or 'BACKWARD'). - """ - # Check for registrations - - def registration_check(): - # If we allow non-registered requests return False = not blacklisted. - is_registered = pubkey in metagraph.hotkeys - if not is_registered: - if config.neuron.blacklist_allow_non_registered: - - return False - raise Exception('blacklist') - - # Check for stake - def stake_check() -> bool: - - # Check stake. - uid = metagraph.hotkeys.index(pubkey) - if metagraph.S[uid].item() < config.neuron.blacklist.stake: - raise Exception('Stake blacklist') - return False - - def validator_check(): - - uid = metagraph.hotkeys.index(pubkey) - if (metagraph.W[uid] >0).sum() >= n_topk_peer_weights: - return False - - raise Exception('Validator blacklist') - - - # Check for time - def time_check(): - current_time = datetime.now() - if pubkey in timecheck.keys(): - prev_time = timecheck[pubkey] - if current_time - prev_time >= timedelta(seconds=config.neuron.blacklist.time): - timecheck[pubkey] = current_time - return False - else: - timecheck[pubkey] = current_time - raise Exception('blacklist') - else: - timecheck[pubkey] = current_time - return False - - # Black list or not - try: - registration_check() - - stake_check() - - validator_check() - - return False - - except Exception as e: - return True - - - # Create our axon server and subscribe it to the network. - if axon == None: - axon = bittensor.axon ( - config = config, - wallet = wallet, - forward_text = forward_text, - backward_text = backward_text, - blacklist = blacklist, - ).start().serve(subtensor=subtensor) - - if config.wandb.api_key != 'default': - # --- Init Wandb. - bittensor.wandb( - config = config, - cold_pubkey = wallet.coldkeypub.ss58_address, - hot_pubkey = wallet.hotkey.ss58_address, - root_dir = config.neuron.full_path - ) - - last_set_block = subtensor.get_current_block() - - - # --- Run Forever. - while True: - - current_block = subtensor.get_current_block() - end_block = current_block + config.neuron.blocks_per_epoch - while end_block >= current_block: - time.sleep( bittensor.__blocktime__ ) - current_block = subtensor.get_current_block() - - nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address) - uid = metagraph.hotkeys.index( wallet.hotkey.ss58_address ) - wandb_data = { - 'stake': nn.stake, - 'rank': nn.rank, - 'trust': nn.trust, - 'consensus': nn.consensus, - 'incentive': nn.incentive, - 'emission': nn.emission, - } - bittensor.__console__.print('[green]Current Status:[/green]', wandb_data) - if config.wandb.api_key != 'default': - - df = pandas.concat( [ - bittensor.utils.indexed_values_to_dataframe( prefix = 'w_i_{}'.format(nn.uid), index = metagraph.uids, values = metagraph.W[:, uid] ), - axon.to_dataframe( metagraph = metagraph ), - ], axis = 1) - df['uid'] = df.index - wandb_info_axon = axon.to_wandb() - wandb.log( { **wandb_data, **wandb_info_axon }, step = current_block ) - wandb.log( { 'stats': wandb.Table( dataframe = df ) }, step = current_block ) - - if current_block - last_set_block > config.neuron.blocks_per_set_weights: - try: - last_set_block = current_block - # Set self weights to maintain activity. - # --- query the chain for the most current number of peers on the network - chain_weights = torch.zeros(subtensor.n) - chain_weights [ uid ] = 1 - did_set = subtensor.set_weights( - uids=torch.arange(0,subtensor.n), - weights = chain_weights, - wait_for_inclusion = False, - wallet = wallet, - ) - - metagraph.sync() - if did_set: - logger.success('Successfully set weights on the chain') - else: - logger.error('Failed to set weights on chain. (Timeout)') - except Exception as e: - logger.error('Failure setting weights on chain with error: {}', e) diff --git a/bittensor/_proto/bittensor.proto b/bittensor/_proto/bittensor.proto index bc7ce96ec7..f9d5b064d7 100644 --- a/bittensor/_proto/bittensor.proto +++ b/bittensor/_proto/bittensor.proto @@ -1,3 +1,5 @@ +// python3 -m grpc.tools.protoc bittensor/_proto/bittensor.proto -I. --python_out=. --grpc_python_out=. + syntax = "proto3"; // Service definition for tensor processing servers. @@ -86,6 +88,154 @@ message TensorMessage { // Requires grad: [OPTIONAL] Does this tensor require a gradient. bool requires_grad = 8; + + // Synapses hold function information this tells the axon how to use the tensor inputs. + // i.e. where to send them + repeated Synapse synapses = 9; +} + + +message Synapse { + + enum SynapseType { + NULL_SYNAPSE = 0; + TEXT_LAST_HIDDEN_STATE = 1; + TEXT_CAUSAL_LM = 2; + TEXT_SEQ_2_SEQ = 3; + TEXT_CAUSAL_LM_NEXT = 4; + } + + // Position of Tensor inputs for corresponding synapse call. + repeated int32 tensor_pos = 1; + + // Serialized special argmument data. This is proto data which + // is packed according to the SynapseType enum. + bytes synapse_data = 2; + + // Type of Synapse. i.e. LastHidden specifies how to decode args_data + // and route information in the axon to the correct function call. + SynapseType synapse_type = 3; + + // Return codes from Backward and Forward call associated + // with this synapse call. + ReturnCode return_code = 4; + + // Message associated with the return code. + string message = 5; + + // Requires grad: [OPTIONAL] Does this synapse call require a gradient. + bool requires_grad = 6; + + message TextLastHiddenState { + // Might as well have this + SynapseType synapse_type = 1; + + // Serializer typing. + Serializer forward_request_serializer_type = 2; + Serializer forward_response_serializer_type = 3; + Serializer backward_request_serializer_type = 4; + Serializer backward_response_serializer_type = 5; + + // Requires grad: [OPTIONAL] Does this synapse call require a gradient. + bool requires_grad = 6; + } + + message TextCausalLM { + // Might as well have this + SynapseType synapse_type = 1; + + // Specifies the number of Topk logits to return. + // Logit values are packed (pos, value) acocording + // to the bittensor tokenizer vocab size. + int32 topk = 2; + + // Serializer typing. + Serializer forward_request_serializer_type = 3; + Serializer forward_response_serializer_type = 4; + Serializer backward_request_serializer_type = 5; + Serializer backward_response_serializer_type = 6; + + // Requires grad: [OPTIONAL] Does this synapse call require a gradient. + bool requires_grad = 7; + } + + message TextSeq2Seq { + // Might as well have this + SynapseType synapse_type = 1; + + // Specifies the number of Topk logits to return. + // Logit values are packed (pos, value) acocording + // to the bittensor tokenizer vocab size. + int32 topk = 2; + + // Number of tokens to predict + int32 num_to_generate = 3; + + + // Serializer typing. + Serializer forward_request_serializer_type = 4; + Serializer forward_response_serializer_type = 5; + Serializer backward_request_serializer_type = 6; + Serializer backward_response_serializer_type = 7; + + //Generate Arguments + // Number of beams + int32 num_beams = 8; + + //Number of repeat words + int32 no_repeat_ngram_size = 9; + + //Early Stopping + bool early_stopping = 10; + + //Number of return seuqences + int32 num_return_sequences = 11; + + //If sampling should be used + bool do_sample = 12; + + //The probability cutoff + float top_p = 13; + + // Requires grad: [OPTIONAL] Does this synapse call require a gradient. + bool requires_grad = 14; + + //temperature of the softmax function + float temperature = 15; + + //penalty for repeated words + float repetition_penalty = 16; + + //penalty for length + float length_penalty = 17; + + //maximum amount of time + float max_time = 18; + + //groups for beam search + int32 num_beam_groups = 19; + } + + message TextCausalLMNext { + // Specifies messaging of topk server token phrases with probabilities. + // Server last position token predictions are retokenized to token phrases with the bittensor tokenizer. + // Allows for zero translation loss CausalLM next generation between different tokenizers. + + // Might as well have this + SynapseType synapse_type = 1; + + // Specifies the number of topk server token phrases to return. + int32 topk = 2; + + // Serializer typing. + Serializer forward_request_serializer_type = 3; + Serializer forward_response_serializer_type = 4; + Serializer backward_request_serializer_type = 5; + Serializer backward_response_serializer_type = 6; + + // Requires grad: [OPTIONAL] Does this synapse call require a gradient. + bool requires_grad = 7; + } } // Return codes from Backward and Forward call. @@ -114,6 +264,7 @@ enum ReturnCode { SenderUnknown = 21; // The requester is not known by the reciever. UnknownException = 22; // Unknown exception. Unauthenticated = 23; // Authentication failed. + BadEndpoint = 24; // Dummy endpoint } // A serialized tensor object created using the serializer class. @@ -158,10 +309,9 @@ message Tensor { // Requires grad: [OPTIONAL] Does this tensor require a gradient. // 1 bit. bool requires_grad = 8; + } -// Dtype: [REQUIRED] The tensor serializer type. -// For use between multiple serialziation deserialziation methods. enum Serializer { // PICKLE = 0; // PICKLE serializer (REMOVED for security reasons.) MSGPACK = 0; // MSGPACK serializer @@ -198,4 +348,4 @@ enum RequestType { NOTDEFINED = 0; FORWARD = 1; BACKWARD = 2; -} \ No newline at end of file +} diff --git a/bittensor/_proto/bittensor_pb2.py b/bittensor/_proto/bittensor_pb2.py index d96ba23c06..4dff530222 100644 --- a/bittensor/_proto/bittensor_pb2.py +++ b/bittensor/_proto/bittensor_pb2.py @@ -20,7 +20,7 @@ syntax='proto3', serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_pb=b'\n bittensor/_proto/bittensor.proto\"\x8f\x01\n\x06Neuron\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x0b\n\x03uid\x18\x02 \x01(\x03\x12\x0e\n\x06hotkey\x18\x03 \x01(\t\x12\x0f\n\x07\x63oldkey\x18\x04 \x01(\t\x12\n\n\x02ip\x18\x05 \x01(\t\x12\x0c\n\x04port\x18\x06 \x01(\x05\x12\x0f\n\x07ip_type\x18\x07 \x01(\x05\x12\x1b\n\x08modality\x18\x08 \x01(\x0e\x32\t.Modality\"\x94\x01\n\rTensorMessage\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x0e\n\x06hotkey\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x05 \x03(\x0b\x32\x07.Tensor\x12 \n\x0breturn_code\x18\x06 \x01(\x0e\x32\x0b.ReturnCode\x12\x0f\n\x07message\x18\x07 \x01(\t\x12\x15\n\rrequires_grad\x18\x08 \x01(\x08\"\xc9\x01\n\x06Tensor\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x0e\n\x06\x62uffer\x18\x02 \x01(\x0c\x12\r\n\x05shape\x18\x03 \x03(\x03\x12\x1f\n\nserializer\x18\x04 \x01(\x0e\x32\x0b.Serializer\x12 \n\x0btensor_type\x18\x05 \x01(\x0e\x32\x0b.TensorType\x12\x18\n\x05\x64type\x18\x06 \x01(\x0e\x32\t.DataType\x12\x1b\n\x08modality\x18\x07 \x01(\x0e\x32\t.Modality\x12\x15\n\rrequires_grad\x18\x08 \x01(\x08*\xb8\x04\n\nReturnCode\x12\x0c\n\x08NoReturn\x10\x00\x12\x0b\n\x07Success\x10\x01\x12\x0b\n\x07Timeout\x10\x02\x12\x0b\n\x07\x42\x61\x63koff\x10\x03\x12\x0f\n\x0bUnavailable\x10\x04\x12\x12\n\x0eNotImplemented\x10\x05\x12\x10\n\x0c\x45mptyRequest\x10\x06\x12\x11\n\rEmptyResponse\x10\x07\x12\x13\n\x0fInvalidResponse\x10\x08\x12\x12\n\x0eInvalidRequest\x10\t\x12\x19\n\x15RequestShapeException\x10\n\x12\x1a\n\x16ResponseShapeException\x10\x0b\x12!\n\x1dRequestSerializationException\x10\x0c\x12\"\n\x1eResponseSerializationException\x10\r\x12#\n\x1fRequestDeserializationException\x10\x0e\x12$\n ResponseDeserializationException\x10\x0f\x12\x15\n\x11NotServingNucleus\x10\x10\x12\x12\n\x0eNucleusTimeout\x10\x11\x12\x0f\n\x0bNucleusFull\x10\x12\x12\x1e\n\x1aRequestIncompatibleVersion\x10\x13\x12\x1f\n\x1bResponseIncompatibleVersion\x10\x14\x12\x11\n\rSenderUnknown\x10\x15\x12\x14\n\x10UnknownException\x10\x16\x12\x13\n\x0fUnauthenticated\x10\x17*&\n\nSerializer\x12\x0b\n\x07MSGPACK\x10\x00\x12\x0b\n\x07\x43MPPACK\x10\x01*2\n\nTensorType\x12\t\n\x05TORCH\x10\x00\x12\x0e\n\nTENSORFLOW\x10\x01\x12\t\n\x05NUMPY\x10\x02*^\n\x08\x44\x61taType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07\x46LOAT32\x10\x01\x12\x0b\n\x07\x46LOAT64\x10\x02\x12\t\n\x05INT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\x08\n\x04UTF8\x10\x05\x12\x0b\n\x07\x46LOAT16\x10\x06*+\n\x08Modality\x12\x08\n\x04TEXT\x10\x00\x12\t\n\x05IMAGE\x10\x01\x12\n\n\x06TENSOR\x10\x02*8\n\x0bRequestType\x12\x0e\n\nNOTDEFINED\x10\x00\x12\x0b\n\x07\x46ORWARD\x10\x01\x12\x0c\n\x08\x42\x41\x43KWARD\x10\x02\x32\x66\n\tBittensor\x12+\n\x07\x46orward\x12\x0e.TensorMessage\x1a\x0e.TensorMessage\"\x00\x12,\n\x08\x42\x61\x63kward\x12\x0e.TensorMessage\x1a\x0e.TensorMessage\"\x00\x62\x06proto3' + serialized_pb=b'\n bittensor/_proto/bittensor.proto\"\x8f\x01\n\x06Neuron\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x0b\n\x03uid\x18\x02 \x01(\x03\x12\x0e\n\x06hotkey\x18\x03 \x01(\t\x12\x0f\n\x07\x63oldkey\x18\x04 \x01(\t\x12\n\n\x02ip\x18\x05 \x01(\t\x12\x0c\n\x04port\x18\x06 \x01(\x05\x12\x0f\n\x07ip_type\x18\x07 \x01(\x05\x12\x1b\n\x08modality\x18\x08 \x01(\x0e\x32\t.Modality\"\xb0\x01\n\rTensorMessage\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x0e\n\x06hotkey\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x05 \x03(\x0b\x32\x07.Tensor\x12 \n\x0breturn_code\x18\x06 \x01(\x0e\x32\x0b.ReturnCode\x12\x0f\n\x07message\x18\x07 \x01(\t\x12\x15\n\rrequires_grad\x18\x08 \x01(\x08\x12\x1a\n\x08synapses\x18\t \x03(\x0b\x32\x08.Synapse\"\xb1\x0e\n\x07Synapse\x12\x12\n\ntensor_pos\x18\x01 \x03(\x05\x12\x14\n\x0csynapse_data\x18\x02 \x01(\x0c\x12*\n\x0csynapse_type\x18\x03 \x01(\x0e\x32\x14.Synapse.SynapseType\x12 \n\x0breturn_code\x18\x04 \x01(\x0e\x32\x0b.ReturnCode\x12\x0f\n\x07message\x18\x05 \x01(\t\x12\x15\n\rrequires_grad\x18\x06 \x01(\x08\x1a\xb4\x02\n\x13TextLastHiddenState\x12*\n\x0csynapse_type\x18\x01 \x01(\x0e\x32\x14.Synapse.SynapseType\x12\x34\n\x1f\x66orward_request_serializer_type\x18\x02 \x01(\x0e\x32\x0b.Serializer\x12\x35\n forward_response_serializer_type\x18\x03 \x01(\x0e\x32\x0b.Serializer\x12\x35\n backward_request_serializer_type\x18\x04 \x01(\x0e\x32\x0b.Serializer\x12\x36\n!backward_response_serializer_type\x18\x05 \x01(\x0e\x32\x0b.Serializer\x12\x15\n\rrequires_grad\x18\x06 \x01(\x08\x1a\xbb\x02\n\x0cTextCausalLM\x12*\n\x0csynapse_type\x18\x01 \x01(\x0e\x32\x14.Synapse.SynapseType\x12\x0c\n\x04topk\x18\x02 \x01(\x05\x12\x34\n\x1f\x66orward_request_serializer_type\x18\x03 \x01(\x0e\x32\x0b.Serializer\x12\x35\n forward_response_serializer_type\x18\x04 \x01(\x0e\x32\x0b.Serializer\x12\x35\n backward_request_serializer_type\x18\x05 \x01(\x0e\x32\x0b.Serializer\x12\x36\n!backward_response_serializer_type\x18\x06 \x01(\x0e\x32\x0b.Serializer\x12\x15\n\rrequires_grad\x18\x07 \x01(\x08\x1a\xd0\x04\n\x0bTextSeq2Seq\x12*\n\x0csynapse_type\x18\x01 \x01(\x0e\x32\x14.Synapse.SynapseType\x12\x0c\n\x04topk\x18\x02 \x01(\x05\x12\x17\n\x0fnum_to_generate\x18\x03 \x01(\x05\x12\x34\n\x1f\x66orward_request_serializer_type\x18\x04 \x01(\x0e\x32\x0b.Serializer\x12\x35\n forward_response_serializer_type\x18\x05 \x01(\x0e\x32\x0b.Serializer\x12\x35\n backward_request_serializer_type\x18\x06 \x01(\x0e\x32\x0b.Serializer\x12\x36\n!backward_response_serializer_type\x18\x07 \x01(\x0e\x32\x0b.Serializer\x12\x11\n\tnum_beams\x18\x08 \x01(\x05\x12\x1c\n\x14no_repeat_ngram_size\x18\t \x01(\x05\x12\x16\n\x0e\x65\x61rly_stopping\x18\n \x01(\x08\x12\x1c\n\x14num_return_sequences\x18\x0b \x01(\x05\x12\x11\n\tdo_sample\x18\x0c \x01(\x08\x12\r\n\x05top_p\x18\r \x01(\x02\x12\x15\n\rrequires_grad\x18\x0e \x01(\x08\x12\x13\n\x0btemperature\x18\x0f \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x10 \x01(\x02\x12\x16\n\x0elength_penalty\x18\x11 \x01(\x02\x12\x10\n\x08max_time\x18\x12 \x01(\x02\x12\x17\n\x0fnum_beam_groups\x18\x13 \x01(\x05\x1a\xbf\x02\n\x10TextCausalLMNext\x12*\n\x0csynapse_type\x18\x01 \x01(\x0e\x32\x14.Synapse.SynapseType\x12\x0c\n\x04topk\x18\x02 \x01(\x05\x12\x34\n\x1f\x66orward_request_serializer_type\x18\x03 \x01(\x0e\x32\x0b.Serializer\x12\x35\n forward_response_serializer_type\x18\x04 \x01(\x0e\x32\x0b.Serializer\x12\x35\n backward_request_serializer_type\x18\x05 \x01(\x0e\x32\x0b.Serializer\x12\x36\n!backward_response_serializer_type\x18\x06 \x01(\x0e\x32\x0b.Serializer\x12\x15\n\rrequires_grad\x18\x07 \x01(\x08\"|\n\x0bSynapseType\x12\x10\n\x0cNULL_SYNAPSE\x10\x00\x12\x1a\n\x16TEXT_LAST_HIDDEN_STATE\x10\x01\x12\x12\n\x0eTEXT_CAUSAL_LM\x10\x02\x12\x12\n\x0eTEXT_SEQ_2_SEQ\x10\x03\x12\x17\n\x13TEXT_CAUSAL_LM_NEXT\x10\x04\"\xc9\x01\n\x06Tensor\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x0e\n\x06\x62uffer\x18\x02 \x01(\x0c\x12\r\n\x05shape\x18\x03 \x03(\x03\x12\x1f\n\nserializer\x18\x04 \x01(\x0e\x32\x0b.Serializer\x12 \n\x0btensor_type\x18\x05 \x01(\x0e\x32\x0b.TensorType\x12\x18\n\x05\x64type\x18\x06 \x01(\x0e\x32\t.DataType\x12\x1b\n\x08modality\x18\x07 \x01(\x0e\x32\t.Modality\x12\x15\n\rrequires_grad\x18\x08 \x01(\x08*\xc9\x04\n\nReturnCode\x12\x0c\n\x08NoReturn\x10\x00\x12\x0b\n\x07Success\x10\x01\x12\x0b\n\x07Timeout\x10\x02\x12\x0b\n\x07\x42\x61\x63koff\x10\x03\x12\x0f\n\x0bUnavailable\x10\x04\x12\x12\n\x0eNotImplemented\x10\x05\x12\x10\n\x0c\x45mptyRequest\x10\x06\x12\x11\n\rEmptyResponse\x10\x07\x12\x13\n\x0fInvalidResponse\x10\x08\x12\x12\n\x0eInvalidRequest\x10\t\x12\x19\n\x15RequestShapeException\x10\n\x12\x1a\n\x16ResponseShapeException\x10\x0b\x12!\n\x1dRequestSerializationException\x10\x0c\x12\"\n\x1eResponseSerializationException\x10\r\x12#\n\x1fRequestDeserializationException\x10\x0e\x12$\n ResponseDeserializationException\x10\x0f\x12\x15\n\x11NotServingNucleus\x10\x10\x12\x12\n\x0eNucleusTimeout\x10\x11\x12\x0f\n\x0bNucleusFull\x10\x12\x12\x1e\n\x1aRequestIncompatibleVersion\x10\x13\x12\x1f\n\x1bResponseIncompatibleVersion\x10\x14\x12\x11\n\rSenderUnknown\x10\x15\x12\x14\n\x10UnknownException\x10\x16\x12\x13\n\x0fUnauthenticated\x10\x17\x12\x0f\n\x0b\x42\x61\x64\x45ndpoint\x10\x18*&\n\nSerializer\x12\x0b\n\x07MSGPACK\x10\x00\x12\x0b\n\x07\x43MPPACK\x10\x01*2\n\nTensorType\x12\t\n\x05TORCH\x10\x00\x12\x0e\n\nTENSORFLOW\x10\x01\x12\t\n\x05NUMPY\x10\x02*^\n\x08\x44\x61taType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07\x46LOAT32\x10\x01\x12\x0b\n\x07\x46LOAT64\x10\x02\x12\t\n\x05INT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\x08\n\x04UTF8\x10\x05\x12\x0b\n\x07\x46LOAT16\x10\x06*+\n\x08Modality\x12\x08\n\x04TEXT\x10\x00\x12\t\n\x05IMAGE\x10\x01\x12\n\n\x06TENSOR\x10\x02*8\n\x0bRequestType\x12\x0e\n\nNOTDEFINED\x10\x00\x12\x0b\n\x07\x46ORWARD\x10\x01\x12\x0c\n\x08\x42\x41\x43KWARD\x10\x02\x32\x66\n\tBittensor\x12+\n\x07\x46orward\x12\x0e.TensorMessage\x1a\x0e.TensorMessage\"\x00\x12,\n\x08\x42\x61\x63kward\x12\x0e.TensorMessage\x1a\x0e.TensorMessage\"\x00\x62\x06proto3' ) _RETURNCODE = _descriptor.EnumDescriptor( @@ -150,11 +150,16 @@ serialized_options=None, type=None, create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='BadEndpoint', index=24, number=24, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), ], containing_type=None, serialized_options=None, - serialized_start=538, - serialized_end=1106, + serialized_start=2410, + serialized_end=2995, ) _sym_db.RegisterEnumDescriptor(_RETURNCODE) @@ -179,8 +184,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1108, - serialized_end=1146, + serialized_start=2997, + serialized_end=3035, ) _sym_db.RegisterEnumDescriptor(_SERIALIZER) @@ -210,8 +215,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1148, - serialized_end=1198, + serialized_start=3037, + serialized_end=3087, ) _sym_db.RegisterEnumDescriptor(_TENSORTYPE) @@ -261,8 +266,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1200, - serialized_end=1294, + serialized_start=3089, + serialized_end=3183, ) _sym_db.RegisterEnumDescriptor(_DATATYPE) @@ -292,8 +297,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1296, - serialized_end=1339, + serialized_start=3185, + serialized_end=3228, ) _sym_db.RegisterEnumDescriptor(_MODALITY) @@ -323,8 +328,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1341, - serialized_end=1397, + serialized_start=3230, + serialized_end=3286, ) _sym_db.RegisterEnumDescriptor(_REQUESTTYPE) @@ -353,6 +358,7 @@ SenderUnknown = 21 UnknownException = 22 Unauthenticated = 23 +BadEndpoint = 24 MSGPACK = 0 CMPPACK = 1 TORCH = 0 @@ -373,6 +379,46 @@ BACKWARD = 2 +_SYNAPSE_SYNAPSETYPE = _descriptor.EnumDescriptor( + name='SynapseType', + full_name='Synapse.SynapseType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='NULL_SYNAPSE', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TEXT_LAST_HIDDEN_STATE', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TEXT_CAUSAL_LM', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TEXT_SEQ_2_SEQ', index=3, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='TEXT_CAUSAL_LM_NEXT', index=4, number=4, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=2079, + serialized_end=2203, +) +_sym_db.RegisterEnumDescriptor(_SYNAPSE_SYNAPSETYPE) + _NEURON = _descriptor.Descriptor( name='Neuron', @@ -505,6 +551,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='synapses', full_name='TensorMessage.synapses', index=6, + number=9, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ], extensions=[ ], @@ -518,7 +571,444 @@ oneofs=[ ], serialized_start=183, - serialized_end=331, + serialized_end=359, +) + + +_SYNAPSE_TEXTLASTHIDDENSTATE = _descriptor.Descriptor( + name='TextLastHiddenState', + full_name='Synapse.TextLastHiddenState', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='synapse_type', full_name='Synapse.TextLastHiddenState.synapse_type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_request_serializer_type', full_name='Synapse.TextLastHiddenState.forward_request_serializer_type', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_response_serializer_type', full_name='Synapse.TextLastHiddenState.forward_response_serializer_type', index=2, + number=3, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_request_serializer_type', full_name='Synapse.TextLastHiddenState.backward_request_serializer_type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_response_serializer_type', full_name='Synapse.TextLastHiddenState.backward_response_serializer_type', index=4, + number=5, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='requires_grad', full_name='Synapse.TextLastHiddenState.requires_grad', index=5, + number=6, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=534, + serialized_end=842, +) + +_SYNAPSE_TEXTCAUSALLM = _descriptor.Descriptor( + name='TextCausalLM', + full_name='Synapse.TextCausalLM', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='synapse_type', full_name='Synapse.TextCausalLM.synapse_type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='topk', full_name='Synapse.TextCausalLM.topk', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_request_serializer_type', full_name='Synapse.TextCausalLM.forward_request_serializer_type', index=2, + number=3, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_response_serializer_type', full_name='Synapse.TextCausalLM.forward_response_serializer_type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_request_serializer_type', full_name='Synapse.TextCausalLM.backward_request_serializer_type', index=4, + number=5, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_response_serializer_type', full_name='Synapse.TextCausalLM.backward_response_serializer_type', index=5, + number=6, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='requires_grad', full_name='Synapse.TextCausalLM.requires_grad', index=6, + number=7, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=845, + serialized_end=1160, +) + +_SYNAPSE_TEXTSEQ2SEQ = _descriptor.Descriptor( + name='TextSeq2Seq', + full_name='Synapse.TextSeq2Seq', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='synapse_type', full_name='Synapse.TextSeq2Seq.synapse_type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='topk', full_name='Synapse.TextSeq2Seq.topk', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_to_generate', full_name='Synapse.TextSeq2Seq.num_to_generate', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_request_serializer_type', full_name='Synapse.TextSeq2Seq.forward_request_serializer_type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_response_serializer_type', full_name='Synapse.TextSeq2Seq.forward_response_serializer_type', index=4, + number=5, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_request_serializer_type', full_name='Synapse.TextSeq2Seq.backward_request_serializer_type', index=5, + number=6, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_response_serializer_type', full_name='Synapse.TextSeq2Seq.backward_response_serializer_type', index=6, + number=7, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_beams', full_name='Synapse.TextSeq2Seq.num_beams', index=7, + number=8, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='no_repeat_ngram_size', full_name='Synapse.TextSeq2Seq.no_repeat_ngram_size', index=8, + number=9, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='early_stopping', full_name='Synapse.TextSeq2Seq.early_stopping', index=9, + number=10, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_return_sequences', full_name='Synapse.TextSeq2Seq.num_return_sequences', index=10, + number=11, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='do_sample', full_name='Synapse.TextSeq2Seq.do_sample', index=11, + number=12, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='top_p', full_name='Synapse.TextSeq2Seq.top_p', index=12, + number=13, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='requires_grad', full_name='Synapse.TextSeq2Seq.requires_grad', index=13, + number=14, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='temperature', full_name='Synapse.TextSeq2Seq.temperature', index=14, + number=15, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='repetition_penalty', full_name='Synapse.TextSeq2Seq.repetition_penalty', index=15, + number=16, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='length_penalty', full_name='Synapse.TextSeq2Seq.length_penalty', index=16, + number=17, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='max_time', full_name='Synapse.TextSeq2Seq.max_time', index=17, + number=18, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_beam_groups', full_name='Synapse.TextSeq2Seq.num_beam_groups', index=18, + number=19, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1163, + serialized_end=1755, +) + +_SYNAPSE_TEXTCAUSALLMNEXT = _descriptor.Descriptor( + name='TextCausalLMNext', + full_name='Synapse.TextCausalLMNext', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='synapse_type', full_name='Synapse.TextCausalLMNext.synapse_type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='topk', full_name='Synapse.TextCausalLMNext.topk', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_request_serializer_type', full_name='Synapse.TextCausalLMNext.forward_request_serializer_type', index=2, + number=3, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='forward_response_serializer_type', full_name='Synapse.TextCausalLMNext.forward_response_serializer_type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_request_serializer_type', full_name='Synapse.TextCausalLMNext.backward_request_serializer_type', index=4, + number=5, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='backward_response_serializer_type', full_name='Synapse.TextCausalLMNext.backward_response_serializer_type', index=5, + number=6, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='requires_grad', full_name='Synapse.TextCausalLMNext.requires_grad', index=6, + number=7, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1758, + serialized_end=2077, +) + +_SYNAPSE = _descriptor.Descriptor( + name='Synapse', + full_name='Synapse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='tensor_pos', full_name='Synapse.tensor_pos', index=0, + number=1, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='synapse_data', full_name='Synapse.synapse_data', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='synapse_type', full_name='Synapse.synapse_type', index=2, + number=3, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='return_code', full_name='Synapse.return_code', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='message', full_name='Synapse.message', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='requires_grad', full_name='Synapse.requires_grad', index=5, + number=6, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[_SYNAPSE_TEXTLASTHIDDENSTATE, _SYNAPSE_TEXTCAUSALLM, _SYNAPSE_TEXTSEQ2SEQ, _SYNAPSE_TEXTCAUSALLMNEXT, ], + enum_types=[ + _SYNAPSE_SYNAPSETYPE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=362, + serialized_end=2203, ) @@ -598,19 +1088,48 @@ extension_ranges=[], oneofs=[ ], - serialized_start=334, - serialized_end=535, + serialized_start=2206, + serialized_end=2407, ) _NEURON.fields_by_name['modality'].enum_type = _MODALITY _TENSORMESSAGE.fields_by_name['tensors'].message_type = _TENSOR _TENSORMESSAGE.fields_by_name['return_code'].enum_type = _RETURNCODE +_TENSORMESSAGE.fields_by_name['synapses'].message_type = _SYNAPSE +_SYNAPSE_TEXTLASTHIDDENSTATE.fields_by_name['synapse_type'].enum_type = _SYNAPSE_SYNAPSETYPE +_SYNAPSE_TEXTLASTHIDDENSTATE.fields_by_name['forward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTLASTHIDDENSTATE.fields_by_name['forward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTLASTHIDDENSTATE.fields_by_name['backward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTLASTHIDDENSTATE.fields_by_name['backward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTLASTHIDDENSTATE.containing_type = _SYNAPSE +_SYNAPSE_TEXTCAUSALLM.fields_by_name['synapse_type'].enum_type = _SYNAPSE_SYNAPSETYPE +_SYNAPSE_TEXTCAUSALLM.fields_by_name['forward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLM.fields_by_name['forward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLM.fields_by_name['backward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLM.fields_by_name['backward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLM.containing_type = _SYNAPSE +_SYNAPSE_TEXTSEQ2SEQ.fields_by_name['synapse_type'].enum_type = _SYNAPSE_SYNAPSETYPE +_SYNAPSE_TEXTSEQ2SEQ.fields_by_name['forward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTSEQ2SEQ.fields_by_name['forward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTSEQ2SEQ.fields_by_name['backward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTSEQ2SEQ.fields_by_name['backward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTSEQ2SEQ.containing_type = _SYNAPSE +_SYNAPSE_TEXTCAUSALLMNEXT.fields_by_name['synapse_type'].enum_type = _SYNAPSE_SYNAPSETYPE +_SYNAPSE_TEXTCAUSALLMNEXT.fields_by_name['forward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLMNEXT.fields_by_name['forward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLMNEXT.fields_by_name['backward_request_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLMNEXT.fields_by_name['backward_response_serializer_type'].enum_type = _SERIALIZER +_SYNAPSE_TEXTCAUSALLMNEXT.containing_type = _SYNAPSE +_SYNAPSE.fields_by_name['synapse_type'].enum_type = _SYNAPSE_SYNAPSETYPE +_SYNAPSE.fields_by_name['return_code'].enum_type = _RETURNCODE +_SYNAPSE_SYNAPSETYPE.containing_type = _SYNAPSE _TENSOR.fields_by_name['serializer'].enum_type = _SERIALIZER _TENSOR.fields_by_name['tensor_type'].enum_type = _TENSORTYPE _TENSOR.fields_by_name['dtype'].enum_type = _DATATYPE _TENSOR.fields_by_name['modality'].enum_type = _MODALITY DESCRIPTOR.message_types_by_name['Neuron'] = _NEURON DESCRIPTOR.message_types_by_name['TensorMessage'] = _TENSORMESSAGE +DESCRIPTOR.message_types_by_name['Synapse'] = _SYNAPSE DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR DESCRIPTOR.enum_types_by_name['ReturnCode'] = _RETURNCODE DESCRIPTOR.enum_types_by_name['Serializer'] = _SERIALIZER @@ -634,6 +1153,45 @@ }) _sym_db.RegisterMessage(TensorMessage) +Synapse = _reflection.GeneratedProtocolMessageType('Synapse', (_message.Message,), { + + 'TextLastHiddenState' : _reflection.GeneratedProtocolMessageType('TextLastHiddenState', (_message.Message,), { + 'DESCRIPTOR' : _SYNAPSE_TEXTLASTHIDDENSTATE, + '__module__' : 'bittensor._proto.bittensor_pb2' + # @@protoc_insertion_point(class_scope:Synapse.TextLastHiddenState) + }) + , + + 'TextCausalLM' : _reflection.GeneratedProtocolMessageType('TextCausalLM', (_message.Message,), { + 'DESCRIPTOR' : _SYNAPSE_TEXTCAUSALLM, + '__module__' : 'bittensor._proto.bittensor_pb2' + # @@protoc_insertion_point(class_scope:Synapse.TextCausalLM) + }) + , + + 'TextSeq2Seq' : _reflection.GeneratedProtocolMessageType('TextSeq2Seq', (_message.Message,), { + 'DESCRIPTOR' : _SYNAPSE_TEXTSEQ2SEQ, + '__module__' : 'bittensor._proto.bittensor_pb2' + # @@protoc_insertion_point(class_scope:Synapse.TextSeq2Seq) + }) + , + + 'TextCausalLMNext' : _reflection.GeneratedProtocolMessageType('TextCausalLMNext', (_message.Message,), { + 'DESCRIPTOR' : _SYNAPSE_TEXTCAUSALLMNEXT, + '__module__' : 'bittensor._proto.bittensor_pb2' + # @@protoc_insertion_point(class_scope:Synapse.TextCausalLMNext) + }) + , + 'DESCRIPTOR' : _SYNAPSE, + '__module__' : 'bittensor._proto.bittensor_pb2' + # @@protoc_insertion_point(class_scope:Synapse) + }) +_sym_db.RegisterMessage(Synapse) +_sym_db.RegisterMessage(Synapse.TextLastHiddenState) +_sym_db.RegisterMessage(Synapse.TextCausalLM) +_sym_db.RegisterMessage(Synapse.TextSeq2Seq) +_sym_db.RegisterMessage(Synapse.TextCausalLMNext) + Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { 'DESCRIPTOR' : _TENSOR, '__module__' : 'bittensor._proto.bittensor_pb2' @@ -650,8 +1208,8 @@ index=0, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1399, - serialized_end=1501, + serialized_start=3288, + serialized_end=3390, methods=[ _descriptor.MethodDescriptor( name='Forward', diff --git a/bittensor/_receptor/receptor_impl.py b/bittensor/_receptor/receptor_impl.py index f2f19c0b61..bfb72756e3 100644 --- a/bittensor/_receptor/receptor_impl.py +++ b/bittensor/_receptor/receptor_impl.py @@ -16,87 +16,28 @@ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import traceback -import sys -import time as clock -from types import SimpleNamespace -from typing import Tuple -import threading +import bittensor +from bittensor._synapse import synapse +import bittensor.utils.stats as stat_utils import torch +import threading import uuid -import time +import sys import torch.nn as nn import grpc +import time as clock + +from types import SimpleNamespace +from typing import Tuple, List, Union from loguru import logger from grpc import _common -import bittensor -import bittensor.utils.stats as stat_utils - -logger = logger.opt(colors=True) - -# dummy tensor that triggers autograd in a RemoteExpert -DUMMY = torch.empty(0, requires_grad=True) - -# Helper function for filling nill (zero) responses on failures. -def nill_response_for(inputs): - """ Empty response - """ - if torch.numel(inputs) == 0: - return torch.tensor([]) - return torch.zeros( (inputs.size(0), inputs.size(1), bittensor.__network_dim__), dtype=torch.float32) - -class Request(): - """ Contains all of the inputs, intermediate, and output state of a forward/backward request. - """ - def __init__( - self, - inputs, - modality, - grads_dy = None, - backward = False - ): - r""" Initialize a forward/backward request. - - Args: - inputs (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): - List of tensors to send to corresponsing endpoints. Tensors are of arbitrary type and shape depending on the - modality. - - grads_dy (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`): - List of grad tensors to send to corresponsing inputs. Only needed when it is a backward request. - - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] - backward (:type:`Bool`); - True if it is a backward request. False when it is a forward request instead. - """ - # ---- Inputs ---- - self.inputs = inputs - self.grads_dy = grads_dy - self.zeros = nill_response_for(inputs) - - # ---- Setups ---- - self.modality = modality - self.backward = backward - self.start_time = clock.time() - self.end_time = None - - # ---- Intermediate states ---- - self.serialized_inputs = None - self.grpc_request = None - self.future = None - - # ---- Outputs ---- - self.code = None - self.message = None - self.outputs = None class Receptor(nn.Module): - """ Encapsulates a grpc connection to an axon endpoint as a standard auto-grad torch.nn.Module. - """ def __init__( self, @@ -123,8 +64,6 @@ def __init__( self.endpoint = endpoint # Endpoint information. self.channel = channel self.stub = stub - self.backoff = 0 # Number o queries to backoff. - self.next_backoff = 1 # Next backoff level. self.receptor_uid = str(uuid.uuid1()) self.semaphore = threading.Semaphore(max_processes) self.state_dict = _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY @@ -160,581 +99,534 @@ def __init__( bittensor.proto.ReturnCode.ResponseIncompatibleVersion: 0, bittensor.proto.ReturnCode.SenderUnknown: 0, bittensor.proto.ReturnCode.UnknownException: 0, + bittensor.proto.ReturnCode.Unauthenticated: 0, + bittensor.proto.ReturnCode.BadEndpoint: 0, } ) - def __str__(self): + def __str__ ( self ): return "Receptor({})".format(self.endpoint) - def __repr__(self): + def __repr__ ( self ): return self.__str__() - def __del__(self): + def __del__ ( self ): try: result = self.channel._channel.check_connectivity_state(True) if self.state_dict[result] != self.state_dict[result].SHUTDOWN: self.channel.close() except: pass - - def __exit__(self): + + def __exit__ ( self ): self.__del__() - def forward ( - self, - inputs: torch.Tensor, - modality: bittensor.proto.Modality, - timeout: int, - ) -> Tuple[torch.Tensor, int]: - r""" Torch.nn.Module forward call: Triggers the grpc call to the remote endpoint. - Call returns the output tensor and a bittensor.proto.ReturnCode. - - Args: - inputs (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): - Single torch tensor to be sent to the remote endpoint. - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] - timeout (:obj:`int`, `required`) - Returns: - output (:obj:`Tuple[torch.FloatTensor, torch.LongTensor]`, `required`): - Result tuple from the forward call. - code (:obj:`bittensor.proto.ReturnCode`, `required`): - Return code associated with forward call. - time (:obj:`float`, `required`): - Time of call. - + def sign ( self ): + r""" Uses the wallet pubkey to sign a message containing the pubkey and the time """ - request = self.preprocess_request ( inputs = inputs, modality = modality) - request = self.make_request_call(request, timeout = timeout) - return self.handle_request_response(request) - - def backward( - self, - inputs_x: torch.Tensor, - grads_dy: torch.Tensor, - modality: bittensor.proto.Modality, - timeout: int - ) -> Tuple[ torch.Tensor, int, float, str ]: - r""" Backward call: Triggers the grpc Backward call to the associated endpoint. - - Args: - inputs_x (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): - inputs from previous forward call. + nounce = self.nounce() + message = str(nounce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid) + spliter = 'bitxx' + signature = spliter.join([ str(nounce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ]) + return signature - grads_dy (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): - gradient outputs. - - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] - - timeout (int): - request timeout. - - Returns: - output (:obj:`Tuple[torch.FloatTensor, torch.LongTensor]`, `required`): - Result tuple from the forward call. - - code (:obj:`bittensor.proto.ReturnCode`, `required`): - Return code associated with backward call. - - time (:obj:`float`, `required`): - Time of call. - """ - request = self.preprocess_request (inputs = inputs_x, modality = modality, grads_dy = grads_dy, backward = True) - request = self.make_request_call(request, timeout = timeout) - return self.handle_request_response(request) - - - def prerequisite_check(self, request): - r""" Check the input size and endpoint validity. - - Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - - Returns: - success: (:type:`bool`, `required`): - True if the check has passed. - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. + def nounce ( self ): + r"""creates a string representation of the time """ - # ---- Check inputs size ---- - if torch.numel(request.inputs) == 0 or ( request.backward and torch.numel( request.grads_dy ) == 0): - request.code = bittensor.proto.ReturnCode.EmptyRequest - request.message = 'Empty request.' - self.request_log(request = request, is_response = False, inputs = list(request.inputs.shape)) - return False, request - - # ---- Check endpoint---- - if self.endpoint.hotkey == 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX': - request.code = bittensor.proto.ReturnCode.EmptyRequest - request.message = 'Bad endpoint.' - self.request_log(request = request, is_response = False, inputs = list(request.inputs.shape)) - return False, request + nounce = int(clock.time() * 1000) + return nounce - return True, request - - def serialization(self, request): - r""" Does the serialization to the request inputs and grads(backward request only). - The result would update request.serialized_inputs and request.serialized_grad. - - Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - - Returns: - success: (:type:`bool`, `required`): - True if the serialization is successful. - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - """ - try: - serializer = bittensor.serializer( bittensor.proto.Serializer.MSGPACK ) - request.serialized_inputs = serializer.serialize(request.inputs, modality = request.modality, from_type = bittensor.proto.TensorType.TORCH) - - if request.backward: - request.serialized_grads = serializer.serialize (request.grads_dy, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH ) + def state ( self ): + try: + return self.state_dict[self.channel._channel.check_connectivity_state(True)] + except ValueError: + return "Channel closed" - except Exception as e: - request.code = bittensor.proto.ReturnCode.RequestSerializationException - request.message = 'Input serialization exception with error:{}'.format(str(e)) - self.request_log(request = request, is_response = False, inputs = list(request.inputs.shape)) - return False, request - - return True, request + def close ( self ): + self.__exit__() - def build_grpc_request(self, request): - r"""Build the grapc call with the serialized_inputs and serialized grad(backward request only). - The result would update request.grpc_request. + def backward ( + self, + synapses: List[ 'bittensor.Synapse' ], + inputs: torch.Tensor, + grads: List[torch.Tensor], + timeout: int + ) -> Tuple[ List[ torch.FloatTensor ], List['bittensor.proto.ReturnCode'], List[float] ]: + r""" Triggers the grpc backward call to the remote endpoint. + This triggers the synapse's backward calls with arguments. + Call returns a list of output gradient tensors one per synapse with corresponding time and bittensor.proto.ReturnCode. Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. + + inputs (:obj:`torch.Tensor` of shape :obj:`(shape)`, `required`): + Single torch tensor input corresponding to the linked forward call. + TODO(const): Make this multi-forward tensor. + + grads (:obj:`List[torch.FloatTensor]` of shape :obj:`num_synapses * (shape_of_synapse_output_i)`, `required`): + List of torch tensor gradients associated with each synapse. + + timeout (:obj:`int`, `required`): + Request max timeout Returns: - success: (:type:`bool`, `required`): - True if the build is successful. - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. + output (:obj:`torch.FloatTensor`, `required`): + Result tensors (likely zero) from the backward call each corresponding to a single forward input. + NOTE(const) Always zeros because responses are not waited. + TODO(const): Make this multi-forward tensor. + + codes (:obj:`bittensor.proto.ReturnCode`, `required`): + List of return codes associated with each passed synapse enum. + Connection failures return all the same code, otherwise a unique code per synapse. + + times (:obj:`float`, `required`): + List of times for each call associated with each passed synapse enum. + Success responses all get the same time. """ - try: - if not request.backward: - request.grpc_request = bittensor.proto.TensorMessage ( - version = bittensor.__version_as_int__, - hotkey = self.wallet.hotkey.ss58_address, - tensors = [request.serialized_inputs], - requires_grad = True, - ) - else: - request.grpc_request = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = self.wallet.hotkey.ss58_address, - tensors = [request.serialized_inputs, request.serialized_grads], - requires_grad = True, + # ===================== + # ==== Init params ==== + # ===================== + # These items are filled through the call and the function returns + # when all codes are non-success or the function finishes completely. + synapse_messages = [ "Success" for _ in synapses ] + synapse_codes = [ bittensor.proto.ReturnCode.Success for _ in synapses ] + synapse_responses = [ synapse.nill_backward_response_tensor ( inputs ) for synapse in synapses ] + synapse_is_response = [ False for _ in synapses ] + synapse_call_times = [ 0 for _ in synapses ] + start_time = clock.time() + + # ================================================================== + # ==== Function which returns true if all codes are non success ==== + # ================================================================== + def check_if_should_return() -> bool: + for code in synapse_codes: + if code == bittensor.proto.ReturnCode.Success: + return False + return True + + # ============================================================== + # ==== Function which prints all log statements per synapse ==== + # ============================================================== + def finalize_stats_and_logs(): + for index, synapse in enumerate( synapses ): + self.stats.codes[ synapse_codes[ index ] ] += 1 + bittensor.logging.rpc_log ( + axon = False, + forward = False, + is_response = synapse_is_response [index], + code = synapse_codes[ index ], + call_time = synapse_call_times[ index ], + pubkey = self.endpoint.hotkey, + uid = self.endpoint.uid, + inputs = list(grads[index].shape), + outputs = None, + message = synapse_messages[ index ], + synapse = synapse.synapse_type ) - except Exception as e: - request.code = bittensor.proto.ReturnCode.UnknownException - request.message = str(e) - self.request_log(request = request, is_response = False, inputs = list(request.serialized_inputs.shape)) - return False, request - return True, request - - def collect_future(self, request): - r"""Get the result of the grpc request. - The result would update request.response. - - Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. + # ======================== + # ==== Check endpoint ==== + # ======================== + if self.endpoint.hotkey == 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX': + # Endpoint is dummy. + code = bittensor.proto.ReturnCode.BadEndpoint + call_time = clock.time() - start_time + message = "Bad endpoint." + synapse_call_times = [ call_time for _ in synapses ] + synapse_codes = [ code for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + # ================================== + # ==== Serialize inputs & grads ==== + # ================================== + serialized_forward_tensors = [] + serialized_backward_grads = [] + serialized_synapses = [] + for index, synapse in enumerate( synapses ): + try: + serialized_forward_tensors.append(synapse.serialize_forward_request_tensor( inputs )) + serialized_backward_grads.append(synapse.serialize_backward_request_gradient (inputs, grads[index] )) + serialized_synapses.append(synapse.serialize_to_wire_proto()) + except Exception as e: + # Input Serialization failed. + synapse_codes [index] = bittensor.proto.ReturnCode.RequestSerializationException + synapse_call_times [index] = clock.time() - start_time + synapse_messages [index] = 'Input serialization exception with error:{}'.format(str(e)) + # Check if the call can stop here. + if check_if_should_return(): + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + + # ============================= + # ==== Build proto request ==== + # ============================= + try: + grpc_request = bittensor.proto.TensorMessage ( + version = bittensor.__version_as_int__, + hotkey = self.wallet.hotkey.ss58_address, + tensors = serialized_forward_tensors + serialized_backward_grads, + synapses = serialized_synapses, + requires_grad = True, + ) - Returns: - success: (:type:`bool`, `required`): - True if getting the result is successful. - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - """ + except Exception as e: + # Synapse request creation failed. + code = bittensor.proto.ReturnCode.UnknownException + call_time = clock.time() - start_time + message = 'Request proto creation failed with error:{}'.format(str(e)) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + + # ======================= + # ==== Make RPC Call ==== + # ======================= try: - request.response = request.future.result() - self.stats.forward_bytes_in.update(sys.getsizeof(request.response)) - self.stats.forward_elapsed_time.update((clock.time()-request.start_time)) - - # ---- Catch GRPC Errors ---- + self.stats.backward_qps.update(1) + self.stats.backward_bytes_out.update(sys.getsizeof(grpc_request)) + # Fire and forget. + self.stub.Backward( + request = grpc_request, + timeout = timeout, + metadata = ( + ('rpc-auth-header','Bittensor'), + ('bittensor-signature',self.sign()), + ('bittensor-version',str(bittensor.__version_as_int__)), + ('request_type', str(bittensor.proto.RequestType.FORWARD)), + )) + + # ==================================== + # ==== Handle GRPC Errors ==== + # ==================================== except grpc.RpcError as rpc_error_call: - request.code, request.message = self.rpc_exception_handler(request, rpc_error_call) - return False, request - - # ---- Catch Unknown Errors ---- + # Request failed with GRPC code. + call_time = clock.time() - start_time + grpc_code = rpc_error_call.code() + if grpc_code == grpc.StatusCode.DEADLINE_EXCEEDED: + code = bittensor.proto.ReturnCode.Timeout + message = 'grpc.StatusCode.DEADLINE_EXCEEDED'+': '+ rpc_error_call.details() + elif grpc_code == grpc.StatusCode.UNAVAILABLE: + code = bittensor.proto.ReturnCode.Unavailable + message = 'grpc.StatusCode.UNAVAILABLE'+': '+ rpc_error_call.details() + elif grpc_code == grpc.StatusCode.UNAUTHENTICATED: + code = bittensor.proto.ReturnCode.Unauthenticated + message = 'grpc.StatusCode.UNAUTHENTICATED'+': '+ rpc_error_call.details() + else: + code = bittensor.proto.ReturnCode.UnknownException + message = 'GRPC error code: {}, details: {}'.format( grpc_code, str(rpc_error_call.details()) ) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + # ==================================== + # ==== Handle GRPC Unknown Errors ==== + # ==================================== except Exception as e: - request.code = bittensor.proto.ReturnCode.UnknownException - request.message = str(e) - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request - - return True, request + # Request failed with unknown exception. + code = bittensor.proto.ReturnCode.UnknownException + call_time = clock.time() - start_time + message = 'GRPC request failed with unknown exception:{}'.format(str(e)) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + + + # ====================================== + # ==== Finalize backward call times ==== + # ====================================== + for index, _ in enumerate( synapses ): + if synapse_codes[index] == bittensor.proto.ReturnCode.Success: + synapse_call_times[index] = clock.time() - start_time + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times - def check_response(self, request): - r"""Check the response. - This function should not update any part of request. + def forward ( + self, + synapses: List[ 'bittensor.Synapse' ], + inputs: torch.Tensor, + timeout: int, + ) -> Tuple[ List[ torch.FloatTensor ], List['bittensor.proto.ReturnCode'], List[float] ]: + r""" Triggers the grpc call to the remote endpoint. + This triggers the synapse calls with arguments. + Call returns a list of output tensors one per synapse with corresponding time and bittensor.proto.ReturnCode. Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. - Returns: - success: (:type:`bool`, `required`): - True if the check is successful. - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - """ - # ---- Get response message ---- - try: - request.message = request.response.message - except Exception: - request.message = '' - - # ---- Catch non-code ---- - request.code = request.response.return_code - - if request.code == bittensor.proto.ReturnCode.NoReturn: - request.message = 'No return code.' - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request - - # ---- Catch bittensor errors ---- - if request.code == bittensor.proto.ReturnCode.UnknownException: - request.message = 'Return code unknown exception.' - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request - - elif request.code != bittensor.proto.ReturnCode.Success: - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request - - # ---- Check for empty length ---- - if len(request.response.tensors) == 0: - request.code = bittensor.proto.ReturnCode.EmptyResponse - request.message = 'No tensors in response.' - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request - - return True, request - - def deserialize_forward_response(self, request): - r"""Deserialization for the forward request. - The result would update request.output. - - Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. + inputs (:obj:`torch.Tensor` of shape :obj:`(shape)`, `required`): + Single torch tensor to be sent to the remote endpoint. + TODO(const): Make this a multi-forward tensor. + timeout (:obj:`int`, `required`): + Request max timeout Returns: - success: (:type:`bool`, `required`): - True if the deserialization is successful. - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - """ + outputs (:obj:`List[ Union[torch.FloatTensor, torch.LongTensor] ]`, `required`): + outputs.shape = [batch_size, synapse_length, response] + List of result tensors from the forward call each corresponding to a passed synapse enum. - # ---- Deserialize response ---- - try: - outputs = request.response.tensors[0] - deserializer = bittensor.serializer( outputs.serializer ) - outputs = deserializer.deserialize( outputs, to_type = bittensor.proto.TensorType.TORCH ) - - except Exception as e: - request.code = bittensor.proto.ReturnCode.ResponseDeserializationException - request.message = 'Deserialziation exception with error:{}'.format(str(e)) - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request - - # ---- Check response shape ---- - if ( - outputs.size(0) != request.inputs.size(0) or - outputs.size(1) != request.inputs.size(1) or - outputs.size(2) != bittensor.__network_dim__ - ): - request.code = bittensor.proto.ReturnCode.ResponseShapeException - request.message = "output.shape:{} does not match inputs:{}".format(outputs.shape, request.inputs.shape) - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape), outputs = list(outputs.shape)) - return False, request - - # ---- Safe catch NaNs and replace with 0.0 ---- - request.outputs = torch.where(torch.isnan(outputs), torch.zeros_like(outputs), outputs).detach() - - # ---- Return ---- - request.code = request.response.return_code - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape), outputs = list(outputs.shape)) - self.stats.codes[request.code] += 1 - - return True, request - - def deserialize_backward_response(self, request): - r"""Deserialization for the backward request. - The result would update request.output. + codes (:obj:`bittensor.proto.ReturnCode`, `required`): + List of return codes associated with each passed synapse enum. + Connection failures return all the same code, otherwise a unique code per synapse. - Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. + times (:obj:`float`, `required`): + List of times for each call associated with each passed synapse enum. + Success responses all get the same time. - Returns: - success: (:type:`bool`, `required`): - True if the deserialization is successful. - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. """ - # ---- Post-process request ---- - try: - outputs = request.response.tensors[0] - deserializer = bittensor.serializer( outputs.serializer ) - outputs = deserializer.deserialize( outputs, to_type = bittensor.proto.TensorType.TORCH ) - except Exception as e: - request.code = bittensor.proto.ReturnCode.ResponseDeserializationException - request.message = 'deserialization exception with error:{}'.format(e) - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request + # ===================== + # ==== Init params ==== + # ===================== + # These items are filled through the call and the function returns + # when all codes are non-success or the function finishes completely. + synapse_messages = [ "Success" for _ in synapses ] + synapse_codes = [ bittensor.proto.ReturnCode.Success for _ in synapses ] + synapse_responses = [ synapse.nill_forward_response_tensor( inputs ) for synapse in synapses ] + synapse_is_response = [ False for _ in synapses ] + synapse_call_times = [ 0 for _ in synapses ] + start_time = clock.time() + + # ================================================================== + # ==== Function which returns true if all codes are non success ==== + # ================================================================== + def check_if_should_return() -> bool: + for code in synapse_codes: + if code == bittensor.proto.ReturnCode.Success: + return False + return True + + # ============================================================== + # ==== Function which prints all log statements per synapse ==== + # ============================================================== + def finalize_stats_and_logs(): + self.stats.forward_elapsed_time.update( clock.time() - start_time ) + for index, synapse in enumerate( synapses ): + self.stats.codes[ synapse_codes[ index ] ] += 1 + bittensor.logging.rpc_log ( + axon = False, + forward = True, + is_response = synapse_is_response [index], + code = synapse_codes[ index ], + call_time = synapse_call_times[ index ], + pubkey = self.endpoint.hotkey, + uid = self.endpoint.uid, + inputs = list(inputs.shape), + outputs = None if synapse_codes[ index ] != bittensor.proto.ReturnCode.Success else list( synapse_responses[index].shape ), + message = synapse_messages[ index ], + synapse = synapse.synapse_type + ) - try: - # ---- Check response shape is same as inputs ---- - if outputs.size() != request.inputs.size(): - request.code = bittensor.proto.ReturnCode.ResponseShapeException - request.message = 'output shape does not match inputs shape' - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request + # =========================== + # ==== Check inputs size ==== + # =========================== + if torch.numel(inputs) == 0: + # Inputs are nill. + code = bittensor.proto.ReturnCode.EmptyRequest + call_time = clock.time() - start_time + message = "Empty Request" + synapse_codes = [ code for _ in synapses ] + synapse_call_times = [ call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times - except Exception as e: - request.code = bittensor.proto.ReturnCode.UnknownException - request.message = 'Size Error: {}'.format(e) - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return False, request - - # ---- Safe catch NaNs and replace with 0.0 ---- - request.outputs = torch.where(torch.isnan(outputs), torch.zeros_like(outputs), outputs).detach() - - # ---- Return ---- - request.code = bittensor.proto.ReturnCode.Success - request.message = 'Success' - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - self.stats.codes[request.code] += 1 - return False, request - - def request_log(self, request, is_response = False, inputs = None, outputs = None): - r""" rpc logging for forward/backward request - Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - - is_response (:type: `bool`): - True if we are logging a response from the grpc call, false if it is a request instead - - inputs (:type: `List`): - shape of the tensor input that was being handled - """ - - call_time = clock.time() - request.start_time - if bittensor.logging.__debug_on__: - bittensor.logging.rpc_log( - axon=False, - forward= not request.backward, - is_response=is_response, - code=request.code, - call_time=call_time, - pubkey=self.endpoint.hotkey, - uid = self.endpoint.uid, - inputs=inputs, - outputs=outputs, - message=request.message - ) - - def preprocess_request ( - self, - inputs: torch.Tensor, - modality: bittensor.proto.Modality, - grads_dy: torch.FloatTensor = None, - backward: str = False - ): - r""" Does all the checking and preprocessing to build the grpc request. + # ======================== + # ==== Check endpoint ==== + # ======================== + if self.endpoint.hotkey == 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX': + # Endpoint is dummy. + code = bittensor.proto.ReturnCode.BadEndpoint + call_time = clock.time() - start_time + message = "Bad endpoint." + synapse_call_times = [ call_time for _ in synapses ] + synapse_codes = [ code for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + # ========================== + # ==== Serialize inputs ==== + # ========================== + serialized_forward_tensors = [] + serialized_synapses = [] + for index, synapse in enumerate( synapses ): + try: + serialized_forward_tensors.append( synapse.serialize_forward_request_tensor ( inputs )) + serialized_synapses.append(synapse.serialize_to_wire_proto()) + except Exception as e: + synapse_codes [index] = bittensor.proto.ReturnCode.RequestSerializationException + synapse_call_times [index] = clock.time() - start_time + synapse_messages [index] = 'Input serialization exception with error:{}'.format(str(e)) + # Check if the call can stop here. + if check_if_should_return(): + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times - Args: - inputs (:obj:`List[torch.Tensor]` of shape :obj:`(shape)`, `required`): - Torch tensor to be sent to this endpoint. - - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality of type Enum: [TEXT, IMAGE, TENSOR] - - grads_dy (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): - List of grad tensors to send to corresponsing inputs. - - backward (:type:`Bool`, `required`); - If the request is a backward request. - - Returns: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - """ - # ---- Setup forward request namespace, which will hold all the objects regarding the forward request ---- - request = Request(inputs = inputs, modality = modality, grads_dy = grads_dy, backward = backward) - - preprocessing_funs = [self.prerequisite_check, self.serialization, self.build_grpc_request] - - for fun in preprocessing_funs: - check, request = fun(request) - if not check: - return request - - request.code = bittensor.proto.ReturnCode.Success - return request - - def make_request_call(self, request, timeout): - r""" Torch.nn.Module forward call: Triggers the grpc call to the remote endpoint. (calls the Forward method on an Axon terminal.) - The resulted future of forward call was stored in forward_request. - - Args: - timeout (:type:`int`, `required`): - request timeout. - - Returns: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - """ - # ---- Return if the previous statue was not finished. ---- - if (request.grpc_request == None) or (request.code != bittensor.proto.ReturnCode.Success): - return request - - # ---- Make RPC call ---- + # ============================ + # ==== Build proto request ==== + # ============================ + try: + grpc_request = bittensor.proto.TensorMessage ( + version = bittensor.__version_as_int__, + hotkey = self.wallet.hotkey.ss58_address, + tensors = serialized_forward_tensors, + synapses = serialized_synapses, + requires_grad = True, + ) + except Exception as e: + # Synapse request creation failed. + code = bittensor.proto.ReturnCode.UnknownException + call_time = clock.time() - start_time + message = 'Request proto creation failed with error:{}'.format(str(e)) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + # ======================= + # ==== Fire RPC Call ==== + # ======================= + grpc_response = None try: - if not request.backward: - self.sign() - self.stats.forward_qps.update(1) - self.stats.forward_bytes_out.update(sys.getsizeof(request.grpc_request)) - request.future = self.stub.Forward.future(request = request.grpc_request, - timeout = timeout, - metadata = ( - ('rpc-auth-header','Bittensor'), - ('bittensor-signature',self.sign()), - ('bittensor-version',str(bittensor.__version_as_int__)), - ('request_type', str(bittensor.proto.RequestType.FORWARD)), - )) - request.future.add_done_callback(lambda z : self.handle_request_response(request)) - else: - self.stats.backward_qps.update(1) - self.stats.backward_bytes_out.update(sys.getsizeof(request.grpc_request)) - request.future = self.stub.Backward.future(request = request.grpc_request, - timeout = timeout, - metadata = ( - ('rpc-auth-header','Bittensor'), - ('bittensor-signature',self.sign()), - ('bittensor-version',str(bittensor.__version_as_int__)), - ('request_type', str(bittensor.proto.RequestType.BACKWARD)), - )) - - request.code = bittensor.proto.ReturnCode.Success - self.request_log(request = request, is_response = False, inputs = list(request.serialized_inputs.shape)) - return request - - # ---- Catch GRPC Errors ---- + self.stats.forward_qps.update(1) + self.stats.forward_bytes_out.update( sys.getsizeof( grpc_request ) ) + finalize_stats_and_logs() + grpc_response = self.stub.Forward ( + request = grpc_request, + timeout = timeout, + metadata = ( + ('rpc-auth-header','Bittensor'), + ('bittensor-signature',self.sign()), + ('bittensor-version',str(bittensor.__version_as_int__)), + ('request_type', str(bittensor.proto.RequestType.FORWARD)), + )) + self.stats.forward_bytes_in.update( grpc_response.ByteSize() ) + synapse_is_response = [ True for _ in synapses ] + # Set successful response booleans to true + + # ==================================== + # ==== Handle GRPC Errors ==== + # ==================================== except grpc.RpcError as rpc_error_call: - request.code, request.message = self.rpc_exception_handler(request, rpc_error_call) - self.request_log(request = request, is_response = False, inputs = list(request.serialized_inputs.shape)) - return request - - # ---- Catch Unknown Errors ---- + # Request failed with GRPC code. + call_time = clock.time() - start_time + grpc_code = rpc_error_call.code() + if grpc_code == grpc.StatusCode.DEADLINE_EXCEEDED: + code = bittensor.proto.ReturnCode.Timeout + message = 'grpc.StatusCode.DEADLINE_EXCEEDED'+': '+ rpc_error_call.details() + elif grpc_code == grpc.StatusCode.UNAVAILABLE: + code = bittensor.proto.ReturnCode.Unavailable + message = 'grpc.StatusCode.UNAVAILABLE'+': '+ rpc_error_call.details() + elif grpc_code == grpc.StatusCode.UNAUTHENTICATED: + code = bittensor.proto.ReturnCode.Unauthenticated + message = 'grpc.StatusCode.UNAUTHENTICATED'+': '+ rpc_error_call.details() + else: + code = bittensor.proto.ReturnCode.UnknownException + message = 'GRPC error code: {}, details: {}'.format( grpc_code, str(rpc_error_call.details()) ) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + + # ==================================== + # ==== Handle GRPC Unknown Errors ==== + # ==================================== except Exception as e: - request.code = bittensor.proto.ReturnCode.UnknownException - request.message = str(e) - self.request_log(request = request, is_response = False, inputs = list(request.serialized_inputs.shape)) - return request + # Request failed with unknown exception. + code = bittensor.proto.ReturnCode.UnknownException + call_time = clock.time() - start_time + message = 'GRPC request failed with unknown exception:{}'.format(str(e)) + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + + # ========================================== + # ==== Handle Non Success GRPC Response ==== + # ========================================== + if grpc_response.return_code != bittensor.proto.ReturnCode.Success: + # Request failed with unknown exception. + call_time = clock.time() - start_time + synapse_call_times = [call_time for _ in synapses ] + if len(grpc_response.synapses) == len(synapses): + synapse_codes = [synapse.return_code for synapse in grpc_response.synapses ] + synapse_messages = ['Remote Server Failure: '+ synapse.message for synapse in grpc_response.synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + + + # ====================================== + # ==== Check response length ==== + # ====================================== + if ( len(grpc_response.tensors) != len(grpc_response.synapses) ) or ( len(grpc_response.tensors) != len(synapses) ): + # Not enough responses per request. + code = bittensor.proto.ReturnCode.ResponseShapeException + call_time = clock.time() - start_time + message = "Responses dont match synape length" + synapse_codes = [code for _ in synapses ] + synapse_call_times = [call_time for _ in synapses ] + synapse_messages = [ message for _ in synapses ] + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + # ====================================== + # ==== Check for non success response codes ==== + # ====================================== + for index, wire_synapse in enumerate( grpc_response.synapses ): + if wire_synapse.return_code != bittensor.proto.ReturnCode.Success: + synapse_codes[index] = wire_synapse.return_code + synapse_messages[index] = wire_synapse.message + synapse_call_times[index] = clock.time() - start_time + + # Check if the call can stop here. + if check_if_should_return(): + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + + # ====================================== + # ==== Deserialize synapse responses ==== + # ====================================== + for index, response_proto in enumerate(grpc_response.tensors): + try: + synapse = synapses[index] + if synapse_codes[index] == bittensor.proto.ReturnCode.Success: + synapse_responses[index] = synapse.deserialize_forward_response_proto ( inputs, response_proto ) + except Exception as e: + # Input Serialization failed. + synapse_codes[index] = bittensor.proto.ReturnCode.ResponseDeserializationException + synapse_call_times[index] = clock.time() - start_time + synapse_messages[index] = 'Response deserialization exception with error:{}'.format(str(e)) + + + # ====================================== + # ==== Finalize forward call times ==== + # ====================================== + for index, _ in enumerate( synapses ): + if synapse_codes[index] == bittensor.proto.ReturnCode.Success: + synapse_call_times[index] = clock.time() - start_time + finalize_stats_and_logs() + return synapse_responses, synapse_codes, synapse_call_times + - def handle_request_response(self, request): - r""" Handle all the getting result checking, and processing the response. + - Args: - request: (:obj:`Request`, required): - The request object holds all specifications and processing of the request. - Returns: - output (:obj:`Tuple[torch.FloatTensor`, torch.LongTensor]`, `optional`): - Result from forward call. May be None in the case of failure. - - code (:obj:`bittensor.proto.ReturnCode`, `required`): - Return code associated with forward call. - - time (:type:`float`, `required`): - Length of call in seconds. - - message (:type:`str`, `required`): - message associated with forward call, potentially error, or 'success'. - """ - if request.outputs != None: - if request.end_time == None: - request.end_time = 15 - return request.outputs, request.code, request.end_time - - if (request.code != bittensor.proto.ReturnCode.Success) or (request.future == None): - request.end_time = clock.time() - request.start_time - return request.zeros, request.code, request.end_time - - deserializer = self.deserialize_forward_response if not request.backward else self.deserialize_backward_response - response_handling_funs = [self.collect_future, self.check_response, deserializer] - - for fun in response_handling_funs: - check, request = fun(request) - if not check: - request.end_time = clock.time()-request.start_time - return request.zeros, request.code, request.end_time - - request.end_time = clock.time()-request.start_time - return request.outputs if check else request.zeros, request.code, request.end_time - - def rpc_exception_handler(self, request, rpc_error_call): - r""" Handle the rpc exception call according to grpc status code. - """ - grpc_code = rpc_error_call.code() - - if grpc_code == grpc.StatusCode.DEADLINE_EXCEEDED: - request.code = bittensor.proto.ReturnCode.Timeout - request.message = 'grpc.StatusCode.DEADLINE_EXCEEDED'+': '+ rpc_error_call.details() - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return request.code, request.message - - elif grpc_code == grpc.StatusCode.UNAVAILABLE: - request.code = bittensor.proto.ReturnCode.Unavailable - request.message = 'grpc.StatusCode.UNAVAILABLE'+': '+ rpc_error_call.details() - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return request.code, request.message - - elif grpc_code == grpc.StatusCode.UNAUTHENTICATED: - request.code = bittensor.proto.ReturnCode.Unauthenticated - request.message = 'grpc.StatusCode.UNAUTHENTICATED'+': '+ rpc_error_call.details() - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return request.code, request.message - else: - request.code = bittensor.proto.ReturnCode.UnknownException - request.message = 'GRPC error code: {}, details: {}'.format( grpc_code, str(rpc_error_call.details()) ) - self.request_log(request = request, is_response = True, inputs = list(request.inputs.shape)) - return request.code, request.message - - - def sign(self): - r""" Uses the wallet pubkey to sign a message containing the pubkey and the time - """ - nounce = self.nounce() - message = str(nounce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid) - spliter = 'bitxx' - signature = spliter.join([ str(nounce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ]) - return signature - - def nounce(self): - r"""creates a string representation of the time - """ - nounce = int(time.time() * 1000) - return nounce - def state(self): - try: - return self.state_dict[self.channel._channel.check_connectivity_state(True)] - except ValueError: - return "Channel closed" - def close(self): - self.__exit__() diff --git a/bittensor/_receptor/receptor_pool_impl.py b/bittensor/_receptor/receptor_pool_impl.py index ed5370a023..9a5849909d 100644 --- a/bittensor/_receptor/receptor_pool_impl.py +++ b/bittensor/_receptor/receptor_pool_impl.py @@ -18,7 +18,7 @@ # DEALINGS IN THE SOFTWARE. import math -from typing import Tuple, List +from typing import Tuple, List, Union from threading import Lock import torch @@ -79,96 +79,89 @@ def get_receptors_state(self): """ return {hotkey: v.state() for hotkey, v in self.receptors.items()} - def forward( + def forward ( self, - endpoints: List['bittensor.Endpoint'], - inputs: List[torch.Tensor], - modality: bittensor.proto.Modality, - timeout: int + endpoints: List [ 'bittensor.Endpoint' ], + synapses: List[ 'bittensor.Synapse' ], + inputs: List [ torch.Tensor ], + timeout: int, ) -> Tuple[List[torch.Tensor], List[int], List[float]]: r""" Forward tensor inputs to endpoints. Args: - endpoints (:obj:`List[bittensor.Endpoint]` of shape :obj:`(num_endpoints)`, `required`): - List of remote endpoints which match length of x. Tensors from x are sent forward to these endpoints. + endpoints (:obj:`List[ bittensor.Endpoint ]` of shape :obj:`(num_endpoints)`, `required`): + List of remote endpoints which match length of inputs. Tensors from x are sent forward to these endpoints. + + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. inputs (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): + TODO(const): Allow multiple tensors. List of tensors to send to corresponsing endpoints. Tensors are of arbitrary type and shape depending on the modality. - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] - timeout (int): - request timeout. + Request timeout. Returns: - forward_outputs (:obj:`List[torch.FloatTensor]` of shape :obj:`num_endpoints * (batch_size, sequence_len, bittensor.network_size)]`, `required`): + forward_outputs (:obj:`List[ List[ torch.FloatTensor ]]` of shape :obj:`(num_endpoints * (num_synapses * (shape)))`, `required`): Output encodings of tensors produced by remote endpoints. Non-responses are zeroes of common shape. - forward_codes (:obj:`List[bittensor.proto.ReturnCodes]` of shape :obj:`(num_endpoints)`, `required`): + forward_codes (:obj:`List[ List[bittensor.proto.ReturnCodes] ]` of shape :obj:`(num_endpoints * ( num_synapses ))`, `required`): dendrite backward call return ops. - forward_times (:obj:`List[float]` of shape :obj:`(num_endpoints)`, `required`): + forward_times (:obj:`List[ List [float] ]` of shape :obj:`(num_endpoints * ( num_synapses ))`, `required`): dendrite backward call times """ - if len(endpoints) != len(inputs): raise ValueError('Endpoints must have the same length as passed inputs. Got {} and {}'.format(len(endpoints), len(inputs))) - # ---- Fill calls ---- - call_args = [ - (self._get_or_create_receptor_for_endpoint( endpoint ), inputs, modality) - for (inputs, endpoint) - in list(zip( inputs, endpoints )) - ] - - # ---- Preprocessing for the forward function, get the request. ---- - requests = [] - for arg in call_args: - self.total_requests += 1 - receptor, inputs, modality = arg - requests.append(receptor.preprocess_request ( inputs = inputs, modality = modality )) - - # ---- Send the forward request to peers. ---- - request_futures = [] - for arg, request in zip(call_args, requests): - receptor = arg[0] - request_futures.append(receptor.make_request_call(request = request, timeout = timeout)) - - # ---- Collect the futures. ---- - results = [] - for arg, request in zip(call_args, request_futures): - receptor = arg[0] - results.append(receptor.handle_request_response(request = request)) - receptor.semaphore.release() - - try: - forward_outputs, forward_codes, forward_times = zip(*results) - - except concurrent.futures._base.TimeoutError: - forward_outputs= [torch.zeros( (inputs[0].size(0), inputs[0].size(1), bittensor.__network_dim__), dtype=torch.float32)] * len(endpoints) - forward_codes= [bittensor.proto.ReturnCode.Timeout] * len(endpoints) - forward_times= [15] * len(endpoints) + # Init receptors. + receptors = [ self._get_or_create_receptor_for_endpoint( endpoint ) for endpoint in endpoints ] + + # Init argument iterables. + call_args = [] + for idx, receptor in enumerate( receptors ): + call_args.append({ + 'receptor': receptor, + 'inputs': inputs [ idx ] , + 'synapses': synapses, + 'timeout': timeout + }) + + # Init function. + def call_forward( args ): + return args['receptor'].forward( args['synapses'], args['inputs'], args['timeout'] ) + + # Submit calls to receptors. + with concurrent.futures.ThreadPoolExecutor( max_workers = len(endpoints) ) as executor: + responses = executor.map( call_forward, call_args, timeout=10*timeout) - except Exception as e: - forward_outputs= [torch.zeros( (inputs[0].size(0), inputs[0].size(1), bittensor.__network_dim__), dtype=torch.float32)] * len(endpoints) - forward_codes= [bittensor.proto.ReturnCode.UnknownException] * len(endpoints) - forward_times= [15] * len(endpoints) - logger.exception('Exception encountered: {}'.format(e)) + # Release semephore. + for receptor in receptors: + receptor.semaphore.release() + + # Unpack responses + forward_outputs = [] + forward_codes = [] + forward_times = [] + for response in responses: + forward_outputs.append( response[0] ) + forward_codes.append( response[1] ) + forward_times.append( response[2] ) # ---- Kill receptors ---- self._destroy_receptors_over_max_allowed() - # ---- Return ---- - return list(forward_outputs), list(forward_codes), list(forward_times) + return forward_outputs, forward_codes, forward_times def backward( self, - endpoints: List['bittensor.Endpoint'], - inputs_x: List[torch.Tensor], - grads_dy: List[torch.Tensor], - modality: bittensor.proto.Modality, + endpoints: List [ 'bittensor.Endpoint' ], + synapses: List[ 'bittensor.Synapse' ], + inputs: List [ torch.Tensor ], + grads: List [ List[ torch.FloatTensor ] ], timeout: int ) -> Tuple[List[torch.Tensor], List[int], List[float]]: r""" Backward tensor inputs to endpoints. @@ -177,66 +170,81 @@ def backward( endpoints (:obj:`List['bittensor.Endpoint']` of shape :obj:`(num_endpoints)`, `required`): List of remote endpoints which match length of x. Tensors from x are sent backward to these endpoints. - inputs_x (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): - List of tensors to send to corresponsing endpoints. Tensors are of arbitrary type and shape depending on the - modality. + synapses (:obj:`List[ 'bittensor.Synapse' ]` of shape :obj:`(num_synapses)`, `required`): + Bittensor synapse objects with arguments. Each corresponds to a synapse function on the axon. + Responses are packed in this ordering. - grads_dy (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): - List of grad tensors to send to corresponsing inputs. + inputs (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): + List of tensors to send to corresponsing endpoints. Tensors are of arbitrary type and shape depending on the + synapse. - modality (:obj:`bittensor.proto.Modality` of shape :obj:`(1)`, `required`): - Bittensor forward modality type. Enum in [TEXT, IMAGE, TENSOR] + grads (:obj:`List[torch.Tensor]` of shape :obj:`(num_endpoints * [shape])`, `required`): + List of list of grad tensors where each grad corresponds to a synapse call on an endpoint. timeout (int): request timeout. Returns: - backward_outputs (:obj:`List[torch.FloatTensor]` of shape :obj:`num_endpoints * (batch_size, sequence_len, -1)]`, `required`): - gradients of returned from backward call. + backward_outputs (:obj:`List[ List[ torch.FloatTensor] ]` of shape :obj:`num_endpoints * (batch_size, sequence_len, -1)]`, `required`): + Gradients returned from the backward call one per endpoint. - backward_codes (:obj:`List[bittensor.proto.ReturnCodes]` of shape :obj:`(num_endpoints)`, `required`): - dendrite call return ops. + backward_codes (:obj:`List[ List[ bittensor.proto.ReturnCodes ] ]` of shape :obj:`(num_endpoints)`, `required`): + List of list of Backward call return ops, one per endpoint and synapse. backward_times (:obj:`List[float]` of shape :obj:`(num_endpoints)`, `required`): - dendrite call times. + List of list of Backward call times one per endpoint and synapse. """ - if len(endpoints) != len(inputs_x): - raise ValueError('Endpoints and inputs must have the same length. Got {} and {}'.format(len(endpoints), len(inputs_x))) - - # ---- Fill calls ---- - call_args = [ - (self._get_or_create_receptor_for_endpoint( endpoint ), inputs_x, grads_dy, modality) - for (inputs_x, grads_dy, endpoint) in - list(zip( inputs_x, grads_dy, endpoints )) - ] - - # ---- Preprocessing for the forward function, get the request. ---- - requests = [] - for arg in call_args: - receptor, inputs, grads_dy, modality = arg - requests.append(receptor.preprocess_request ( inputs = inputs, modality = modality, grads_dy = grads_dy, backward = True)) - - # ---- Send the forward request to peers. ---- - request_futures = [] - for arg, request in zip(call_args, requests): - receptor = arg[0] - request_futures.append(receptor.make_request_call(request = request, timeout = timeout)) - - for request_future in request_futures: - request_future.future.cancel() + if len(endpoints) != len(inputs): + raise ValueError('Endpoints must have the same length as passed inputs. Got {} and {}'.format(len(endpoints), len(inputs))) + if len(endpoints) != len(grads): + raise ValueError('Endpoints must have the same length as passed grads_dy. Got {} and {}'.format(len(endpoints), len(grads))) + for grads_per_synapse in grads: + if len(grads_per_synapse) != len(synapses): + raise ValueError('Gradients must have the same length as passed synapses. Got {} and {}'.format(len(grads_per_synapse), len(synapses))) + + # Init receptors. + receptors = [ self._get_or_create_receptor_for_endpoint( endpoint ) for endpoint in endpoints ] + + # Init argument iterables. + call_args = [] + for idx, receptor in enumerate( receptors ): + call_args.append({ + 'receptor': receptor, + 'synapses': synapses, + 'inputs': inputs [ idx ] , + 'grads': grads [ idx ] , + 'timeout': timeout + }) + + # Init function. + def call_backward( args ): + return args['receptor'].backward ( + synapses = args['synapses'], + inputs = args['inputs'], + grads = args['grads'], + timeout = args['timeout'] + ) - for arg in call_args: - receptor = arg[0] - receptor.semaphore.release() + # Submit calls to receptors. + with concurrent.futures.ThreadPoolExecutor( max_workers = len(endpoints) ) as executor: + responses = executor.map ( call_backward, call_args, timeout=10*timeout ) - # ---- Return zeros ---- - backward_outputs= [torch.zeros( (inputs_x[0].size(0), inputs_x[0].size(1), bittensor.__network_dim__), dtype=torch.float32)] * len(endpoints) - backward_codes= [bittensor.proto.ReturnCode.Timeout] * len(endpoints) - backward_times= [15] * len(endpoints) + # Release semephore. + for receptor in receptors: + receptor.semaphore.release() + + # Unpack responses + backward_outputs = [] + backward_codes = [] + backward_times = [] + for response in responses: + backward_outputs.append( response[0] ) + backward_codes.append( response[1] ) + backward_times.append( response[2] ) # ---- Kill receptors ---- self._destroy_receptors_over_max_allowed() - + # ---- Return ---- return backward_outputs, backward_codes, backward_times def _destroy_receptors_over_max_allowed( self ): diff --git a/bittensor/_serializer/__init__.py b/bittensor/_serializer/__init__.py index c573915cac..a5b2b4310e 100644 --- a/bittensor/_serializer/__init__.py +++ b/bittensor/_serializer/__init__.py @@ -20,6 +20,7 @@ import torch import numpy as np import bittensor +from typing import Tuple, List, Union, Optional from . import serializer_impl @@ -38,12 +39,12 @@ class NoSerializerForEnum (Exception): class SerializationTypeNotImplementedException (Exception): """ Raised if serialization/deserialization is not implemented for the passed object type """ - def __new__(cls, serialzer_type: bittensor.proto.Serializer = bittensor.proto.Serializer.MSGPACK ) -> 'bittensor.Serializer': + def __new__(cls, serializer_type: bittensor.proto.Serializer = bittensor.proto.Serializer.MSGPACK ) -> 'bittensor.Serializer': r"""Returns the correct serializer object for the passed Serializer enum. Args: - serialzer_type (:obj:`bittensor.proto.Serializer`, `required`): - The serialzer_type ENUM from bittensor.proto. + serializer_type (:obj:`bittensor.proto.Serializer`, `required`): + The serializer_type ENUM from bittensor.proto. Returns: Serializer: (obj: `bittensor.Serializer`, `required`): @@ -54,14 +55,14 @@ def __new__(cls, serialzer_type: bittensor.proto.Serializer = bittensor.proto.Se Raised if the passed there is no serialzier for the passed type. """ # WARNING: the pickle serializer is not safe. Should be removed in future verions. - # if serialzer_type == bittensor.proto.Serializer.PICKLE: + # if serializer_type == bittensor.proto.Serializer.PICKLE: # return PyTorchPickleSerializer() - if serialzer_type == bittensor.proto.Serializer.MSGPACK: + if serializer_type == bittensor.proto.Serializer.MSGPACK: return serializer_impl.MSGPackSerializer() - elif serialzer_type == bittensor.proto.Serializer.CMPPACK: + elif serializer_type == bittensor.proto.Serializer.CMPPACK: return serializer_impl.CMPPackSerializer() else: - raise bittensor.serializer.NoSerializerForEnum("No known serialzier for proto type {}".format(serialzer_type)) + raise bittensor.serializer.NoSerializerForEnum("No known serialzier for proto type {}".format(serializer_type)) @staticmethod def torch_dtype_to_bittensor_dtype(tdtype): @@ -136,3 +137,4 @@ def bittensor_dtype_np_dtype(bdtype): 'Unknown bittensor.dtype or no equivalent numpy.dtype for bittensor.dtype = {}' .format(bdtype)) return dtype + diff --git a/bittensor/_serializer/serializer_impl.py b/bittensor/_serializer/serializer_impl.py index 9cba3f8201..a17a9ae1ef 100644 --- a/bittensor/_serializer/serializer_impl.py +++ b/bittensor/_serializer/serializer_impl.py @@ -20,6 +20,8 @@ import torch import msgpack import msgpack_numpy +from typing import Tuple, List, Union, Optional + import bittensor @@ -28,18 +30,24 @@ class Serializer(object): various python tensor equivalents. i.e. torch.Tensor or tensorflow.Tensor """ - def serialize (self, tensor_obj: object, modality: bittensor.proto.Modality, from_type: int) -> bittensor.proto.Tensor: + @staticmethod + def empty(): + """Returns an empty bittensor.proto.Tensor message with the version""" + torch_proto = bittensor.proto.Tensor(version= bittensor.__version_as_int__) + return torch_proto + + def serialize (self, tensor_obj: object, modality: bittensor.proto.Modality= bittensor.proto.Modality.TEXT, from_type: int = bittensor.proto.TensorType.TORCH) -> bittensor.proto.Tensor: """Serializes a torch object to bittensor.proto.Tensor wire format. Args: tensor_obj (:obj:`object`, `required`): general tensor object i.e. torch.Tensor or tensorflow.Tensor - from_type (`obj`: bittensor.proto.TensorType, `required`): + from_type (`obj`: bittensor.proto.TensorType, `Optional`): Serialization from this type. i.e. bittensor.proto.TensorType.TORCH or bittensor.proto.TensorType.TENSORFLOW Returns: - tensor_pb2: (obj: `bittensor.proto.Tensor`, `required`): + tensor_pb2: (obj: `bittensor.proto.Tensor`, `Optional`): Serialized tensor as bittensor.proto.proto. Raises: @@ -122,6 +130,7 @@ def deserialize_to_numpy(self, tensor_pb2: bittensor.proto.Tensor) -> object: """ bittensor.proto.Tensor -> numpy """ raise bittensor.serializer.SerializationTypeNotImplementedException + class MSGPackSerializer( Serializer ): """ Make conversion between torch and bittensor.proto.torch """ @@ -222,3 +231,4 @@ def deserialize_to_torch(self, torch_proto: bittensor.proto.Tensor) -> torch.Ten numpy_object = msgpack.unpackb(torch_proto.buffer, object_hook=msgpack_numpy.decode).copy() torch_object = torch.as_tensor(numpy_object).view(shape).requires_grad_(torch_proto.requires_grad) return torch_object.type(dtype) + diff --git a/bittensor/_subtensor/__init__.py b/bittensor/_subtensor/__init__.py index 08d9ed88db..9819c897ed 100644 --- a/bittensor/_subtensor/__init__.py +++ b/bittensor/_subtensor/__init__.py @@ -200,7 +200,7 @@ def add_defaults(cls, defaults ): @staticmethod def check_config( config: 'bittensor.Config' ): assert config.subtensor - assert config.subtensor.network != None + #assert config.subtensor.network != None @staticmethod def determine_chain_endpoint(network: str): diff --git a/bittensor/_subtensor/subtensor_impl.py b/bittensor/_subtensor/subtensor_impl.py index 75f8aa5d23..e1f3dc1b33 100644 --- a/bittensor/_subtensor/subtensor_impl.py +++ b/bittensor/_subtensor/subtensor_impl.py @@ -26,7 +26,7 @@ from retry import retry from substrateinterface import SubstrateInterface from bittensor.utils.balance import Balance -from bittensor.utils import is_valid_destination_address +from bittensor.utils import is_valid_bittensor_address_or_public_key from types import SimpleNamespace # Mocking imports @@ -427,7 +427,7 @@ def serve_axon ( wallet = axon.wallet, ip = external_ip, port = external_port, - modality = axon.modality, + modality = 0, wait_for_inclusion = wait_for_inclusion, wait_for_finalization = wait_for_finalization, prompt = prompt @@ -759,7 +759,8 @@ def transfer( If we did not wait for finalization / inclusion, the response is true. """ # Validate destination address. - if not is_valid_destination_address( dest ): + if not is_valid_bittensor_address_or_public_key( dest ): + bittensor.__console__.print(":cross_mark: [red]Invalid destination address[/red]:[bold white]\n {}[/bold white]".format(dest)) return False if isinstance( dest, bytes): diff --git a/bittensor/_subtensor/subtensor_mock.py b/bittensor/_subtensor/subtensor_mock.py index bd9900c5f3..13e824f6e0 100644 --- a/bittensor/_subtensor/subtensor_mock.py +++ b/bittensor/_subtensor/subtensor_mock.py @@ -43,7 +43,6 @@ GLOBAL_SUBTENSOR_MOCK_PROCESS_NAME = "node-subtensor" - class mock_subtensor(): r""" Returns a subtensor connection interface to a mocked subtensor process running in the background. Optionall creates the background process if it does not exist. diff --git a/bittensor/_synapse/__init__.py b/bittensor/_synapse/__init__.py new file mode 100644 index 0000000000..941c0b6baa --- /dev/null +++ b/bittensor/_synapse/__init__.py @@ -0,0 +1,228 @@ +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from multiprocessing.sharedctypes import Value +from yaml import serialize_all +import bittensor +import torch +from typing import Union, List, Tuple, Optional + +from bittensor._serializer import serializer +from .synapse_impl import Synapse, NullSynapse +from .text_causallm_impl import TextCausalLM +from .text_causallmnext_impl import TextCausalLMNext +from .text_lasthiddenstate_impl import TextLastHiddenState +from .text_seq2seq_impl import TextSeq2Seq + + + +class synapse: + """ + Factory class for the synapse objects. The synapses are designed to work the bittensor protocol and is + reponsible for the serialization and deserialization of their contents. They are expected to be included by + the forwarding neuron when making a call through the bittensor api. + + Examples: + >>> causallm_synapse = bittensor.synapse.TextCausalLM() + >>> dendrite.text(endpoints = [..], inputs = [..], synapses= [causallm_synapse] ) + + """ + __synapses_types__ = ['TextLastHiddenState', 'TextCausalLM', 'TextSeq2Seq'] + + @staticmethod + def TextLastHiddenState ( + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ) -> TextLastHiddenState: + """ Factory function which returns a TextLastHiddenState synapse adapter given arguments. + Args: + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serialzer used to pack torch tensors on backward response. + Returns: + TextLastHiddenState (:obj:`TextLastHiddenState`, `required`): + TextLastHiddenState instance adapter class. + """ + return TextLastHiddenState ( + forward_request_serializer_type = forward_request_serializer_type, + forward_response_serializer_type = forward_response_serializer_type, + backward_request_serializer_type = backward_request_serializer_type, + backward_response_serializer_type = backward_response_serializer_type, + ) + + @staticmethod + def TextCausalLM ( + topk:int = 512, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ) -> TextCausalLM: + """ Factory function which returns a TextCausalLM synapse adapter given arguments. + Args: + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serialzer used to pack torch tensors on backward response. + Returns: + TextCausalLM (:obj:`TextCausalLM`, `required`): + TextCausalLM instance adapter class. + """ + return TextCausalLM ( + topk = topk, + forward_request_serializer_type = forward_request_serializer_type, + forward_response_serializer_type = forward_response_serializer_type, + backward_request_serializer_type = backward_request_serializer_type, + backward_response_serializer_type = backward_response_serializer_type, + ) + + @staticmethod + def TextCausalLMNext( + topk: int = 4096, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ) -> TextCausalLMNext: + """ Factory function which returns a TextCausalLMNext synapse adapter given arguments. + Args: + topk (:obj:`int`): + Specifies the number of topk server token phrases to return. + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on backward response. + Returns: + TextCausalLMNext (:obj:`TextCausalLMNext`, `required`): + TextCausalLMNext instance adapter class. + """ + return TextCausalLMNext( + topk=topk, + forward_request_serializer_type=forward_request_serializer_type, + forward_response_serializer_type=forward_response_serializer_type, + backward_request_serializer_type=backward_request_serializer_type, + backward_response_serializer_type=backward_response_serializer_type, + ) + + @staticmethod + def TextSeq2Seq ( + topk:int = 50, + num_to_generate: int = 256, + num_beams: int = 5, + no_repeat_ngram_size: int = 2, + early_stopping: bool = False, + num_return_sequences: int = 1, + do_sample: bool = False, + top_p: float = 0.95, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + length_penalty: float = 1.0, + max_time: float = 150, + num_beam_groups: int = 1, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ) -> TextSeq2Seq: + """ Factory function which returns a TextSeq2Seq synapse adapter given arguments. + Args: + Topk (:obj:int, :default: 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + num_to_generate (:obj: int, :default: 256): + The number of tokens to generate using the language model + num_beams (:obj: int, :default: 5): + The number of beams to keep during beam search + no_repeat_ngram_size (:obj: int, :default: 2): + The number of repeat n gram allowed + early_stopping: (:obj: bool, :default: True): + If the model should early stop if the probabilty drops a certain threshold + num_return_sequences: (:obj: int, :default: 1): + How many sequences should the model return + do_sample (:obj: bool, :default: False): + If the model should do sample its probablity during generation + top_p (:obj: float, :default: 0.95): + probability cutoff for top p sampling + temperature: (:obj: float, :default: 1.0): + The value used to module the next token probabilities for the softmax calculation + repetition_penalty (:obj: float, :default: 1.0): + The parameter for repetition penalty. 1.0 means no penalty. + length_penalty (:obj: float, :default: 1.0): + The parameter for length penalty. 0.0 means no penalty, <0 to encourage longer sequences. + num_beam_groups (:obj: int, :default: 1): + Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. + max_time (:obj: float, :default: 150): + The maximum time that a server can use to generate + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serialzer used to pack torch tensors on backward response. + Returns: + TextSeq2Seq (:obj:`TextSeq2Seq`, `required`): + TextSeq2Seq instance adapter class. + """ + return TextSeq2Seq ( + topk = topk, + num_to_generate = num_to_generate, + num_beams = num_beams, + no_repeat_ngram_size = no_repeat_ngram_size, + early_stopping = early_stopping, + num_return_sequences = num_return_sequences, + do_sample = do_sample, + top_p = top_p, + temperature = temperature, + repetition_penalty = repetition_penalty, + length_penalty = length_penalty, + num_beam_groups = num_beam_groups, + max_time = max_time, + forward_request_serializer_type = forward_request_serializer_type, + forward_response_serializer_type = forward_response_serializer_type, + backward_request_serializer_type = backward_request_serializer_type, + backward_response_serializer_type = backward_response_serializer_type, + ) + + @staticmethod + def deserialize( synapse_wire_proto: bittensor.proto.Synapse ) -> Synapse: + if synapse_wire_proto.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE: + return TextLastHiddenState.deserialize_from_wire_proto ( synapse_wire_proto ) + elif synapse_wire_proto.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM: + return TextCausalLM.deserialize_from_wire_proto( synapse_wire_proto ) + elif synapse_wire_proto.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT: + return TextCausalLMNext.deserialize_from_wire_proto(synapse_wire_proto) + elif synapse_wire_proto.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ: + return TextSeq2Seq.deserialize_from_wire_proto( synapse_wire_proto ) + else: + return NullSynapse() \ No newline at end of file diff --git a/bittensor/_synapse/synapse_impl.py b/bittensor/_synapse/synapse_impl.py new file mode 100644 index 0000000000..8747211e3d --- /dev/null +++ b/bittensor/_synapse/synapse_impl.py @@ -0,0 +1,207 @@ +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import bittensor +import torch +from typing import Union, List, Tuple, Optional + +class Synapse: + """ Proto serializable class which specifies the function to be called on a recieving neuron + as well as the method of serialization and packing of forward/backward request/responses. + """ + + # Unique proto enum. + synapse_type: bittensor.proto.Synapse.SynapseType = None + + def __init__( + self, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ) -> 'Synapse': + """ Synapse super class initializer. + Args: + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serialzer used to pack torch tensors on backward response. + Returns: + Synapse (:obj:`Synapse`, `required`): + Synapse super class. + """ + self.forward_request_serializer_type = forward_request_serializer_type + self.forward_response_serializer_type = forward_response_serializer_type + self.backward_request_serializer_type = backward_request_serializer_type + self.backward_response_serializer_type = backward_response_serializer_type + + def __repr__(self) -> str: return self.__str__() + def __str__(self) -> str: return "Synapse" + + @staticmethod + def deserialize_from_instance_proto ( isntance_proto: bittensor.proto.Synapse ) -> 'Synapse': + """ Deserialzied the instance proto to an instance class. + Args: + isntance_proto (:obj:`bittensor.proto.Synapse` of shape :obj:`(1)`, `required`): + Synapse instance proto to be deserialized. + Returns: + synapse_instance_clasee (:obj:`torch.Tensor`, `required`): + Deserialized instance class. + """ + raise NotImplementedError("deserialize_from_instance_proto should be implemented by the subclass.") + + @staticmethod + def deserialize_from_wire_proto ( wire_proto: bittensor.proto.Synapse ) -> 'Synapse': + """ Deserialzied the wire proto to an instance class. + Args: + wire_proto (:obj:`bittensor.proto.Synapse` of shape :obj:`(1)`, `required`): + Synapse wire proto to be deserialized. + Returns: + synapse_instance_clasee (:obj:`torch.Tensor`, `required`): + Deserialized instance class. + """ + raise NotImplementedError("deserialize_from_wire_proto should be implemented by the subclass.") + + def serialize_to_instance_proto( self, **kwargs ) -> 'bittensor.proto.Synapse': + """ Serializes the class instance to a Synapse instance proto. + Returns: + serialized_synapse_as_instance_proto (:obj:`torch.Tensor`, `required`): + Instance class serialized to a instance proto. + """ + raise NotImplementedError("serialize_to_instance_proto should be implemented by the subclass.") + + def serialize_to_wire_proto( self, **kwargs ) -> 'bittensor.proto.Synapse': + """ Serializes the class instance to a Synapse wire proto. + Returns: + serialized_synapse_as_wire_proto (:obj:`torch.Tensor`, `required`): + Instance class serialized to a wire proto. + """ + raise NotImplementedError("serialize_to_wire_proto should be implemented by the subclass.") + + def nill_forward_response_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + """ Returns a zeroed tensor used as response to a dendrite forward call when the call fails. + Args: + forward_request_tensor (:obj:`torch.Tensor`, `required`): + Tensor being sent as forward request. + Returns: + nill_forward_response_tensor (:obj:`torch.Tensor`, `required`): + Zeroed forward response tensor. + """ + raise NotImplementedError("nill_forward_response_tensor should be implemented by the subclass.") + + def nill_backward_response_tensor ( self, forward_request_tensor : torch.Tensor ) -> torch.Tensor: + """ Returns a zeroed tensor used as response to a dendrite backward call when the call fails. + Args: + forward_request_tensor (:obj:`torch.Tensor`, `required`): + Tensor being sent as forward request. + Returns: + nill_backward_response_tensor (:obj:`torch.Tensor`, `required`): + Zeroed backward response gradient. + """ + raise NotImplementedError("nill_backward_response_tensor should be implemented by the subclass.") + + def check_forward_request_tensor ( self, forward_request_tensor ): pass + def check_forward_response_tensor ( self, forward_request_tensor, forward_response_tensor ): pass + def check_backward_request_gradient ( self, forward_request_tensor, backward_request_gradient ): pass + def check_backward_response_gradient ( self, forward_request_tensor, backward_request_gradient ): pass + def encode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: return forward_request_tensor + def decode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: return forward_request_tensor + def encode_forward_response_tensor ( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: return forward_response_tensor + def decode_forward_response_tensor ( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: return forward_response_tensor + def encode_backward_request_gradient ( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: return backward_request_gradient + def decode_backward_request_gradient ( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: return backward_request_gradient + def encode_backward_response_gradient ( self, backward_response_gradient: torch.Tensor ) -> torch.Tensor: return backward_response_gradient + def decode_backward_response_gradient ( self, backward_response_gradient: torch.Tensor ) -> torch.Tensor: return backward_response_gradient + + def serialize_forward_request_tensor( self, forward_request_tensor: torch.Tensor ) -> Tuple[ 'bittensor.proto.Tensor', 'bittensor.proto.ReturnCode', str ]: + self.check_forward_request_tensor ( forward_request_tensor ) + forward_request_tensor = self.encode_forward_request_tensor ( forward_request_tensor ) + tensor_serialzier = bittensor.serializer( serializer_type = self.forward_request_serializer_type ) + return tensor_serialzier.serialize( tensor_obj = forward_request_tensor, from_type = bittensor.proto.TensorType.TORCH ) + + def deserialize_forward_request_tensor( self, forward_request_proto: bittensor.proto.Tensor ) -> Tuple[ 'torch.Tensor', 'bittensor.proto.ReturnCode', str ]: + """ Returns a torch.Tensor from wire proto.Tensor after relevant deserialization has been applied. """ + tensor_deserialzier = bittensor.serializer( serializer_type = self.forward_request_serializer_type ) + forward_request_tensor = tensor_deserialzier.deserialize( tensor_pb2 = forward_request_proto, to_type = bittensor.proto.TensorType.TORCH ) + forward_request_tensor = self.decode_forward_request_tensor ( forward_request_tensor ) + self.check_forward_request_tensor ( forward_request_tensor ) + return forward_request_tensor + + def serialize_forward_response_tensor( self, forward_request_tensor: torch.Tensor, forward_response_tensor: torch.Tensor ) -> Tuple[ 'bittensor.proto.Tensor', 'bittensor.proto.ReturnCode', str ]: + """ Returns a bittensor.proto.Tensor to be sent on the wire after relevant serialization applied. """ + encoded_tensor = self.encode_forward_response_tensor ( forward_response_tensor ) + self.check_forward_response_tensor ( forward_request_tensor, encoded_tensor ) + tensor_serialzier = bittensor.serializer( serializer_type = self.forward_response_serializer_type ) + return tensor_serialzier.serialize( tensor_obj = encoded_tensor, from_type = bittensor.proto.TensorType.TORCH ) + + def deserialize_forward_response_proto( self, forward_request_tensor: torch.Tensor, forward_response_proto: bittensor.proto.Tensor ) -> Tuple[ 'torch.Tensor', 'bittensor.proto.ReturnCode', str ]: + """ Returns a torch.Tensor from wire proto.Tensor after relevant deserialization has been applied. """ + tensor_deserialzier = bittensor.serializer( serializer_type = self.forward_response_serializer_type ) + forward_response_tensor = tensor_deserialzier.deserialize( tensor_pb2 = forward_response_proto, to_type = bittensor.proto.TensorType.TORCH ) + self.check_forward_response_tensor ( forward_request_tensor, forward_response_tensor ) + forward_response_tensor = self.decode_forward_response_tensor ( forward_response_tensor ) + forward_response_tensor = torch.nan_to_num( forward_response_tensor, nan=0) + return forward_response_tensor + + def serialize_backward_request_gradient( self, forward_request_tensor: torch.Tensor, backward_request_gradient: torch.Tensor ) -> Tuple[ 'bittensor.proto.Tensor', 'bittensor.proto.ReturnCode', str ]: + """ Returns a bittensor.proto.Tensor gradient to be sent on the wire after relevant serialization applied. """ + self.check_backward_request_gradient ( forward_request_tensor, backward_request_gradient ) + encoded_tensor = self.encode_backward_request_gradient ( backward_request_gradient ) + tensor_serialzier = bittensor.serializer( serializer_type = self.forward_request_serializer_type ) + return tensor_serialzier.serialize( tensor_obj = encoded_tensor, from_type = bittensor.proto.TensorType.TORCH ) + + def deserialize_backward_request_gradient( self, forward_request_tensor: torch.Tensor, backward_request_proto: bittensor.proto.Tensor ) -> Tuple[ 'torch.Tensor', 'bittensor.proto.ReturnCode', str ]: + tensor_deserialzier = bittensor.serializer( serializer_type = self.backward_request_serializer_type ) + backward_request_gradient = tensor_deserialzier.deserialize( tensor_pb2 = backward_request_proto, to_type = bittensor.proto.TensorType.TORCH ) + backward_request_gradient = self.decode_backward_request_gradient ( backward_request_gradient ) + self.check_backward_request_gradient (forward_request_tensor, backward_request_gradient ) + return backward_request_gradient + + def empty(self): + tensor_deserialzier = bittensor.serializer( serializer_type = self.forward_request_serializer_type ) + return tensor_deserialzier.empty() + + +class NullSynapse (Synapse): + """ Null Synapse type + """ + synapse_type: bittensor.proto.Synapse.SynapseType = bittensor.proto.Synapse.SynapseType.NULL_SYNAPSE + + def __init__( + self + ) -> 'NullSynapse': + """ Null Synapse initializer. Used when a request contains synapses that has not been initalized + Returns: + NullSynapse (:obj:`NullSynapse`, `required`): + NullSynapse instance adapter class. + """ + super().__init__ () + self.synapse_type = NullSynapse.synapse_type + + def __repr__(self) -> str: return self.__str__() + def __str__(self) -> str: return "Null" + + def serialize_to_wire_proto ( self, code: 'bittensor.proto.ReturnCode' = 0, message: str = '' ) -> bittensor.proto.Synapse: + return bittensor.proto.Synapse ( + synapse_type = NullSynapse.synapse_type, + return_code = code, + message = message + ) \ No newline at end of file diff --git a/bittensor/_synapse/text_causallm_impl.py b/bittensor/_synapse/text_causallm_impl.py new file mode 100644 index 0000000000..7c89dbd462 --- /dev/null +++ b/bittensor/_synapse/text_causallm_impl.py @@ -0,0 +1,180 @@ +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import bittensor +import torch +from typing import Union, List, Tuple, Optional +from .synapse_impl import Synapse + +class TextCausalLM (Synapse): + """ TextCausalLM Synapse type for next token prediction from languge models. + """ + synapse_type: bittensor.proto.Synapse.SynapseType = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM + + def __init__( + self, + topk: int = 512, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ): + """ TextCausalLM Synapse initializer. + Args: + Topk (:obj:int, :default: 512): + The top k number of logits to compress and send over the wire + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serialzer used to pack torch tensors on backward response. + Returns: + TextLastHiddenState (:obj:`TextLastHiddenState`, `required`): + TextLastHiddenState instance adapter class. + """ + super().__init__ ( + forward_request_serializer_type, + forward_response_serializer_type, + backward_request_serializer_type, + backward_response_serializer_type + ) + self.topk = topk + self.synapse_type = TextCausalLM.synapse_type + + def __repr__(self) -> str: return self.__str__() + def __str__(self) -> str: return "TextCausalLM" + + @staticmethod + def deserialize_from_instance_proto ( instance_proto: bittensor.proto.Synapse ) -> 'TextCausalLM': + return TextCausalLM ( + topk = instance_proto.topk, + forward_request_serializer_type = instance_proto.forward_request_serializer_type, + forward_response_serializer_type = instance_proto.forward_response_serializer_type, + backward_request_serializer_type = instance_proto.backward_request_serializer_type, + backward_response_serializer_type = instance_proto.backward_response_serializer_type, + ) + + @staticmethod + def deserialize_from_wire_proto ( wire_proto: bittensor.proto.Synapse ) -> 'TextCausalLM': + instance_proto = bittensor.proto.Synapse.TextCausalLM() + instance_proto.ParseFromString( wire_proto.synapse_data ) + return TextCausalLM.deserialize_from_instance_proto( instance_proto ) + + def serialize_to_instance_proto( self ) -> 'bittensor.proto.Synapse.TextCausalLM': + return bittensor.proto.Synapse.TextCausalLM ( + topk = self.topk, + forward_request_serializer_type = self.forward_request_serializer_type, + forward_response_serializer_type = self.forward_response_serializer_type, + backward_request_serializer_type = self.backward_request_serializer_type, + backward_response_serializer_type = self.backward_response_serializer_type, + ) + + def serialize_to_wire_proto ( self, code: 'bittensor.proto.ReturnCode' = 0, message: str = '' ) -> bittensor.proto.Synapse: + return bittensor.proto.Synapse ( + synapse_data = self.serialize_to_instance_proto().SerializeToString(), + synapse_type = TextCausalLM.synapse_type, + return_code = code, + message = message + ) + + def check_forward_request_tensor ( self, forward_request_tensor ): + if len( forward_request_tensor.shape ) != 2 or forward_request_tensor.shape[0] == 0 or forward_request_tensor.shape[1] == 0: + raise ValueError( "forward_request_tensor.shape must be in [-1, -1], got: {} for synapse: {}".format( list(forward_request_tensor.shape), self ) ) + + def check_forward_response_tensor ( self, forward_request_tensor, forward_response_tensor ): + if forward_response_tensor == None: + raise ValueError("Empty Response") + + if ( + len( forward_response_tensor.shape ) != 3 or + forward_response_tensor.size(0) != forward_request_tensor.size(0) or + forward_response_tensor.size(1) != forward_request_tensor.size(1) or + forward_response_tensor.size(2) != self.topk*2 + ): + raise ValueError( "forward_response_tensor.shape must be in [{}, {}, {}], got: {} for synapse: {}".format( forward_request_tensor.size(0) , forward_request_tensor.size(1), self.topk*2, list(forward_response_tensor.shape), self ) ) + + def check_backward_request_gradient ( self, forward_request_tensor, backward_request_gradient ): + if ( len( backward_request_gradient.shape ) != 3 or + backward_request_gradient.size(0) != forward_request_tensor.size(0) or + backward_request_gradient.size(1) != forward_request_tensor.size(1) or + backward_request_gradient.size(2) != bittensor.__vocab_size__ + ): + raise ValueError( "backward_request_gradient.shape: {} must be equivalent to forward_request_tensor.shape: {} for synapse: {}".format( list( backward_request_gradient.shape ), list(forward_request_tensor.shape), self ) ) + + def encode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: return forward_request_tensor + def decode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: return forward_request_tensor + + def encode_forward_response_tensor( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: + """ Returns topk tokens/probabilities given unnormalized logits as input. """ + logits = forward_response_tensor # unnormalized logit scores: [batch_size, sequence_len, vocab_size] + probs = torch.softmax(logits, dim=-1) # normalized probabilities: [batch_size, sequence_len, vocab_size] + topk_values, topk_indices = torch.topk(probs, self.topk) # topk probs and indices: [batch_size, sequence_len, topk] + encoded_probs = torch.cat((topk_values, topk_indices), dim=-1) # [batch_size, sequence_len, topk + topk] + return encoded_probs # [batch_size, sequence_len, topk + topk] + + def decode_forward_response_tensor( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: + """ Returns full logits by decoding topk-encoding input. """ + batch_size, sequence_len, _ = forward_response_tensor.shape + encoded_probs = forward_response_tensor # encoded probabilities: [batch_size, sequence_len, topk + topk] + topk_values = encoded_probs[..., :self.topk] # topk probs: [batch_size, sequence_len, topk] + topk_indices = encoded_probs[..., self.topk:].long() # topk probs indices: [batch_size, sequence_len, topk] + + topk_pmass = topk_values.sum(dim=-1) # topk probability mass: [batch_size, sequence_len] + remainder_pmass = torch.clamp(1 - topk_pmass, 1e-40, 1) # remainder probability mass: [batch_size, sequence_len] + remainder_floor = remainder_pmass / (bittensor.__vocab_size__ - self.topk) # divide remainder: [batch_size, sequence_len] + + logits = torch.ones((batch_size, sequence_len, bittensor.__vocab_size__)).to(topk_values.device) + logits *= torch.log(remainder_floor)[:, :, None] # set probability floor: [batch_size, sequence_len, vocab_size] + logits.scatter_(-1, topk_indices, torch.log(topk_values + 1e-40)) # insert topk probs: [batch_size, sequence_len, vocab_size] + + return logits # [batch_size, sequence_len, vocab_size] + + def encode_backward_response_gradient( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: return backward_request_gradient + def decode_backward_response_gradient ( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: return backward_request_gradient + + def encode_backward_request_gradient( self, backward_response_gradient: torch.Tensor ) -> torch.Tensor: + """ Return topk most negative token grads given full logit gradients. """ + values, indices = torch.topk(backward_response_gradient, self.topk) # ascend sort to get most negative gradients - informs on ideal logits + encoded_grads = torch.cat((values, indices), dim=-1) # [batch_size, sequence_len, topk + topk] + return encoded_grads # [batch_size, sequence_len, topk + topk] + + def decode_backward_request_gradient( self, backward_response_gradient: torch.Tensor ) -> torch.Tensor: + """ Return full gradients by decoding topk-encoding input. """ + batch_size, sequence_len, _ = backward_response_gradient.shape + encoded_grads = backward_response_gradient # encoded gradients: [batch_size, sequence_len, topk + topk] + topk_values = encoded_grads[..., :self.topk] # topk grads: [batch_size, sequence_len, topk] + topk_indices = encoded_grads[..., self.topk:].long() # topk grads indices: [batch_size, sequence_len, topk] + + gradients = torch.zeros((batch_size, sequence_len, bittensor.__vocab_size__)).to(topk_values.device) + gradients.scatter_(-1, topk_indices, topk_values) # insert topk grads: [batch_size, sequence_len, vocab_size] + + return gradients # [batch_size, sequence_len, vocab_size] + + def nill_forward_response_tensor( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + try: + return torch.zeros( ( forward_request_tensor.size(0), forward_request_tensor.size(1), bittensor.__vocab_size__ ), dtype=torch.float32) + except: + return torch.tensor([]) + + def nill_backward_response_tensor( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + try: + return torch.zeros( ( forward_request_tensor.size(0), forward_request_tensor.size(1), bittensor.__vocab_size__ ), dtype=torch.float32) + except: + return torch.tensor([]) diff --git a/bittensor/_synapse/text_causallmnext_impl.py b/bittensor/_synapse/text_causallmnext_impl.py new file mode 100644 index 0000000000..284b018d9e --- /dev/null +++ b/bittensor/_synapse/text_causallmnext_impl.py @@ -0,0 +1,194 @@ +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import bittensor +import torch +from .synapse_impl import Synapse +from bittensor.utils.tokenizer_utils import compact_topk_token_phrases, unravel_topk_token_phrases + + +class TextCausalLMNext(Synapse): + """ TextCausalLMNext Synapse type for next token prediction from language models. + """ + synapse_type: bittensor.proto.Synapse.SynapseType = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT + + def __init__( + self, + topk: int = 4096, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ): + """ TextCausalLMNext Synapse initializer. + Args: + topk (:obj:`int`): + Specifies the number of topk server token phrases to return. + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on backward response. + Returns: + TextCausalLMNext (:obj:`TextCausalLMNext`, `required`): + TextCausalLMNext instance adapter class. + """ + super().__init__( + forward_request_serializer_type, + forward_response_serializer_type, + backward_request_serializer_type, + backward_response_serializer_type + ) + self.topk = topk + self.synapse_type = TextCausalLMNext.synapse_type + + def __repr__(self) -> str: + return self.__str__() + + def __str__(self) -> str: + return "TextCausalLMNext" + + @staticmethod + def deserialize_from_instance_proto(instance_proto: bittensor.proto.Synapse) -> 'TextCausalLMNext': + return TextCausalLMNext( + topk=instance_proto.topk, + forward_request_serializer_type=instance_proto.forward_request_serializer_type, + forward_response_serializer_type=instance_proto.forward_response_serializer_type, + backward_request_serializer_type=instance_proto.backward_request_serializer_type, + backward_response_serializer_type=instance_proto.backward_response_serializer_type, + ) + + @staticmethod + def deserialize_from_wire_proto(wire_proto: bittensor.proto.Synapse) -> 'TextCausalLMNext': + instance_proto = bittensor.proto.Synapse.TextCausalLMNext() + instance_proto.ParseFromString(wire_proto.synapse_data) + return TextCausalLMNext.deserialize_from_instance_proto(instance_proto) + + def serialize_to_instance_proto(self) -> 'bittensor.proto.Synapse.TextCausalLMNext': + return bittensor.proto.Synapse.TextCausalLMNext( + topk=self.topk, + forward_request_serializer_type=self.forward_request_serializer_type, + forward_response_serializer_type=self.forward_response_serializer_type, + backward_request_serializer_type=self.backward_request_serializer_type, + backward_response_serializer_type=self.backward_response_serializer_type, + ) + + def serialize_to_wire_proto(self, code: 'bittensor.proto.ReturnCode' = 0, + message: str = '') -> bittensor.proto.Synapse: + return bittensor.proto.Synapse( + synapse_data=self.serialize_to_instance_proto().SerializeToString(), + synapse_type=TextCausalLMNext.synapse_type, + return_code=code, + message=message + ) + + def check_forward_request_tensor(self, forward_request_tensor): + # forward_request_tensor: [batch_size, sequence_len] + if ( + len(forward_request_tensor.shape) != 2 or + forward_request_tensor.shape[0] == 0 or + forward_request_tensor.shape[1] == 0 + ): + raise ValueError(f"forward_request_tensor.shape must be in [-1, -1], " + f"got: {list(forward_request_tensor.shape)} for synapse: {self}") + + def check_forward_response_tensor(self, forward_request_tensor, forward_response_tensor): + # forward_request_tensor: [batch_size, sequence_len] + # forward_response_tensor: [ >= batch_size * (2 * topk + 1)] + if forward_response_tensor is None: + raise ValueError("Empty Response") + + if ( + len(forward_response_tensor.shape) != 1 or + forward_response_tensor.size(0) < forward_request_tensor.shape[0] * (2 * self.topk + 1) + ): + raise ValueError(f"forward_response_tensor.shape must be in " + f"[>={forward_request_tensor.shape[0]} x (2 x {self.topk} + 1)], " + f"got: {forward_response_tensor.size(0)} for synapse: {self}") + + def check_backward_request_gradient(self, forward_request_tensor, backward_request_gradient): + # forward_request_tensor: [batch_size, sequence_len] + # backward_request_gradient: [batch_size, (topk + 1), max_len] + if ( + len(backward_request_gradient.shape) != 3 or + backward_request_gradient.size(0) != forward_request_tensor.shape[0] or + backward_request_gradient.size(1) != (self.topk + 1) + ): + raise ValueError(f"backward_request_gradient.shape must be in " + f"[{forward_request_tensor.shape[0]}, ({self.topk} + 1), max_len], " + f"got: {backward_request_gradient.shape} for synapse: {self}") + + def encode_forward_request_tensor(self, forward_request_tensor: torch.Tensor) -> torch.Tensor: + return forward_request_tensor + + def decode_forward_request_tensor(self, forward_request_tensor: torch.Tensor) -> torch.Tensor: + return forward_request_tensor + + def encode_forward_response_tensor(self, forward_response_tensor: torch.Tensor) -> torch.Tensor: + """ Compact [batch_size, (topk + 1), max_len] topk std_token_phrases to [ >= batch_size * (2 * topk + 1)]. """ + compact_topk = compact_topk_token_phrases(forward_response_tensor) + # compact_topk: [sum_b(sum_k(len(phrase_k) + 1)_b)] Compacted 1-D tensor >= batch_size * (2 * topk + 1) + return compact_topk + + def decode_forward_response_tensor(self, forward_response_tensor: torch.Tensor) -> torch.Tensor: + """ Unravel [ >= batch_size * (2 * topk + 1)] into [batch_size, (topk + 1), max_len] topk std_token_phrases. """ + topk_tensor = unravel_topk_token_phrases(forward_response_tensor, topk=self.topk) + return topk_tensor # [batch_size, (topk + 1), max_len] + + def encode_backward_response_gradient(self, backward_request_gradient: torch.Tensor) -> torch.Tensor: + return backward_request_gradient + + def decode_backward_response_gradient(self, backward_request_gradient: torch.Tensor) -> torch.Tensor: + return backward_request_gradient + + def encode_backward_request_gradient(self, backward_response_gradient: torch.Tensor) -> torch.Tensor: + """ Compact gradients of [batch_size, (topk + 1), max_len] to [2 + batch_size * (topk + 1)]. """ + batch_size, topk_p1, max_len = backward_response_gradient.shape + dims = torch.tensor([batch_size, max_len]).to(backward_response_gradient.device) + prob_grads = backward_response_gradient[:, :, 0] # [batch_size, topk + 1] first column w/ prob grads + encoded_gradient = torch.hstack((dims, prob_grads.flatten())) # [2 + batch_size * (topk + 1)] + return encoded_gradient # [2 + batch_size * (topk + 1)] + + def decode_backward_request_gradient(self, backward_response_gradient: torch.Tensor) -> torch.Tensor: + """ Restructure [2 + batch_size * (topk + 1)] prob grads into [batch_size, (topk + 1), max_len]. """ + batch_size = int(backward_response_gradient[0].item()) + max_len = int(backward_response_gradient[1].item()) + decoded_gradient = torch.zeros((batch_size, self.topk + 1, max_len)).to(backward_response_gradient.device) + decoded_gradient[:, :, 0] = backward_response_gradient[2:].reshape(batch_size, self.topk + 1) + return decoded_gradient # [batch_size, (topk + 1), max_len] + + def nill_forward_response_tensor(self, forward_request_tensor: torch.Tensor, + encoded=False, ignore_index=-100) -> torch.Tensor: + if forward_request_tensor.dim() == 0 or forward_request_tensor.shape[0] == 0: + return torch.tensor([]) + + forward_response_tensor = torch.zeros(forward_request_tensor.shape[0], (self.topk + 1), 1 + 1) + forward_response_tensor[:, :, 1] = 2 # set 2 <= token_ids to preserve 0 <= probs <= 1 in column 0 + forward_response_tensor[:, self.topk::(self.topk + 1), 1] = ignore_index # add ignore_index padding after floor_prob + + if encoded: + return self.encode_forward_response_tensor(forward_response_tensor) + + return forward_response_tensor + + def nill_backward_response_tensor(self, forward_request_tensor: torch.Tensor) -> torch.Tensor: + if forward_request_tensor.dim() == 0 or forward_request_tensor.shape[0] == 0: + return torch.tensor([]) + return torch.zeros((forward_request_tensor.shape[0], (self.topk + 1), 1 + 1), dtype=torch.float32) diff --git a/bittensor/_synapse/text_lasthiddenstate_impl.py b/bittensor/_synapse/text_lasthiddenstate_impl.py new file mode 100644 index 0000000000..f0786ee134 --- /dev/null +++ b/bittensor/_synapse/text_lasthiddenstate_impl.py @@ -0,0 +1,153 @@ +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import bittensor +import torch +from typing import Union, List, Tuple, Optional + +from .synapse_impl import Synapse + +class TextLastHiddenState (Synapse): + """ TastHiddenState Synapse type for getting last hidden layer embeddings from languge models. + """ + synapse_type: bittensor.proto.Synapse.SynapseType = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE + + def __init__( + self, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ) -> 'TextLastHiddenState': + """ TextLastHiddenState Synapse initializer. + Args: + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serialzer used to pack torch tensors on backward response. + Returns: + TextLastHiddenState (:obj:`TextLastHiddenState`, `required`): + TextLastHiddenState instance adapter class. + """ + super().__init__ ( + forward_request_serializer_type, + forward_response_serializer_type, + backward_request_serializer_type, + backward_response_serializer_type + ) + self.synapse_type = TextLastHiddenState.synapse_type + + def __repr__(self) -> str: return self.__str__() + def __str__(self) -> str: return "TextLastHiddenState" + + @staticmethod + def deserialize_from_wire_proto ( wire_proto: bittensor.proto.Synapse ) -> 'Synapse': + """ Deserialzied the wire proto to an instance class. + """ + instance_proto = bittensor.proto.Synapse.TextLastHiddenState() + instance_proto.ParseFromString( wire_proto.synapse_data ) + return TextLastHiddenState.deserialize_from_instance_proto( instance_proto ) + + @staticmethod + def deserialize_from_instance_proto ( instance_proto: bittensor.proto.Synapse ) -> 'Synapse': + """ Deserialzied the instance proto to an instance class. + Args: + isntance_proto (:obj:`bittensor.proto.Synapse` of shape :obj:`(1)`, `required`): + Synapse instance proto to be deserialized. + Returns: + synapse_instance_clasee (:obj:`torch.Tensor`, `required`): + Deserialized instance class. + """ + return TextLastHiddenState ( + forward_request_serializer_type = instance_proto.forward_request_serializer_type, + forward_response_serializer_type = instance_proto.forward_response_serializer_type, + backward_request_serializer_type = instance_proto.backward_request_serializer_type, + backward_response_serializer_type = instance_proto.backward_response_serializer_type, + ) + + def serialize_to_instance_proto( self ) -> 'bittensor.proto.Synapse.TextLastHiddenState': + """ Serializes the class instance to a Synapse instance proto. + """ + return bittensor.proto.Synapse.TextLastHiddenState ( + forward_request_serializer_type = self.forward_request_serializer_type, + forward_response_serializer_type = self.forward_response_serializer_type, + backward_request_serializer_type = self.backward_request_serializer_type, + backward_response_serializer_type = self.backward_response_serializer_type, + ) + + def serialize_to_wire_proto( self, code: 'bittensor.proto.ReturnCode' = 0, message: str = '' ) -> 'bittensor.proto.Synapse': + """ Serializes the class instance to a Synapse wire proto. + """ + return bittensor.proto.Synapse ( + synapse_data = self.serialize_to_instance_proto().SerializeToString(), + synapse_type = TextLastHiddenState.synapse_type, + return_code = code, + message = message + ) + + def check_forward_request_tensor ( self, forward_request_tensor ): + if len( forward_request_tensor.shape ) != 2 or forward_request_tensor.shape[0] == 0 or forward_request_tensor.shape[1] == 0: + raise ValueError( "forward_request_tensor.shape must be in [-1, -1], got: {} for synapse: {}".format( list(forward_request_tensor.shape), self ) ) + + def check_forward_response_tensor ( self, forward_request_tensor, forward_response_tensor ): + if forward_response_tensor == None: + raise ValueError('Empty Response') + + if ( + len( forward_response_tensor.shape ) != 3 or + forward_response_tensor.size(0) != forward_request_tensor.size(0) or + forward_response_tensor.size(1) != forward_request_tensor.size(1) or + forward_response_tensor.size(2) != bittensor.__network_dim__ + ): + raise ValueError( "forward_response_tensor.shape must be in [{}, {}, {}], got: {} for synapse: {}".format( forward_request_tensor.size(0) , forward_request_tensor.size(1), bittensor.__network_dim__, list(forward_response_tensor.shape), self ) ) + + def check_backward_request_gradient ( self, forward_request_tensor, backward_request_gradient ): + if ( + len( backward_request_gradient.shape ) != 3 or + backward_request_gradient.size(0) != forward_request_tensor.size(0) or + backward_request_gradient.size(1) != forward_request_tensor.size(1) or + backward_request_gradient.size(2) != bittensor.__network_dim__ + ): + raise ValueError( "backward_request_gradient.shape must be in [{}, {}, {}], got: {} for synapse: {}".format( forward_request_tensor.size(0) , forward_request_tensor.size(1), bittensor.__network_dim__, list(backward_request_gradient.shape), self ) ) + + def encode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: return forward_request_tensor + def decode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: return forward_request_tensor + def encode_forward_response_tensor ( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: return forward_response_tensor + def decode_forward_response_tensor ( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: return forward_response_tensor + def encode_backward_request_gradient ( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: return backward_request_gradient + def decode_backward_request_gradient ( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: return backward_request_gradient + + + def nill_forward_response_tensor( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + """ Returns a zeroed tensor used as response to a dendrite forward call when the call fails. + """ + try: + return torch.zeros( ( forward_request_tensor.size(0), forward_request_tensor.size(1), bittensor.__network_dim__ ), dtype=torch.float32) + except: + return torch.tensor([]) + def nill_backward_response_tensor( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + """ Returns a zeroed tensor used as response to a dendrite backward call when the call fails. + """ + try: + return torch.zeros( ( forward_request_tensor.size(0), forward_request_tensor.size(1), bittensor.__network_dim__ ), dtype=torch.float32) + except: + return torch.tensor([]) + \ No newline at end of file diff --git a/bittensor/_synapse/text_seq2seq_impl.py b/bittensor/_synapse/text_seq2seq_impl.py new file mode 100644 index 0000000000..a67bfdeb80 --- /dev/null +++ b/bittensor/_synapse/text_seq2seq_impl.py @@ -0,0 +1,228 @@ +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import bittensor +import torch +from typing import Union, List, Tuple, Optional + +from .synapse_impl import Synapse + +class TextSeq2Seq (Synapse): + """ TextSeq2Seq Synapse type for sequence generation from language models. + """ + synapse_type: bittensor.proto.Synapse.SynapseType = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ + def __init__( + self, + topk:int = 50, + num_to_generate: int = 256, + num_beams: int = 5, + no_repeat_ngram_size: int = 2, + early_stopping: bool = False, + num_return_sequences: int = 1, + do_sample: bool = False, + top_p: float = 0.95, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + length_penalty: float = 1.0, + max_time: float = 150, + num_beam_groups: int = 1, + forward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + forward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_request_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + backward_response_serializer_type: 'bittensor.proto.Serializer.Type' = bittensor.proto.Serializer.MSGPACK, + ) -> 'TextSeq2Seq': + """ TextSeq2Seq Synapse initializer. + Args: + Topk (:obj:int, :default: 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + num_to_generate (:obj: int, :default: 256): + The number of tokens to generate using the language model + num_beams (:obj: int, :default: 5): + The number of beams to keep during beam search + no_repeat_ngram_size (:obj: int, :default: 2): + The number of repeat n gram allowed + early_stopping: (:obj: bool, :default: True): + If the model should early stop if the probabilty drops a certain threshold + num_return_sequences: (:obj: int, :default: 1): + How many sequences should the model return + do_sample (:obj: bool, :default: False): + If the model should do sample its probablity during generation + top_p (:obj: float, :default: 0.95): + probability cutoff for top p sampling + temperature: (:obj: float, :default: 1.0): + The value used to module the next token probabilities for the softmax calculation + repetition_penalty (:obj: float, :default: 1.0): + The parameter for repetition penalty. 1.0 means no penalty. + length_penalty (:obj: float, :default: 1.0): + The parameter for length penalty. 0.0 means no penalty, <0 to encourage longer sequences. + num_beam_groups (:obj: int, :default: 1): + Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. + max_time (:obj: float, :default: 150): + The maximum time that a server can use to generate + forward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + forward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward response. + backward_request_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serializer used to pack torch tensors on forward request. + backward_response_serializer_type (:obj:`bittensor.proto.Serializer.Type` of shape :obj:`(1)`, `optional`, :default: `bittensor.proto.Serializer.MSGPACK`): + Serialzer used to pack torch tensors on backward response. + Returns: + TextSeq2Seq (:obj:`TextSeq2Seq`, `required`): + TextSeq2Seq instance adapter class. + """ + super().__init__ ( + forward_request_serializer_type, + forward_response_serializer_type, + backward_request_serializer_type, + backward_response_serializer_type + ) + self.topk = topk + self.num_to_generate = num_to_generate + self.num_beams = num_beams + self.no_repeat_ngram_size = no_repeat_ngram_size + self.early_stopping = early_stopping + self.num_return_sequences = num_return_sequences + self.do_sample = do_sample + self.top_p = top_p + self.temperature = temperature + self.repetition_penalty = repetition_penalty + self.length_penalty = length_penalty + self.num_beam_groups = num_beam_groups + self.max_time = max_time + self.synapse_type = TextSeq2Seq.synapse_type + + def __repr__(self) -> str: return self.__str__() + def __str__(self) -> str: return "TextSeq2Seq" + + @staticmethod + def deserialize_from_instance_proto ( instance_proto: bittensor.proto.Synapse ) -> 'Synapse': + """ Deserialzied the instance proto to an instance class.""" + return TextSeq2Seq ( + topk = instance_proto.topk, + num_to_generate = instance_proto.num_to_generate, + num_beams = instance_proto.num_beams, + no_repeat_ngram_size = instance_proto.no_repeat_ngram_size, + early_stopping = instance_proto.early_stopping, + num_return_sequences = instance_proto.num_return_sequences, + do_sample = instance_proto.do_sample, + top_p = instance_proto.top_p, + temperature = instance_proto.temperature, + repetition_penalty = instance_proto.repetition_penalty, + length_penalty = instance_proto.length_penalty, + num_beam_groups = instance_proto.num_beam_groups, + max_time = instance_proto.max_time, + forward_request_serializer_type = instance_proto.forward_request_serializer_type, + forward_response_serializer_type = instance_proto.forward_response_serializer_type, + backward_request_serializer_type = instance_proto.backward_request_serializer_type, + backward_response_serializer_type = instance_proto.backward_response_serializer_type, + ) + + @staticmethod + def deserialize_from_wire_proto ( wire_proto: bittensor.proto.Synapse ) -> 'Synapse': + """ Deserialzied the wire proto to an instance class. """ + instance_proto = bittensor.proto.Synapse.TextSeq2Seq() + instance_proto.ParseFromString( wire_proto.synapse_data ) + return TextSeq2Seq.deserialize_from_instance_proto( instance_proto ) + + def serialize_to_instance_proto( self ) -> 'bittensor.proto.Synapse.TextSeq2Seq': + """ Serializes the class instance to a Synapse instance proto.""" + return bittensor.proto.Synapse.TextSeq2Seq ( + topk = self.topk, + num_to_generate = self.num_to_generate, + num_beams = self.num_beams, + no_repeat_ngram_size = self.no_repeat_ngram_size, + early_stopping = self.early_stopping, + num_return_sequences = self.num_return_sequences, + do_sample = self.do_sample, + top_p = self.top_p, + temperature = self.temperature, + repetition_penalty = self.repetition_penalty, + length_penalty = self.length_penalty, + num_beam_groups = self.num_beam_groups, + max_time = self.max_time, + forward_request_serializer_type = self.forward_request_serializer_type, + forward_response_serializer_type = self.forward_response_serializer_type, + backward_request_serializer_type = self.backward_request_serializer_type, + backward_response_serializer_type = self.backward_response_serializer_type, + ) + + def serialize_to_wire_proto( self, code: 'bittensor.proto.ReturnCode' = 0, message: str = '' ) -> bittensor.proto.Synapse: + """ Serializes the class instance to a Synapse wire proto. """ + return bittensor.proto.Synapse ( + synapse_data = self.serialize_to_instance_proto().SerializeToString(), + synapse_type = TextSeq2Seq.synapse_type, + return_code = code, + message = message + ) + + def check_forward_request_tensor ( self, forward_request_tensor ): + if len( forward_request_tensor.shape ) != 2 or forward_request_tensor.shape[0] == 0 or forward_request_tensor.shape[1] == 0: + raise ValueError( "forward_request_tensor.shape must be in [-1, -1], got: {} for synapse: {}".format( list(forward_request_tensor.shape), self ) ) + + def check_forward_response_tensor ( self, forward_request_tensor, forward_response_tensor ): + if forward_response_tensor == None: + raise ValueError('Empty Response') + if ( + len( forward_response_tensor.shape ) != 2 or + forward_response_tensor.size(0) != forward_request_tensor.size(0) or + forward_response_tensor.size(1) > self.num_to_generate + ): + raise ValueError( "forward_response_tensor.shape must be in [{}, <{}], got: {} for synapse: {}".format( forward_request_tensor.size(0) , self.num_to_generate, list(forward_response_tensor.shape), self ) ) + + def check_backward_request_gradient ( self, forward_request_tensor, backward_request_gradient ): + if len(backward_request_gradient.shape) > 1 or ( torch.numel(backward_request_gradient) >= 1 ): # the gradient for seq2seq should always be torch.tensor([]) + raise ValueError( "backward_request_gradient.shape must be in [0], got: {} for synapse: {}".format( forward_request_tensor.size(0) , forward_request_tensor.size(1), bittensor.__network_dim__, list(backward_request_gradient.shape), self ) ) + + return + + def encode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + return forward_request_tensor + + def decode_forward_request_tensor ( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + return forward_request_tensor + + def encode_forward_response_tensor ( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: + # Apply topk logit encoding. + return forward_response_tensor + + def decode_forward_response_tensor ( self, forward_response_tensor: torch.Tensor ) -> torch.Tensor: + # Decode topk logit encoding. + return forward_response_tensor # [batch_size, sequence_len] + + def encode_backward_request_gradient ( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: + # Apply topk logit encoding for gradients. + return backward_request_gradient + + def decode_backward_request_gradient ( self, backward_request_gradient: torch.Tensor ) -> torch.Tensor: + # Decode topk logit encoding for gradients. + return backward_request_gradient + + def nill_forward_response_tensor( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + """ Returns a zeroed tensor used as response to a dendrite forward call when the call fails.""" + try: + if forward_request_tensor.size(0) == 0 : + return torch.tensor([]) + + return torch.zeros( ( forward_request_tensor.size(0), self.num_to_generate), dtype=torch.float32) + except: + return torch.tensor([]) + + def nill_backward_response_tensor( self, forward_request_tensor: torch.Tensor ) -> torch.Tensor: + """ Returns a zeroed tensor used as response to a dendrite backward call when the call fails.""" + + return torch.tensor([]) \ No newline at end of file diff --git a/bittensor/_threadpool/__init__.py b/bittensor/_threadpool/__init__.py index 27e01eff6e..63cc8d4a24 100644 --- a/bittensor/_threadpool/__init__.py +++ b/bittensor/_threadpool/__init__.py @@ -78,8 +78,8 @@ def add_defaults(cls, defaults): """ defaults.axon = bittensor.Config() defaults.axon.priority = bittensor.Config() - defaults.axon.priority.max_workers = os.getenv('BT_AXON_PRIORITY_MAX_WORKERS') if os.getenv('BT_AXON_PRIORITY_MAX_WORKERS') != None else 10 - defaults.axon.priority.maxsize = os.getenv('BT_AXON_PRIORITY_MAXSIZE') if os.getenv('BT_AXON_PRIORITY_MAXSIZE') != None else -1 + defaults.axon.priority.max_workers = os.getenv('BT_AXON_PRIORITY_MAX_WORKERS') if os.getenv('BT_AXON_PRIORITY_MAX_WORKERS') != None else 5 + defaults.axon.priority.maxsize = os.getenv('BT_AXON_PRIORITY_MAXSIZE') if os.getenv('BT_AXON_PRIORITY_MAXSIZE') != None else 10 @classmethod def config(cls) -> 'bittensor.Config': diff --git a/bittensor/_threadpool/priority_thread_pool_impl.py b/bittensor/_threadpool/priority_thread_pool_impl.py index 5559067ff0..a162d4f680 100644 --- a/bittensor/_threadpool/priority_thread_pool_impl.py +++ b/bittensor/_threadpool/priority_thread_pool_impl.py @@ -8,14 +8,14 @@ import os import sys - +import bittensor from concurrent.futures import _base import itertools import queue import random import threading import weakref - +import time from loguru import logger # Workers are created as daemon threads. This is done to allow the interpreter @@ -36,16 +36,18 @@ _shutdown = False class _WorkItem(object): - def __init__(self, future, fn, args, kwargs): + def __init__(self, future, fn, start_time, args, kwargs): self.future = future self.fn = fn + self.start_time = start_time self.args = args self.kwargs = kwargs def run(self): """ Run the given work item """ - if not self.future.set_running_or_notify_cancel(): + # Checks if future is canceled or if work item is stale + if (not self.future.set_running_or_notify_cancel()) or (time.time()-self.start_time > bittensor.__blocktime__): return try: @@ -58,7 +60,7 @@ def run(self): self.future.set_result(result) -NULL_ENTRY = (sys.maxsize, _WorkItem(None, None, (), {})) +NULL_ENTRY = (sys.maxsize, _WorkItem(None, None, time.time(), (), {})) def _worker(executor_reference, work_queue, initializer, initargs): if initializer is not None: @@ -161,11 +163,12 @@ def submit(self, fn, *args, **kwargs): if priority == 0: priority = random.randint(1, 100) eplison = random.uniform(0,0.01) * priority + start_time = time.time() if 'priority' in kwargs: del kwargs['priority'] f = _base.Future() - w = _WorkItem(f, fn, args, kwargs) + w = _WorkItem(f, fn, start_time, args, kwargs) self._work_queue.put((-float(priority + eplison), w), block=False) self._adjust_thread_count() diff --git a/bittensor/_tokenizer/__init__.py b/bittensor/_tokenizer/__init__.py index 7237651a2c..7e017164f3 100644 --- a/bittensor/_tokenizer/__init__.py +++ b/bittensor/_tokenizer/__init__.py @@ -17,8 +17,10 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -from transformers import GPT2Tokenizer -import bittensor +from transformers import AutoTokenizer +import bittensor +from bittensor.utils.tokenizer_utils import prep_tokenizer + class tokenizer: """ Implementation of the bittensor tokenizer @@ -43,66 +45,6 @@ def __new__( cls, version: str = None ): def get_tokenizer_for_version( cls, version = bittensor.__version__ ): """ Return the GPT2 tokenizer with bittersor's special tokens """ - _tokenizer = GPT2Tokenizer.from_pretrained("gpt2", local_files_only=False) - _tokenizer.padding_side = "left" - _tokenizer.add_prefix_space = False - _tokenizer.add_special_tokens({'bos_token': "[BOS]"}) # A special token representing the beginning of a sentence. - _tokenizer.add_special_tokens({'eos_token': "[EOS]"}) # A special token representing the end of a sentence. - _tokenizer.add_special_tokens({'unk_token': "[UNK]"}) # A special token representing an out-of-vocabulary token. - _tokenizer.add_special_tokens({'sep_token': "[SEP]"}) # A special token separating two different sentences in the same input (used by BERT for instance) - _tokenizer.add_special_tokens({'pad_token': "[PAD]"}) # A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. - _tokenizer.add_special_tokens({'cls_token': "[CLS]"}) # A special token representing the class of the input (used by BERT for instance). - _tokenizer.add_special_tokens({'mask_token': "[MASK]"}) # A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). - additional_special_tokens = [ - "NOTUSED", # Used by BARThez - "NOTUSED", # Used by BARThez - "", # Used by MarianMT - "", # Used by MarianMT - "", # Used by Transformer XL - "" # Used by Pegasus - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - ] - _tokenizer.additional_special_tokens = additional_special_tokens + _tokenizer = AutoTokenizer.from_pretrained('gpt2', local_files_only=False) + _tokenizer = prep_tokenizer(_tokenizer) return _tokenizer - - @staticmethod - def prep_tokenizer(tokenizer): - tokenizer.padding_side = "left" - tokenizer.add_prefix_space = False - tokenizer.add_special_tokens({'bos_token': "[BOS]"}) # A special token representing the beginning of a sentence. - tokenizer.add_special_tokens({'eos_token': "[EOS]"}) # A special token representing the end of a sentence. - tokenizer.add_special_tokens({'unk_token': "[UNK]"}) # A special token representing an out-of-vocabulary token. - tokenizer.add_special_tokens({'sep_token': "[SEP]"}) # A special token separating two different sentences in the same input (used by BERT for instance) - tokenizer.add_special_tokens({'pad_token': "[PAD]"}) # A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. - tokenizer.add_special_tokens({'cls_token': "[CLS]"}) # A special token representing the class of the input (used by BERT for instance). - tokenizer.add_special_tokens({'mask_token': "[MASK]"}) # A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). - additional_special_tokens = [ - "NOTUSED", # Used by BARThez - "NOTUSED", # Used by BARThez - "", # Used by MarianMT - "", # Used by MarianMT - "", # Used by Transformer XL - "" # Used by Pegasus - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - "", # Used by XLM - ] - tokenizer.additional_special_tokens = additional_special_tokens - return tokenizer - diff --git a/bittensor/_wallet/__init__.py b/bittensor/_wallet/__init__.py index 48b9e76d8f..df8fb4fdf7 100644 --- a/bittensor/_wallet/__init__.py +++ b/bittensor/_wallet/__init__.py @@ -58,8 +58,8 @@ def __new__( if config == None: config = wallet.config() config = copy.deepcopy( config ) - config.wallet.name = name if name != None else config.wallet.get('name', bittensor.defaults.wallet.name) - config.wallet.hotkey = hotkey if hotkey != None else config.wallet.get('hotkey', bittensor.defaults.wallet.hotkey) + config.wallet.name = name if name != None else config.wallet.name + config.wallet.hotkey = hotkey if hotkey != None else config.wallet.hotkey config.wallet.path = path if path != None else config.wallet.path config.wallet._mock = _mock if _mock != None else config.wallet._mock wallet.check_config( config ) @@ -73,12 +73,14 @@ def __new__( hotkey = config.wallet.get('hotkey', bittensor.defaults.wallet.hotkey), path = config.wallet.path, _mock = True, + config = config ) return wallet_impl.Wallet( name = config.wallet.get('name', bittensor.defaults.wallet.name), hotkey = config.wallet.get('hotkey', bittensor.defaults.wallet.hotkey), path = config.wallet.path, + config = config ) @classmethod @@ -113,6 +115,7 @@ def add_args(cls, parser: argparse.ArgumentParser, prefix: str = None ): parser.add_argument('--' + prefix_str + 'wallet.all_hotkeys', required=False, action='store_true', default=bittensor.defaults.wallet.all_hotkeys, help='''To specify all hotkeys. Specifying hotkeys will exclude them from this all.''') parser.add_argument('--' + prefix_str + 'wallet.sort_by', required=False, action='store', default=bittensor.defaults.wallet.sort_by, type=str, help='''Sort the hotkeys by the specified column title (e.g. name, uid, axon).''') parser.add_argument('--' + prefix_str + 'wallet.sort_order', required=False, action='store', default=bittensor.defaults.wallet.sort_order, type=str, help='''Sort the hotkeys in the specified ordering. (ascending/asc or descending/desc/reverse)''') + parser.add_argument('--' + prefix_str + 'wallet.reregister', required=False, action='store', default=bittensor.defaults.wallet.reregister, type=bool, help='''Whether to reregister the wallet if it is not already registered.''') except argparse.ArgumentError as e: import pdb #pdb.set_trace() @@ -133,6 +136,8 @@ def add_defaults(cls, defaults): defaults.wallet.all_hotkeys = False defaults.wallet.sort_by = "" defaults.wallet.sort_order = "ascending" + # Defaults for registration + defaults.wallet.reregister = True @classmethod def check_config(cls, config: 'bittensor.Config' ): @@ -145,3 +150,4 @@ def check_config(cls, config: 'bittensor.Config' ): assert isinstance(config.wallet.hotkeys, list) assert isinstance(config.wallet.sort_by, str) assert isinstance(config.wallet.sort_order, str) + assert isinstance(config.wallet.reregister, bool) diff --git a/bittensor/_wallet/wallet_impl.py b/bittensor/_wallet/wallet_impl.py index 091bc3dae1..158f8d211f 100644 --- a/bittensor/_wallet/wallet_impl.py +++ b/bittensor/_wallet/wallet_impl.py @@ -18,15 +18,15 @@ # DEALINGS IN THE SOFTWARE. import os -import time -import json -import requests +import sys from types import SimpleNamespace +from typing import Optional, Union -from typing import Union, Optional +import bittensor +from bittensor.utils import is_valid_bittensor_address_or_public_key from substrateinterface import Keypair from termcolor import colored -import bittensor + def display_mnemonic_msg( keypair : Keypair, key_type : str ): """ Displaying the mnemonic and warning message to keep mnemonic safe @@ -53,6 +53,7 @@ def __init__( name:str, path:str, hotkey:str, + config: 'bittensor.Config' = None, ): r""" Init bittensor wallet object containing a hot and coldkey. Args: @@ -62,6 +63,8 @@ def __init__( The name of hotkey used to running the miner. path (required=True, default='~/.bittensor/wallets/'): The path to your bittensor wallets + config (:obj:`bittensor.Config`, `optional`): + bittensor.wallet.config() """ self.name = name self.path = path @@ -69,6 +72,7 @@ def __init__( self._hotkey = None self._coldkey = None self._coldkeypub = None + self.config = config def __str__(self): return "Wallet ({}, {}, {})".format(self.name, self.hotkey_str, self.path) @@ -211,6 +215,38 @@ def get_balance( self, subtensor: 'bittensor.Subtensor' = None ) -> 'bittensor.B if subtensor == None: subtensor = bittensor.subtensor() return subtensor.get_balance(address = self.coldkeypub.ss58_address) + def reregister( + self, + subtensor: 'bittensor.Subtensor' = None, + wait_for_inclusion: bool = False, + wait_for_finalization: bool = True, + prompt: bool = False + ) -> Optional['bittensor.Wallet']: + """ Re-register this wallet on the chain. + Args: + subtensor( 'bittensor.Subtensor' ): + Bittensor subtensor connection. Overrides with defaults if None. + wait_for_inclusion (bool): + if set, waits for the extrinsic to enter a block before returning true, + or returns false if the extrinsic fails to enter the block within the timeout. + wait_for_finalization (bool): + if set, waits for the extrinsic to be finalized on the chain before returning true, + or returns false if the extrinsic fails to be finalized within the timeout. + prompt (bool): + If true, the call waits for confirmation from the user before proceeding. + + Return: + wallet (bittensor.Wallet): + This wallet. + """ + if subtensor == None: + subtensor = bittensor.subtensor() + if not self.is_registered(subtensor=subtensor): + # Check if the wallet should reregister + if not self.config.wallet.get('reregister'): + sys.exit(0) + return self.register(subtensor=subtensor, wait_for_inclusion=wait_for_inclusion, wait_for_finalization=wait_for_finalization, prompt=prompt) + def register ( self, subtensor: 'bittensor.Subtensor' = None, @@ -559,6 +595,37 @@ def regen_coldkey( self, mnemonic: Optional[Union[list, str]]=None, seed: Option """ self.regenerate_coldkey(mnemonic, seed, use_password, overwrite) + def regenerate_coldkeypub( self, ss58_address: Optional[str] = None, public_key: Optional[Union[str, bytes]] = None, overwrite: bool = False ) -> 'Wallet': + """ Regenerates the coldkeypub from passed ss58_address or public_key and saves the file + Requires either ss58_address or public_key to be passed. + Args: + ss58_address: (str, optional): + Address as ss58 string. + public_key: (str | bytes, optional): + Public key as hex string or bytes. + overwrite (bool, optional) (default: False): + Will this operation overwrite the coldkeypub (if exists) under the same path //coldkeypub + Returns: + wallet (bittensor.Wallet): + newly re-generated Wallet with coldkeypub. + + """ + if ss58_address is None and public_key is None: + raise ValueError("Either ss58_address or public_key must be passed") + + if not is_valid_bittensor_address_or_public_key( ss58_address if ss58_address is not None else public_key ): + raise ValueError(f"Invalid {'ss58_address' if ss58_address is not None else 'public_key'}") + + keypair = Keypair(ss58_address=ss58_address, public_key=public_key, ss58_format=bittensor.__ss58_format__) + + # No need to encrypt the public key + self.set_coldkeypub( keypair, overwrite = overwrite) + + return self + + # Short name for regenerate_coldkeypub + regen_coldkeypub = regenerate_coldkeypub + def regenerate_coldkey( self, mnemonic: Optional[Union[list, str]]=None, seed: Optional[str]=None, use_password: bool = True, overwrite:bool = False) -> 'Wallet': """ Regenerates the coldkey from passed mnemonic, encrypts it with the user's password and save the file Args: diff --git a/bittensor/utils/__init__.py b/bittensor/utils/__init__.py index 0372080c2d..8ebdb9ef95 100644 --- a/bittensor/utils/__init__.py +++ b/bittensor/utils/__init__.py @@ -294,7 +294,7 @@ def is_valid_ed25519_pubkey( public_key: Union[str, bytes] ) -> bool: except (ValueError, IndexError): return False -def is_valid_destination_address( address: Union[str, bytes] ) -> bool: +def is_valid_bittensor_address_or_public_key( address: Union[str, bytes] ) -> bool: """ Checks if the given address is a valid destination address. @@ -307,23 +307,13 @@ def is_valid_destination_address( address: Union[str, bytes] ) -> bool: if isinstance( address, str ): # Check if ed25519 if address.startswith('0x'): - if not is_valid_ed25519_pubkey( address ): - bittensor.__console__.print(":cross_mark: [red]Invalid Destination Public Key[/red]: {}".format( address )) - return False - # Assume ss58 address + return is_valid_ed25519_pubkey( address ) else: - if not is_valid_ss58_address( address ): - bittensor.__console__.print(":cross_mark: [red]Invalid Destination Address[/red]: {}".format( address )) - return False + # Assume ss58 address + return is_valid_ss58_address( address ) elif isinstance( address, bytes ): # Check if ed25519 - if not is_valid_ed25519_pubkey( address ): - bittensor.__console__.print(":cross_mark: [red]Invalid Destination Public Key[/red]: {}".format( address )) - return False + return is_valid_ed25519_pubkey( address ) else: - bittensor.__console__.print(":cross_mark: [red]Invalid Destination[/red]: {}".format( address )) + # Invalid address type return False - - return True - - diff --git a/bittensor/utils/codes.py b/bittensor/utils/codes.py index 9ebee68609..48c754785e 100644 --- a/bittensor/utils/codes.py +++ b/bittensor/utils/codes.py @@ -124,4 +124,17 @@ def code_to_loguru_color( code: 'bittensor.proto.ReturnCode' ) -> str: elif code == 22: return 'red' else: - return 'red' \ No newline at end of file + return 'red' + +def code_to_synapse( code: 'bittensor.proto.Synapse.SynapseType'): + """Return Code -> Synapse Type""" + if code == 1: + return 'text_last_hidden_state' + elif code == 2: + return 'text_causal_lm' + elif code == 3: + return 'text_seq_2_seq' + elif code == 4: + return 'text_causal_lm_next' + else: + return 'Null' \ No newline at end of file diff --git a/bittensor/utils/tokenizer_utils.py b/bittensor/utils/tokenizer_utils.py new file mode 100644 index 0000000000..10a82df5b0 --- /dev/null +++ b/bittensor/utils/tokenizer_utils.py @@ -0,0 +1,1322 @@ +""" Utils for tokenizer equivalence checking, logit translation, etc. +""" +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch + +from typing import List, Dict, Tuple, Any, Union +from transformers import PreTrainedTokenizerBase + +EPSILON = 1e-40 + + +def get_tokenizer_alignment_splits(offset_mapping: List[tuple], offset_mapping_std: List[tuple]) -> Dict[int, tuple]: + r""" + Calculates split depths necessary for tokens to align input offsets to standard offsets. + Only input offsets may be split, not standard offsets, to create one-to-one, one-to-many, or many-to-one + token alignments between input-to-standard tokenization. + Allows for multiple depth splits on a token. + Args: + offset_mapping (:obj:`List[tuple]`, `required`): + Tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...]. + offset_mapping_std (:obj:`List[tuple]`, `required`): + Standard tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...] + + Returns: + splits (:obj:`Dict[int, tuple]`, `required`): + For tokens that have to be split, {Token index: (split depth 1, split depth 2, ...), ...}. + """ + + splits = {} + idx = 0 # index of token segment (server tokenization) + idx_std = 0 # index of token segment (standard tokenization) + + right = offset_mapping[idx][1] # first right edge + right_std = offset_mapping_std[idx_std][1] # first std right edge + + while (idx + 1 < len(offset_mapping) and + offset_mapping[idx + 1][1] == right): # ignore overlapping tokens + idx += 1 + + while (idx_std + 1 < len(offset_mapping_std) and + offset_mapping_std[idx_std + 1][1] == right_std): # ignore overlapping tokens + idx_std += 1 + + segment_count = 1 # keep count of segments traversed, + segment_count_std = 1 # to track one-to-many, many-to-one conditions + + while idx < len(offset_mapping) and idx_std < len(offset_mapping_std): + if right < right_std: + # Examples: [|] edge, [\] next edge, [.] split + # (45, 49) + # (45, 48) (48, 51) std + # | .| \ + # | | | + if segment_count == 1 and segment_count_std > 1: # prevent many-to-many + # | . | \ + # | | | | | + left = offset_mapping[idx][0] + left_std = offset_mapping_std[idx_std][0] + splits.setdefault(idx, []) + splits[idx] += [left_std - left] # server token, split depth + segment_count_std = 1 + continue + + idx += 1 + if idx < len(offset_mapping): + right = offset_mapping[idx][1] + segment_count += 1 + + while (idx + 1 < len(offset_mapping) and + offset_mapping[idx + 1][1] == right): # ignore right-aligned overlapping tokens + idx += 1 + + elif right_std < right: + if segment_count_std == 1 and segment_count > 1: # prevent many-to-many + # Examples: [|] edge, [\] next edge, [.] split + # | | | | . | + # | | \ + + # (775, 778, 781, 788, 791) + # (775, 782, 785, 795) std + # | | |. . | | allow for multiple splits on a single token + # | | | | + left = offset_mapping[idx][0] + splits.setdefault(idx, []) + splits[idx] += [right_std - left] # server token, split depth + segment_count = 1 + segment_count_std = 0 + + idx_std += 1 + if idx_std < len(offset_mapping_std): + right_std = offset_mapping_std[idx_std][1] + segment_count_std += 1 + + while (idx_std + 1 < len(offset_mapping_std) and + offset_mapping_std[idx_std + 1][1] == right_std): # ignore right-aligned overlapping tokens + idx_std += 1 + + else: # right == right_std + idx += 1 + if idx < len(offset_mapping): + right = offset_mapping[idx][1] + segment_count = 1 + + idx_std += 1 + if idx_std < len(offset_mapping_std): + right_std = offset_mapping_std[idx_std][1] + segment_count_std = 1 + + while (idx + 1 < len(offset_mapping) and + offset_mapping[idx + 1][1] == right): # ignore right-aligned overlapping tokens + idx += 1 + + while (idx_std + 1 < len(offset_mapping_std) and + offset_mapping_std[idx_std + 1][1] == right_std): # ignore right-aligned overlapping tokens + idx_std += 1 + + continue + + for idx in splits: + splits[idx] = tuple(splits[idx]) # to enable hashable depths for split_map_cache keying + + return splits + + +def get_tokenizer_sequence_mappings(offset_mapping: List[tuple], offset_mapping_std: List[tuple]) -> List[tuple]: + r""" + Greedily determine the one-to-one, one-to-many, or many-to-one token alignments + between input-to-standard tokenizations. + Disallow many-to-many mappings, but allow for right-aligned overlapping tokens. + Args: + offset_mapping (:obj:`List[tuple]`, `required`): + Tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...]. + offset_mapping_std (:obj:`List[tuple]`, `required`): + Standard tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...] + + Returns: + mappings (:obj:`List[tuple]`, `required`): + List of mapping tuples: + [tuple( right_idx, right_idx_std, + segment_count_base, segment_count_std_base, + segment_count_overlap, segment_count_std_overlap), ...] + """ + mappings = [] + + idx = 0 # index of token segment (server tokenization) + idx_std = 0 # index of token segment (standard tokenization) + + right = offset_mapping[idx][1] # first right edge + right_std = offset_mapping_std[idx_std][1] # first std right edge + + segment_count = 1 # keep count of segments traversed, + segment_count_std = 1 # to track one-to-many, many-to-one conditions + segment_count_overlap = 0 # keep count of overlapping segments + segment_count_std_overlap = 0 + + while (idx + 1 < len(offset_mapping) and + offset_mapping[idx + 1][1] == right): # ignore overlapping tokens + idx += 1 + segment_count_overlap += 1 + + while (idx_std + 1 < len(offset_mapping_std) and + offset_mapping_std[idx_std + 1][1] == right_std): # ignore overlapping tokens + idx_std += 1 + segment_count_std_overlap += 1 + + while idx < len(offset_mapping) and idx_std < len(offset_mapping_std): + if right < right_std: + if segment_count == 1 and segment_count_std > 1: + # Examples: [|] edge, [\] next edge, [.] split + # | . | \ + # | | | | | + print('Unaligned: Expected an aligned std edge.') + print('idx, idx_std, right, right_std, segment_count, segment_count_std') + print(idx, idx_std, right, right_std, segment_count, segment_count_std) + + idx += 1 + if idx < len(offset_mapping): + right = offset_mapping[idx][1] + segment_count += 1 + + while (idx + 1 < len(offset_mapping) and + offset_mapping[idx + 1][1] == right): # ignore overlapping tokens + idx += 1 + segment_count_overlap += 1 + + elif right_std < right: + if segment_count_std == 1 and segment_count > 1: + # Examples: [|] edge, [\] next edge, [.] split + # | | | | . | + # | | \ + print('Unaligned: Expected an aligned edge.') + print('idx, idx_std, right, right_std, segment_count, segment_count_std') + print(idx, idx_std, right, right_std, segment_count, segment_count_std) + + idx_std += 1 + if idx_std < len(offset_mapping_std): + right_std = offset_mapping_std[idx_std][1] + segment_count_std += 1 + + while (idx_std + 1 < len(offset_mapping_std) and + offset_mapping_std[idx_std + 1][1] == right_std): # ignore overlapping tokens + idx_std += 1 + segment_count_std_overlap += 1 + + else: # right == right_std + mappings += [(idx, idx_std, segment_count, segment_count_std, + segment_count_overlap, segment_count_std_overlap)] + + segment_count_overlap = 0 + segment_count_std_overlap = 0 + + idx += 1 + if idx < len(offset_mapping): + right = offset_mapping[idx][1] + segment_count = 1 + + idx_std += 1 + if idx_std < len(offset_mapping_std): + right_std = offset_mapping_std[idx_std][1] + segment_count_std = 1 + + while (idx + 1 < len(offset_mapping) and + offset_mapping[idx + 1][1] == right): # ignore overlapping tokens + idx += 1 + segment_count_overlap += 1 + + while (idx_std + 1 < len(offset_mapping_std) and + offset_mapping_std[idx_std + 1][1] == right_std): # ignore overlapping tokens + idx_std += 1 + segment_count_std_overlap += 1 + continue + + mappings += [(len(offset_mapping), len(offset_mapping_std), 1, 1, 0, 0)] # validation segment + + return mappings + + +def get_tokenizer_depth_split_map(tokenizer: PreTrainedTokenizerBase, + depths: tuple) -> List[Dict[str, torch.LongTensor]]: + r""" + Split individual token strings at specified depths, retokenize each resulting segment, + keep only the first token of each segment (if there is one). + Purpose is to provide targets for scattering probabilities when a single distribution requires a depth split. + Args: + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Tokenizer. + depths (:obj:`tuple`, `required`): + Tuple of depths at which tokens strings will be split. + + Returns: + split_map (:obj:`List[Dict[str, torch.LongTensor]]`, `required`): + """ + split_map = [] + + phrases = tokenizer.batch_decode(range(tokenizer.vocab_len)) # list of variable len strings (one per token) + + # first part of the phrase up to distance characters + split_phrases = [[phrase[:depths[0]] for phrase in phrases]] + for i in range(len(depths)-1): + # middle parts of the phrase from distance characters to end + split_phrases += [[phrase[depths[i]:depths[i+1]] for phrase in phrases]] + # right part of the phrase from distance characters to end + split_phrases += [[phrase[depths[-1]:] for phrase in phrases]] + + for i, phrases in enumerate(split_phrases): # loop through left, middle, right phrase collections + side_tokens = tokenizer(phrases)['input_ids'] # tokenize phrase collection + tokens_lens = [len(p) for p in side_tokens] # get token lengths of each phrase + from_idx = [i for i, l in enumerate(tokens_lens) if l > 0] # only non-zero len tokens list + first_tokens = [side_tokens[i][0] for i in from_idx] # collect first tokens of each tokenized phrase + # add dict for phrase collection, mapping from original index to first tokens of tokenized phrase substrings + split_map += [{'from': torch.tensor(from_idx, dtype=torch.long), + 'to': torch.tensor(first_tokens, dtype=torch.long)}] + + return split_map + + +def split_probs(probs: torch.FloatTensor, split_map: List[Dict[str, torch.Tensor]]) -> torch.FloatTensor: + r""" + Split a given probability distribution over a tokenizer vocabulary, given a split_map + of mappings from original tokens to target tokens at each depth of the split. + Args: + probs (:obj:`torch.FloatTensor`, `required`): + [vocab_size] Input probability distribution over a tokenizer vocabulary. + split_map (:obj:`List[Dict[str, torch.Tensor]]`, `required`): + A split_map of mappings from original tokens to target tokens at each depth of the split. + + Returns: + new_probs (:obj:`torch.FloatTensor`, `required`): + [splits, vocab_size] A new tensor with resultant probability distribution at each index + of the first dim, representing corresponding split depth. + """ + splits = len(split_map) # how many parts to the depth split map, e.g. left, middle, right parts + vocab_size = probs.shape[0] # retain input vocabulary size + new_probs = torch.zeros((splits, vocab_size)).to(probs.device) # provision prob dist for each part + + for pos in range(splits): # loop through all parts of the split + from_idx = split_map[pos]['from'] # from original string token index + to_idx = split_map[pos]['to'] # to first token index of retokenized part string + new_probs[pos].scatter_add_(0, to_idx, probs[from_idx]) # transfer probabilities to new part distributions + + return new_probs # [splits, vocab_size] + + +def align_tokenizer_sequences(probs: torch.FloatTensor, offset_mapping: List[tuple], offset_mapping_std: List[tuple], + tokenizer: PreTrainedTokenizerBase, + split_map_cache: Dict[tuple, List[Dict[str, torch.Tensor]]], + tokens: torch.LongTensor, tokens_std: torch.LongTensor) -> Tuple[torch.FloatTensor, + List[tuple], + torch.LongTensor]: + r""" + Align an input tokenization distribution to standard tokenization segments by depth-splitting + the input distribution at greedily chosen locations. Prepares the input distribution for mapping to a standard + distribution. + Args: + probs (:obj:`torch.FloatTensor`, `required`): + [sequence_len, vocab_size] Input probability distribution over a tokenizer vocabulary. + offset_mapping (:obj:`List[tuple]`, `required`): + Tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...]. + offset_mapping_std (:obj:`List[tuple]`, `required`): + Standard tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...] + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Source tokenizer. + split_map_cache (:obj:`Dict[tuple, List[Dict[str, torch.Tensor]]]`, `required`): + A dictionary of depths keying split_maps of mappings from original tokens to + target tokens at each depth of the split. + tokens (:obj:`torch.LongTensor`, `required`): + [sequence_len] A sequence of tokens produced by the source tokenizer. + tokens_std (:obj:`torch.LongTensor`, `required`): + [std_sequence_len] A sequence of tokens produced by the standard tokenizer. + + Returns: + aligned_probs (:obj:`torch.FloatTensor`, `required`): + [new_sequence_len, vocab_size] Aligned probability distribution over a tokenizer vocabulary. + aligned_offset_mapping (:obj:`List[tuple]`, `required`): + Tokenizer aligned offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...]. + aligned_tokens (:obj:`torch.LongTensor`, `required`): + A sequence of aligned tokens produced by the source tokenizer. + """ + aligned_tokens = [] # to store new aligned tokens + aligned_probs = [] # to store new aligned probability distributions + aligned_offset_mapping = [] # to store new aligned offset mappings of aligned tokens + splits = get_tokenizer_alignment_splits(offset_mapping, offset_mapping_std) # get necessary token split locations + + prev_idx = 0 + for idx in splits: # each source token index that must be split + depths = splits[idx] # list of depths at which the token string must be split + aligned_probs += [probs[prev_idx:idx]] # retain preceding token probabilities + aligned_offset_mapping += offset_mapping[prev_idx:idx] # retain preceding offset mappings + aligned_tokens += [tokens[prev_idx:idx]] # retain preceding tokens + + if depths not in split_map_cache: + # add depths split to cache to reuse in future (split map calc is relatively time-consuming) + split_map_cache[depths] = get_tokenizer_depth_split_map(tokenizer, depths) + + new_probs = split_probs(probs[idx], split_map_cache[depths]) # [splits, vocab_size] new split probabilities + aligned_probs += [new_probs] + + text_idx = tokenizer.decode(tokens[idx]) + + # === Left part === + new_tokens = tokenizer(text_idx[:depths[0]], add_special_tokens=False, return_tensors='pt')['input_ids'][0] + aligned_tokens += [new_tokens[:1]] + aligned_offset_mapping += [(offset_mapping[idx][0], offset_mapping[idx][0] + depths[0])] + + # === Middle parts === + for d in range(len(depths)-1): + new_tokens = tokenizer(text_idx[depths[d]:depths[d+1]], + add_special_tokens=False, return_tensors='pt')['input_ids'][0] + aligned_tokens += [new_tokens[:1]] + aligned_offset_mapping += [(offset_mapping[idx][0] + depths[d], offset_mapping[idx][0] + depths[d+1])] + + # == Right part === + new_tokens = tokenizer(text_idx[depths[-1]:], add_special_tokens=False, return_tensors='pt')['input_ids'][0] + aligned_tokens += [new_tokens[:1]] + aligned_offset_mapping += [(offset_mapping[idx][0] + depths[-1], offset_mapping[idx][1])] + + prev_idx = idx + 1 + + aligned_probs += [probs[prev_idx:]] # retain remainder of tokens probabilities + aligned_tokens += [tokens[prev_idx:]] # retain remainder of tokens + aligned_offset_mapping += offset_mapping[prev_idx:] # retain remainder of offset mappings + + aligned_probs = torch.cat(aligned_probs, dim=0) # [sequence_len, vocab_size] assemble final probability tensor + aligned_tokens = torch.cat(aligned_tokens, dim=0).long() # [sequence_len] assemble final token sequence + + return aligned_probs, aligned_offset_mapping, aligned_tokens + + +def get_translation_map(from_tokenizer: PreTrainedTokenizerBase, + to_tokenizer: PreTrainedTokenizerBase) -> Dict[str, Any]: + r""" + Map individual token phrases from a tokenizer to another tokenizer. + Args: + from_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + From tokenizer. + to_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + To tokenizer. + + Returns: + translation_map (:obj:`Dict[str, Any]`, `required`): + Maps for each observed length, a source token to a token sequence of that length, + with source index to target indices. + """ + set_vocab_len(from_tokenizer) + set_vocab_len(to_tokenizer) + + translation_map = {'lengths': {}} + + phrases = from_tokenizer.batch_decode(range(from_tokenizer.vocab_len)) # tokens to strings + + to_tokens = to_tokenizer(phrases)['input_ids'] # convert single token from-phrases to to-tokenization + to_tokens_lens = [len(p) for p in to_tokens] + unique_lens = set(to_tokens_lens) + max_len = max(unique_lens) + counts = torch.zeros((max_len, to_tokenizer.vocab_len), dtype=torch.long) + + for l in unique_lens: # each unique one-to-many mapping length + from_idx = [i for i, k in enumerate(to_tokens_lens) if k == l] # find len l to-tokenizations + subset = [to_tokens[i] for i in from_idx] # find len l to-tokenizations + from_idx = torch.tensor(from_idx, dtype=torch.long) # [subset_size] + to_idx = torch.tensor(subset, dtype=torch.long) # [subset_size, l] + translation_map['lengths'][l] = {'from': from_idx, + 'to': to_idx} + # accumulate counts on tokens, to be used to divide probability mass over its channeled sequences + counts[:l, :].scatter_add_(1, to_idx.T, torch.ones((l, len(subset)), dtype=torch.long)) + + translation_map['counts'] = counts + return translation_map + + +def translate_one_to_many(probs_from: torch.FloatTensor, probs_to: torch.FloatTensor, + translation_map: Dict[str, Any]) -> None: + r""" + Translate a single token probability distribution from a source tokenization to a + sequence of probability distributions over a target tokenization. + Args: + probs_from (:obj:`torch.FloatTensor`, `required`): + [vocab_size] Input probability distribution over a from-tokenizer vocabulary. + probs_to (:obj:`torch.FloatTensor`, `required`): + [many, vocab_size] Output probability distributions over a to-tokenizer vocabulary. + translation_map (:obj:`Dict[str, Any]`, `required`): + Maps for each observed length, a source token to a token sequence of that length, + with source index to target indices. + + Returns: + + """ + many_len = probs_to.shape[0] + + # === Unroll single distribution into std sequence === + for i in range(many_len): # each unrolling step + for map_len in translation_map['lengths'].keys(): # each one-to-many mapping length available + if map_len < i + 1: + continue # skip unrolling steps not available in a shorter mapping length + from_idx = translation_map['lengths'][map_len]['from'] + to_idx = translation_map['lengths'][map_len]['to'].T # [map_len, subset_size_std] + probs_to[i, :].scatter_add_(0, to_idx[i, :], probs_from[from_idx]) # add probs in-place + + +def translate_many_to_one(probs_from: torch.FloatTensor, probs_to: torch.FloatTensor, + translation_map: Dict[str, Any]) -> None: + r""" + Translate a sequence of token probability distributions from a source tokenization to a + single token probability distribution over a target tokenization. + Args: + probs_from (:obj:`torch.FloatTensor`, `required`): + [many, vocab_size] Input probability distributions over a from-tokenizer vocabulary. + probs_to (:obj:`torch.FloatTensor`, `required`): + [vocab_size] Output probability distribution over a to-tokenizer vocabulary. + translation_map (:obj:`Dict[str, Any]`, `required`): + Maps for each observed length, a source token to a token sequence of that length, + with source index to target indices. + + Returns: + + """ + many_len = probs_from.shape[0] + probs_from_copy = probs_from.clone() # will modify from-probabilities + + # === Spread probability mass over realized sequences === + counts = translation_map['counts'] # [max_len, vocab_size] + translation_max_len = counts.shape[0] # maximum possible many-to-one length available in translation map + + if many_len <= translation_max_len: + probs_from_copy /= counts[:many_len, :] # divide probability mass by amount of paths crossing each token + else: # limit probs_from token depth to max_len + probs_from_copy[:translation_max_len, :] /= counts + + # === Reverse map std token to source sequences, gather avg. sequence prob === + for map_len in translation_map['lengths'].keys(): # mutually exclusive over std tokens + from_idx = translation_map['lengths'][map_len]['from'] # [subset_size_std] one std token + to_idx = translation_map['lengths'][map_len]['to'].T # [map_len, subset_size_std] many server token seq + if many_len < map_len: # sequence beyond segment_count has min probability 0 + to_idx = to_idx[:many_len, :] # [segment_count, subset_size_std] + server_seq_tokens = probs_from_copy.gather(1, to_idx) # [map_len, subset_size_std] gather sequences + probs_to[from_idx] = server_seq_tokens.sum(dim=0) / map_len # [subset_size_std] in-place average approx. + + +def translate_tokenizer_probs(probs: torch.FloatTensor, probs_std: torch.FloatTensor, + offset_mapping: List[tuple], offset_mapping_std: List[tuple], + tokenizer: PreTrainedTokenizerBase, std_tokenizer: PreTrainedTokenizerBase, + split_map_cache: Dict[tuple, List[Dict[str, torch.Tensor]]], + to_translation_map: Dict[str, Any], from_translation_map: Dict[str, Any], + tokens: torch.LongTensor, tokens_std: torch.LongTensor) -> None: + r""" + Translates source token probability distributions to target probability distributions, by + aligning segments through source token splits, then greedily performing one-to-one, + one-to-many, many-to-one distribution mappings. + Args: + probs (:obj:`torch.FloatTensor`, `required`): + [sequence_len, vocab_size] Input probability distribution over a source tokenizer vocabulary. + probs_std (:obj:`torch.FloatTensor`, `required`): + [std_sequence_len, std_vocab_size] Output probability distribution over a target tokenizer vocabulary. + Reference that will be written in-place. + offset_mapping (:obj:`List[tuple]`, `required`): + Tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...]. + offset_mapping_std (:obj:`List[tuple]`, `required`): + Standard tokenizer offset mappings for a specific sequence [(left_0, right_0), (left_1, right_1), ...] + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Source tokenizer. + std_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Standard/target tokenizer. + split_map_cache (:obj:`Dict[tuple, List[Dict[str, torch.Tensor]]]`, `required`): + A dictionary of depths keying split_maps of mappings from original tokens to + target tokens at each depth of the split. Adds split_maps to cache for faster future recall. + tokens (:obj:`torch.LongTensor`, `required`): + [sequence_len] A sequence of tokens produced by the source tokenizer. + tokens_std (:obj:`torch.LongTensor`, `required`): + [std_sequence_len] A sequence of tokens produced by the standard tokenizer. + to_translation_map (:obj:`Dict[str, Any]`, `required`): + Maps for each observed length, a source token to a token sequence of that length, + with source index to target indices. + from_translation_map (:obj:`Dict[str, Any]`, `required`): + Maps for each observed length, a source token to a token sequence of that length, + from target index to source indices. + + Returns: + + """ + # === Align tokenized sequences via source token splitting === + result = align_tokenizer_sequences(probs, offset_mapping, offset_mapping_std, + tokenizer, split_map_cache, tokens.cpu(), tokens_std.cpu()) + aligned_probs, aligned_offset_mapping, aligned_tokens = result + + # === Get one-to-many / many-to-one mappings === + mappings = get_tokenizer_sequence_mappings(aligned_offset_mapping, offset_mapping_std) + + # === Perform probability mappings === + for (right_idx, right_idx_std, segment_count_base, segment_count_std_base, + segment_count_overlap, segment_count_std_overlap) in mappings[1:]: # don't map start token + + segment_count = segment_count_base + segment_count_overlap # calculate effective segments length + segment_count_std = segment_count_std_base + segment_count_std_overlap # calculate effective segments length + + # === One-to-many / one-to-one mapping === + if segment_count_base == 1: + start_idx_std = right_idx_std - segment_count_std # calculate starting index + + translate_one_to_many(aligned_probs[right_idx-1], + probs_std[start_idx_std:start_idx_std+segment_count_std], + to_translation_map) + + # === Many-to-one mapping === + elif segment_count_std_base == 1: # many-to-one + start_idx = right_idx - segment_count # calculate starting index + + translate_many_to_one(aligned_probs[start_idx:right_idx], + probs_std[right_idx_std-1], + from_translation_map) + + else: + print('Undefined mapping.') + + +def get_top_probs(probs: torch.FloatTensor, tokenizer: PreTrainedTokenizerBase, amount: int = 10) -> str: + r""" + Constructs output string with top amount of highest probability token strings. + Used to display the top probabilities. + Args: + probs (:obj:`torch.FloatTensor`, `required`): + [vocab_size] Probability distribution over a tokenizer vocabulary. + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Tokenizer. + amount: (:obj:`int`, `optional`): + Amount of top tokens to return + + Returns: + string (:obj:`str`, `required`): + Highest probability token strings, prob[token-string] ... + """ + string = '' + + vals, indices = probs.sort(dim=-1, descending=True) # descending sort token probabilities + + for i in range(amount): + string += '%.4f[%s] ' % (vals[i], tokenizer.decode(indices[i])) # prob[token-string] + + return string + + +def translate_logits_to_probs_std(logits: torch.FloatTensor, + offset_mapping: List[List[tuple]], offset_mapping_std: List[List[tuple]], + tokenizer: PreTrainedTokenizerBase, std_tokenizer: PreTrainedTokenizerBase, + split_map_cache: Dict[tuple, List[Dict[str, torch.Tensor]]], + to_translation_map: Dict[str, Any], from_translation_map: Dict[str, Any], + tokens: torch.LongTensor, tokens_std: torch.LongTensor, + skip_equivalent: bool = True) -> torch.FloatTensor: + r""" + Translates source token logit scores to probability distributions over the standard tokenizer. + Args: + logits (:obj:`torch.FloatTensor`, `required`): + [batch_size, sequence_len, vocab_size] Input source logits over a source tokenizer vocabulary. + offset_mapping (:obj:`List[List[tuple]]`, `required`): + Batch of tokenizer offset mappings + [[(left_0, right_0), (left_1, right_1), ...], ...]. + offset_mapping_std (:obj:`List[List[tuple]]`, `required`): + Batch of standard tokenizer offset mappings + [[(left_0, right_0), (left_1, right_1), ...], ...]. + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Source tokenizer. + std_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Standard/target tokenizer. + split_map_cache (:obj:`Dict[tuple, List[Dict[str, torch.Tensor]]]`, `required`): + A dictionary of depths keying split_maps of mappings from original tokens to + target tokens at each depth of the split. Adds split_maps to cache for faster future recall. + tokens (:obj:`torch.LongTensor`, `required`): + [batch_size, sequence_len] A sequence of tokens produced by the source tokenizer. + tokens_std (:obj:`torch.LongTensor`, `required`): + [batch_size, std_sequence_len] A sequence of tokens produced by the standard tokenizer. + to_translation_map (:obj:`Dict[str, Any]`, `required`): + Maps for each observed length, a source token to a token sequence of that length, + with source index to target indices. + from_translation_map (:obj:`Dict[str, Any]`, `required`): + Maps for each observed length, a source token to a token sequence of that length, + from target index to source indices. + skip_equivalent (:obj:`bool`, `optional`): + Skips translation if tokenizer and std_tokenizer are equivalent. + + Returns: + probs_std (:obj:`torch.FloatTensor`, `required`): + [batch_size, std_sequence_len, std_vocab_size] Output probability distribution over the + standard tokenizer vocabulary. + """ + set_vocab_len(tokenizer) + set_vocab_len(std_tokenizer) + + # === Check tokenizer equivalence / Skip if equivalent === + if skip_equivalent and check_tokenizer_equivalence(tokenizer, std_tokenizer): + logits = logits.to(torch.float).to('cpu') + probs = torch.softmax(logits, dim=2) + return probs + + # === Get shape sizes === + batch_size, sequence_len, vocab_size = logits.shape + std_sequence_len = tokens_std.shape[-1] + std_vocab_size = std_tokenizer.vocab_len + + if tokenizer.vocab_len < vocab_size: + logits = logits[..., :tokenizer.vocab_len] + vocab_size = tokenizer.vocab_len + + # === Convert logits to probabilities === + logits = logits.to(torch.float).to('cpu') + probs = torch.softmax(logits, dim=2) # [batch_size, sequence_len, vocab_size] + + if vocab_size < tokenizer.vocab_len: # fixes bug when model logits output is not full width + padded_probs = torch.zeros((batch_size, sequence_len, tokenizer.vocab_len)) + padded_probs[..., :vocab_size] = probs + probs = padded_probs + + # === Translate to probabilities over standard tokenizer === + probs_std = torch.zeros(batch_size, std_sequence_len, std_vocab_size) + for b in range(batch_size): + probs_b = probs[b][-len(offset_mapping[b]):] # remove left padding + tokens_b = tokens[b][-len(offset_mapping[b]):] # remove left padding + translate_tokenizer_probs(probs_b, probs_std[b], offset_mapping[b], offset_mapping_std[b], + tokenizer, std_tokenizer, + split_map_cache, to_translation_map, from_translation_map, + tokens_b, tokens_std[b]) + + # === Correct excess probability mass (haircut) === + probs_std_sum = probs_std.sum(dim=-1) # [batch_size, std_sequence_len] + over = (probs_std_sum > 1) + probs_std[over] /= probs_std_sum[over][:, None] + + # === Correct deficient probability mass (raise) === + probs_std_sum = probs_std.sum(dim=-1) # [batch_size, std_sequence_len] + under = (probs_std_sum < 1) + probs_std[under] += ((1 - probs_std_sum[under]) / probs_std[under].shape[-1])[:, None] # raise noise floor so sum 1 + + return probs_std # [batch_size, std_sequence_len, std_vocab_size] + + +def topk_token_phrases(logits: torch.Tensor, tokenizer: PreTrainedTokenizerBase, + topk: int, ignore_index: int = -100) -> torch.Tensor: + r""" + Select topk tokenizer logits/phrases and include std_token_phrases counterparts (std_tokenization of token text) + in topk_tensor output of shape [batch_size, (topk + 1), max_len], where max len of all phrase lists + (with prob in front) is max_{b,k}(len([prob_k, tok_0_k, tok_1_k, ...])). + The output topk_tensor also includes a floor_prob for each batch item. The floor probability is the + mean probability of token phrases not captured in topk, required since the tokenizer vocab_size may + not be known to the receiver. + Requires prep_tokenizer(tokenizer, std_tokenizer) to set_std_token_phrases first, to make + std_token_phrases available here. + Args: + logits (:obj:`torch.Tensor`, `required`): + [batch_size, vocab_size] Input source logits for last token over a source tokenizer vocabulary. + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Source tokenizer (usually server tokenizer) + topk (:obj:`int`, `required`): + Amount of top phrases to expect (to check for mismatch) + ignore_index (:obj:`int`, `optional`): + Padding value to use for unfilled token positions in a shorter token phrase. + + Returns: + topk_tensor (:obj:`torch.Tensor`, `required`): + [batch_size, (topk + 1), max_len] tensor includes topk token probabilities (prob_k) + floor_prob + in first column with gradients attached, with std_tokens in remaining columns with ignore_index padding. + Content structure: + [[[prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., ignore_index?], + [prob_k=1_b=0, tok_0_k=1_b=0, tok_1_k=1_b=0, ..., ignore_index?], + [...], + [prob_floor_b=0, ignore_index, ..., ignore_index]], + [[prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., ignore_index?], + [prob_k=1_b=1, tok_0_k=1_b=1, tok_1_k=1_b=1, ..., ignore_index?], + [...], + [prob_floor_b=1, ignore_index, ..., ignore_index]], + [...]] + """ + # Get shape sizes + batch_size, vocab_size = logits.shape # [batch_size, vocab_size] only last token prediction + + # Convert logits to probabilities + logits = logits.float() # ensure further computations done in float32 for improved precision + probs = torch.softmax(logits, dim=1) # [batch_size, vocab_size] + + # TopK phrase selection + topk_probs, topk_indices = torch.topk(probs, topk) # topk probs and indices: [batch_size, topk] + + # === Calculate floor probability === + topk_pmass = topk_probs.sum(dim=-1) # [batch_size] topk probability mass + remainder_pmass = torch.clamp(1 - topk_pmass, 1e-40, 1) # [batch_size] remainder probability mass + floor_probs = remainder_pmass / (vocab_size - topk) # [batch_size]divide remainder + + # convert to list for faster iteration in list comprehension + topk_probs_list = topk_probs.tolist() + topk_indices_list = topk_indices.tolist() + floor_probs_list = floor_probs.tolist() + + # === Construct topk phrases list === + probs = [] # collect probability tensors with gradients attached (to be grafted into topk_tensor) + phrases = [] # form topk token phrases with prob prepend [prob, tok_0, tok_1, ... tok_n] + + for b in range(batch_size): + # collect probability tensors with gradients attached (to be grafted into topk_tensor) + probs += [topk_probs[b], floor_probs[b]] # [tensor(prob_k=0_b, prob_k=1_b, ...), tensor(prob_floor_b)] + + # form topk token phrases with prob prepend [prob, tok_0, tok_1, ... tok_n] + phrases += [[prob] + tokenizer.std_token_phrases[i] + for prob, i in zip(topk_probs_list[b], topk_indices_list[b])] # [prob_k, tok_0_k, tok_1_k, ...] + + # also add prob_floor for batch item + phrases += [[floor_probs_list[b]]] # [prob_floor_b] + + # determine width of topk_tensor as max len of all phrase lists (with prob in front) + max_len = max([len(p) for p in phrases]) # max_{b,k}(len([prob_k, tok_0_k, tok_1_k, ...])) + + # form single 2D tensor with all phrase and probs (typically to send to axon wire encoding) + topk_tensor = torch.tensor([p + [ignore_index] * (max_len - len(p)) + for p in phrases]).to(logits.device) # [batch_size * (topk + 1), max_len] + + # grafting probability tensors into first column to attach gradients + topk_tensor[:, 0] = torch.hstack(probs) # tensor([prob_k=0_b, prob_k=1_b, ..., prob_floor_b]) + + topk_tensor = topk_tensor.reshape(batch_size, topk + 1, max_len) # [batch_size, (topk + 1), max_len] reshaped + + return topk_tensor # [batch_size, (topk + 1), max_len] (probability gradients attached in first column) + + +def compact_topk_token_phrases(topk_tensor: torch.Tensor): + r""" + Compact 2D topk_tensor [batch_size, (topk + 1), max_len] by removing ignore_index padding, and also offset + tokens by 2 to preserve [0, 1] for probabilities to allow for proper unraveling demarcated by + probability boundaries. + Args: + topk_tensor (:obj:`torch.Tensor`, `required`): + [batch_size, (topk + 1), max_len] tensor includes topk token probabilities (prob_k) + floor_prob + in first column with gradients attached, with std_tokens in remaining columns with ignore_index padding. + Content structure: + [[[prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., ignore_index?], + [prob_k=1_b=0, tok_0_k=1_b=0, tok_1_k=1_b=0, ..., ignore_index?], + [...], + [prob_floor_b=0, ignore_index, ..., ignore_index]], + [[prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., ignore_index?], + [prob_k=1_b=1, tok_0_k=1_b=1, tok_1_k=1_b=1, ..., ignore_index?], + [...], + [prob_floor_b=1, ignore_index, ..., ignore_index]], + [...]] + + Returns: + compact_topk (:obj:`torch.Tensor`, `required`): + [sum_b(sum_k(len(phrase_k) + 1)_b)] Compacted 1-D tensor >= batch_size * (2 * topk + 1), + since 2 * topk + 1: topk x [probability, token sequence (at least one token)] + + floor probability (rest). + Content structure: + [prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., prob_k=1_b=0, tok_0_k=1_b=0, ..., prob_floor_b=0, + prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., prob_k=1_b=1, tok_0_k=1_b=1, ..., prob_floor_b=1, + ...] + """ + topk_tensor_offset = topk_tensor.clone() # assume topk_tensor may be reused elsewhere so clone + topk_tensor_offset[:, :, 1:] += 2 # add 2 to token ids to preserve [0, 1] for probabilities (in first column) + + flattened = topk_tensor_offset.flatten() # [batch_size * (topk + 1) * max_len] 1D tensor + compact_topk = flattened[flattened > -1] # remove ignore_index < -1 padding to compact content + + return compact_topk # [>= batch_size * (2 * topk + 1)] + + +def unravel_topk_token_phrases(compact_topk: torch.Tensor, topk: int, ignore_index: int = -100) -> torch.Tensor: + r""" + Unravel topk token phrases input_tensor from 1-D to [batch_size, (topk + 1), max_len] topk_tensor, which + includes topk token probabilities (prob_k) + floor_prob in first column with gradients attached, with + std_tokens in remaining columns with ignore_index padding. + Args: + compact_topk (:obj:`torch.Tensor`, `required`): + [sum_b(sum_k(len(phrase_k) + 1)_b)] Compacted 1-D tensor >= batch_size * (2 * topk + 1), + since 2 * topk + 1: topk x [probability, token sequence (at least one token)] + + floor probability (rest). + Content structure: + [prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., prob_k=1_b=0, tok_0_k=1_b=0, ..., prob_floor_b=0, + prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., prob_k=1_b=1, tok_0_k=1_b=1, ..., prob_floor_b=1, + ...] + topk (:obj:`int`, `required`): + Amount of top phrases to expect (to check for mismatch) + ignore_index (:obj:`int`, `optional`): + Padding value to use for unfilled token positions in a shorter token phrase. + Returns: + topk_tensor (:obj:`torch.Tensor`, `required`): + [batch_size, (topk + 1), max_len] tensor includes topk token probabilities (prob_k) + floor_prob + in first column with gradients attached, with std_tokens in remaining columns with ignore_index padding. + Content structure: + [[[prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., ignore_index?], + [prob_k=1_b=0, tok_0_k=1_b=0, tok_1_k=1_b=0, ..., ignore_index?], + [...], + [prob_floor_b=0, ignore_index, ..., ignore_index]], + [[prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., ignore_index?], + [prob_k=1_b=1, tok_0_k=1_b=1, tok_1_k=1_b=1, ..., ignore_index?], + [...], + [prob_floor_b=1, ignore_index, ..., ignore_index]], + [...]] + """ + + # Find probability markers (per batch item: topk phrase probabilities + floor_prob) + prob_idx = torch.where(compact_topk <= 1.5)[0] # 0 <= prob <= 1 [batch_size * (topk + 1)], expect token_ids >= 2 + + batch_size = len(prob_idx) // (topk + 1) # (batch_size * (topk + floor)) / (topk + floor) + assert batch_size * (topk + 1) == len(prob_idx), f'{batch_size} * ({topk} + 1) != {len(prob_idx)}' # decoding irregularity otherwise + + # split into topk token phrases with prob prepend [prob, tok_0, tok_1, ... tok_n] + phrases = [s.tolist() for s in torch.tensor_split(compact_topk, prob_idx)] # tolist for faster list comprehension + phrases = phrases[1:] # ignore first (empty) split + + # determine width of topk_tensor as max len of all phrase lists (with prob in front) + max_len = max([len(p) for p in phrases]) # max_{b,k}(len([prob_k, tok_0_k, tok_1_k, ...])) + + ignore_index_2 = ignore_index + 2 # increment with 2, as decrement with 2 follows + + # form single 2D tensor with topk token phrases with prob prepend [prob, tok_0, tok_1, ... tok_n] + topk_tensor = torch.tensor([p + [ignore_index_2] * (max_len - len(p)) + for p in phrases]).to(compact_topk.device) # [batch_size * (topk + 1), max_len] + topk_tensor -= 2 # remove token offset + + # grafting probability tensors into first column to attach gradients + topk_tensor[:, 0] = compact_topk[prob_idx] # tensor([prob_k=0_b, prob_k=1_b, ..., prob_floor_b]) + + topk_tensor = topk_tensor.reshape(batch_size, topk + 1, max_len) # [batch_size, (topk + 1), max_len] reshaped + + return topk_tensor # [batch_size, (topk + 1), max_len] + + +def phrase_cross_entropy(target_phrases: Union[List[List[int]], torch.Tensor], + topk_tensor: torch.Tensor, + ignore_index: int = -100, reduce=True, reduction='mean', + vocab_size_min: int = 50257) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Calculates the cross entropy of a phrase prediction against a target phrase, so that this is a multi-token + extension of typical cross entropy calculated for next token prediction. + Args: + target_phrases (:obj:`List[List[int]]`, `required`): + [batch_size, *] Target phrases in standard token sequence list. + topk_tensor (:obj:`torch.Tensor`, `required`): + [batch_size, (topk + 1), max_len] tensor includes topk token probabilities (prob_k) + floor_prob + in first column with gradients attached, with std_tokens in remaining columns with ignore_index padding. + Content structure: + [[[prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., ignore_index?], + [prob_k=1_b=0, tok_0_k=1_b=0, tok_1_k=1_b=0, ..., ignore_index?], + [...], + [prob_floor_b=0, ignore_index, ..., ignore_index]], + [[prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., ignore_index?], + [prob_k=1_b=1, tok_0_k=1_b=1, tok_1_k=1_b=1, ..., ignore_index?], + [...], + [prob_floor_b=1, ignore_index, ..., ignore_index]], + [...]] + ignore_index (:obj:`int`, `optional`): + Padding value to use for unfilled token positions in a shorter token phrase. + reduce (:obj:`bool`, `optional`): + Whether to reduce the cross entropy over the batch dimension. + reduction (:obj:`str`, `optional`): + Reduction function to perform when reduce is True. + vocab_size_min (:obj:`int`, `optional`): + Minimum server vocab_size expected, should set to nominal 50257, + used to prevent the floor_probs from being too large. + Returns: + loss_val (:obj:`torch.Tensor`, `required`): + Validation cross entropy loss, either scalar if reduce or [batch_size]. + loss (:obj:`torch.Tensor`, `required`): + Phrase cross entropy loss, either scalar if reduce or [batch_size]. + """ + + batch_size, topk_p1, max_len = topk_tensor.shape # [batch_size, (topk + 1), max_len] + topk = topk_p1 - 1 + + topk_tokens = topk_tensor[:, :-1, 1:] # [batch_size, topk, max_len - 1] Phrase tokens with ignore_index token for padding. + topk_probs = topk_tensor[:, :-1, 0] # [batch_size, topk] Probabilities for each phrase in topk + floor_probs = topk_tensor[:, -1, 0] # [batch_size] Floor probabilities as mean probability for non-topk tokens + + # === Ensure total probability is 1 === + total_probs = topk_probs.sum(dim=-1) + max(0, vocab_size_min - topk) * floor_probs # [batch_size] total probs + n_topk_probs = topk_probs / total_probs[:, None] # [batch_size, topk] normalized topk_probs + n_floor_probs = floor_probs / total_probs # [batch_size] normalized floor_probs + + val_probs = torch.zeros(batch_size).to(topk_probs.device) # accumulate probabilities when first tokens match + match_probs = torch.zeros(batch_size).to(topk_probs.device) # accumulate probabilities when sub target matches phrase + for b in range(batch_size): + target_phrase = target_phrases[b] + if not isinstance(target_phrase, torch.Tensor): + target_phrase = torch.tensor(target_phrases[b]) + + match = (topk_tokens[b, :, 0] == target_phrase[0].item()) # bool where first tokens match (validation token) + if match.sum() > 0: + val_probs[b] = n_topk_probs[b, match].sum() # accumulate all matches + else: # no matches + val_probs[b] = n_floor_probs[b] # assume match is in non-topk tokens with avg floor_prob + + # === Integrate sub target matches === + check_len = min(max_len - 1, len(target_phrase)) + for c in range(1, check_len + 1): # progressively increase sub target length + target = ignore_index * torch.ones(check_len, dtype=torch.int32).to(topk_tensor.device) # [-100, ..., -100] + target[:c] = target_phrase[:c] # [tok0, tok1, ...tokc, -100, ..., -100] + + # Find sub target matches + match = (topk_tokens[b, :, :check_len] == target) + match_idx = torch.where(match.sum(dim=-1) == check_len)[0] # phrase indices which match sub target + + if len(match_idx): # at least one match + match_probs[b] += n_topk_probs[b, match_idx].sum() # accumulate all matches + else: # no matches + match_probs[b] += n_floor_probs[b] # assume match is in non-topk tokens with avg floor_prob + + val_probs = torch.clamp(val_probs, 0, 1) # [batch_size] ensure 0 <= total probability <= 1 + loss_val = - torch.log(val_probs + 1e-40) # [batch_size] calculate cross entropy loss + + match_probs = torch.clamp(match_probs, 0, 1) # [batch_size] ensure 0 <= total probability <= 1 + loss = - torch.log(match_probs + 1e-40) # [batch_size] calculate cross entropy loss + + if reduce: + if not hasattr(loss_val, reduction) or not hasattr(loss, reduction): + raise RuntimeError(f'phase_cross_entropy(): Reduction function {reduction} not found.') + loss_val = getattr(loss_val, reduction)() + loss = getattr(loss, reduction)() + if loss.numel() > 1: + raise ValueError(f'phase_cross_entropy(): Expected reduction to scalar, obtained {loss.shape} instead.') + + return loss_val, loss + + +def check_tokenizer_equivalence(tokenizer_to_check: PreTrainedTokenizerBase, + target_tokenizer: PreTrainedTokenizerBase) -> bool: + r""" + Is tokenizer_to_check equivalent to target_tokenizer? + Args: + tokenizer_to_check (:obj:`PreTrainedTokenizerBase`, `required`): + Tokenizer to check for equivalence. + target_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Target tokenizer to check equivalence against. + + Returns: + result (:obj:`bool`, `required`) + """ + set_vocab_len(tokenizer_to_check) + set_vocab_len(target_tokenizer) + + if tokenizer_to_check.vocab_len != target_tokenizer.vocab_len: + return False + + to_check_vocab = tokenizer_to_check.batch_decode(range(tokenizer_to_check.vocab_len)) + target_vocab = target_tokenizer.batch_decode(range(target_tokenizer.vocab_len)) + + return to_check_vocab == target_vocab # indexed tokenizer vocabularies should match + + +def pad_offsets(offsets_batch: List[List[tuple]], source_offsets_batch: List[List[List[Any]]], + pad_offsets_batch: List[List[List[Any]]]) -> List[List[List[Any]]]: + r""" + Pads specific tuples in offsets_batch, selected by source_offsets_batch with + associated paddings in pad_offsets_batch. + Purpose is typically to add padding to align two tokenization offsets at special tokens. + Args: + offsets_batch (:obj:`List[List[tuple]]`, `required`): + Batch of full input tokenizer offset mappings to be used for alteration + [[(left_0, right_0), (left_1, right_1), ...], ...]. + source_offsets_batch (:obj:`List[List[List[Any]]]`, `required`): + Batch of tokenizer offset mappings indicating replacement tuples in offsets_batch + [[(left_0, right_0), (left_1, right_1), ...], ...]. + pad_offsets_batch (:obj:`List[List[List[Any]]]`, `required`): + Batch of offset paddings associated with each source_offsets_batch replacement tuple + [[(left_pad_0, right_pad_0), (left_pad_1, right_pad_1), ...], ...]. + + Returns: + new_offsets_batch (:obj:`List[List[List[Any]]]`, `required`): + Batch of padded full input tokenizer offset mappings + [[(left_0, right_0), (left_1, right_1), ...], ...]. + """ + new_offsets_batch = [] + batch_len = len(offsets_batch) + + for b in range(batch_len): + new_offsets = [] + pad = 0 + + idx = 0 + for left, right in offsets_batch[b]: # go through original offsets + if idx < len(source_offsets_batch[b]): + source_left, source_right = source_offsets_batch[b][idx] + if left == source_left and right == source_right: # matching offset found + pad_left, pad_right = pad_offsets_batch[b][idx] + new_offsets += [(pad_left + pad, pad_right + pad)] # replace offsets with padded + accum. pad + pad += pad_right - right + idx += 1 + continue + new_offsets += [(left + pad, right + pad)] # adjust original offsets w/ accum. pad + + new_offsets_batch += [new_offsets] + + return new_offsets_batch + + +def find_offsets(string: str, substring: str) -> List[List[int]]: + r""" + Finds all the [start, end] offsets of substring in string. + Assumes there is no overlap of substring, nor recursive overlap. + Args: + string (:obj:`str`, `required`): + Main string to find offsets in. + substring (:obj:`str`, `required`): + Substring to search for in string. + + Returns: + offsets (:obj:`List[List[int]]`, `required`): + Offsets denoting the [start, end] positions of substring in string. + """ + offsets = [] + idx = string.find(substring) # find first instance + while idx != -1: # found an instance + offsets += [[idx, idx + len(substring)]] # add offsets + idx = string.find(substring, idx + len(substring)) # find next instance + + return offsets + + +def replace_at_offsets(string: str, offsets: List[List[Any]]) -> Tuple[str, List[List[int]]]: + r""" + Replace indicated [left, right] offset positions with a new substring, by + deleting [left, right] content and adding [left, left+len(substring)] substring, + adjusting offsets incrementally. + Assumes an incremental ordered, non-overlapping list of offsets, constructing + the new string incrementally and recording new offsets. + Args: + string (:obj:`str`, `required`): + Main string to perform replacements for. + offsets (:obj:`List[List[Any]]`, `required`): + Offsets where replacements are made with replacement substring + [[left_0, right_0, substring_0], ...] + + Returns: + new_string (:obj:`str`, `required`): + New string where replacements were made. + new_offsets (:obj:`List[List[Any]]`, `required`): + New offsets where replacements are now located + [[left_0, right_0], [left_1, right_1], ...] + """ + new_string = '' + new_offsets = [] + + prev = 0 + for left, right, substring in offsets: + new_string += string[prev:left] # retain preceding string + new_left = len(new_string) # advance index + + new_string += substring # add new substring + new_right = len(new_string) + + new_offsets += [[new_left, new_right]] # add offsets + + prev = right # advance index + + new_string += string[prev:] + + return new_string, new_offsets + + +def get_special_token_pairings(from_tokenizer: PreTrainedTokenizerBase, + to_tokenizer: PreTrainedTokenizerBase) -> Dict[str, str]: + r""" + Determines a prioritized matching of special token texts between two tokenizers. + Purpose is to produce replacement pairs so special token test is correctly represented for target tokenizer. + Args: + from_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + From tokenizer. + to_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + To tokenizer. + + Returns: + pairings (:obj:`Dict[str, str]`, `required`): + Prioritized dictionary of From_special_token_text -> To_special_token_text. + """ + pairings = {} + + # some tokenizers e.g. GPT2 have the same text signifying BOS and EOS, while in other e.g. XGLM they differ + # so prioritize EOS token first, since this seems to be the default context separator, e.g. XGLM, GerPT2, GPT2 + if ('eos_token' in from_tokenizer.special_tokens_map) and ('eos_token' in to_tokenizer.special_tokens_map): + pairings[getattr(from_tokenizer, 'eos_token')] = getattr(to_tokenizer, 'eos_token') + + for special_token in from_tokenizer.special_tokens_map: + if special_token in to_tokenizer.special_tokens_map: + if getattr(from_tokenizer, special_token) not in pairings: # prevent priority overwrite + pairings[getattr(from_tokenizer, special_token)] = getattr(to_tokenizer, special_token) + + return pairings + + +def translate_special_token_text(text_batch: List[str], from_tokenizer: PreTrainedTokenizerBase, + to_tokenizer: PreTrainedTokenizerBase) -> Tuple[List[str], + List[List[List[int]]], + List[List[List[int]]], + List[List[List[Any]]]]: + r""" + Translates special_token signifier text in from_tokenizer to to_tokenizer special_token text, for + a given text_batch. Resulting to_text_batch can then be to_tokenized where special_tokens should + map to its single corresponding token, despite signifier text difference compared to from_tokenizer. + Args: + text_batch (:obj:`List[str]`, `required`): + List of strings to translate special tokens for. + from_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + From tokenizer. + to_tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + To tokenizer. + + Returns: + to_text_batch (:obj:`List[str]`, `required`): + List of strings where special text has been replaced. + from_offsets_batch (:obj:`List[List[List[int]]]`, `required`): + Batch of tokenizer offset mappings selecting replacement tuples in from_tokenizer text + [[(left_0, right_0), (left_1, right_1), ...], ...]. + to_offsets_batch (:obj:`List[List[List[int]]]`, `required`): + Batch of tokenizer offset mappings selecting replacement tuples in to_tokenizer text + [[(left_0, right_0), (left_1, right_1), ...], ...]. + pad_offsets_batch (:obj:`List[List[List[Any]]]`, `required`): + Batch of offset paddings associated with each replacement tuple + [[(left_pad_0, right_pad_0), (left_pad_1, right_pad_1), ...], ...]. + """ + to_text_batch = [] + from_offsets_batch = [] + to_offsets_batch = [] + pad_offsets_batch = [] + + # === Get special-token text replacement pairs === + pairings = get_special_token_pairings(from_tokenizer, to_tokenizer) + + for text in text_batch: + from_offsets = [] + padding_offsets = [] + for token_string in pairings: + offsets = find_offsets(text, token_string) # find special-token locations + from_offsets += [[left, right, pairings[token_string]] for left, right in offsets] + + pad_string = token_string if len(token_string) > len(pairings[token_string]) else pairings[token_string] + padding_offsets += [[left, right, pad_string] for left, right in offsets] + + from_offsets = sorted(from_offsets) # incrementally arrange locations + to_text, to_offsets = replace_at_offsets(text, from_offsets) # replace special-token text + pad_text, padding_offsets = replace_at_offsets(text, padding_offsets) # pad special-token text locations + + to_text_batch += [to_text] + from_offsets_batch += [[[left, right] for left, right, _ in from_offsets]] + to_offsets_batch += [to_offsets] + pad_offsets_batch += [padding_offsets] + + return to_text_batch, from_offsets_batch, to_offsets_batch, pad_offsets_batch + + +def set_vocab_len(tokenizer: PreTrainedTokenizerBase): + r""" + Sets the tokenizer.vocab_len if unset, to store the real vocabulary size according to the vocab or encoder. + Args: + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Tokenizer to set vocab_len for. + Returns: + + """ + if not hasattr(tokenizer, 'vocab_len'): + if hasattr(tokenizer, 'vocab'): # use independent vocab_len when tokenizer.vocab_size != len(tokenizer.vocab) + tokenizer.vocab_len = len(tokenizer.vocab) + elif hasattr(tokenizer, 'encoder'): # tokenizers like facebook/opt-* has encoder=vocab + tokenizer.vocab_len = len(tokenizer.encoder) + else: # revert to vocab_size + tokenizer.vocab_len = tokenizer.vocab_size + + +def set_whitespace_preserving(tokenizer: PreTrainedTokenizerBase): + r""" + Sets the tokenizer.whitespace_preserving if unset, indicates if tokenizer preserves whitespace like GPT-style, + or not like BERT-style. + Args: + tokenizer (:obj:`PreTrainedTokenizerBase`, `required`): + Tokenizer to set vocab_len for. + Returns: + + """ + if not hasattr(tokenizer, 'whitespace_preserving'): + space_token = tokenizer(' ', add_special_tokens=False)['input_ids'] + space_text = tokenizer.decode(space_token) + if space_text == ' ': + tokenizer.whitespace_preserving = True + else: + tokenizer.whitespace_preserving = False + + +def set_std_token_phrases(tokenizer, std_tokenizer): + r""" + Sets std_token_phrases which are the tokenizer token strings tokenized with std_tokenizer, so + the std_tokenizer equivalent of the tokenizer token strings. + Used for converting model predictions/logits into std_tokenizer representations, for example in TextCausalLMNext. + Args: + tokenizer(:obj:`PreTrainedTokenizerBase`, `required`): + Tokenizer to set std_token_phrases for. + std_tokenizer(:obj:`PreTrainedTokenizerBase`, `required`): + Standard bittensor tokenizer to convert to. + + Returns: + + """ + # === Tokenizer phrases to memory === + if not hasattr(tokenizer, 'phrases'): + if tokenizer.whitespace_preserving: + tokenizer.phrases = tokenizer.batch_decode(range(tokenizer.vocab_len)) # server tokens to strings + else: + tokenizer.phrases = [' ' + phrase for phrase in + tokenizer.batch_decode(range(tokenizer.vocab_len))] # server tokens to strings + + if not hasattr(tokenizer, 'std_token_phrases'): + # Retokenize phrases to new tokenizer + tokenizer.std_token_phrases = std_tokenizer(tokenizer.phrases)['input_ids'] # [topk, max_len] convert phrases to tokens sequences + + +def prep_tokenizer(tokenizer, std_tokenizer=None): + tokenizer.padding_side = "left" # Generative default expects most recent token on right-hand side with padding on left. https://github.com/huggingface/transformers/pull/10552 + # tokenizer.add_prefix_space = False + # tokenizer.add_special_tokens({'bos_token': "[BOS]"}) # A special token representing the beginning of a sentence. + # tokenizer.add_special_tokens({'eos_token': "[EOS]"}) # A special token representing the end of a sentence. + # tokenizer.add_special_tokens({'unk_token': "[UNK]"}) # A special token representing an out-of-vocabulary token. + # tokenizer.add_special_tokens({'sep_token': "[SEP]"}) # A special token separating two different sentences in the same input (used by BERT for instance) + # tokenizer.add_special_tokens({'pad_token': "[PAD]"}) # A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. + # tokenizer.add_special_tokens({'cls_token': "[CLS]"}) # A special token representing the class of the input (used by BERT for instance). + # tokenizer.add_special_tokens({'mask_token': "[MASK]"}) # A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). + # additional_special_tokens = [ + # "NOTUSED", # Used by BARThez + # "NOTUSED", # Used by BARThez + # "", # Used by MarianMT + # "", # Used by MarianMT + # "", # Used by Transformer XL + # "" # Used by Pegasus + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # "", # Used by XLM + # ] + # tokenizer.additional_special_tokens = additional_special_tokens + + # Define PAD Token = EOS Token (GPT2 generate convention, when PAD Token is None) + # https://github.com/huggingface/transformers/blob/49c8c67fb815a277405f84dea4a66353e19fb347/tests/models/gpt2/test_modeling_gpt2.py#L532 + if tokenizer.pad_token is None and tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + set_vocab_len(tokenizer) + set_whitespace_preserving(tokenizer) + + if std_tokenizer is not None: + set_std_token_phrases(tokenizer, std_tokenizer) + + return tokenizer diff --git a/miners/advanced_server.py b/miners/advanced_server.py deleted file mode 100644 index 2b466238f2..0000000000 --- a/miners/advanced_server.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" Run the advanced server - -Example: - $ python3 miners/advanced_server.py ... args - -""" - -import bittensor -if __name__ == "__main__": - template = bittensor.neurons.advanced_server.neuron().run() \ No newline at end of file diff --git a/miners/template_miner.py b/miners/core_server.py similarity index 89% rename from miners/template_miner.py rename to miners/core_server.py index 49c5590586..60e4b4579a 100644 --- a/miners/template_miner.py +++ b/miners/core_server.py @@ -15,13 +15,13 @@ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -""" Run the template miner +""" Run the template server Example: - $ python3 miners/template_miner.py + $ python3 miners/core_server.py ... args """ import bittensor if __name__ == "__main__": - template = bittensor.neurons.template_miner.neuron().run() \ No newline at end of file + template = bittensor.neurons.core_server.neuron().run() diff --git a/miners/multitron_server.py b/miners/multitron_server.py deleted file mode 100644 index 8e9d3615ba..0000000000 --- a/miners/multitron_server.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/python3 -# The MIT License (MIT) -# Copyright © 2021 Yuma Rao - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. -""" Run the multitron server - -Example: - $ python3 miners/multitron_server.py ... args - -Testing: - $ python3 miners/multitron_server.py --subtensor.network mock --neuron.model_name albert-base-v1 --logging.debug --wallet.name mock --dataset._mock - -""" - -import bittensor -if __name__ == "__main__": - template = bittensor.neurons.multitron_server.neuron().run() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1d52cc96c3..1e2367e28e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,9 +35,10 @@ rich retry requests>=2.25.0 scalecodec>=1.0.35 +sentencepiece termcolor -torch==1.11 -transformers>=4.5.0 +torch>=1.11 +transformers>=4.20.1 numpy wheel codecov diff --git a/sample_configs/advanced_server_sample_config.txt b/sample_configs/advanced_server_sample_config.txt deleted file mode 100644 index 40f6453894..0000000000 --- a/sample_configs/advanced_server_sample_config.txt +++ /dev/null @@ -1,55 +0,0 @@ -axon.backward_timeout: 20 -axon.forward_timeout: 10 -axon.ip: '[::]' -axon.max_workers: 10 -axon.maximum_concurrent_rpcs: 400 -axon.port: 8091 -axon.priority.max_workers: 10 -axon.priority.maxsize: -1 - -dataset.batch_size: 10 -dataset.block_size: 20 -dataset.data_dir: ~/.bittensor/data/ -dataset.dataset_name: default -dataset.num_batches: 500 -dataset.max_datasets: 3 -dataset.no_tokenizer: false -dataset.num_workers: 0 -dataset.save_dataset: false - -logging.debug: false -logging.logging_dir: ~/.bittensor/miners -logging.record_log: false -logging.trace: false - -neuron.blacklist.stake.backward: 100 -neuron.blacklist.stake.forward: 10 -neuron.blacklist.time: 2 -neuron.blocks_per_epoch: 2 -neuron.checking: true -neuron.clip_gradients: 1.0 -neuron.device: cpu -neuron.inter_degree: nearest -neuron.interpolate: true -neuron.learning_rate: 0.01 -neuron.model_name: gpt2 -neuron.momentum: 0.8 -neuron.name: advanced_server -neuron.restart: false -neuron.padding: true -neuron.pretrained: true - -subtensor.chain_endpoint: null -subtensor.network: nakamoto - -wallet.hotkey: default -wallet.name: default -wallet.path: ~/.bittensor/wallets/ - -wandb.api_key: default -wandb.directory: default -wandb.name: default -wandb.offline: false -wandb.project: default -wandb.run_group: default -wandb.tags: default diff --git a/sample_configs/template_server_sample_config.txt b/sample_configs/core_server_sample_config.txt similarity index 100% rename from sample_configs/template_server_sample_config.txt rename to sample_configs/core_server_sample_config.txt diff --git a/sample_configs/core_validator_sample_config.txt b/sample_configs/core_validator_sample_config.txt index 24ffc2bef7..c96d614389 100644 --- a/sample_configs/core_validator_sample_config.txt +++ b/sample_configs/core_validator_sample_config.txt @@ -18,16 +18,16 @@ logging.logging_dir: ~/.bittensor/miners logging.record_log: false logging.trace: false -neuron.name = core_validator -neuron.learning_rate = 0.1 -neuron.momentum = 0.8 -neuron.blocks_per_epoch = 1000 -neuron.epochs_until_reset = 10 -neuron.n_topk_peer_weights = 250 -neuron.device = cpu -neuron.clip_gradients = 1.0 -neuron.restart_on_failure = True -neuron._mock = False +neuron.name: core_validator +neuron.learning_rate: 0.1 +neuron.momentum: 0.8 +neuron.blocks_per_epoch: 1000 +neuron.epochs_until_reset: 10 +neuron.n_topk_peer_weights: 250 +neuron.device: cpu +neuron.clip_gradients: 1.0 +neuron.restart_on_failure: True +neuron._mock: False nucleus.dropout: 0.2 nucleus.importance: 0.01 diff --git a/sample_configs/template_miner_sample_config.txt b/sample_configs/template_miner_sample_config.txt deleted file mode 100644 index 935e6b4e2e..0000000000 --- a/sample_configs/template_miner_sample_config.txt +++ /dev/null @@ -1,72 +0,0 @@ -axon.backward_timeout: 20 -axon.forward_timeout: 10 -axon.ip: '[::]' -axon.max_workers: 10 -axon.maximum_concurrent_rpcs: 400 -axon.port: 8091 -axon.priority.max_workers: 10 -axon.priority.maxsize: -1 - -dataset.batch_size: 10 -dataset.block_size: 20 -dataset.data_dir: ~/.bittensor/data/ -dataset.dataset_name: default -dataset.num_batches: 500 -dataset.max_datasets: 3 -dataset.no_tokenizer: false -dataset.num_workers: 0 -dataset.save_dataset: false - -dendrite.max_active_receptors: 500 -dendrite.max_worker_threads: 150 -dendrite.requires_grad: true -dendrite.timeout: 12 - -logging.debug: false -logging.logging_dir: ~/.bittensor/miners -logging.record_log: false -logging.trace: false - -neuron.accumulate_remote_gradients: false -neuron.batch_size_train: 2 -neuron.blacklist: 0 -neuron.blacklist_allow_non_registered: true -neuron.clip_gradients: 1.0 -neuron.compute_remote_gradients: false -neuron.device: cpu -neuron.epoch_length: 100 -neuron.learning_rate: 1 -neuron.learning_rate_chain: 1 -neuron.momentum: 0.8 -neuron.n_epochs: 9223372036854775807 -neuron.n_topk_peer_weights: 100 -neuron.name: template_neuron -neuron.restart: false -neuron.restart_on_failure: true -neuron.sync_block_time: 100 -neuron.timeout: 10 -neuron.use_upnpc: false -neuron.use_wandb: false -neuron.weight_decay: 0.25 - -nucleus.dropout: 0.2 -nucleus.nhead: 2 -nucleus.nhid: 200 -nucleus.nlayers: 2 -nucleus.punishment: 0.001 -nucleus.topk: 20 - -subtensor.chain_endpoint: null -subtensor.network: nakamoto - -wallet.hotkey: default -wallet.name: default -wallet.path: ~/.bittensor/wallets/ - -wandb.api_key: default -wandb.directory: default -wandb.name: default -wandb.offline: false -wandb.project: default -wandb.run_group: default -wandb.tags: default \ No newline at end of file diff --git a/tests/integration_tests/constant.py b/tests/integration_tests/constant.py index e60e05a6d9..b03365ea4e 100644 --- a/tests/integration_tests/constant.py +++ b/tests/integration_tests/constant.py @@ -5,4 +5,10 @@ 'dataset_name': ["Books3"], 'num_batches': 10 } +) + +synapse = Munch().fromDict( + { + 'num_to_generate': 70, + } ) \ No newline at end of file diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index d05ae77632..6c147338b3 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -114,7 +114,7 @@ def test_check_configs(self): "set_weights", "inspect"] config = self.config config.no_prompt = True - config.model = "template_miner" + config.model = "core_server" config.dest = "no_prompt" config.amount = 1 config.mnemonic = "this is a mnemonic" @@ -1102,7 +1102,7 @@ def test_stake( self ): config.stake_all = False config.no_password = True - config.model = "template_miner" + config.model = "core_server" cli = bittensor.cli(config) cli.run() @@ -1116,7 +1116,7 @@ def test_new_coldkey( self ): config.amount = 1 config.dest = "no_prompt" config.subtensor._mock = True - config.model = "template_miner" + config.model = "core_server" config.n_words = 12 config.use_password = False config.no_prompt = True @@ -1135,7 +1135,7 @@ def test_new_hotkey( self ): config.subtensor.network = "mock" config.dest = "no_prompt" config.subtensor._mock = True - config.model = "template_miner" + config.model = "core_server" config.n_words = 12 config.use_password = False config.no_prompt = True @@ -1152,7 +1152,7 @@ def test_regen_coldkey( self ): config.subtensor.network = "mock" config.dest = "no_prompt" config.subtensor._mock = True - config.model = "template_miner" + config.model = "core_server" config.mnemonic = "faculty decade seven jelly gospel axis next radio grain radio remain gentle" config.seed = None config.n_words = 12 @@ -1163,6 +1163,21 @@ def test_regen_coldkey( self ): cli = bittensor.cli(config) cli.run() + def test_regen_coldkeypub( self ): + config = self.config + config.wallet.name = "regen_coldkeypub_testwallet" + config.command = "regen_coldkeypub" + config.subtensor.network = "mock" + config.subtensor._mock = True + config.ss58_address = "5DD26kC2kxajmwfbbZmVmxhrY9VeeyR1Gpzy9i8wxLUg6zxm" + config.public_key = None + config.use_password = False + config.no_prompt = True + config.overwrite_coldkeypub = True + + cli = bittensor.cli(config) + cli.run() + def test_regen_hotkey( self ): config = self.config config.wallet.name = "regen_hotkey_testwallet" @@ -1170,7 +1185,7 @@ def test_regen_hotkey( self ): config.amount = 1 config.subtensor.network = "mock" config.subtensor._mock = True - config.model = "template_miner" + config.model = "core_server" config.mnemonic = "faculty decade seven jelly gospel axis next radio grain radio remain gentle" config.n_words = 12 config.use_password = False diff --git a/tests/integration_tests/test_dendrite.py b/tests/integration_tests/test_dendrite.py index d766c2ffb1..df0920441b 100644 --- a/tests/integration_tests/test_dendrite.py +++ b/tests/integration_tests/test_dendrite.py @@ -18,7 +18,8 @@ import torch import pytest import bittensor -from multiprocessing import Pool +from bittensor._proto.bittensor_pb2 import UnknownException +from . import constant wallet = bittensor.wallet.mock() dendrite = bittensor.dendrite( wallet = wallet ) @@ -33,11 +34,24 @@ modality = 0 ) +synapses = [bittensor.synapse.TextLastHiddenState(), + bittensor.synapse.TextCausalLM(), + bittensor.synapse.TextCausalLMNext(), + bittensor.synapse.TextSeq2Seq(num_to_generate=constant.synapse.num_to_generate)] +dataset = bittensor.dataset(num_batches=20, dataset_name = ['Books3']) + +def check_resp_shape(resp, num_resp, block_size, seq_len): + assert len(resp) == num_resp + assert list(resp[0][0].shape) == [block_size, seq_len, bittensor.__network_dim__] + assert list(resp[0][1].shape) == [block_size, seq_len, bittensor.__vocab_size__] + assert list(resp[0][2].shape) == [block_size, (synapses[2].topk + 1), 1 + 1] + assert list(resp[0][3].shape) == [block_size, constant.synapse.num_to_generate] + def test_dendrite_forward_text_endpoints_tensor(): endpoints = neuron_obj.to_tensor() x = torch.tensor( [[ 1,2,3 ], [ 1,2,3 ]] ) - resp1, _, _ = dendrite.forward_text( endpoints, x ) - assert list(torch.stack(resp1, dim=0).shape) == [1, 2, 3, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = endpoints, inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 2, seq_len = 3 ) assert dendrite.stats.total_requests == 1 dendrite.to_wandb() @@ -46,8 +60,8 @@ def test_dendrite_forward_text_multiple_endpoints_tensor(): endpoints_2 = neuron_obj.to_tensor() endpoints = torch.stack( [endpoints_1, endpoints_2], dim=0) x = torch.tensor( [[ 1,2,3 ], [ 1,2,3 ]] ) - resp1, _, _ = dendrite.forward_text( endpoints, x ) - assert list(torch.stack(resp1, dim=0).shape) == [2, 2, 3, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = endpoints, inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 2, seq_len = 3 ) def test_dendrite_forward_text_multiple_endpoints_tensor_list(): endpoints_1 = neuron_obj.to_tensor() @@ -55,190 +69,120 @@ def test_dendrite_forward_text_multiple_endpoints_tensor_list(): endpoints_3 = neuron_obj.to_tensor() endpoints = [torch.stack( [endpoints_1, endpoints_2], dim=0), endpoints_3] x = torch.tensor( [[ 1,2,3 ], [ 1,2,3 ]] ) - resp1, _, _ = dendrite.forward_text( endpoints, x ) - assert list(torch.stack(resp1, dim=0).shape) == [3, 2, 3, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = endpoints, inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 3, block_size = 2, seq_len = 3 ) def test_dendrite_forward_text_singular(): x = torch.tensor( [[ 1,2,3 ], [ 1,2,3 ]] ) - resp1, _, _ = dendrite.forward_text( [neuron_obj], x ) - assert list(torch.stack(resp1, dim=0).shape) == [1, 2, 3, bittensor.__network_dim__] - resp2, _, _ = dendrite.forward_text( [neuron_obj], [x] ) - assert list(torch.stack(resp2, dim=0).shape) == [1, 2, 3, bittensor.__network_dim__] - resp3, _, _ = dendrite.forward_text( [neuron_obj, neuron_obj], x ) - assert list(torch.stack(resp3, dim=0).shape) == [2, 2, 3, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 2, seq_len = 3 ) + + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = [x], synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 2, seq_len = 3 ) + + resp, _, _ = dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 2, seq_len = 3 ) + with pytest.raises(ValueError): - dendrite.forward_text( [neuron_obj, neuron_obj], [x] ) + dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = [x], synapses = synapses ) def test_dendrite_forward_text_singular_no_batch_size(): x = torch.tensor( [ 1,2,3 ] ) - resp1, _, _ = dendrite.forward_text( [neuron_obj], x ) - assert list(torch.stack(resp1, dim=0).shape) == [1, 1, 3, bittensor.__network_dim__] - resp2, _, _ = dendrite.forward_text( [neuron_obj], [x] ) - assert list(torch.stack(resp2, dim=0).shape) == [1, 1, 3, bittensor.__network_dim__] - resp3, _, _ = dendrite.forward_text( [neuron_obj, neuron_obj], x ) - assert list(torch.stack(resp3, dim=0).shape) == [2, 1, 3, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 1, seq_len = 3 ) + + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = [x], synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 1, seq_len = 3 ) + + resp, _, _ = dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 1, seq_len = 3 ) + with pytest.raises(ValueError): - dendrite.forward_text( [neuron_obj, neuron_obj], [x] ) + dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = [x], synapses = synapses ) def test_dendrite_forward_text_tensor_list_singular(): x = [ torch.tensor( [ 1,2,3 ] ) for _ in range(2) ] with pytest.raises(ValueError): - resp1, _, _ = dendrite.forward_text( [neuron_obj], x ) - resp1, _, _ = dendrite.forward_text( [neuron_obj, neuron_obj], x ) - assert list(torch.stack(resp1, dim=0).shape) == [2, 1, 3, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = x, synapses = synapses ) + resp, _, _ = dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 1, seq_len = 3 ) def test_dendrite_forward_text_tensor_list(): x = [ torch.tensor( [[ 1,2,3 ], [ 1,2,3 ]] ) for _ in range(2) ] with pytest.raises(ValueError): - resp1, _, _ = dendrite.forward_text( [neuron_obj], x ) - resp1, _, _ = dendrite.forward_text( [neuron_obj, neuron_obj], x ) - assert list(torch.stack(resp1, dim=0).shape) == [2, 2, 3, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = x, synapses = synapses ) + resp, _, _ = dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 2, seq_len = 3 ) def test_dendrite_forward_text_singular_string(): x = "the cat" - resp1, _, _ = dendrite.forward_text( [neuron_obj], x ) - assert list(torch.stack(resp1, dim=0).shape) == [1, 1, 2, bittensor.__network_dim__] - resp2, _, _ = dendrite.forward_text( [neuron_obj], [x] ) - assert list(torch.stack(resp2, dim=0).shape) == [1, 1, 2, bittensor.__network_dim__] - resp3, _, _ = dendrite.forward_text( [neuron_obj, neuron_obj], x ) - assert list(torch.stack(resp3, dim=0).shape) == [2, 1, 2, bittensor.__network_dim__] - resp4, _, _ = dendrite.forward_text( [neuron_obj, neuron_obj], [x] ) - assert list(torch.stack(resp4, dim=0).shape) == [2, 1, 2, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 1, seq_len = 2 ) + + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = [x], synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 1, seq_len = 2 ) + + resp, _, _ = dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 1, seq_len = 2 ) + + resp, _, _ = dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = [x], synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 1, seq_len = 2 ) def test_dendrite_forward_text_list_string(): x = ["the cat", 'the dog', 'the very long sentence that needs to be padded'] - resp1, _, _ = dendrite.forward_text( [neuron_obj], x ) - assert list(torch.stack(resp1, dim=0).shape) == [1, 3, 9, bittensor.__network_dim__] - resp2, _, _ = dendrite.forward_text( [neuron_obj, neuron_obj], x ) - assert list(torch.stack(resp2, dim=0).shape) == [2, 3, 9, bittensor.__network_dim__] + resp, _, _ = dendrite.text( endpoints = [neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 1, block_size = 3, seq_len = 9 ) -def test_dendrite_forward_tensor_shape_error(): - x = torch.rand(3, 3, 3, dtype=torch.float32) - with pytest.raises(ValueError): - dendrite.forward_tensor( [neuron_obj], [x]) + resp, _, _ = dendrite.text( endpoints = [neuron_obj, neuron_obj], inputs = x, synapses = synapses ) + check_resp_shape(resp, num_resp = 2, block_size = 3, seq_len = 9 ) -def test_dendrite_forward_image_shape_error(): +def test_dendrite_forward_tensor_shape_error(): x = torch.rand(3, 3, 3, dtype=torch.float32) with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], [x]) - -def test_dendrite_forward_text_shape_error(): - x = torch.zeros((3, 3, 3), dtype=torch.int64) - with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], [x]) + dendrite.text( endpoints = [neuron_obj], inputs = [x], synapses = synapses) def test_dendrite_forward_tensor_type_error(): x = torch.zeros(3, 3, bittensor.__network_dim__, dtype=torch.int32) with pytest.raises(ValueError): - dendrite.forward_tensor( [neuron_obj], x) - -def test_dendrite_forward_image_type_error(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.int64) - with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], x) - -def test_dendrite_forward_text_type_error(): - x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.float32) - with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], x) + dendrite.text( endpoints = [neuron_obj], inputs = x, synapses = synapses) def test_dendrite_forward_tensor_endpoint_type_error(): x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) with pytest.raises(ValueError): - dendrite.forward_tensor( [dict()], [x]) - -def test_dendrite_forward_image_endpoint_type_error(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.float32) - with pytest.raises(ValueError): - dendrite.forward_image( [dict()], [x]) - -def test_dendrite_forward_text_endpoint_type_error(): - x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - with pytest.raises(ValueError): - dendrite.forward_image( [dict()], [x]) + dendrite.text( endpoints = [dict()], inputs = [x], synapses = synapses) def test_dendrite_forward_tensor_endpoint_len_error(): x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) with pytest.raises(ValueError): - dendrite.forward_tensor( [], [x]) - -def test_dendrite_forward_image_endpoint_len_error(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.float32) - with pytest.raises(ValueError): - dendrite.forward_image( [], [x]) - -def test_dendrite_forward_text_endpoint_len_error(): - x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - with pytest.raises(ValueError): - dendrite.forward_image( [], [x]) + dendrite.text( endpoints = [], inputs = [x], synapses = synapses) def test_dendrite_forward_tensor_input_len_error(): x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) with pytest.raises(ValueError): - dendrite.forward_tensor( [neuron_obj], []) - -def test_dendrite_forward_image_input_len_error(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.float32) - with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], []) - -def test_dendrite_forward_text_input_len_error(): - x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], []) - + dendrite.text( endpoints = [neuron_obj], inputs = [], synapses = synapses) def test_dendrite_forward_tensor_mismatch_len_error(): x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) with pytest.raises(ValueError): - dendrite.forward_tensor( [neuron_obj], [x,x]) - -def test_dendrite_forward_image_mismatch_len_error(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.float32) - with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], [x,x]) - -def test_dendrite_forward_text_mismatch_len_error(): - x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - with pytest.raises(ValueError): - dendrite.forward_image( [neuron_obj], [x,x]) + dendrite.text( endpoints = [neuron_obj], inputs = [x, x], synapses = synapses) def test_dendrite_forward_text_non_list(): x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - out, ops, times = dendrite.forward_text( neuron_obj, x) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable - assert list(out[0].shape) == [2, 4, bittensor.__network_dim__] - -def test_dendrite_forward_image_non_list(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.float32) - out, ops, times = dendrite.forward_image( neuron_obj, x) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable - assert list(out[0].shape) == [1, bittensor.__network_dim__] - -def test_dendrite_forward_tensor_non_list(): - x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) - out, ops, times = dendrite.forward_tensor( neuron_obj, x) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable - assert list(out[0].shape) == [3, bittensor.__network_dim__] - + out, ops, times = dendrite.text( endpoints = neuron_obj, inputs = x, synapses = synapses ) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) + check_resp_shape(out, 1,2,4) def test_dendrite_forward_text(): x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - out, ops, times = dendrite.forward_text( [neuron_obj], [x]) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable - assert list(out[0].shape) == [2, 4, bittensor.__network_dim__] - -def test_dendrite_forward_image(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.float32) - out, ops, times = dendrite.forward_image( [neuron_obj], [x]) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable - assert list(out[0].shape) == [1, 1, bittensor.__network_dim__] + out, ops, times = dendrite.text( endpoints = [neuron_obj], inputs = [x], synapses = synapses ) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) + check_resp_shape(out, 1,2,4) def test_dendrite_forward_tensor(): - x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) - out, ops, times = dendrite.forward_tensor( [neuron_obj], [x]) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable - assert list(out[0].shape) == [3, 3, bittensor.__network_dim__] + x = torch.rand(3, 3, dtype=torch.float32) + out, ops, times = dendrite.text( endpoints = [neuron_obj], inputs = [x], synapses = synapses) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) + check_resp_shape(out, 1, 3, 3) def test_dendrite_backoff(): _dendrite = bittensor.dendrite( wallet = wallet ) @@ -255,10 +199,10 @@ def test_dendrite_backoff(): print (_endpoint_obj) # Normal call. - x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) - out, ops, times = _dendrite.forward_tensor( [_endpoint_obj], [x]) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable - assert list(out[0].shape) == [3, 3, bittensor.__network_dim__] + x = torch.rand(3, 3, dtype=torch.float32) + out, ops, times = _dendrite.text( endpoints = [_endpoint_obj], inputs = [x], synapses = synapses) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) + check_resp_shape(out, 1, 3, 3) del _dendrite def test_dendrite_multiple(): @@ -290,34 +234,26 @@ def test_dendrite_multiple(): dend3 = bittensor.dendrite( wallet = wallet, multiprocess=True) dend4 = bittensor.dendrite( wallet = wallet, multiprocess=True) - out, ops, times = dend1.forward_text( endpoint_obj, x ) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable + out, ops, times = dend1.text( endpoints = endpoint_obj, inputs = x, synapses = synapses ) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) - out, ops, times = dend2.forward_text( endpoint_obj, x ) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable + out, ops, times = dend2.text( endpoints = endpoint_obj, inputs = x, synapses = synapses ) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) - out, ops, times = dend3.forward_text( endpoint_obj, x ) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable + out, ops, times = dend3.text( endpoints = endpoint_obj, inputs = x, synapses = synapses ) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) - out, ops, times = dend4.forward_text( endpoint_obj, x ) - assert ops[0].item() == bittensor.proto.ReturnCode.Unavailable + out, ops, times = dend4.text( endpoints = endpoint_obj, inputs = x, synapses = synapses ) + assert list(ops[0]) == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) assert len(receptor_pool.receptors) == 1 - assert manager_server.connected_count == 4 - dend4.__del__() - assert manager_server.connected_count == 3 - dend3.__del__() - assert manager_server.connected_count == 2 - dend2.__del__() - assert manager_server.connected_count == 1 - dend1.__del__() @@ -326,7 +262,171 @@ def test_dendrite_to_df(): def test_dend_del(): dendrite.__del__() + +def test_successful_synapse(): + wallet = bittensor.wallet() + def forward_generate( inputs_x, synapse, model_output = None): + return None, None, torch.rand(inputs_x.shape[0], synapse.num_to_generate) + + def forward_hidden_state( inputs_x, synapse, model_output = None): + return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__) + + def forward_casual_lm(inputs_x, synapse, model_output = None): + return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__) + + def forward_casual_lm_next(inputs_x, synapse, model_output=None): + return None, None, synapse.nill_forward_response_tensor(inputs_x) + + axon = bittensor.axon ( + port = 8096, + ip = '0.0.0.0', + wallet = wallet, + ) + + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + axon.start() + + endpoint = bittensor.endpoint( + version = bittensor.__version_as_int__, + uid = 0, + hotkey = wallet.hotkey.ss58_address, + ip = '0.0.0.0', + ip_type = 4, + port = 8096, + modality = 0, + coldkey = wallet.coldkeypub.ss58_address + ) + + dendrite = bittensor.dendrite() + inputs = next(dataset) + synapses = [bittensor.synapse.TextLastHiddenState(), bittensor.synapse.TextCausalLM(), + bittensor.synapse.TextCausalLMNext(), bittensor.synapse.TextSeq2Seq(num_to_generate=20)] + + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs = inputs, synapses=synapses) + axon.stop() + + print(codes) + assert list(codes[0]) == [bittensor.proto.ReturnCode.Success] * len(synapses) + +def test_failing_synapse(): + wallet = bittensor.wallet() + def faulty( inputs_x, synapse, model_output = None): + raise UnknownException + + def forward_hidden_state( inputs_x, synapse, model_output = None): + return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__) + + def forward_casual_lm(inputs_x, synapse, model_output = None): + return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__) + + def forward_casual_lm_next(inputs_x, synapse, model_output=None): + return None, None, synapse.nill_forward_response_tensor(inputs_x) + + axon = bittensor.axon ( + port = 8097, + ip = '0.0.0.0', + wallet = wallet, + ) + + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + axon.attach_synapse_callback(faulty, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + axon.start() + + endpoint = bittensor.endpoint( + version = bittensor.__version_as_int__, + uid = 0, + hotkey = wallet.hotkey.ss58_address, + ip = '0.0.0.0', + ip_type = 4, + port = 8097, + modality = 0, + coldkey = wallet.coldkeypub.ss58_address + ) + + dendrite = bittensor.dendrite() + inputs = next(dataset) + synapses = [bittensor.synapse.TextLastHiddenState(), bittensor.synapse.TextCausalLM(), + bittensor.synapse.TextCausalLMNext(), bittensor.synapse.TextSeq2Seq(num_to_generate=20)] + + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs = inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.UnknownException] + + axon.attach_synapse_callback( faulty, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs = inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.UnknownException, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.UnknownException] + + axon.attach_synapse_callback( faulty, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs = inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.UnknownException, bittensor.proto.ReturnCode.UnknownException, + bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.UnknownException] + + axon.attach_synapse_callback(faulty, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + return_tensors, codes, times = dendrite.text(endpoints=endpoint, inputs=inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) + +def test_missing_synapse(): + wallet = bittensor.wallet() + def forward_hidden_state( inputs_x, synapse, model_output = None): + return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__) + + def forward_casual_lm(inputs_x, synapse, model_output = None): + return None, None, torch.rand(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__) + + def forward_casual_lm_next(inputs_x, synapse, model_output=None): + return None, None, synapse.nill_forward_response_tensor(inputs_x) + + axon = bittensor.axon ( + port = 8098, + ip = '0.0.0.0', + wallet = wallet, + ) + + axon.start() + + endpoint = bittensor.endpoint( + version = bittensor.__version_as_int__, + uid = 0, + hotkey = wallet.hotkey.ss58_address, + ip = '0.0.0.0', + ip_type = 4, + port = 8098, + modality = 0, + coldkey = wallet.coldkeypub.ss58_address + ) + + dendrite = bittensor.dendrite() + inputs = next(dataset) + synapses = [bittensor.synapse.TextLastHiddenState(), bittensor.synapse.TextCausalLM(), + bittensor.synapse.TextCausalLMNext(), bittensor.synapse.TextSeq2Seq(num_to_generate=20)] + + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs = inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.NotImplemented] * len(synapses) + + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs = inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.NotImplemented, + bittensor.proto.ReturnCode.NotImplemented, bittensor.proto.ReturnCode.NotImplemented] + + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs = inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.NotImplemented, bittensor.proto.ReturnCode.NotImplemented] + + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + return_tensors, codes, times = dendrite.text( endpoints=endpoint, inputs=inputs, synapses=synapses) + assert list(codes[0]) == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.NotImplemented] + +def test_clear(): + dataset.close() if __name__ == "__main__": bittensor.logging(debug = True) - test_dendrite_multiple() \ No newline at end of file + test_dendrite_forward_tensor() \ No newline at end of file diff --git a/tests/integration_tests/test_server_compression.py b/tests/integration_tests/test_server_compression.py index c65f73da5f..ebc5f405e6 100644 --- a/tests/integration_tests/test_server_compression.py +++ b/tests/integration_tests/test_server_compression.py @@ -227,7 +227,7 @@ def sign(wallet): inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) -serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) +serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) diff --git a/tests/integration_tests/test_wallet.py b/tests/integration_tests/test_wallet.py index 2c82d8d6a7..505b1cb631 100644 --- a/tests/integration_tests/test_wallet.py +++ b/tests/integration_tests/test_wallet.py @@ -36,28 +36,32 @@ def init_wallet(): return the_wallet -def check_keys_exists(the_wallet = None): - - # --- test file and key exists - assert os.path.isfile(the_wallet.coldkey_file.path) - assert os.path.isfile(the_wallet.hotkey_file.path) - assert os.path.isfile(the_wallet.coldkeypub_file.path) - - assert the_wallet._hotkey != None - assert the_wallet._coldkey != None - - # --- test _load_key() - the_wallet._hotkey = None - the_wallet._coldkey = None - the_wallet._coldkeypub = None - - the_wallet.hotkey - the_wallet.coldkey - the_wallet.coldkeypub - - assert the_wallet._hotkey != None - assert the_wallet._coldkey != None - assert the_wallet._coldkeypub != None +def check_keys_exists(the_wallet = None, coldkey_exists = True, hotkey_exists = True, coldkeypub_exists = True): + if coldkey_exists: + # --- test file and key exists + assert os.path.isfile(the_wallet.coldkey_file.path) + assert the_wallet._coldkey != None + # --- test _load_key() + the_wallet._coldkey = None + the_wallet.coldkey + assert the_wallet._coldkey != None + + if hotkey_exists: + # --- test file and key exists + assert os.path.isfile(the_wallet.hotkey_file.path) + assert the_wallet._hotkey != None + # --- test _load_key() + the_wallet._hotkey = None + the_wallet.hotkey + assert the_wallet._hotkey != None + + if coldkeypub_exists: + # --- test file and key exists + assert os.path.isfile(the_wallet.coldkeypub_file.path) + # --- test _load_key() + the_wallet._coldkeypub = None + the_wallet.coldkeypub + assert the_wallet._coldkeypub != None def test_create_wallet(): the_wallet = init_wallet().create(coldkey_use_password = False, hotkey_use_password = False) @@ -108,6 +112,19 @@ def test_wallet_mnemonic_create(): the_wallet.regen_hotkey( mnemonic = "solve arrive guilt syrup dust sea used phone flock vital narrow endorse".split(), use_password=False, overwrite = True ) check_keys_exists(the_wallet) +def test_wallet_coldkeypub_create(): + the_wallet = init_wallet() + public_key_hex_str = "0x32939b6abc4d81f02dff04d2b8d1d01cc8e71c5e4c7492e4fa6a238cdca3512f" + the_wallet.regenerate_coldkeypub( public_key = public_key_hex_str, overwrite = True ) + check_keys_exists(the_wallet, coldkey_exists=False, hotkey_exists=False) # Don't check the coldkey or hotkey + assert the_wallet.coldkeypub.ss58_address == "5DD26kC2kxajmwfbbZmVmxhrY9VeeyR1Gpzy9i8wxLUg6zxm" + + the_wallet = init_wallet() + ss58_address = "5DD26kC2kxajmwfbbZmVmxhrY9VeeyR1Gpzy9i8wxLUg6zxm" + the_wallet.regenerate_coldkeypub( ss58_address = ss58_address, overwrite = True ) + check_keys_exists(the_wallet, coldkey_exists=False, hotkey_exists=False) # Don't check the coldkey or hotkey + assert the_wallet.coldkeypub.ss58_address == "5DD26kC2kxajmwfbbZmVmxhrY9VeeyR1Gpzy9i8wxLUg6zxm" + def test_wallet_add_stake(): subtensor = bittensor.subtensor(network = 'nobunaga') diff --git a/tests/unit_tests/bittensor_tests/test_axon.py b/tests/unit_tests/bittensor_tests/test_axon.py index 7777d6d32e..20040e0b3f 100644 --- a/tests/unit_tests/bittensor_tests/test_axon.py +++ b/tests/unit_tests/bittensor_tests/test_axon.py @@ -25,9 +25,18 @@ import bittensor from bittensor.utils.test_utils import get_random_unused_port +import concurrent wallet = bittensor.wallet.mock() axon = bittensor.axon(wallet = wallet) +bittensor.logging(debug = True) +""" +TODO: Tests that need to be added + - Different synapses in combination + - Different errors for different synapses + - Correct Messages when only a single synapse fails +""" + def sign(wallet): nounce = str(int(time.time() * 1000)) @@ -43,458 +52,739 @@ def test_sign(): def test_forward_wandb(): inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, tensors=[inputs_serialized] ) - response, code, call_time, message = axon._forward( request ) - axon.update_stats_for_request( request, response, call_time, code ) + response, code, synapses = axon._forward( request ) + #axon.update_stats_for_request( request, response, call_time, code ) print( axon.to_wandb() ) def test_forward_not_implemented(): - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - + synapses = [bittensor.synapse.TextLastHiddenState()] + request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, - tensors=[inputs_serialized] + tensors=[inputs_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ], + hotkey = axon.wallet.hotkey.ss58_address, ) - response, code, call_time, message = axon._forward( request ) - assert code == bittensor.proto.ReturnCode.NotImplemented + response, code, synapses = axon._forward( request ) + assert synapses[0].return_code == bittensor.proto.ReturnCode.NotImplemented + +def test_forward_last_hidden_success(): + def forward( inputs_x: torch.FloatTensor, synapse , model_output = None): + return None, dict(), torch.zeros( [inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__]) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextLastHiddenState()] + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + tensors=[inputs_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ], + hotkey = axon.wallet.hotkey.ss58_address, + ) + response, code, synapses = axon._forward( request ) + assert code == bittensor.proto.ReturnCode.Success + assert synapses[0].return_code == bittensor.proto.ReturnCode.Success -def test_forward_tensor_success(): - def forward( inputs_x: torch.FloatTensor): - return torch.zeros( [inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__]) - axon.attach_forward_callback( forward, modality=2) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) +def test_forward_causallm_success(): + def forward( inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, dict(), torch.zeros( [inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__]) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextCausalLM()] + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, - tensors=[inputs_serialized] + tensors=[inputs_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ], + hotkey = axon.wallet.hotkey.ss58_address, ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.Success -def test_forward_tensor_success_image(): - def forward( inputs_x: torch.FloatTensor): - return torch.zeros( [inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__]) - axon.attach_forward_callback( forward, modality=1) - inputs_raw = torch.rand(1,1,1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) +def test_forward_causallmnext_success(): + def forward(inputs_x: torch.FloatTensor, synapse, model_output=None): # [batch_size, (topk + 1), max_len] + return None, dict(), torch.zeros([inputs_x.shape[0], (synapse.topk + 1), 1 + 1]) + axon.attach_synapse_callback(forward, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer(serializer_type=bittensor.proto.Serializer.MSGPACK) + synapses = [bittensor.synapse.TextCausalLMNext()] + inputs_serialized = serializer.serialize(inputs_raw, from_type=bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version=bittensor.__version_as_int__, + tensors=[inputs_serialized], + synapses=[syn.serialize_to_wire_proto() for syn in synapses], + hotkey=axon.wallet.hotkey.ss58_address, + ) + response, code, synapses = axon._forward(request) + assert code == bittensor.proto.ReturnCode.Success + +def test_forward_seq_2_seq_success(): + def forward( inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, dict(), torch.zeros( [inputs_x.shape[0], synapse.num_to_generate]) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextSeq2Seq()] + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, - tensors=[inputs_serialized] + tensors=[inputs_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ], + hotkey = axon.wallet.hotkey.ss58_address, ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.Success def test_forward_empty_request(): - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, tensors=[] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.EmptyRequest def test_forward_deserialization_error(): x = dict() # Not tensors that can be deserialized. + synapses = [bittensor.synapse.TextLastHiddenState()] request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ x ] + tensors=[ x ], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.RequestDeserializationException def test_forward_batch_shape_error(): - inputs_raw = torch.rand(0, 1, 1) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) + inputs_raw = torch.rand(0, 1) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized ] + tensors=[ inputs_serialized ], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException def test_forward_seq_shape_error(): inputs_raw = torch.rand(1, 0, 1) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized ] + tensors=[ inputs_serialized ], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException -def test_forward_text_shape_error(): +def test_forward_last_hidden_shape_error(): inputs_raw = torch.rand(1, 1, 1) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized ] + tensors=[ inputs_serialized ], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException -def test_forward_image_shape_error(): +def test_forward_causallm_shape_error(): inputs_raw = torch.rand(1, 1, 1) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) + synapses = [bittensor.synapse.TextCausalLM()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized ] + tensors=[ inputs_serialized ], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException -def test_forward_tensor_shape_error(): - inputs_raw = torch.rand(1, 1, 1, 1) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) +def test_forward_causallmnext_shape_error(): + inputs_raw = torch.rand(1, 1, 1) + synapses = [bittensor.synapse.TextCausalLMNext()] + serializer = bittensor.serializer(serializer_type=bittensor.proto.Serializer.MSGPACK) + inputs_serialized = serializer.serialize(inputs_raw, modality=bittensor.proto.Modality.TEXT, + from_type=bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version=bittensor.__version_as_int__, + hotkey=axon.wallet.hotkey.ss58_address, + tensors=[inputs_serialized], + synapses=[syn.serialize_to_wire_proto() for syn in synapses] + ) + response, code, synapses = axon._forward(request) + assert code == bittensor.proto.ReturnCode.RequestShapeException + +def test_forward_seq_2_seq_shape_error(): + inputs_raw = torch.rand(1, 1, 1) + synapses = [bittensor.synapse.TextSeq2Seq()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized ] + tensors=[ inputs_serialized ], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException + def test_forward_deserialization_empty(): - def forward( inputs_x: torch.Tensor): - return None - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + def forward( inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, dict(), None + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[inputs_serialized] + tensors=[inputs_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.EmptyResponse def test_forward_response_deserialization_error(): - def forward( inputs_x: torch.Tensor): - return dict() - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + def forward( inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, dict(), dict() + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - + request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[inputs_serialized] + tensors=[inputs_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) - assert code == bittensor.proto.ReturnCode.ResponseDeserializationException -def test_forward_tensor_exception(): - def forward( inputs_x: torch.FloatTensor): + def check(a, b, c): + pass + + with mock.patch('bittensor.TextLastHiddenState.check_forward_response_tensor', new=check): + response, code, synapses = axon._forward( request ) + assert code == bittensor.proto.ReturnCode.ResponseSerializationException + +def test_forward_last_hidden_state_exception(): + def forward( inputs_x: torch.FloatTensor , synapse , model_output = None): if inputs_x.size() == (1,1,1): return None else: raise Exception('Mock') - axon.attach_forward_callback( forward, modality=2) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, tensors=[inputs_serialized], - hotkey= '123' + hotkey= '123', + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.UnknownException -def test_forward_tensor_timeout(): - def forward( inputs_x: torch.FloatTensor): +def test_forward_causal_lm_state_exception(): + def forward( inputs_x: torch.FloatTensor , synapse, model_output = None): if inputs_x.size() == (1,1,1): return None else: - raise TimeoutError('Timeout') - axon.attach_forward_callback( forward, modality=2) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + raise Exception('Mock') + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextCausalLM()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, tensors=[inputs_serialized], - hotkey= '123' + hotkey= '123', + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) + response, code, synapses = axon._forward( request ) + assert code == bittensor.proto.ReturnCode.UnknownException - response, code, call_time, message = axon._forward( request ) +def test_forward_causal_lm_next_state_exception(): + def forward(inputs_x: torch.FloatTensor, synapse, model_output=None): + if inputs_x.size() == (1, 1, 1): + return None + else: + raise Exception('Mock') + axon.attach_synapse_callback(forward, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextCausalLMNext()] + serializer = bittensor.serializer(serializer_type=bittensor.proto.Serializer.MSGPACK) + inputs_serialized = serializer.serialize(inputs_raw, modality=bittensor.proto.Modality.TENSOR, + from_type=bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version=bittensor.__version_as_int__, + tensors=[inputs_serialized], + hotkey='123', + synapses=[syn.serialize_to_wire_proto() for syn in synapses] + ) + response, code, synapses = axon._forward(request) + assert code == bittensor.proto.ReturnCode.UnknownException + +def test_forward_seq_2_seq_state_exception(): + def forward( inputs_x: torch.FloatTensor , synapse, model_output = None): + if inputs_x.size() == (1,1,1): + return None + else: + raise Exception('Mock') + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextSeq2Seq()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + tensors=[inputs_serialized], + hotkey= '123', + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] + ) + response, code, synapses = axon._forward( request ) + assert code == bittensor.proto.ReturnCode.UnknownException + +def test_forward_seq_2_seq_success(): + def forward( inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, dict(), torch.zeros( [inputs_x.shape[0], synapse.num_to_generate]) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextSeq2Seq()] + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + tensors=[inputs_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ], + hotkey = axon.wallet.hotkey.ss58_address, + ) + response, code, synapses = axon._forward( request ) + assert code == bittensor.proto.ReturnCode.Success + +def test_forward_joint_success(): + def forward_generate( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros( (inputs_x.shape[0], synapse.num_to_generate) ) + def forward_causal_lm( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__) + def forward_causal_lm_next(inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, None, torch.zeros(inputs_x.shape[0], synapse.topk + 1, 1 + 1) + def forward_hidden_state( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros( inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__) + + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + axon.attach_synapse_callback( forward_causal_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + axon.attach_synapse_callback(forward_causal_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextCausalLM(), bittensor.synapse.TextCausalLMNext(), + bittensor.synapse.TextLastHiddenState(), bittensor.synapse.TextSeq2Seq()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + tensors=[inputs_serialized] * len(synapses), + hotkey= axon.wallet.hotkey.ss58_address, + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] + ) + response, code, synapses = axon._forward( request ) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.Success] * len(synapses) + +def test_forward_joint_missing_synapse(): + def forward_generate( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros( (inputs_x.shape[0], synapse.num_to_generate) ) + def forward_causal_lm( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__) + def forward_causal_lm_next(inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, None, torch.zeros(inputs_x.shape[0], synapse.topk + 1, 1 + 1) + def forward_hidden_state( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros( inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__) + + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + axon.attach_synapse_callback( forward_causal_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + axon.attach_synapse_callback(forward_causal_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextCausalLM(), bittensor.synapse.TextCausalLMNext(), + bittensor.synapse.TextLastHiddenState(), bittensor.synapse.TextSeq2Seq()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + tensors=[inputs_serialized] * len(synapses), + hotkey= axon.wallet.hotkey.ss58_address, + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] + ) + axon.attach_synapse_callback( None, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + response, code, synapses = axon._forward( request ) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.NotImplemented] + + axon.attach_synapse_callback( None, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + response, code, synapses = axon._forward( request ) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.NotImplemented, bittensor.proto.ReturnCode.NotImplemented] + + axon.attach_synapse_callback(None, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + response, code, synapses = axon._forward(request) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.NotImplemented, + bittensor.proto.ReturnCode.NotImplemented, bittensor.proto.ReturnCode.NotImplemented] + + axon.attach_synapse_callback( None, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + response, code, synapses = axon._forward( request ) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.NotImplemented] * len(synapses) + +def test_forward_joint_faulty_synapse(): + def faulty( inputs_x: torch.FloatTensor , synapse, model_output = None): + raise Exception + def forward_causal_lm( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros(inputs_x.shape[0], inputs_x.shape[1], bittensor.__vocab_size__) + def forward_causal_lm_next(inputs_x: torch.FloatTensor, synapse, model_output = None): + return None, None, torch.zeros(inputs_x.shape[0], synapse.topk + 1, 1 + 1) + def forward_hidden_state( inputs_x: torch.FloatTensor , synapse, model_output = None): + return None, None, torch.zeros( inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__) + + axon.attach_synapse_callback( forward_causal_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + axon.attach_synapse_callback(forward_causal_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + axon.attach_synapse_callback( faulty, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextCausalLM(), bittensor.synapse.TextCausalLMNext(), + bittensor.synapse.TextLastHiddenState(), bittensor.synapse.TextSeq2Seq()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + tensors=[inputs_serialized] * len(synapses), + hotkey= axon.wallet.hotkey.ss58_address, + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] + ) + + axon.attach_synapse_callback( faulty, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ) + response, code, synapses = axon._forward( request ) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.UnknownException] + + axon.attach_synapse_callback( faulty, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + response, code, synapses = axon._forward( request ) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.UnknownException, bittensor.proto.ReturnCode.UnknownException] + + axon.attach_synapse_callback(faulty, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + response, code, synapses = axon._forward(request) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.UnknownException, + bittensor.proto.ReturnCode.UnknownException, bittensor.proto.ReturnCode.UnknownException] + + axon.attach_synapse_callback( faulty, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + response, code, synapses = axon._forward( request ) + assert [syn.return_code for syn in synapses] == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) + +def test_forward_timeout(): + def forward( inputs_x: torch.FloatTensor, synapses, hotkey): + if inputs_x[0].size() == (3,3): + return None + else: + raise concurrent.futures.TimeoutError('Timeout') + + axon.attach_forward_callback( forward) + + inputs_raw = torch.rand(1,1) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + request = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + tensors=[inputs_serialized], + hotkey = axon.wallet.hotkey.ss58_address, + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] + ) + + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.Timeout def test_forward_unknown_error(): - def forward( inputs_x: torch.FloatTensor,modality): + def forward( inputs_x: torch.FloatTensor,modality, model_output = None): raise Exception('Unknown') - with mock.patch.object(axon, '_call_forward', new=forward): - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + + with mock.patch.object(axon, 'forward_callback', new=forward): + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextLastHiddenState()] inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, tensors=[inputs_serialized], - hotkey= '123' + hotkey= '123', + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.UnknownException #--- backwards --- def test_backward_invalid_request(): - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, tensors=[inputs_serialized] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.InvalidRequest -def test_backward_response_not_implemented(): - inputs_raw = torch.rand(1, 1, 1) - grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - request = bittensor.proto.TensorMessage( - version=bittensor.__version_as_int__, - hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] - ) - response, code, call_time, message = axon._backward( request ) - assert code == bittensor.proto.ReturnCode.NotImplemented - def test_backward_deserialization_error(): x = dict() # Not tensors that can be deserialized. g = dict() + synapses = [bittensor.synapse.TextLastHiddenState()] request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ x, g] + tensors=[ x, g], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.RequestDeserializationException + assert synapses[0].return_code == bittensor.proto.ReturnCode.RequestDeserializationException -def test_backward_text_shape_error(): +def test_backward_last_hidden_shape_error(): inputs_raw = torch.rand(1, 1, 1) grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException + assert synapses[0].return_code == bittensor.proto.ReturnCode.RequestShapeException -def test_backward_image_shape_error(): +def test_backward_causal_lm_shape_error(): inputs_raw = torch.rand(1, 1, 1) - grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) + grads_raw = torch.rand(1, 1, bittensor.__vocab_size__) + synapses = [bittensor.synapse.TextCausalLM()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException + assert synapses[0].return_code == bittensor.proto.ReturnCode.RequestShapeException -def test_backward_tensor_shape_error(): - inputs_raw = torch.rand(1, 1, 1, 1) - grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - request = bittensor.proto.TensorMessage( - version=bittensor.__version_as_int__, - hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] - ) - response, code, call_time, message = axon._backward( request ) - assert code == bittensor.proto.ReturnCode.RequestShapeException -def test_backward_grads_shape_error(): +def test_backward_causal_lm_next_shape_error(): + synapses = [bittensor.synapse.TextCausalLMNext()] inputs_raw = torch.rand(1, 1, 1) - grads_raw = torch.rand(1, 1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + grads_raw = torch.rand(1, synapses[0].topk + 1, 1 + 1) + serializer = bittensor.serializer(serializer_type=bittensor.proto.Serializer.MSGPACK) + inputs_serialized = serializer.serialize(inputs_raw, from_type=bittensor.proto.TensorType.TORCH) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, - hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + hotkey=axon.wallet.hotkey.ss58_address, + tensors=[inputs_serialized, grads_serialized], + synapses=[syn.serialize_to_wire_proto() for syn in synapses] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward(request) assert code == bittensor.proto.ReturnCode.RequestShapeException + assert synapses[0].return_code == bittensor.proto.ReturnCode.RequestShapeException -def test_backward_grad_inputs_shape_error(): + +def test_backward_seq_2_seq_shape_error(): inputs_raw = torch.rand(1, 1, 1) - grads_raw = torch.rand(2, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + grads_raw = torch.tensor([]) + synapses = [bittensor.synapse.TextSeq2Seq()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = serializer.serialize(inputs_raw, from_type = bittensor.proto.TensorType.TORCH) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.RequestShapeException + assert synapses[0].return_code == bittensor.proto.ReturnCode.RequestShapeException -def test_backward_response_serialization_error(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor ): - return dict() - axon.attach_backward_callback( backward, modality=bittensor.proto.Modality.TENSOR) - inputs_raw = torch.rand(1, 1, 1) - grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - request = bittensor.proto.TensorMessage( - version=bittensor.__version_as_int__, - hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] - ) - response, code, call_time, message = axon._backward( request ) - assert code == bittensor.proto.ReturnCode.ResponseSerializationException -def test_backward_response_empty_error(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor ): - return None - axon.attach_backward_callback( backward,modality=bittensor.proto.Modality.TENSOR) - inputs_raw = torch.rand(1, 1, 1) - grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) +def test_backward_grads_shape_error(): + inputs_raw = torch.rand(1, 1) + grads_raw = torch.rand(1, 1, 1, bittensor.__network_dim__) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = serializer.serialize(grads_raw, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) - assert code == bittensor.proto.ReturnCode.EmptyResponse + response, code, synapses = axon._backward( request ) + assert code == bittensor.proto.ReturnCode.RequestShapeException -def test_backward_response_success_text(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - return torch.zeros( [1, 1]) - axon.attach_backward_callback( backward,modality = bittensor.proto.Modality.TEXT ) - inputs_raw = torch.ones((1, 1)) - grads_raw = torch.zeros((1, 1, bittensor.__network_dim__)) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) + +def test_backward_response_success_hidden(): + def forward( inputs_x:torch.FloatTensor, synapse, model_output = None): + return None, dict(), torch.zeros( [1, 1, bittensor.__network_dim__], requires_grad=True) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + inputs_raw = torch.ones(1, 1) + grads_raw = torch.zeros(1, 1, bittensor.__network_dim__) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.Success -def test_backward_response_success_image(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - return torch.zeros( [1, 1]) - axon.attach_backward_callback( backward,modality = bittensor.proto.Modality.IMAGE ) - inputs_raw = torch.ones((1, 1,1,1,1)) - grads_raw = torch.zeros((1, 1, bittensor.__network_dim__)) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) +def test_backward_response_success_causal_lm(): + def forward( inputs_x:torch.FloatTensor, synapse, model_output = None): + return None, dict(), torch.zeros( [1, 1, bittensor.__vocab_size__], requires_grad=True) + + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM) + inputs_raw = torch.ones(1, 1) + grads_raw = torch.zeros(1, 1, bittensor.__vocab_size__) + synapses = [bittensor.synapse.TextCausalLM()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw,grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.Success -def test_backward_response_success(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - return torch.zeros( [1, 1, 1]) - axon.attach_backward_callback( backward,modality = bittensor.proto.Modality.TENSOR ) - inputs_raw = torch.rand(1, 1, 1) - grads_raw = torch.rand(1, 1, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) +def test_backward_response_success_causal_lm_next(): + def forward(inputs_x: torch.FloatTensor, synapse, model_output=None): # [batch_size, (topk + 1), max_len] + return None, dict(), torch.zeros([1, (synapses[0].topk + 1), 1 + 1], requires_grad=True) + + axon.attach_synapse_callback(forward, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) + synapses = [bittensor.synapse.TextCausalLMNext()] + + inputs_raw = torch.ones(1, 1) + grads_raw = torch.zeros([1, (synapses[0].topk + 1), 1 + 1]) + + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, - hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + hotkey=axon.wallet.hotkey.ss58_address, + tensors=[inputs_serialized, grads_serialized], + synapses=[syn.serialize_to_wire_proto() for syn in synapses] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward(request) assert code == bittensor.proto.ReturnCode.Success def test_backward_response_timeout(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - if inputs_x.size() == (1,1,1): - return None - else: - raise TimeoutError('Timeout') - axon.attach_backward_callback( backward,modality = bittensor.proto.Modality.TENSOR ) - inputs_raw = torch.rand(2, 2, 2) + def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, synapses): + raise concurrent.futures.TimeoutError('Timeout') + + axon.attach_backward_callback( backward) + inputs_raw = torch.rand(2, 2) grads_raw = torch.rand(2, 2, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.Timeout def test_backward_response_exception(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - if inputs_x.size() == (1,1,1): - return None - else: - raise Exception('Timeout') - axon.attach_backward_callback( backward,modality = bittensor.proto.Modality.TENSOR ) - inputs_raw = torch.rand(2, 2, 2) + def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, synapses): + raise Exception('Timeout') + + axon.attach_backward_callback( backward) + inputs_raw = torch.rand(2, 2) + synapses = [bittensor.synapse.TextLastHiddenState()] grads_raw = torch.rand(2, 2, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.UnknownException # -- axon priority: @@ -506,17 +796,20 @@ def priority(pubkey:str, request_type:str, inputs_x): axon = bittensor.axon(wallet = wallet, priority= priority) - def forward( inputs_x: torch.FloatTensor): - return torch.zeros( [inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__]) - axon.attach_forward_callback( forward, modality=2) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + def forward( inputs_x: torch.FloatTensor, synapses , model_output = None): + return None, dict(), torch.zeros( [inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__]) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + inputs_raw = torch.rand(3, 3) + synapses = [bittensor.synapse.TextLastHiddenState()] + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, - tensors=[inputs_serialized] + tensors=[inputs_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ], + hotkey = axon.wallet.hotkey.ss58_address, ) - response, code, call_time, message = axon._forward( request ) + response, code, synapses = axon._forward( request ) assert code == bittensor.proto.ReturnCode.Success def test_backward_response_success_text_priority(): @@ -526,32 +819,36 @@ def priority(pubkey:str, request_type:str, inputs_x): axon = bittensor.axon(wallet = wallet, priority= priority) - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - return torch.zeros( [1, 1]) - axon.attach_backward_callback( backward,modality = bittensor.proto.Modality.TEXT ) + def forward( inputs_x: torch.FloatTensor, synapses, model_output = None): + return None, dict(), torch.zeros( [inputs_x.shape[0], inputs_x.shape[1], bittensor.__network_dim__]) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) + inputs_raw = torch.ones((1, 1)) grads_raw = torch.zeros((1, 1, bittensor.__network_dim__)) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) + synapses = [bittensor.synapse.TextLastHiddenState()] + + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version=bittensor.__version_as_int__, hotkey = axon.wallet.hotkey.ss58_address, - tensors=[ inputs_serialized, grads_serialized] + tensors=[ inputs_serialized, grads_serialized], + synapses= [ syn.serialize_to_wire_proto() for syn in synapses ] ) - response, code, call_time, message = axon._backward( request ) + response, code, synapses = axon._backward( request ) assert code == bittensor.proto.ReturnCode.Success def test_grpc_forward_works(): - def forward( inputs_x:torch.FloatTensor): - return torch.zeros( [1, 1, 1]) + def forward( inputs_x:torch.FloatTensor, synapse , model_output = None): + return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__]) axon = bittensor.axon ( port = 7084, ip = '127.0.0.1', wallet = wallet, ) - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR ) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) axon.start() channel = grpc.insecure_channel( @@ -560,93 +857,17 @@ def forward( inputs_x:torch.FloatTensor): ('grpc.max_receive_message_length', -1)]) stub = bittensor.grpc.BittensorStub( channel ) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - request = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = '1092310312914', - tensors = [inputs_serialized] - ) - response = stub.Forward(request, - metadata = ( - ('rpc-auth-header','Bittensor'), - ('bittensor-signature',sign(axon.wallet)), - ('bittensor-version',str(bittensor.__version_as_int__)), - )) - - outputs = serializer.deserialize(response.tensors[0], to_type=bittensor.proto.TensorType.TORCH) - assert outputs.tolist() == [[[0]]] - axon.stop() - assert axon.stats.total_requests == 1 - axon.to_wandb() - - -def test_grpc_forward_works_gzip(): - def forward( inputs_x:torch.FloatTensor): - return torch.zeros( [1, 1, 1]) - axon = bittensor.axon ( - port = 7082, - ip = '127.0.0.1', - wallet = wallet, - compression= 'gzip' - ) - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR ) - axon.start() - - channel = grpc.insecure_channel( - '127.0.0.1:7082', - options=[('grpc.max_send_message_length', -1), - ('grpc.max_receive_message_length', -1)]) - stub = bittensor.grpc.BittensorStub( channel ) - - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - request = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = '1092310312914', - tensors = [inputs_serialized] - ) - response = stub.Forward(request, - metadata = ( - ('rpc-auth-header','Bittensor'), - ('bittensor-signature',sign(axon.wallet)), - ('bittensor-version',str(bittensor.__version_as_int__)), - )) + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextLastHiddenState()] - outputs = serializer.deserialize(response.tensors[0], to_type=bittensor.proto.TensorType.TORCH) - assert outputs.tolist() == [[[0]]] - axon.stop() - assert axon.stats.total_requests == 1 - axon.to_wandb() - - -def test_grpc_forward_works_deflate(): - def forward( inputs_x:torch.FloatTensor): - return torch.zeros( [1, 1, 1]) - axon = bittensor.axon ( - port = 7083, - ip = '127.0.0.1', - wallet = wallet, - compression= 'deflate' - ) - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR ) - axon.start() + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) - channel = grpc.insecure_channel( - '127.0.0.1:7083', - options=[('grpc.max_send_message_length', -1), - ('grpc.max_receive_message_length', -1)]) - stub = bittensor.grpc.BittensorStub( channel ) - - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, - hotkey = '1092310312914', - tensors = [inputs_serialized] + hotkey = axon.wallet.hotkey.ss58_address, + tensors = [inputs_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ] ) response = stub.Forward(request, metadata = ( @@ -655,23 +876,22 @@ def forward( inputs_x:torch.FloatTensor): ('bittensor-version',str(bittensor.__version_as_int__)), )) - outputs = serializer.deserialize(response.tensors[0], to_type=bittensor.proto.TensorType.TORCH) - assert outputs.tolist() == [[[0]]] + outputs = synapses[0].deserialize_forward_response_proto (inputs_raw, response.tensors[0]) + assert outputs.size(2) == bittensor.__network_dim__ + assert response.return_code == bittensor.proto.ReturnCode.Success axon.stop() - assert axon.stats.total_requests == 1 - axon.to_wandb() def test_grpc_backward_works(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - return torch.zeros( [1, 1, 1]) + def forward( inputs_x:torch.FloatTensor, synapse , model_output = None): + return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__], requires_grad=True) axon = bittensor.axon ( port = 7086, ip = '127.0.0.1', wallet = wallet, ) - axon.attach_backward_callback( backward , modality = bittensor.proto.Modality.TENSOR) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) axon.start() channel = grpc.insecure_channel( @@ -679,16 +899,17 @@ def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): options=[('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)]) stub = bittensor.grpc.BittensorStub( channel ) - - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) + synapses = [bittensor.synapse.TextLastHiddenState()] + inputs_raw = torch.rand(3, 3) grads_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = '1092310312914', - tensors = [inputs_serialized, grads_serialized] + tensors = [inputs_serialized, grads_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ] ) response = stub.Backward(request, metadata = ( @@ -696,34 +917,37 @@ def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): ('bittensor-signature',sign(axon.wallet)), ('bittensor-version',str(bittensor.__version_as_int__)), )) - outputs = serializer.deserialize(response.tensors[0], to_type=bittensor.proto.TensorType.TORCH) - assert outputs.tolist() == [[[0]]] + assert response.return_code == bittensor.proto.ReturnCode.Success axon.stop() def test_grpc_forward_fails(): - def forward( inputs_x:torch.FloatTensor): - return torch.zeros( [1, 1, 1]) + def forward( inputs_x:torch.FloatTensor, synapse, model_output = None): + return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__]) axon = bittensor.axon ( - port = 7081, + port = 7084, ip = '127.0.0.1', wallet = wallet, ) - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR ) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) axon.start() channel = grpc.insecure_channel( - '127.0.0.1:7081', + '127.0.0.1:7084', options=[('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)]) stub = bittensor.grpc.BittensorStub( channel ) - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + inputs_raw = torch.rand(3, 3) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + synapses = [bittensor.synapse.TextLastHiddenState()] + + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = '1092310312914', - tensors = [inputs_serialized] + tensors = [inputs_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ] ) try: response = stub.Forward(request) @@ -734,32 +958,33 @@ def forward( inputs_x:torch.FloatTensor): axon.stop() def test_grpc_backward_fails(): - def backward( inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor): - return torch.zeros( [1, 1, 1]) + def forward( inputs_x:torch.FloatTensor, synapse): + return torch.zeros( [3, 3, bittensor.__network_dim__], requires_grad=True) axon = bittensor.axon ( - port = 7085, + port = 7086, ip = '127.0.0.1', - wallet = wallet + wallet = wallet, ) - axon.attach_backward_callback( backward , modality = bittensor.proto.Modality.TENSOR) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) axon.start() channel = grpc.insecure_channel( - '127.0.0.1:7085', + '127.0.0.1:7086', options=[('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)]) stub = bittensor.grpc.BittensorStub( channel ) - - inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) + synapses = [bittensor.synapse.TextLastHiddenState()] + inputs_raw = torch.rand(3, 3) grads_raw = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - inputs_serialized = serializer.serialize(inputs_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - grads_serialized = serializer.serialize(grads_raw, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + inputs_serialized = synapses[0].serialize_forward_request_tensor(inputs_raw) + grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = '1092310312914', - tensors = [inputs_serialized, grads_serialized] + tensors = [inputs_serialized, grads_serialized], + synapses = [ syn.serialize_to_wire_proto() for syn in synapses ] ) try: @@ -818,7 +1043,6 @@ def test_axon_is_destroyed(): if __name__ == "__main__": - #test_backward_response_serialization_error() - #test_axon_is_destroyed() - test_forward_wandb() - test_grpc_forward_works() \ No newline at end of file + # test_forward_joint_success() + test_forward_joint_missing_synapse() + # test_forward_joint_faulty_synapse() \ No newline at end of file diff --git a/tests/unit_tests/bittensor_tests/test_forward_backward.py b/tests/unit_tests/bittensor_tests/test_forward_backward.py index 9605f232dd..48296f2cc0 100644 --- a/tests/unit_tests/bittensor_tests/test_forward_backward.py +++ b/tests/unit_tests/bittensor_tests/test_forward_backward.py @@ -44,131 +44,166 @@ coldkey = wallet.coldkey.ss58_address ) -def test_dendrite_forward_tensor_shape_error(): +def test_dendrite_forward_causal_lm_shape_error(): x = torch.rand(3, 3, 3) + synapses = [bittensor.synapse.TextCausalLM()] with pytest.raises(ValueError): - dendrite_mock.forward_tensor( endpoints=[endpoint], inputs=[x]) + dendrite_mock.text( endpoints=[endpoint], inputs=[x], synapses=synapses) -def test_dendrite_forward_image_shape_error(): +def test_dendrite_forward_causal_lm_next_shape_error(): x = torch.rand(3, 3, 3) + synapses = [bittensor.synapse.TextCausalLMNext()] with pytest.raises(ValueError): - dendrite_mock.forward_image( endpoints=[endpoint], inputs=[x]) + dendrite_mock.text(endpoints=[endpoint], inputs=[x], synapses=synapses) -def test_dendrite_forward_text_shape_error(): +def test_dendrite_forward_last_hidden_shape_error(): x = torch.rand(3, 3, 3) + synapses = [bittensor.synapse.TextLastHiddenState()] with pytest.raises(ValueError): - dendrite_mock.forward_image( endpoints=[endpoint], inputs=[x]) + dendrite_mock.text( endpoints=[endpoint], inputs=[x], synapses=synapses) -def test_dendrite_forward_text(): +def test_dendrite_forward_seq_2_seq_shape_error(): + x = torch.rand(3, 3, 3) + synapses = [bittensor.synapse.TextSeq2Seq()] + with pytest.raises(ValueError): + dendrite_mock.text( endpoints=[endpoint], inputs=[x], synapses=synapses) + +def test_dendrite_forward_text_causal_lm(): x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [torch.zeros([2, 4, bittensor.__network_dim__])], [1], [0]]) - tensors, codes, times = dendrite_mock.forward_text( endpoints=[endpoint], inputs=[x]) + synapses = [bittensor.synapse.TextCausalLM()] + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[torch.zeros([2, 4, bittensor.__network_dim__])]], [[1]], [[0]]]) + tensors, codes, times = dendrite_mock.text( endpoints=[endpoint], inputs=[x], synapses=synapses) assert codes[0].item() == bittensor.proto.ReturnCode.Success - assert list(tensors[0].shape) == [2, 4, bittensor.__network_dim__] + assert list(tensors[0][0].shape) == [2, 4, bittensor.__network_dim__] -def test_dendrite_forward_image(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ], dtype=torch.float32) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [torch.zeros([1, 1, bittensor.__network_dim__])] , [1], [0]]) - tensors, codes, times = dendrite_mock.forward_image( endpoints=[endpoint], inputs=[x]) +def test_dendrite_forward_text_causal_lm_next(): + x = torch.LongTensor([[1, 2, 3, 4], [5, 6, 7, 8]]) # [2, 4] + synapses = [bittensor.synapse.TextCausalLMNext()] + dendrite_mock.receptor_pool.forward = MagicMock(return_value=[[[torch.zeros([2, (synapses[0].topk + 1), 1 + 1])]], [[1]], [[0]]]) + tensors, codes, times = dendrite_mock.text(endpoints=[endpoint], inputs=[x], synapses=synapses) assert codes[0].item() == bittensor.proto.ReturnCode.Success - assert list(tensors[0].shape) == [1, 1, bittensor.__network_dim__] + assert list(tensors[0][0].shape) == [2, (synapses[0].topk + 1), 1 + 1] # [batch_size, (topk + 1), max_len] -def test_dendrite_forward_tensor(): - x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [torch.zeros([3, 3, bittensor.__network_dim__])], [1], [0]]) - tensors, codes, times = dendrite_mock.forward_tensor( endpoints=[endpoint], inputs=[x]) +def test_dendrite_forward_text_last_hidden(): + x = torch.tensor([[1],[8]]) + synapses = [bittensor.synapse.TextLastHiddenState()] + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[torch.zeros([1, 1, bittensor.__network_dim__])]], [[1]], [[0]]]) + tensors, codes, times = dendrite_mock.text( endpoints=[endpoint], inputs=[x], synapses=synapses) assert codes[0].item() == bittensor.proto.ReturnCode.Success - assert list(tensors[0].shape) == [3, 3, bittensor.__network_dim__] + assert list(tensors[0][0].shape) == [1, 1, bittensor.__network_dim__] -def test_dendrite_forward_tensor_pass_through_text(): - x = torch.ones((3, 3), dtype=torch.int64) - y = torch.zeros([3, 3, bittensor.__network_dim__]) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [y, y, y] , [1, 1, 1], [0,0,0]]) - tensors, codes, times = dendrite_mock.forward_text( endpoints=[endpoint, endpoint, endpoint], inputs=[x, x, x]) +def test_dendrite_forward_text_seq_2_seq(): + x = torch.rand(3, 3) + synapses = [bittensor.synapse.TextSeq2Seq()] + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[torch.zeros([3, 3, bittensor.__network_dim__])]], [[1]], [[0]]]) + tensors, codes, times = dendrite_mock.text( endpoints=[endpoint], inputs=[x], synapses=synapses) assert codes[0].item() == bittensor.proto.ReturnCode.Success - assert codes[1].item() == bittensor.proto.ReturnCode.Success - assert codes[2].item() == bittensor.proto.ReturnCode.Success - assert tensors[0].shape == y.shape - assert tensors[1].shape == y.shape - assert tensors[2].shape == y.shape - -def test_dendrite_forward_tensor_pass_through_image(): - x = torch.rand(3, 3, 3, 3, 3) + assert list(tensors[0][0].shape) == [3, 3, bittensor.__network_dim__] + +def test_dendrite_forward_tensor_pass_through_text_causal_lm(): + x = torch.ones((3, 3), dtype=torch.int64) y = torch.zeros([3, 3, bittensor.__network_dim__]) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [y, y, y] , [1, 1, 1], [0,0,0]]) - tensors, codes, times = dendrite_mock.forward_image( endpoints=[endpoint, endpoint, endpoint], inputs=[x, x, x]) - assert codes[0].item() == bittensor.proto.ReturnCode.Success - assert codes[1].item() == bittensor.proto.ReturnCode.Success - assert codes[2].item() == bittensor.proto.ReturnCode.Success - assert tensors[0].shape == y.shape - assert tensors[1].shape == y.shape - assert tensors[2].shape == y.shape - -def test_dendrite_forward_tensor_pass_through_tensor(): - x = torch.rand(3, 3, bittensor.__network_dim__) + synapses = [bittensor.synapse.TextCausalLM()] + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[y, y, y]] , [[1, 1, 1]], [[0,0,0]]]) + tensors, codes, times = dendrite_mock.text( endpoints=[endpoint, endpoint, endpoint], inputs=[x, x, x], synapses=synapses) + assert codes[0][0].item() == bittensor.proto.ReturnCode.Success + assert codes[1][0].item() == bittensor.proto.ReturnCode.Success + assert codes[2][0].item() == bittensor.proto.ReturnCode.Success + assert tensors[0][0].shape == y.shape + assert tensors[1][0].shape == y.shape + assert tensors[2][0].shape == y.shape + +def test_dendrite_forward_tensor_pass_through_text_causal_lm_next(): + x = torch.ones((3, 3), dtype=torch.int64) + synapses = [bittensor.synapse.TextCausalLMNext()] + y = torch.zeros([3, (synapses[0].topk + 1), 1 + 1]) + dendrite_mock.receptor_pool.forward = MagicMock(return_value=[[[y, y, y]], [[1, 1, 1]], [[0, 0, 0]]]) + tensors, codes, times = dendrite_mock.text(endpoints=[endpoint, endpoint, endpoint], inputs=[x, x, x], synapses=synapses) + assert codes[0][0].item() == bittensor.proto.ReturnCode.Success + assert codes[1][0].item() == bittensor.proto.ReturnCode.Success + assert codes[2][0].item() == bittensor.proto.ReturnCode.Success + assert tensors[0][0].shape == y.shape + assert tensors[1][0].shape == y.shape + assert tensors[2][0].shape == y.shape + +def test_dendrite_forward_tensor_pass_through_text_last_hidden(): + x = torch.ones((3, 3), dtype=torch.int64) y = torch.zeros([3, 3, bittensor.__network_dim__]) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [y, y, y] , [1, 1, 1], [0,0,0]]) - tensors, codes, times = dendrite_mock.forward_tensor( endpoints = [endpoint, endpoint, endpoint], inputs=[x, x, x]) - assert codes[0].item() == bittensor.proto.ReturnCode.Success - assert codes[1].item() == bittensor.proto.ReturnCode.Success - assert codes[2].item() == bittensor.proto.ReturnCode.Success - assert tensors[0].shape == y.shape - assert tensors[1].shape == y.shape - assert tensors[2].shape == y.shape - -def test_dendrite_forward_tensor_stack(): - x = torch.rand(3, 3, bittensor.__network_dim__) + synapses = [bittensor.synapse.TextLastHiddenState()] + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[y, y, y]] , [[1, 1, 1]], [[0,0,0]]]) + tensors, codes, times = dendrite_mock.text( endpoints=[endpoint, endpoint, endpoint], inputs=[x, x, x], synapses=synapses) + assert codes[0][0].item() == bittensor.proto.ReturnCode.Success + assert codes[1][0].item() == bittensor.proto.ReturnCode.Success + assert codes[2][0].item() == bittensor.proto.ReturnCode.Success + assert tensors[0][0].shape == y.shape + assert tensors[1][0].shape == y.shape + assert tensors[2][0].shape == y.shape + +def test_dendrite_forward_tensor_pass_through_text_seq_2_seq(): + x = torch.ones((3, 3), dtype=torch.int64) y = torch.zeros([3, 3, bittensor.__network_dim__]) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [y, y, y] , [1, 1, 1], [0,0,0]]) - tensors, codes, times = dendrite_mock.forward_tensor( endpoints = [endpoint, endpoint, endpoint], inputs = [x, x, x]) - stacked = torch.stack(tensors, dim=2) - assert stacked.shape == torch.zeros([3, 3, 3, bittensor.__network_dim__ ]).shape - averaged = torch.mean(stacked, dim=2) - assert averaged.shape == torch.zeros([3, 3, bittensor.__network_dim__ ]).shape + synapses = [bittensor.synapse.TextSeq2Seq()] + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[y, y, y]] , [[1, 1, 1]], [[0,0,0]]]) + tensors, codes, times = dendrite_mock.text( endpoints=[endpoint, endpoint, endpoint], inputs=[x, x, x], synapses=synapses) + assert codes[0][0].item() == bittensor.proto.ReturnCode.Success + assert codes[1][0].item() == bittensor.proto.ReturnCode.Success + assert codes[2][0].item() == bittensor.proto.ReturnCode.Success + assert tensors[0][0].shape == y.shape + assert tensors[1][0].shape == y.shape + assert tensors[2][0].shape == y.shape def test_dendrite_backward(): - x = Variable(torch.rand((1, 1, bittensor.__network_dim__), dtype=torch.float32), requires_grad=True) - y = torch.ones((1, 1, bittensor.__network_dim__)) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [y], [0], [0]]) - dendrite_mock.receptor_pool.backward = MagicMock(return_value = [ [y], [0], [0]]) - tensors, codes, times = dendrite_mock.forward_tensor( endpoints = [ endpoint ], inputs=[ x ]) - tensors[0].sum().backward() + x = Variable(torch.rand((2, 2), dtype=torch.float32), requires_grad=True) + y = torch.ones((2, 2)) + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[y]], [[0]], [[0]]]) + dendrite_mock.receptor_pool.backward = MagicMock(return_value = [ [[y]], [[0]], [[0]]]) + dendrite_mock.format_text_inputs = MagicMock(return_value = ( [ endpoint ], [ x ] )) + synapses = [bittensor.synapse.TextCausalLM()] + tensors, codes, times = dendrite_mock.text( endpoints = [ endpoint ], inputs=[ x ], synapses=synapses) + tensors[0][0].sum().backward() assert x.grad.shape == y.shape def test_dendrite_backward_large(): - x = Variable(torch.rand((1, 1, bittensor.__network_dim__), dtype=torch.float32), requires_grad=True) - y = torch.ones((1, 1, bittensor.__network_dim__)) - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [y], [0], [0]]) - dendrite_mock.receptor_pool.backward = MagicMock(return_value = [ [y], [0], [0]]) - tensors, codes, times = dendrite_mock.forward_tensor( endpoints = [ endpoint ], inputs=[ x ]) - tensors[0].sum().backward() + x = Variable(torch.rand((1, 1), dtype=torch.float32), requires_grad=True) + y = torch.ones((1, 1)) + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[y]], [[0]], [[0]]]) + dendrite_mock.receptor_pool.backward = MagicMock(return_value = [ [[y]], [[0]], [[0]]]) + dendrite_mock.format_text_inputs = MagicMock(return_value = ( [ endpoint ], [ x ] )) + synapses = [bittensor.synapse.TextCausalLM()] + tensors, codes, times = dendrite_mock.text( endpoints = [ endpoint ], inputs=[ x ], synapses=synapses) + tensors[0][0].sum().backward() assert x.grad.shape == y.shape assert x.grad.tolist() == y.tolist() def test_dendrite_backward_no_grad(): - x = Variable(torch.rand((1, 1, bittensor.__network_dim__), dtype=torch.float32), requires_grad=True) - y = torch.ones((1, 1, bittensor.__network_dim__)) - nill_response = torch.zeros((1, 1, bittensor.__network_dim__)) - dendrite_no_grad.receptor_pool.forward = MagicMock(return_value = [ [y], [0], [0]]) - dendrite_no_grad.receptor_pool.backward = MagicMock(return_value = [ [y], [0], [0]]) - tensors, codes, times = dendrite_no_grad.forward_tensor( endpoints = [ endpoint ], inputs=[ x ]) - tensors[0].sum().backward() + x = Variable(torch.rand((1, 1), dtype=torch.float32), requires_grad=True) + y = torch.ones((1, 1)) + nill_response = torch.zeros((1, 1)) + dendrite_no_grad.receptor_pool.forward = MagicMock(return_value = [ [[y]], [[0]], [[0]]]) + dendrite_no_grad.receptor_pool.backward = MagicMock(return_value = [ [[y]], [[0]], [[0]]]) + dendrite_no_grad.format_text_inputs = MagicMock(return_value = ( [ endpoint ], [ x ] )) + synapses = [bittensor.synapse.TextCausalLM()] + tensors, codes, times = dendrite_no_grad.text( endpoints = [ endpoint ], inputs=[ x ], synapses=synapses) + tensors[0][0].sum().backward() assert x.grad.shape == y.shape assert x.grad.tolist() == nill_response.tolist() def test_dendrite_backward_multiple(): - x1 = Variable(torch.rand((1, 1, bittensor.__network_dim__), dtype=torch.float32), requires_grad=True) - x2 = Variable(torch.rand((1, 1, bittensor.__network_dim__), dtype=torch.float32), requires_grad=True) - x3 = Variable(torch.rand((1, 1, bittensor.__network_dim__), dtype=torch.float32), requires_grad=True) - y1 = torch.ones(1, 1, bittensor.__network_dim__) - y2 = torch.ones(1, 1, bittensor.__network_dim__) - y3 = torch.ones(1, 1, bittensor.__network_dim__) - - dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [y1, y2, y3], [1,1,1], [0,0,0]]) - dendrite_mock.receptor_pool.backward = MagicMock(return_value = [ [y1, y2, y3], [1,1,1], [0,0,0]]) - tensors, codes, times = dendrite_mock.forward_tensor( endpoints = [endpoint, endpoint, endpoint], inputs=[ x1, x2, x3 ]) - tensors[0].sum().backward() + x1 = Variable(torch.rand((1, 1), dtype=torch.float32), requires_grad=True) + x2 = Variable(torch.rand((1, 1), dtype=torch.float32), requires_grad=True) + x3 = Variable(torch.rand((1, 1), dtype=torch.float32), requires_grad=True) + y1 = torch.ones(1, 1) + y2 = torch.ones(1, 1) + y3 = torch.ones(1, 1) + + dendrite_mock.receptor_pool.forward = MagicMock(return_value = [ [[y1], [y2], [y3]], [[1],[1],[1]], [[0],[0],[0]]]) + dendrite_mock.receptor_pool.backward = MagicMock(return_value = [ [[y1], [y2], [y3]], [[1],[1],[1]], [[0],[0],[0]]]) + dendrite_mock.format_text_inputs = MagicMock(return_value = ( [ endpoint, endpoint, endpoint ], [ x1, x2, x3 ] )) + + synapses = [bittensor.synapse.TextCausalLM()] + tensors, codes, times = dendrite_mock.text( endpoints = [endpoint, endpoint, endpoint], inputs=[ x1, x2, x3 ], synapses=synapses) + tensors[0][0].sum().backward() assert x1.grad.shape == y1.shape assert x2.grad.shape == y2.shape assert x3.grad.shape == y3.shape @@ -177,9 +212,9 @@ def test_dendrite_backward_multiple(): assert x3.grad.tolist() == y3.tolist() def test_axon_receptor_forward_works(): - def forward( inputs_x:torch.FloatTensor): - time.sleep(0.2) - return torch.zeros([3, 3, bittensor.__network_dim__]) + def forward( inputs_x: torch.FloatTensor, synapse , model_output = None): + return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__]) + axon_port = get_random_unused_port() axon = bittensor.axon ( @@ -187,7 +222,7 @@ def forward( inputs_x:torch.FloatTensor): ip = '0.0.0.0', wallet = wallet, ) - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR ) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) axon.start() endpoints = [] for i in range(20): @@ -203,31 +238,33 @@ def forward( inputs_x:torch.FloatTensor): coldkey = wallet.coldkey.ss58_address ) endpoints += [endpoint] - x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) - tensors, codes, times = dendrite.forward_tensor( endpoints=endpoints, inputs=[x for i in endpoints]) + x = torch.zeros(3, 3) + synapses = [bittensor.synapse.TextLastHiddenState()] + + tensors, codes, times = dendrite.text( endpoints=endpoints, inputs=[x for i in endpoints], synapses=synapses) receptors_states = dendrite.receptor_pool.get_receptors_state() # TODO: Fails locally independent of multiprocessing. assert receptors_states[endpoint.hotkey] == receptors_states[endpoint.hotkey].READY - assert codes[0].item() == bittensor.proto.ReturnCode.Success - assert list(tensors[0].shape) == [3, 3, bittensor.__network_dim__] + assert codes[0][0].item() == bittensor.proto.ReturnCode.Success + assert list(tensors[0][0].shape) == [3, 3, bittensor.__network_dim__] print('assertions passed') axon.stop() def test_dendrite_call_time(): - def forward( inputs_x:torch.FloatTensor): + def forward( inputs_x: torch.FloatTensor, synapse , model_output = None): time.sleep(12) - return torch.zeros([3, 3, bittensor.__network_dim__]) - + return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__]) + axon_port = get_random_unused_port() axon = bittensor.axon ( port = axon_port, ip = '0.0.0.0', wallet = wallet, ) - axon.attach_forward_callback( forward, modality = bittensor.proto.Modality.TENSOR ) + axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) axon.start() endpoints = [] - for i in range(100): + for i in range(10): wallet.create_new_hotkey( use_password=False, overwrite = True) endpoint = bittensor.endpoint( version = bittensor.__version_as_int__, @@ -240,9 +277,11 @@ def forward( inputs_x:torch.FloatTensor): coldkey = wallet.coldkey.ss58_address ) endpoints += [endpoint] - x = torch.rand(3, 3, bittensor.__network_dim__, dtype=torch.float32) + x = torch.zeros(3, 3) + synapses = [bittensor.synapse.TextLastHiddenState()] start_time = time.time() - tensors, codes, times = dendrite.forward_tensor( endpoints=endpoints, inputs=[x for i in endpoints]) + + tensors, codes, times = dendrite.text( endpoints=endpoints, inputs=[x for i in endpoints], synapses=synapses) total_time = time.time() - start_time axon.stop() @@ -253,5 +292,5 @@ def test_dendrite_del(): del dendrite_mock if __name__ == "__main__": - test_dendrite_call_time() + test_dendrite_backward_multiple() diff --git a/tests/unit_tests/bittensor_tests/test_neuron.py b/tests/unit_tests/bittensor_tests/test_neuron.py index e400e1125f..02ee588d22 100644 --- a/tests/unit_tests/bittensor_tests/test_neuron.py +++ b/tests/unit_tests/bittensor_tests/test_neuron.py @@ -1,9 +1,16 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from more_itertools import side_effect + +import pytest + import bittensor -import torch +import torch import torch.nn as nn +from bittensor._subtensor import subtensor +from bittensor._subtensor.subtensor_mock import mock_subtensor from torch.nn import TransformerEncoder, TransformerEncoderLayer - def test_set_fine_tuning_params(): class Model(nn.Module): def __init__(self): @@ -15,33 +22,33 @@ def __init__(self): self.encoder2 = TransformerEncoder( self.encoder_layers, nlayers_2 ) self.decoder = torch.nn.Linear( network_dim, vocab_size , bias=False) - adv_server = bittensor._neuron.text.advanced_server.server() + core_server = bittensor._neuron.text.core_server.server() # test for the basic default gpt2 case - assert adv_server.set_fine_tuning_params() == (True, 'h.11') + assert core_server.set_fine_tuning_params() == (True, 'transformer.h.11') # test for the case when there are 2 modulelists - adv_server.pre_model = Model() - assert adv_server.set_fine_tuning_params() == (True, 'encoder2.layers.2') + core_server.pre_model = Model() + assert core_server.set_fine_tuning_params() == (True, 'encoder2.layers.2') # test for user specification of the number of layers - adv_server.config.neuron.finetune.num_layers = 3 - assert adv_server.set_fine_tuning_params() == (True, 'encoder2.layers.0') + core_server.config.neuron.finetune.num_layers = 3 + assert core_server.set_fine_tuning_params() == (True, 'encoder2.layers.0') # test for user specification of the number of layers - adv_server.config.neuron.finetune.num_layers = 4 - assert adv_server.set_fine_tuning_params() == (True, 'encoder.layers.0') + core_server.config.neuron.finetune.num_layers = 4 + assert core_server.set_fine_tuning_params() == (True, 'encoder.layers.0') # test for user specification of the number of layers set too large - adv_server.config.neuron.finetune.num_layers = 5 - assert adv_server.set_fine_tuning_params() == (False, None) + core_server.config.neuron.finetune.num_layers = 5 + assert core_server.set_fine_tuning_params() == (False, None) # test for user specification of the layer name - adv_server.config.neuron.finetune.layer_name = 'encoder2.layers.1' - assert adv_server.set_fine_tuning_params() == (True, 'encoder2.layers.1') + core_server.config.neuron.finetune.layer_name = 'encoder2.layers.1' + assert core_server.set_fine_tuning_params() == (True, 'encoder2.layers.1') # test for user specification of a non-existing layer name - adv_server.config.neuron.finetune.layer_name = 'non_existing_layer' - assert adv_server.set_fine_tuning_params() == (False, 'non_existing_layer') + core_server.config.neuron.finetune.layer_name = 'non_existing_layer' + assert core_server.set_fine_tuning_params() == (False, 'non_existing_layer') class Model(nn.Module): @@ -51,10 +58,185 @@ def __init__(self): self.decoder = torch.nn.Linear( network_dim, vocab_size , bias=False) # test for a non-existing modulelist - adv_server.pre_model = Model() - adv_server.config.neuron.finetune.layer_name = None - assert adv_server.set_fine_tuning_params() == (False, None) + core_server.pre_model = Model() + core_server.config.neuron.finetune.layer_name = None + assert core_server.set_fine_tuning_params() == (False, None) + +def test_coreserver_reregister_flag_false_exit(): + config = bittensor.Config() + config.wallet = bittensor.Config() + config.wallet.reregister = False # don't reregister the wallet + + mock_wallet = bittensor.wallet.mock() + mock_wallet.config = config + + class MockException(Exception): + pass + + def exit_early(*args, **kwargs): + raise MockException('exit_early') + + mock_register = MagicMock(side_effect=exit_early) + + mock_self_neuron=MagicMock( + wallet=mock_wallet, + model=MagicMock(), + axon=MagicMock(), + metagraph=MagicMock(), + spec=bittensor.neurons.core_server.neuron, + subtensor=MagicMock( + network="mock" + ), + config=config, + ) + + with patch.multiple( + 'bittensor.Wallet', + register=mock_register, + is_registered=MagicMock(return_value=False), # mock the wallet as not registered + ): + + # Should exit without calling register + with pytest.raises(SystemExit) as pytest_wrapped_e: + # Should not raise MockException + bittensor.neurons.core_server.neuron.run( + self=mock_self_neuron + ) + + # Should not try to register the neuron + mock_register.assert_not_called() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 # No error + +def test_coreserver_reregister_flag_true(): + config = bittensor.Config() + config.wallet = bittensor.Config() + config.wallet.reregister = True # try to reregister the wallet + + mock_wallet = bittensor.wallet.mock() + mock_wallet.config = config + + class MockException(Exception): + pass + + def exit_early(*args, **kwargs): + raise MockException('exit_early') + + mock_register = MagicMock(side_effect=exit_early) + + mock_self_neuron=MagicMock( + wallet=mock_wallet, + model=MagicMock(), + axon=MagicMock(), + metagraph=MagicMock(), + spec=bittensor.neurons.core_server.neuron, + subtensor=MagicMock( + network="mock" + ), + config=config, + ) + + with patch.multiple( + 'bittensor.Wallet', + register=mock_register, + is_registered=MagicMock(return_value=False), # mock the wallet as not registered + ): + + # Should not exit + with pytest.raises(MockException): + # Should raise MockException + bittensor.neurons.core_server.neuron.run( + self=mock_self_neuron + ) + + # Should try to register the neuron + mock_register.assert_called_once() + +def test_corevalidator_reregister_flag_false_exit(): + config = bittensor.Config() + config.wallet = bittensor.Config() + config.wallet.reregister = False # don't reregister the wallet + + mock_wallet = bittensor.wallet.mock() + mock_wallet.config = config + + class MockException(Exception): + pass + + def exit_early(*args, **kwargs): + raise MockException('exit_early') + + mock_register = MagicMock(side_effect=exit_early) + + mock_self_neuron=MagicMock( + wallet=mock_wallet, + spec=bittensor.neurons.core_validator.neuron, + subtensor=MagicMock( + network="mock" + ), + config=config, + ) + + with patch.multiple( + 'bittensor.Wallet', + register=mock_register, + is_registered=MagicMock(return_value=False), # mock the wallet as not registered + ): + + # Should exit without calling register + with pytest.raises(SystemExit) as pytest_wrapped_e: + # Should not raise MockException + bittensor.neurons.core_validator.neuron.__enter__( + self=mock_self_neuron + ) + + # Should not try to register the neuron + mock_register.assert_not_called() + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 # No error + +def test_corevalidator_reregister_flag_true(): + config = bittensor.Config() + config.wallet = bittensor.Config() + config.wallet.reregister = True # try to reregister the wallet + + mock_wallet = bittensor.wallet.mock() + mock_wallet.config = config + + class MockException(Exception): + pass + + def exit_early(*args, **kwargs): + raise MockException('exit_early') + + mock_register = MagicMock(side_effect=exit_early) + + mock_self_neuron=MagicMock( + wallet=mock_wallet, + spec=bittensor.neurons.core_validator.neuron, + subtensor=MagicMock( + network="mock" + ), + config=config, + ) + + with patch.multiple( + 'bittensor.Wallet', + register=mock_register, + is_registered=MagicMock(return_value=False), # mock the wallet as not registered + ): + + # Should not exit + with pytest.raises(MockException): + # Should raise MockException + bittensor.neurons.core_validator.neuron.__enter__( + self=mock_self_neuron + ) + + # Should try to register the neuron + mock_register.assert_called_once() + if __name__ == '__main__': - test_set_fine_tuning_params() + pass diff --git a/tests/unit_tests/bittensor_tests/test_receptor.py b/tests/unit_tests/bittensor_tests/test_receptor.py index 2d54ab92bf..d1d3458be2 100644 --- a/tests/unit_tests/bittensor_tests/test_receptor.py +++ b/tests/unit_tests/bittensor_tests/test_receptor.py @@ -22,8 +22,10 @@ from unittest.mock import MagicMock import unittest.mock as mock import asyncio +from types import SimpleNamespace +import time as clock -logging = bittensor.logging() +logging = bittensor.logging(debug = True) wallet = bittensor.wallet.mock() @@ -46,6 +48,13 @@ ('grpc.max_receive_message_length', -1)]) stub = bittensor.grpc.BittensorStub(channel) +synapses = [ + bittensor.synapse.TextLastHiddenState(), + bittensor.synapse.TextCausalLM(), + bittensor.synapse.TextCausalLMNext(), + bittensor.synapse.TextSeq2Seq(num_to_generate=70) +] + def test_print(): print(receptor) print(str(receptor)) @@ -57,438 +66,319 @@ def test_dummy_forward(): dummy_receptor = bittensor.receptor ( endpoint= endpoint, wallet=wallet) assert dummy_receptor.endpoint.uid == 0 x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - out, ops, time = dummy_receptor.forward( x, bittensor.proto.Modality.TEXT, timeout=1) - assert ops == bittensor.proto.ReturnCode.EmptyRequest - assert list(out.shape) == [2, 4, bittensor.__network_dim__] + out, ops, time = dummy_receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.BadEndpoint for _ in synapses] + assert [list(o.shape) for o in out] == [[2, 4,bittensor.__network_dim__], + [2, 4, bittensor.__vocab_size__], + [2, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1], + [2, 70]] + def test_dummy_backward(): endpoint = bittensor.endpoint.dummy() dummy_receptor = bittensor.receptor ( endpoint= endpoint, wallet=wallet) assert dummy_receptor.endpoint.uid == 0 - x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - grads = torch.ones((x.size(0),x.size(1),bittensor.__network_dim__)) - out, ops, time = dummy_receptor.backward( x,grads,bittensor.proto.Modality.TEXT , timeout=1) - print (out, ops, time) - assert ops == bittensor.proto.ReturnCode.EmptyRequest - assert list(out.shape) == [2, 4, bittensor.__network_dim__] + grads = [torch.ones((x.size(0),x.size(1),bittensor.__network_dim__))]*len(synapses) + out, ops, time = dummy_receptor.backward( synapses, x,grads, timeout=1) + assert ops == [bittensor.proto.ReturnCode.BadEndpoint for _ in synapses] + assert [list(o.shape) for o in out] == [[2,4,bittensor.__network_dim__], + [2,4, bittensor.__vocab_size__], + [2, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1], + [0]] # -- request serialization -- def test_receptor_forward_request_serialize_error(): - x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - out, ops, time = receptor.forward( x, dict(), timeout=1) - assert ops == bittensor.proto.ReturnCode.RequestSerializationException + x = torch.tensor([[[1,2,3,4]]], dtype=torch.long) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.RequestSerializationException]*len(synapses) def test_receptor_backward_request_serialize_error(): x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - grads = torch.ones((x.size(0),x.size(1),bittensor.__network_dim__)) - out, ops, time = receptor.backward( x,grads, dict(), timeout=1) - assert ops == bittensor.proto.ReturnCode.RequestSerializationException + grads = [torch.ones((x.size(0))), + torch.ones((x.size(0))), + torch.ones((x.size(0))), + torch.ones((x.size(0))) ] + out, ops, time = receptor.backward( synapses, x,grads, timeout=1) + assert ops == [bittensor.proto.ReturnCode.RequestSerializationException]*len(synapses) # -- forward testing -- def test_receptor_neuron_text(): x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.TEXT, timeout=1) - print (out, ops, time) - assert ops == bittensor.proto.ReturnCode.Unavailable - assert list(out.shape) == [2, 4, bittensor.__network_dim__] - -def test_receptor_neuron_image(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ]) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.IMAGE, timeout=1) - assert ops == bittensor.proto.ReturnCode.Unavailable - assert list(out.shape) == [1, 1, bittensor.__network_dim__] - -def test_receptor_neuron_tensor(): - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Unavailable - assert list(out.shape) == [3, 3, bittensor.__network_dim__] + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.Unavailable]*len(synapses) + assert [list(o.shape) for o in out] == [[2, 4,bittensor.__network_dim__], + [2, 4, bittensor.__vocab_size__], + [2, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1], + [2, 70]] def test_receptor_neuron_request_empty(): x = torch.tensor([]) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.TEXT, timeout=1) - assert ops == bittensor.proto.ReturnCode.EmptyRequest - assert list(out.shape) == [0] + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.EmptyRequest]*len(synapses) + assert [list(o.shape) for o in out] == [[0]]*len(synapses) def test_receptor_neuron_mock_server(): - y = torch.rand(3, 3, bittensor.__network_dim__) + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(3, 70) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses], return_code = bittensor.proto.ReturnCode.Success, - tensors = [y_serialized]) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) + tensors=[y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized] + ) + stub.Forward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Success - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + print([list(o.shape) for o in out]) + assert ops == [bittensor.proto.ReturnCode.Success] * len(synapses) + assert [list(o.shape) for o in out] == [[3, 3, bittensor.__network_dim__], + [3, 3, bittensor.__vocab_size__], + [3, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1], + [3, 70]] def test_receptor_neuron_serve_timeout(): - y = torch.rand(3, 3, bittensor.__network_dim__) + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(3, 70, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, - return_code = bittensor.proto.ReturnCode.Timeout, - tensors = [y_serialized]) - - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) - receptor.stub = stub - - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Timeout - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - - -def test_receptor_neuron_serve_empty(): - mock_return_val = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = wallet.hotkey.ss58_address, - return_code = bittensor.proto.ReturnCode.Success, - tensors = []) + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Timeout, message= 'Timeout' ) for synapse in synapses], + tensors=[y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized], + return_code = bittensor.proto.ReturnCode.Timeout + ) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) + stub.Forward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.EmptyResponse - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.Timeout] * len(synapses) + assert [list(o.shape) for o in out] == [[3, 3, bittensor.__network_dim__], + [3, 3, bittensor.__vocab_size__], + [3, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1], + [3, 70]] def test_receptor_neuron_mock_server_deserialization_error(): y = dict() # bad response mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses], return_code = bittensor.proto.ReturnCode.Success, - tensors = [y]) + tensors=[y, y, y, y] + ) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) + stub.Forward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.ResponseDeserializationException - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.ResponseDeserializationException] * len(synapses) + assert [list(o.shape) for o in out] == [[3, 3, bittensor.__network_dim__], + [3, 3, bittensor.__vocab_size__], + [3, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1], + [3, 70]] def test_receptor_neuron_mock_server_shape_error(): - y = torch.rand(1, 3, bittensor.__network_dim__) + y = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, return_code = bittensor.proto.ReturnCode.Success, - tensors = [y_serialized]) + tensors = [y_serialized], + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses], + ) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) + stub.Forward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.ResponseShapeException - assert list(out.shape) == [3, 3, bittensor.__network_dim__] + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + print(ops, bittensor.proto.ReturnCode.ResponseShapeException) + assert ops == [bittensor.proto.ReturnCode.ResponseShapeException] * len(synapses) + assert [list(o.shape) for o in out] == [[3, 3, bittensor.__network_dim__], + [3, 3, bittensor.__vocab_size__], + [3, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1], + [3, 70]] def test_receptor_neuron_server_response_with_nans(): import numpy as np - y = torch.rand(3, 3, bittensor.__network_dim__) - y[0][0][0] = np.nan - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(3, 70) + + y_hidden[0][0][0] = np.nan + y_causallm[0][0][0] = np.nan + y_causallmnext[0] = np.nan # unravel fails because demarcating probability is replaced by nan, ResponseDeserializationException + y_seq_2_seq[0][0] = np.nan + + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) + mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, return_code = bittensor.proto.ReturnCode.Success, - tensors = [y_serialized]) + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses], + tensors = [y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized] + ) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) + stub.Forward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Success - assert out[0][0][0] == 0 + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, + bittensor.proto.ReturnCode.ResponseDeserializationException, bittensor.proto.ReturnCode.Success] + assert out[0][0][0][0] != np.nan + assert out[1][0][0][0] != np.nan + assert out[3][0][0] != np.nan # -- backwards testing -- def test_receptor_neuron_text_backward(): x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - grads = torch.ones((x.size(0),x.size(1),bittensor.__network_dim__)) - out, ops, time = receptor.backward( x,grads, bittensor.proto.Modality.TEXT, timeout=1) - print (out, ops, time) - assert ops == bittensor.proto.ReturnCode.Unavailable - assert list(out.shape) == [2, 4, bittensor.__network_dim__] - -def test_receptor_neuron_image_backward(): - x = torch.tensor([ [ [ [ [ 1 ] ] ] ] ]) - out, ops, time = receptor.backward( x,x, bittensor.proto.Modality.IMAGE, timeout=1) - assert ops == bittensor.proto.ReturnCode.Unavailable - assert list(out.shape) == [1, 1, bittensor.__network_dim__] - -def test_receptor_neuron_tensor_backward(): - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward( x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Unavailable - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - -def test_receptor_neuron_request_empty_backward(): - x = torch.tensor([]) - out, ops, time = receptor.backward( x,x, bittensor.proto.Modality.TEXT, timeout=1) - assert ops == bittensor.proto.ReturnCode.EmptyRequest - assert list(out.shape) == [0] + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) + assert ops == [bittensor.proto.ReturnCode.Unavailable] * len(synapses) def test_receptor_neuron_grads_misshape(): x = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.long) - grads = torch.zeros([0]) - out, ops, time = receptor.backward( x,grads, bittensor.proto.Modality.TEXT, timeout=1) - print (out, ops, time) - assert ops == bittensor.proto.ReturnCode.EmptyRequest + grads = torch.zeros([0,1,2,3,4]) + out, ops, time = receptor.backward( synapses, x, [grads, grads, grads, grads], timeout=1) + assert ops == [bittensor.proto.ReturnCode.RequestSerializationException] * len(synapses) -def test_receptor_neuron_backward_empty_response(): - - mock_return_val = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = "0x" + wallet.hotkey.public_key.hex(), - return_code = bittensor.proto.ReturnCode.Success, - tensors = []) - - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Backward.future = MagicMock( return_value = future ) - - receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.EmptyResponse def test_receptor_neuron_mock_server_backward(): y = torch.rand(3, 3, bittensor.__network_dim__) - - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = "0x" + wallet.hotkey.public_key.hex(), return_code = bittensor.proto.ReturnCode.Success, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses], tensors = [y_serialized]) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Backward.future = MagicMock( return_value = future ) - receptor.stub = stub - - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Success - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - - -def test_receptor_neuron_mock_server_deserialization_error_backward(): - y = dict() # bad response - mock_return_val = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = "0x" + wallet.hotkey.public_key.hex(), - return_code = bittensor.proto.ReturnCode.Success, - tensors = [y]) - - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Backward.future = MagicMock( return_value = future ) - receptor.stub = stub - - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.ResponseDeserializationException - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - - -def test_receptor_neuron_mock_server_shape_error_backward(): - y = torch.rand(1, 3, bittensor.__network_dim__) - - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - - mock_return_val = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = "0x" + wallet.hotkey.public_key.hex(), - return_code = bittensor.proto.ReturnCode.Success, - tensors = [y_serialized]) - - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Backward.future = MagicMock( return_value = future ) + stub.Backward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.ResponseShapeException - assert list(out.shape) == [3, 3, bittensor.__network_dim__] - -def test_receptor_neuron_server_response_with_nans_backward(): - import numpy as np - y = torch.rand(3, 3, bittensor.__network_dim__) - y[0][0][0] = np.nan - - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - - mock_return_val = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = "0x" + wallet.hotkey.public_key.hex(), - return_code = bittensor.proto.ReturnCode.Success, - tensors = [y_serialized]) - - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Backward.future = MagicMock( return_value = future ) - receptor.stub = stub + x = torch.rand(3, 3) + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) + assert ops == [bittensor.proto.ReturnCode.Success] * len(synapses) - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Success - assert out[0][0][0] == 0 # -- no return code -- def test_receptor_forward_no_return(): y = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - - mock_return_val = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = wallet.hotkey.ss58_address, - tensors = [y_serialized]) - - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) - receptor.stub = stub - - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.NoReturn - -def test_receptor_backward_no_return(): - y = torch.rand(3, 3, bittensor.__network_dim__) - - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, - tensors = [y_serialized]) + synapses = [synapse.serialize_to_wire_proto(message= 'NoReturn' ) for synapse in synapses], + tensors = [y_serialized] + ) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Backward.future = MagicMock( return_value = future ) + stub.Forward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.NoReturn + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.NoReturn] * len(synapses) # -- no exception in response -- def test_receptor_forward_exception(): y = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, return_code = bittensor.proto.ReturnCode.UnknownException, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.UnknownException, message= 'Success' ) for synapse in synapses], tensors = [y_serialized]) - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Forward.future = MagicMock( return_value = future ) + stub.Forward = MagicMock( return_value = mock_return_val ) receptor.stub = stub - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.UnknownException - -def test_receptor_backward_exception(): - y = torch.zeros(3, 3, bittensor.__network_dim__) - - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) - - mock_return_val = bittensor.proto.TensorMessage( - version = bittensor.__version_as_int__, - hotkey = wallet.hotkey.ss58_address, - return_code = bittensor.proto.ReturnCode.UnknownException, - tensors = [y_serialized]) - - future = asyncio.Future() - future.set_result(mock_return_val) - stub.Backward.future = MagicMock( return_value = future ) - receptor.stub = stub + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.UnknownException # -- stub erorr -- def test_receptor_forward_stub_exception(): - def forward_break(): raise Exception('Mock') with mock.patch.object(receptor.stub, 'Forward', new=forward_break): - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.UnknownException + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) def test_receptor_backward_stub_exception(): def backward_break(): raise Exception('Mock') with mock.patch.object(receptor.stub, 'Backward', new=backward_break): - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.UnknownException + x = torch.rand(3, 3) + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) + assert ops == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) def test_receptor_forward_endpoint_exception(): @@ -502,9 +392,9 @@ def forward_break(): raise Exception('Mock') with mock.patch.object(bittensor.proto, 'TensorMessage', new=forward_break): - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward(x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.UnknownException + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) def test_receptor_backward_endpoint_exception(): @@ -516,21 +406,38 @@ def backward_break(): raise Exception('Mock') with mock.patch.object(bittensor.proto, 'TensorMessage', new=backward_break): - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.UnknownException + x = torch.rand(3, 3) + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) + assert ops == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) #-- axon receptor connection -- def test_axon_receptor_connection_forward_works(): - def forward(inputs_x:torch.FloatTensor): - return torch.zeros( [3, 3, bittensor.__network_dim__]) + def forward_generate( input, synapse, model_output = None): + return None, None, torch.zeros( [3, 70]) + + def forward_hidden_state( input, synapse, model_output = None): + return None, None, torch.zeros( [3, 3, bittensor.__network_dim__]) + + def forward_casual_lm( input, synapse, model_output = None): + return None, None, torch.zeros( [3, 3, bittensor.__vocab_size__]) + + def forward_casual_lm_next(input, synapse, model_output=None): + return None, None, torch.zeros([3, (synapse.topk + 1), 1 + 1]) + axon = bittensor.axon ( - forward_tensor= forward, port = 8081, ip = '127.0.0.1', wallet = wallet, ) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) axon.start() endpoint = bittensor.endpoint( @@ -549,21 +456,34 @@ def forward(inputs_x:torch.FloatTensor): wallet = wallet, ) - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Success + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.Success] * len(synapses) axon.stop() def test_axon_receptor_connection_forward_unauthenticated(): - def forward(inputs_x:torch.FloatTensor): - return torch.zeros( [3, 3, bittensor.__network_dim__]) + def forward_generate( input, synapse, model_output = None ): + return None, None, torch.zeros( [3, 70]) + + def forward_hidden_state( input, synapse, model_output = None ): + return None, None, torch.zeros( [3, 3, bittensor.__network_dim__]) + + def forward_casual_lm( input, synapse, model_output = None ): + return None, None, torch.zeros( [3, 3, bittensor.__vocab_size__]) + + def forward_casual_lm_next(input, synapse, model_output=None): + return None, None, torch.zeros([3, (synapse.topk + 1), 1 + 1]) + axon = bittensor.axon ( - forward_tensor= forward, - port = 8082, + port = 8081, ip = '127.0.0.1', wallet = wallet, ) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) axon.start() endpoint = bittensor.endpoint( @@ -571,7 +491,7 @@ def forward(inputs_x:torch.FloatTensor): uid = 0, ip = '127.0.0.1', ip_type = 4, - port = 8082, + port = 8081, hotkey = wallet.hotkey.ss58_address, coldkey = wallet.coldkey.ss58_address, modality = 2 @@ -582,22 +502,34 @@ def forward(inputs_x:torch.FloatTensor): wallet = wallet, ) - x = torch.rand(3, 3, bittensor.__network_dim__) + x = torch.rand(3, 3) receptor.sign = MagicMock( return_value='mock' ) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Unauthenticated + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.Unauthenticated] * len(synapses) axon.stop() def test_axon_receptor_connection_backward_works(): - def backward( inputs_x:torch.FloatTensor, grads): - return torch.zeros( [ 3,3,bittensor.__network_dim__]) - + def forward_generate( input, synapse ): + return torch.zeros( [3, 70]) + + def forward_hidden_state( input, synapse ): + return torch.zeros( [3, 3, bittensor.__network_dim__]) + + def forward_casual_lm( input, synapse ): + return torch.zeros( [3, 3, bittensor.__vocab_size__]) + + def forward_casual_lm_next(input, synapse): + return torch.zeros([3, (synapse.topk + 1), 1 + 1]) + axon = bittensor.axon ( - backward_tensor = backward, - port = 8083, + port = 8082, ip = '127.0.0.1', wallet = wallet, ) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) axon.start() endpoint = bittensor.endpoint( @@ -605,7 +537,7 @@ def backward( inputs_x:torch.FloatTensor, grads): uid = 0, ip = '127.0.0.1', ip_type = 4, - port = 8083, + port = 8082, hotkey = wallet.hotkey.ss58_address, coldkey = wallet.coldkey.ss58_address, modality = 2 @@ -615,20 +547,38 @@ def backward( inputs_x:torch.FloatTensor, grads): endpoint = endpoint, wallet = wallet, ) - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Success + x = torch.rand(3, 3) + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + + out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) + assert ops == [bittensor.proto.ReturnCode.Success] * len(synapses) axon.stop() def test_axon_receptor_connection_backward_unauthenticated(): - def backward( inputs_x:torch.FloatTensor, grads): + def forward_generate( input, synapse ): + return torch.zeros( [3, 70]) + + def forward_hidden_state( input, synapse ): return torch.zeros( [3, 3, bittensor.__network_dim__]) + + def forward_casual_lm( input, synapse ): + return torch.zeros( [3, 3, bittensor.__vocab_size__]) + + def forward_casual_lm_next(input, synapse): + return torch.zeros([3, (synapse.topk + 1), 1 + 1]) + axon = bittensor.axon ( - backward_tensor= backward, port = 8090, ip = '127.0.0.1', wallet = wallet, ) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) axon.start() endpoint = bittensor.endpoint( @@ -647,20 +597,23 @@ def backward( inputs_x:torch.FloatTensor, grads): wallet = wallet, ) - x = torch.rand(3, 3, bittensor.__network_dim__) + x = torch.rand(3, 3) + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + receptor.sign = MagicMock( return_value='mock' ) - out, ops, time = receptor.backward( x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Unauthenticated + out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) + + assert ops == [bittensor.proto.ReturnCode.Unauthenticated] * len(synapses) axon.stop() ## --unimplemented error def test_axon_receptor_connection_forward_unimplemented(): - def forward( inputs_x:torch.FloatTensor): - return torch.zeros( [3, 3, bittensor.__network_dim__]) axon = bittensor.axon ( - forward_tensor= forward, - port = 8085, + port = 8081, ip = '127.0.0.1', wallet = wallet, ) @@ -671,7 +624,7 @@ def forward( inputs_x:torch.FloatTensor): uid = 0, ip = '127.0.0.1', ip_type = 4, - port = 8085, + port = 8081, hotkey = wallet.hotkey.ss58_address, coldkey = wallet.coldkey.ss58_address, modality = 2 @@ -683,57 +636,39 @@ def forward( inputs_x:torch.FloatTensor): ) x = torch.rand(3, 3) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.TEXT, timeout=1) - assert ops == bittensor.proto.ReturnCode.NotImplemented + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.NotImplemented] * len(synapses) axon.stop() +## -- timeout error -def test_axon_receptor_connection_backward_unimplemented(): - def backward( inputs_x:torch.FloatTensor, grads): - return torch.zeros( [3, 3, bittensor.__network_dim__]) - axon = bittensor.axon ( - backward_tensor= backward, - port = 8086, - ip = '127.0.0.1', - wallet = wallet, - ) - axon.start() - endpoint = bittensor.endpoint( - version = bittensor.__version_as_int__, - uid = 0, - ip = '127.0.0.1', - ip_type = 4, - port = 8086, - hotkey = wallet.hotkey.ss58_address, - coldkey = wallet.coldkey.ss58_address, - modality = 2 - ) +def test_axon_receptor_connection_forward_timeout(): - receptor = bittensor.receptor ( - endpoint = endpoint, - wallet = wallet, - ) + def forward_generate( inputs, synapse, model_output = None): + clock.sleep(5) + raise TimeoutError('Timeout') - x = torch.rand(3, 3) - grads = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward( x,grads, bittensor.proto.Modality.TEXT, timeout=1) - assert ops == bittensor.proto.ReturnCode.NotImplemented - axon.stop() + def forward_hidden_state( inputs, synapse, model_output = None ): + clock.sleep(5) + raise TimeoutError('Timeout') -## -- timeout error + def forward_casual_lm( inputs, synapse, model_output = None ): + clock.sleep(5) + raise TimeoutError('Timeout') + + def forward_casual_lm_next(inputs, synapse, model_output=None): + clock.sleep(5) + raise TimeoutError('Timeout') -def test_axon_receptor_connection_forward_timeout(): - def forward(inputs_x:torch.FloatTensor): - if inputs_x.size() == (1,1,1): - return None - else: - raise TimeoutError('Timeout') axon = bittensor.axon ( - forward_tensor= forward, - port = 8087, + port = 8085, ip = '127.0.0.1', wallet = wallet, ) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) axon.start() endpoint = bittensor.endpoint( @@ -741,7 +676,7 @@ def forward(inputs_x:torch.FloatTensor): uid = 0, ip = '127.0.0.1', ip_type = 4, - port = 8087, + port = 8085, hotkey = wallet.hotkey.ss58_address, coldkey = wallet.coldkey.ss58_address, modality = 2 @@ -752,24 +687,37 @@ def forward(inputs_x:torch.FloatTensor): wallet = wallet, ) - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.forward( x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Timeout + x = torch.rand(3, 3) + out, ops, time = receptor.forward( synapses, x, timeout=1) + assert ops == [bittensor.proto.ReturnCode.Timeout] * len(synapses) axon.stop() def test_axon_receptor_connection_backward_timeout(): - def backward( inputs_x:torch.FloatTensor, grads): - if inputs_x.size() == (1,1,1): - return None - else: - raise TimeoutError('Timeout') - + def forward_generate( inputs, synapse ): + clock.sleep(5) + raise TimeoutError('Timeout') + + def forward_hidden_state( inputs, synapse ): + clock.sleep(5) + raise TimeoutError('Timeout') + + def forward_casual_lm( inputs, synapse ): + clock.sleep(5) + raise TimeoutError('Timeout') + + def forward_casual_lm_next(inputs, synapse): + clock.sleep(5) + raise TimeoutError('Timeout') + axon = bittensor.axon ( - backward_tensor = backward, port = 8088, ip = '127.0.0.1', wallet = wallet, ) + axon.attach_synapse_callback( forward_hidden_state, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE ) + axon.attach_synapse_callback( forward_generate, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_SEQ_2_SEQ ) + axon.attach_synapse_callback( forward_casual_lm, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM ) + axon.attach_synapse_callback(forward_casual_lm_next, synapse_type=bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT) axon.start() endpoint = bittensor.endpoint( @@ -787,11 +735,47 @@ def backward( inputs_x:torch.FloatTensor, grads): endpoint = endpoint, wallet = wallet, ) - x = torch.rand(3, 3, bittensor.__network_dim__) - out, ops, time = receptor.backward(x,x, bittensor.proto.Modality.TENSOR, timeout=1) - assert ops == bittensor.proto.ReturnCode.Timeout + + x = torch.rand(3, 3) + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) + + assert ops == [bittensor.proto.ReturnCode.Timeout] * len(synapses) axon.stop() if __name__ == "__main__": - test_axon_receptor_connection_backward_unauthenticated() + # test_dummy_forward() + # test_dummy_backward() + # test_receptor_forward_request_serialize_error() + # test_receptor_backward_request_serialize_error() + # test_receptor_neuron_text() + # test_receptor_neuron_image() + # test_receptor_neuron_request_empty() + # test_receptor_neuron_mock_server() + # test_receptor_neuron_serve_timeout() + # test_axon_receptor_connection_backward_unauthenticated() + # test_receptor_neuron_mock_server_deserialization_error() + # test_receptor_neuron_mock_server_shape_error() + # test_receptor_neuron_server_response_with_nans() + # test_receptor_neuron_text_backward() + # test_receptor_neuron_grads_misshape() + # test_receptor_neuron_mock_server_deserialization_error_backward() + # test_receptor_neuron_backward_empty_response() + # test_receptor_forward_no_return() + # test_receptor_forward_exception() + # test_axon_receptor_connection_forward_works() + # test_receptor_neuron_mock_server() + # test_receptor_neuron_server_response_with_nans() + # test_axon_receptor_connection_forward_works() + # test_axon_receptor_connection_forward_unauthenticated() + # test_axon_receptor_connection_forward_timeout() + # test_axon_receptor_connection_backward_works() + # test_axon_receptor_connection_backward_unimplemented() + test_axon_receptor_connection_forward_works() + # test_receptor_neuron_mock_server() + # test_receptor_neuron_mock_server_backward() + # test_receptor_neuron_server_response_with_nans() diff --git a/tests/unit_tests/bittensor_tests/test_receptor_pool.py b/tests/unit_tests/bittensor_tests/test_receptor_pool.py index 28417e159f..337be0923d 100644 --- a/tests/unit_tests/bittensor_tests/test_receptor_pool.py +++ b/tests/unit_tests/bittensor_tests/test_receptor_pool.py @@ -25,7 +25,7 @@ import unittest.mock as mock import asyncio -logging = bittensor.logging() +logging = bittensor.logging(debug = True) # --- Receptor Pool --- wallet = bittensor.wallet.mock() @@ -47,16 +47,30 @@ receptor_pool = bittensor.receptor_pool(wallet=wallet) +synapses = [ + bittensor.synapse.TextLastHiddenState(), + bittensor.synapse.TextCausalLM(), + bittensor.synapse.TextCausalLMNext(), + bittensor.synapse.TextSeq2Seq(num_to_generate=70) +] + def test_receptor_pool_forward(): endpoints = [neuron_obj] - x = torch.ones( (1,2,2) ) - resp1, _, _ = receptor_pool.forward( endpoints, x, bittensor.proto.Modality.TENSOR, timeout=1) - assert list(torch.stack(resp1, dim=0).shape) == [1, 2, 2, bittensor.__network_dim__] + x = torch.ones( (1, 2 ,2) ) + resp1, _, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert list(resp1[0][0].shape) == [2, 2, bittensor.__network_dim__] + assert list(resp1[0][1].shape) == [2, 2, bittensor.__vocab_size__] + assert list(resp1[0][2].shape) == [2, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1] + assert list(resp1[0][3].shape) == [2, 70] def test_receptor_pool_backward(): endpoints = [neuron_obj] x = torch.ones( (1,2,2) ) - receptor_pool.backward( endpoints, x,x, bittensor.proto.Modality.TENSOR, timeout=1) + grads = [[torch.ones(2, 2, bittensor.__network_dim__), + torch.ones(2, 2, bittensor.__vocab_size__), + torch.ones(1, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1), + torch.tensor([])]] + receptor_pool.backward( endpoints, synapses, x, grads, timeout=1) def test_receptor_pool_max_workers_forward(): @@ -73,28 +87,190 @@ def test_receptor_pool_max_workers_forward(): receptor_pool = bittensor.receptor_pool(wallet=wallet,max_active_receptors=1) endpoints = [neuron_obj,neuron_obj2] x = torch.ones( (2,2,2) ) - resp1, _, _ = receptor_pool.forward( endpoints, x, bittensor.proto.Modality.TENSOR, timeout=1) - assert list(torch.stack(resp1, dim=0).shape) == [2, 2, 2, bittensor.__network_dim__] + resp1, _, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert list(resp1[0][0].shape) == [2, 2, bittensor.__network_dim__] + assert list(resp1[0][1].shape) == [2, 2, bittensor.__vocab_size__] + assert list(resp1[0][2].shape) == [2, (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1] + assert list(resp1[0][3].shape) == [2, 70] -def test_receptor_pool_forward_hang(): +def test_receptor_pool_forward_success(): endpoints = [neuron_obj,neuron_obj] - x = torch.ones( (2,2,2) ) - y = torch.rand(3, 3, bittensor.__network_dim__) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) - y_serialized = serializer.serialize(y, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) + x = torch.ones( (2, 3, 3) ) + + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(3, 70) + + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) mock_return_val = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses], + return_code = bittensor.proto.ReturnCode.Success, + tensors = [y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized] + ) + + receptor_pool._get_or_create_receptor_for_endpoint(neuron_obj) + receptor_pool.receptors[neuron_obj.hotkey].stub.Forward = MagicMock( return_value = mock_return_val ) + resp1, codes, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert codes == [[bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success], + [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success]] + +def test_receptor_pool_forward_timeout(): + endpoints = [neuron_obj,neuron_obj] + x = torch.ones( (2, 3, 3) ) + + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(3, 70) + + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) + + mock_return_val = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Timeout, message= 'Timeout' ) for synapse in synapses], return_code = bittensor.proto.ReturnCode.Timeout, - tensors = []) + tensors=[y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized] + ) + + + receptor_pool._get_or_create_receptor_for_endpoint(neuron_obj) + receptor_pool.receptors[neuron_obj.hotkey].stub.Forward = MagicMock( return_value = mock_return_val ) + resp1, codes, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert codes == [ + [bittensor.proto.ReturnCode.Timeout, bittensor.proto.ReturnCode.Timeout, bittensor.proto.ReturnCode.Timeout, + bittensor.proto.ReturnCode.Timeout], + [bittensor.proto.ReturnCode.Timeout, bittensor.proto.ReturnCode.Timeout, bittensor.proto.ReturnCode.Timeout, + bittensor.proto.ReturnCode.Timeout]] + +def test_receptor_pool_forward_num_synapse_mismatch(): + endpoints = [neuron_obj,neuron_obj] + x = torch.ones( (2, 3, 3) ) + + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(3, 70) - future = asyncio.Future() - future.set_result(mock_return_val) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) + + mock_return_val = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Timeout' ) for synapse in synapses], + return_code = bittensor.proto.ReturnCode.Success, + tensors = [y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized] + ) + receptor_pool._get_or_create_receptor_for_endpoint(neuron_obj) - receptor_pool.receptors[neuron_obj.hotkey].stub.Forward.future = MagicMock( return_value = future ) - resp1, codes, _ = receptor_pool.forward( endpoints, x, bittensor.proto.Modality.TENSOR, timeout=1) - assert codes == [bittensor.proto.ReturnCode.Timeout,bittensor.proto.ReturnCode.Timeout] + receptor_pool.receptors[neuron_obj.hotkey].stub.Forward = MagicMock( return_value = mock_return_val ) + resp1, codes, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert codes == [[bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException], + [bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException]] + +def test_receptor_pool_forward_response_partial_shape_error(): + endpoints = [neuron_obj,neuron_obj] + x = torch.ones( (2, 3, 3) ) + + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(2, 70) + + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) + + mock_return_val = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses], + return_code = bittensor.proto.ReturnCode.Success, + tensors = [y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized] + ) + + receptor_pool._get_or_create_receptor_for_endpoint(neuron_obj) + receptor_pool.receptors[neuron_obj.hotkey].stub.Forward = MagicMock( return_value = mock_return_val ) + resp1, codes, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert codes == [[bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.ResponseDeserializationException], + [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.ResponseDeserializationException]] + +def test_receptor_pool_partial_remote_success_return_code(): + endpoints = [neuron_obj,neuron_obj] + x = torch.ones( (2, 3, 3) ) + + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(2, 70) + + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) + + mock_return_val = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses[:-1]] + + [synapses[-1].serialize_to_wire_proto(code = bittensor.proto.ReturnCode.UnknownException, message= 'UnknownException' )], + return_code = bittensor.proto.ReturnCode.Success, + tensors = [y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized] + ) + + receptor_pool._get_or_create_receptor_for_endpoint(neuron_obj) + receptor_pool.receptors[neuron_obj.hotkey].stub.Forward = MagicMock( return_value = mock_return_val ) + resp1, codes, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert codes == [[bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.UnknownException], + [bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.Success, bittensor.proto.ReturnCode.UnknownException]] + +def test_receptor_pool_missing_synapse(): + endpoints = [neuron_obj,neuron_obj] + x = torch.ones( (2, 3, 3) ) + + y_hidden = torch.rand(3, 3, bittensor.__network_dim__) + y_causallm = torch.rand(3, 3, bittensor.__network_dim__) + y_causallmnext = bittensor.synapse.TextCausalLMNext().nill_forward_response_tensor(torch.ones(3), encoded=True) + y_seq_2_seq = torch.rand(3, 70) + + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) + y_hidden_serialized = serializer.serialize(y_hidden, from_type = bittensor.proto.TensorType.TORCH) + y_causallm_serialized = serializer.serialize(y_causallm, from_type = bittensor.proto.TensorType.TORCH) + y_causallmnext_serialized = serializer.serialize(y_causallmnext, from_type=bittensor.proto.TensorType.TORCH) + y_seq_2_seq_serialized = serializer.serialize(y_seq_2_seq, from_type = bittensor.proto.TensorType.TORCH) + + mock_return_val = bittensor.proto.TensorMessage( + version = bittensor.__version_as_int__, + hotkey = wallet.hotkey.ss58_address, + synapses = [synapse.serialize_to_wire_proto(code = bittensor.proto.ReturnCode.Success, message= 'Success' ) for synapse in synapses[:2]], + return_code = bittensor.proto.ReturnCode.Success, + tensors = [y_hidden_serialized, y_causallm_serialized, y_causallmnext_serialized, y_seq_2_seq_serialized] + ) + + receptor_pool._get_or_create_receptor_for_endpoint(neuron_obj) + receptor_pool.receptors[neuron_obj.hotkey].stub.Forward = MagicMock( return_value = mock_return_val ) + resp1, codes, _ = receptor_pool.forward( endpoints, synapses, x, timeout=1) + assert codes == [[bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException], + [bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException, bittensor.proto.ReturnCode.ResponseShapeException]] def test_receptor_pool_backward_hang(): endpoints = [neuron_obj,neuron_obj] @@ -105,11 +281,16 @@ def test_receptor_pool_backward_hang(): return_code = bittensor.proto.ReturnCode.Timeout, tensors = []) - future = asyncio.Future() - future.set_result(mock_return_val) + hidden_grads = torch.ones((x.size(0), x.size(1), bittensor.__network_dim__)) + causal_grads = torch.ones((x.size(0), x.size(1), bittensor.__vocab_size__)) + causallmnext_grads = torch.ones((x.size(0), (bittensor.synapse.TextCausalLMNext().topk + 1), 1 + 1)) + seq_2_seq_grads = torch.tensor([]) + receptor_pool._get_or_create_receptor_for_endpoint(neuron_obj) - receptor_pool.receptors[neuron_obj.hotkey].stub.Backward.future = MagicMock( return_value = future ) - receptor_pool.backward( endpoints, x,x, bittensor.proto.Modality.TENSOR, timeout=1) + receptor_pool.receptors[neuron_obj.hotkey].stub.Backward = MagicMock( return_value = mock_return_val ) + receptor_pool.backward(endpoints, synapses, x, [[hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], + [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads]], timeout=1) if __name__ == "__main__": - test_receptor_pool_backward_hang() \ No newline at end of file + test_receptor_pool_missing_synapse() + pass \ No newline at end of file diff --git a/tests/unit_tests/bittensor_tests/test_serialization.py b/tests/unit_tests/bittensor_tests/test_serialization.py index 5fa52309c9..2b3a8b9f12 100644 --- a/tests/unit_tests/bittensor_tests/test_serialization.py +++ b/tests/unit_tests/bittensor_tests/test_serialization.py @@ -25,7 +25,7 @@ class TestSerialization(unittest.TestCase): def test_serialize(self): for _ in range(10): tensor_a = torch.rand([12, 23]) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) content = serializer.serialize(tensor_a, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) tensor_b = serializer.deserialize(content, to_type = bittensor.proto.TensorType.TORCH) torch.all(torch.eq(tensor_a, tensor_b)) @@ -34,14 +34,14 @@ def test_serialize_object_type_exception(self): # Let's grab a random image, and try and de-serialize it incorrectly. image = torch.ones( [1, 28, 28] ) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) with pytest.raises(bittensor.serializer.SerializationTypeNotImplementedException): serializer.serialize(image, modality = bittensor.proto.Modality.IMAGE, from_type = 11) def test_deserialization_object_type_exception(self): data = torch.rand([12, 23]) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) tensor_message = serializer.serialize(data, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) with pytest.raises(bittensor.serializer.SerializationTypeNotImplementedException): @@ -52,7 +52,7 @@ def test_serialize_deserialize_image(self): # Let's grab a random image, and give it a crazy type to break the system image = torch.ones( [1, 28, 28] ) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) serialized_image_tensor_message = serializer.serialize(image, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) assert image.requires_grad == serialized_image_tensor_message.requires_grad @@ -81,7 +81,7 @@ def test_serialize_deserialize_text(self): for i, ts in enumerate(ts_list): data[i, 0:ts.size()[0]] = ts - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) serialized_data_tensor_message = serializer.serialize(data, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) assert data.requires_grad == serialized_data_tensor_message.requires_grad @@ -100,7 +100,7 @@ def test_serialize_deserialize_text(self): def test_serialize_deserialize_tensor(self): data = torch.rand([12, 23]) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.MSGPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.MSGPACK ) serialized_tensor_message = serializer.serialize(data, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) assert data.requires_grad == serialized_tensor_message.requires_grad @@ -126,7 +126,7 @@ class TestCMPSerialization(unittest.TestCase): def test_serialize(self): for _ in range(10): tensor_a = torch.rand([12, 23]) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.CMPPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.CMPPACK ) content = serializer.serialize(tensor_a, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) tensor_b = serializer.deserialize(content, to_type = bittensor.proto.TensorType.TORCH) torch.all(torch.eq(tensor_a, tensor_b)) @@ -135,14 +135,14 @@ def test_serialize_object_type_exception(self): # Let's grab a random image, and try and de-serialize it incorrectly. image = torch.ones( [1, 28, 28] ) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.CMPPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.CMPPACK ) with pytest.raises(bittensor.serializer.SerializationTypeNotImplementedException): serializer.serialize(image, modality = bittensor.proto.Modality.IMAGE, from_type = 11) def test_deserialization_object_type_exception(self): data = torch.rand([12, 23]) - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.CMPPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.CMPPACK ) tensor_message = serializer.serialize(data, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) with pytest.raises(bittensor.serializer.SerializationTypeNotImplementedException): @@ -154,7 +154,7 @@ def test_serialize_deserialize_image(self): image = torch.ones( [1, 28, 28] ) data_size = image.element_size()*image.nelement() - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.CMPPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.CMPPACK ) serialized_image_tensor_message = serializer.serialize(image, modality = bittensor.proto.Modality.IMAGE, from_type = bittensor.proto.TensorType.TORCH) assert image.requires_grad == serialized_image_tensor_message.requires_grad @@ -185,7 +185,7 @@ def test_serialize_deserialize_text(self): data[i, 0:ts.size()[0]] = ts data_size = data.element_size()*data.nelement() - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.CMPPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.CMPPACK ) serialized_data_tensor_message = serializer.serialize(data, modality = bittensor.proto.Modality.TEXT, from_type = bittensor.proto.TensorType.TORCH) assert data.requires_grad == serialized_data_tensor_message.requires_grad @@ -206,7 +206,7 @@ def test_serialize_deserialize_tensor(self): data = torch.rand([12, 23]) data_size = data.element_size()*data.nelement() - serializer = bittensor.serializer( serialzer_type = bittensor.proto.Serializer.CMPPACK ) + serializer = bittensor.serializer( serializer_type = bittensor.proto.Serializer.CMPPACK ) serialized_tensor_message = serializer.serialize(data, modality = bittensor.proto.Modality.TENSOR, from_type = bittensor.proto.TensorType.TORCH) assert data.requires_grad == serialized_tensor_message.requires_grad diff --git a/tests/unit_tests/bittensor_tests/test_wallet.py b/tests/unit_tests/bittensor_tests/test_wallet.py new file mode 100644 index 0000000000..6415966a5b --- /dev/null +++ b/tests/unit_tests/bittensor_tests/test_wallet.py @@ -0,0 +1,66 @@ +# The MIT License (MIT) +# Copyright © 2022 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import unittest +from unittest.mock import patch +import pytest +import bittensor + +class TestWallet(unittest.TestCase): + def setUp(self): + self.mock_wallet = bittensor.wallet( _mock = True ) + + def test_regen_coldkeypub_from_ss58_addr(self): + ss58_address = "5DD26kC2kxajmwfbbZmVmxhrY9VeeyR1Gpzy9i8wxLUg6zxm" + with patch.object(self.mock_wallet, 'set_coldkeypub') as mock_set_coldkeypub: + self.mock_wallet.regenerate_coldkeypub( ss58_address=ss58_address ) + + mock_set_coldkeypub.assert_called_once() + keypair: bittensor.Keypair = mock_set_coldkeypub.call_args_list[0][0][0] + self.assertEqual(keypair.ss58_address, ss58_address) + + ss58_address_bad = "5DD26kC2kxajmwfbbZmVmxhrY9VeeyR1Gpzy9i8wxLUg6zx" # 1 character short + with pytest.raises(ValueError): + self.mock_wallet.regenerate_coldkeypub(ss58_address=ss58_address_bad) + + def test_regen_coldkeypub_from_hex_pubkey_str(self): + pubkey_str = "0x32939b6abc4d81f02dff04d2b8d1d01cc8e71c5e4c7492e4fa6a238cdca3512f" + with patch.object(self.mock_wallet, 'set_coldkeypub') as mock_set_coldkeypub: + self.mock_wallet.regenerate_coldkeypub(public_key=pubkey_str) + + mock_set_coldkeypub.assert_called_once() + keypair: bittensor.Keypair = mock_set_coldkeypub.call_args_list[0][0][0] + self.assertEqual('0x' + keypair.public_key.hex(), pubkey_str) + + pubkey_str_bad = "0x32939b6abc4d81f02dff04d2b8d1d01cc8e71c5e4c7492e4fa6a238cdca3512" # 1 character short + with pytest.raises(ValueError): + self.mock_wallet.regenerate_coldkeypub(ss58_address=pubkey_str_bad) + + def test_regen_coldkeypub_from_hex_pubkey_bytes(self): + pubkey_str = "0x32939b6abc4d81f02dff04d2b8d1d01cc8e71c5e4c7492e4fa6a238cdca3512f" + pubkey_bytes = bytes.fromhex(pubkey_str[2:]) # Remove 0x from beginning + with patch.object(self.mock_wallet, 'set_coldkeypub') as mock_set_coldkeypub: + self.mock_wallet.regenerate_coldkeypub(public_key=pubkey_bytes) + + mock_set_coldkeypub.assert_called_once() + keypair: bittensor.Keypair = mock_set_coldkeypub.call_args_list[0][0][0] + self.assertEqual(keypair.public_key, pubkey_bytes) + + def test_regen_coldkeypub_no_pubkey(self): + with pytest.raises(ValueError): + # Must provide either public_key or ss58_address + self.mock_wallet.regenerate_coldkeypub(ss58_address=None, public_key=None) diff --git a/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.pt b/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.pt new file mode 100644 index 0000000000..a2c194ce2c Binary files /dev/null and b/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.pt differ diff --git a/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.py b/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.py new file mode 100644 index 0000000000..110b274132 --- /dev/null +++ b/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.py @@ -0,0 +1,596 @@ +""" Unit test for tokenizer utilities. +""" +# The MIT License (MIT) +# Copyright © 2021 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import bittensor + +from transformers import AutoTokenizer, AutoModelForCausalLM +from torch import nn +from bittensor.utils.tokenizer_utils import * + +EPSILON = 1e-40 +encodings_cache_file = "tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.pt" + +sample_text = {'English-1': ['''The Three Laws of Robotics (often shortened to The Three Laws or known as Asimov's Laws) are a set of rules devised by science fiction author Isaac Asimov. The rules were introduced in his 1942 short story "Runaround" (included in the 1950 collection I, Robot), although they had been foreshadowed in some earlier stories. The Three Laws, quoted from the "Handbook of Robotics, 56th Edition, 2058 A.D.", are:''', + + '''(Zeroth Law: A robot may not harm humanity, or, by inaction, allow humanity to come to harm.) +First Law: A robot may not injure a human being or, through inaction, allow a human being to come to harm. +Second Law: A robot must obey the orders given it by human beings except where such orders would conflict with the First Law. +Third Law: A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.''' + ], + + + 'German-1': ['''Die Drei Gesetze der Robotik (oft abgekürzt als Die Drei Gesetze oder bekannt als Asimovs Gesetze) sind eine reihe von regeln, die vom Science-Fiction-Autor Isaac Asimov entwickelt wurden. Die regeln wurden in seiner kurzgeschichte "Runaround" von 1942 (in der sammlung I, Robot von 1950 enthalten) eingeführt, obwohl sie in einigen früheren geschichten angedeutet worden waren. Die Drei Gesetze, zitiert aus dem "Handbook of Robotics, 56th Edition, 2058 A.D.", sind:''', + + '''(Nulltes Gesetz: Ein roboter darf der menschheit keinen schaden zufügen oder durch untätigkeit zulassen, dass der menschheit schaden zugefügt wird.) +Erstes Gesetz: Ein roboter darf einen menschen nicht verletzen oder durch untätigkeit zulassen, dass einem menschen schaden zugefügt wird. +Zweites Gesetz: Ein roboter muss den ihm von menschen erteilten befehlen gehorchen, es sei denn, solche befehle würden im widerspruch zum Ersten Gesetz stehen. +Drittes Gesetz: Ein roboter muss seine eigene existenz schützen, solange dieser schutz nicht im widerspruch zum Ersten oder Zweiten Gesetz steht.''' + ]} + + +def test_tokenizer_equivalence(): + r""" + Checks if two tokenizers are equivalent w.r.t. their vocabularies. + Equivalent tokenizers should always produce the same tokenization for the same text. + Returns: + Asserts expected result for list of tokenizer pairs. + """ + test_pairs = [('gpt2', 'gpt2', True), + ('gpt2', 'EleutherAI/gpt-neo-125M', True), + ('gpt2', 'EleutherAI/gpt-neo-2.7B', True), + ('gpt2', 'EleutherAI/gpt-j-6B', False), + ('gpt2', 'KoboldAI/fairseq-dense-2.7B', False), + ('gpt2', 'bert-base-uncased', False), + ('gpt2', 'xlnet-base-cased', False), + ('gpt2', 'facebook/xglm-564M', False), + ('gpt2', 'benjamin/gerpt2-large', False)] + + for target, to_check, expected_result in test_pairs: + tokenizer_to_check = AutoTokenizer.from_pretrained(to_check) + target_tokenizer = AutoTokenizer.from_pretrained(target) + assert check_tokenizer_equivalence(tokenizer_to_check, target_tokenizer) == expected_result + + +def get_loss_fct(logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor: + """ + Calculate loss_fct, CausalLM loss, next-token prediction loss. + Args: + logits (:obj:`torch.FloatTensor`, `required`): + [batch_size, sequence_len, bittensor.__network_dim__] + labels (:obj:`torch.LongTensor`, `required`): + [batch_size, sequence_len] + + Returns: + loss (:obj:`torch.FloatTensor`): + scalar + """ + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + return loss + + +def encode_forward_response_tensor(forward_response_tensor: torch.Tensor, topk: int = 512) -> torch.FloatTensor: + """ Returns topk tokens/probabilities given unnormalized logits as input. """ + logits = forward_response_tensor # unnormalized logit scores: [batch_size, sequence_len, vocab_size] + probs = torch.softmax(logits, dim=-1) # normalized probabilities: [batch_size, sequence_len, vocab_size] + + values, indices = probs.sort(dim=-1, descending=True) # descend sort probs + topk_values = values[..., :topk] # topk probs: [batch_size, sequence_len, topk] + topk_indices = indices[..., :topk] # topk probs indices: [batch_size, sequence_len, topk] + encoded_probs = torch.cat((topk_values, topk_indices), dim=-1) # [batch_size, sequence_len, topk + topk] + + return encoded_probs # [batch_size, sequence_len, topk + topk] + + +def decode_forward_response_tensor(forward_response_tensor: torch.Tensor, + vocab_size=bittensor.__vocab_size__, topk: int = 512) -> torch.FloatTensor: + """ Returns full logits by decoding topk-encoding input. """ + batch_size, sequence_len, _ = forward_response_tensor.shape + encoded_probs = forward_response_tensor # encoded probabilities: [batch_size, sequence_len, topk + topk] + topk_values = encoded_probs[..., :topk] # topk probs: [batch_size, sequence_len, topk] + topk_indices = encoded_probs[..., topk:].long() # topk probs indices: [batch_size, sequence_len, topk] + + topk_pmass = topk_values.sum(dim=-1) # topk probability mass: [batch_size, sequence_len] + remainder_pmass = torch.clamp(1 - topk_pmass, 1e-40, 1) # remainder probability mass: [batch_size, sequence_len] + remainder_floor = remainder_pmass / (vocab_size - topk) # divide remainder: [batch_size, sequence_len] + + logits = torch.ones((batch_size, sequence_len, vocab_size)).to(topk_values.device) + logits *= torch.log(remainder_floor)[:, :, None] # set probability floor: [batch_size, sequence_len, vocab_size] + logits.scatter_(-1, topk_indices, torch.log(topk_values + 1e-40)) # insert topk probs: [batch_size, sequence_len, vocab_size] + + return logits # [batch_size, sequence_len, vocab_size] + + +def tokenizer_translation(text_batch: List[str], model_name: str, max_length: int, + enc_pre_logits: torch.FloatTensor = None, + device: str = 'cuda', topk: int = 512) -> Tuple[torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor]: + r""" + Emulates validator -> server -> validator interaction where the server-side logit translation + to standard token probabilities allow the validator to calculate standard loss without + having to know any server tokenizer/model/decoder particulars. + Topk encoding is only used to save the server model response to avoid CUDA-device requirement + when routinely running the unit test. + Args: + text_batch (:obj:`List[str]`, `required`): + Input text_batch to test tokenizer translation with. + model_name (:obj:`str`, `required`): + Name of transformer model to use as template server. + max_length (:obj:`int`, `required`): + Specific tokenization max length, small enough to prevent padding, + since GPT2 tokenization doesn't have padding. + enc_pre_logits (:obj:`torch.FloatTensor`, `optional`): + [batch_size, sequence_len, vocab_size] Encoded pre_logits from saved source, to + bypass server model forward call. + device (:obj:`str`, `optional`): + CUDA device for server model forward call. + topk (:obj:`int`, `optional`): + Amount of top logits to encode the server model pre_logits with (for saving purposes). + + Returns: + original_loss (:obj:`torch.FloatTensor`, `required`): + Original server model loss, before any encoding/compression. + encoded_loss (:obj:`torch.FloatTensor`, `required`): + Loss after server model logits have been topk encoded/compressed. + translated_loss (:obj:`torch.FloatTensor`, `required`): + Standard loss after logit translation to standard probabilities. + enc_pre_logits (:obj:`torch.FloatTensor`, `required`): + [batch_size, sequence_len, vocab_size] Encoded pre_logits. + """ + # ============================================= + # ==== Validator-side: CausalLM task setup ==== + # ============================================= + + std_tokenizer = AutoTokenizer.from_pretrained('gpt2') + std_tokenizer.pad_token = std_tokenizer.eos_token # Define PAD Token = EOS Token = 50256. https://github.com/huggingface/transformers/blob/49c8c67fb815a277405f84dea4a66353e19fb347/tests/models/gpt2/test_modeling_gpt2.py#L532 + std_tokenizer.padding_side = "left" # Generative default expects most recent token on right-hand side with padding on left. https://github.com/huggingface/transformers/pull/10552 + + input_batch = std_tokenizer(text_batch, return_offsets_mapping=True, add_special_tokens=False, + max_length=max_length, truncation=True, return_tensors='pt') + + token_batch = input_batch['input_ids'] + + # ============================ + # ==== Server-side: Setup ==== + # ============================ + + tokenizer = AutoTokenizer.from_pretrained(model_name) + # Generative default expects most recent token on right-hand side with padding on left. https://github.com/huggingface/transformers/pull/10552 + tokenizer.padding_side = "left" + + # Define PAD Token = EOS Token (GPT2 generate convention, when PAD Token is None) + # https://github.com/huggingface/transformers/blob/49c8c67fb815a277405f84dea4a66353e19fb347/tests/models/gpt2/test_modeling_gpt2.py#L532 + if tokenizer.pad_token is None and tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + + to_translation_map = get_translation_map(tokenizer, std_tokenizer) + from_translation_map = get_translation_map(std_tokenizer, tokenizer) + split_map_cache = {} + + # ================================================ + # ==== Server-side: CausalLM task translation ==== + # ================================================ + + text_batch = std_tokenizer.batch_decode(token_batch) # decode tokens to original text + result = translate_special_token_text(text_batch, std_tokenizer, tokenizer) # translate special tokens + to_text_batch, from_offsets_batch, to_offsets_batch, pad_offsets_batch = result + + tokens = tokenizer(to_text_batch, padding=True, truncation=True, return_tensors='pt', + add_special_tokens=False) # assume tokenizer.padding_side = 'left' + + # get offsets_mapping in tokenization to delineate token segment positions + server_tokens = tokenizer(to_text_batch, return_offsets_mapping=True, add_special_tokens=False) + std_tokens = std_tokenizer(text_batch, return_offsets_mapping=True) # encode again to get offsets mapping + + # pad offsets so that special token offset widths match for continued correct alignment + tokens['offset_mapping'] = pad_offsets(server_tokens['offset_mapping'], to_offsets_batch, pad_offsets_batch) + tokens['offset_mapping_std'] = pad_offsets(std_tokens['offset_mapping'], from_offsets_batch, pad_offsets_batch) + + # ============================================== + # ==== Server-side: CausalLM task execution ==== + # ============================================== + + original_loss = None + + if enc_pre_logits is None: + server_model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + if server_model.config.pad_token_id is None and server_model.config.eos_token_id is not None: + server_model.config.pad_token_id = server_model.config.eos_token_id + + with torch.no_grad(): + token_batch = input_batch['input_ids'].to(device) + # transformer models like gerpt2 typically perform worse with left-side attention mask, so turning it off + pre_model_output = server_model(input_ids=tokens['input_ids'].to(device), + # attention_mask=tokens['attention_mask'].to(device), + output_hidden_states=True) + pre_logits = pre_model_output.logits + + original_loss = get_loss_fct(pre_logits.cpu(), tokens['input_ids']) + enc_pre_logits = encode_forward_response_tensor(pre_logits, topk=topk).cpu() + + dec_pre_logits = decode_forward_response_tensor(enc_pre_logits, len(tokenizer.vocab), topk=topk) + encoded_loss = get_loss_fct(dec_pre_logits, tokens['input_ids']) + + # ============================================ + # ==== Server-side: Tokenizer translation ==== + # ============================================ + + with torch.no_grad(): + probs_std = translate_logits_to_probs_std(dec_pre_logits.cpu(), + tokens['offset_mapping'], tokens['offset_mapping_std'], + tokenizer, std_tokenizer, + split_map_cache, to_translation_map, from_translation_map, + tokens['input_ids'].cpu(), token_batch.cpu(), + skip_equivalent=False) + + logits_std = torch.log(probs_std + EPSILON) + translated_loss = get_loss_fct(logits_std, token_batch.cpu()) + + return original_loss, encoded_loss, translated_loss, enc_pre_logits + + +def test_tokenizer_translation(): + r""" + Unit test for tokenizer translation. + + Returns: + Asserts that tokenizer translation produces previous encoded and translated losses. + """ + test_pairs = [('English-1', 'EleutherAI/gpt-j-6B', 95), + ('English-1', 'benjamin/gerpt2-large', 95), + ('German-1', 'benjamin/gerpt2-large', 172)] + + try: + encodings = torch.load(encodings_cache_file) + + except FileNotFoundError as e: + print('FileNotFoundError: Server model results not yet saved to', encodings_cache_file) + raise + + # # === Run server models to obtain encoded logits === + # print('Will first run server models (requires CUDA)...') + # + # encodings = {} + # for text_name, model_name, max_length in test_pairs: + # result = tokenizer_translation(sample_text[text_name], model_name, max_length, topk=128) + # original_loss, encoded_loss, translated_loss, enc_pre_logits = result + # encodings[(text_name, model_name)] = (encoded_loss, translated_loss, enc_pre_logits) + # + # print(text_name, model_name, original_loss, encoded_loss, translated_loss) + # + # # English-1 EleutherAI/gpt-j-6B tensor(1.2530) tensor(1.3275) tensor(1.3275) + # # English-1 benjamin/gerpt2-large tensor(3.7216) tensor(4.1541) tensor(4.6420) + # # German-1 benjamin/gerpt2-large tensor(4.2805) tensor(4.4141) tensor(4.7391) + # + # torch.save(encodings, encodings_cache_file) + # encodings = torch.load(encodings_cache_file) + + # === Run token translations on saved encoded logits === + for text_name, model_name, max_length in test_pairs: + _encoded_loss, _translated_loss, _enc_pre_logits = encodings[(text_name, model_name)] + result = tokenizer_translation(sample_text[text_name], model_name, max_length, _enc_pre_logits, topk=128) + original_loss, encoded_loss, translated_loss, enc_pre_logits = result + + assert torch.isclose(encoded_loss, _encoded_loss, rtol=1e-2) + assert torch.isclose(translated_loss, _translated_loss, rtol=1e-2) + + +def tokenizer_topk_phrases(text_batch: List[str], model_name: str, max_length: int, + enc_pre_logits: torch.FloatTensor = None, + device: str = 'cuda', topk: int = 128): + r""" + Emulates validator -> server -> validator interaction where the server-side logits phrases are + standard tokenized to token sequences / phrase with associated probabilities. + This allows the validator to receive full server continuation possibilities consisting of multiple tokens + per phrase, and not just a single token, without having to know any server tokenizer/model/decoder particulars. + Topk logit encoding is only used to save the server model response to avoid CUDA-device requirement + when routinely running the unit test. + Args: + text_batch (:obj:`List[str]`, `required`): + Input text_batch to test tokenizer translation with. + model_name (:obj:`str`, `required`): + Name of transformer model to use as template server. + max_length (:obj:`int`, `required`): + Specific tokenization max length, small enough to prevent padding, + since GPT2 tokenization doesn't have padding. + enc_pre_logits (:obj:`torch.FloatTensor`, `optional`): + [batch_size, sequence_len, vocab_size] Encoded pre_logits from saved source, to + bypass server model forward call. + device (:obj:`str`, `optional`): + CUDA device for server model forward call. + topk (:obj:`int`, `optional`): + Amount of top logits to encode the server model pre_logits with (for saving purposes). + + Returns: + + """ + # ============================================= + # ==== Validator-side: CausalLM task setup ==== + # ============================================= + + std_tokenizer = AutoTokenizer.from_pretrained('gpt2') + std_tokenizer.pad_token = std_tokenizer.eos_token # Define PAD Token = EOS Token = 50256. https://github.com/huggingface/transformers/blob/49c8c67fb815a277405f84dea4a66353e19fb347/tests/models/gpt2/test_modeling_gpt2.py#L532 + std_tokenizer.padding_side = "left" # Generative default expects most recent token on right-hand side with padding on left. https://github.com/huggingface/transformers/pull/10552 + + input_batch = std_tokenizer(text_batch, return_offsets_mapping=True, add_special_tokens=False, + max_length=max_length, truncation=True, return_tensors='pt') + + token_batch = input_batch['input_ids'] + + # ============================ + # ==== Server-side: Setup ==== + # ============================ + + tokenizer = AutoTokenizer.from_pretrained(model_name) + prep_tokenizer(tokenizer, std_tokenizer) + + # ================================================ + # ==== Server-side: CausalLM task translation ==== + # ================================================ + + text_batch = std_tokenizer.batch_decode(token_batch) # decode tokens to original text + result = translate_special_token_text(text_batch, std_tokenizer, tokenizer) # translate special tokens + to_text_batch, from_offsets_batch, to_offsets_batch, pad_offsets_batch = result + + tokens = tokenizer(to_text_batch, padding=True, truncation=True, return_tensors='pt', + add_special_tokens=False) # assume tokenizer.padding_side = 'left' + + # get offsets_mapping in tokenization to delineate token segment positions + server_tokens = tokenizer(to_text_batch, return_offsets_mapping=True, add_special_tokens=False) + std_tokens = std_tokenizer(text_batch, return_offsets_mapping=True) # encode again to get offsets mapping + + # pad offsets so that special token offset widths match for continued correct alignment + tokens['offset_mapping'] = pad_offsets(server_tokens['offset_mapping'], to_offsets_batch, pad_offsets_batch) + tokens['offset_mapping_std'] = pad_offsets(std_tokens['offset_mapping'], from_offsets_batch, pad_offsets_batch) + + # ============================================== + # ==== Server-side: CausalLM task execution ==== + # ============================================== + + if enc_pre_logits is None: + server_model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + if server_model.config.pad_token_id is None and server_model.config.eos_token_id is not None: + server_model.config.pad_token_id = server_model.config.eos_token_id + + with torch.no_grad(): + # transformer models like gerpt2 typically perform worse with left-side attention mask, so turning it off + pre_model_output = server_model(input_ids=tokens['input_ids'].to(device), + # attention_mask=tokens['attention_mask'].to(device), + output_hidden_states=True) + pre_logits = pre_model_output.logits + + enc_pre_logits = encode_forward_response_tensor(pre_logits, topk=topk).cpu() + + dec_pre_logits = decode_forward_response_tensor(enc_pre_logits, len(tokenizer.vocab), topk=topk) + # dec_pre_logits.shape = [batch_size, sequence_len, vocab_size] + + last_logits = dec_pre_logits[:, -1, :] # last token predictions: [batch_size, vocab_size] + + _topk_tensor = topk_token_phrases(last_logits, tokenizer, topk=topk) # [batch_size, (topk + 1), max_len] + compact_topk = compact_topk_token_phrases(_topk_tensor) + # compact_topk: [sum_b(sum_k(len(phrase_k) + 1)_b)] Compacted 1-D tensor >= batch_size * (2 * topk + 1) + + topk_tensor = unravel_topk_token_phrases(compact_topk, topk=topk) + + assert (_topk_tensor - topk_tensor).abs().sum() < 1e-9 + + +def test_topk_token_phrases(): + r""" + Unit test for topk token phrases raveling and unraveling. + + Returns: + Asserts that compact tensor of topk token phrases can be unraveled to recover original topk tensors. + """ + test_pairs = [('English-1', 'EleutherAI/gpt-j-6B', 95), + ('English-1', 'benjamin/gerpt2-large', 95), + ('German-1', 'benjamin/gerpt2-large', 172)] + + try: + encodings = torch.load(encodings_cache_file) + + except FileNotFoundError as e: + print('FileNotFoundError: Server model results not yet saved to', encodings_cache_file) + raise + + # # === Run server models to obtain encoded logits === + # print('Will first run server models (requires CUDA)...') + # + # encodings = {} + # for text_name, model_name, max_length in test_pairs: + # result = tokenizer_translation(sample_text[text_name], model_name, max_length, topk=128) + # original_loss, encoded_loss, translated_loss, enc_pre_logits = result + # encodings[(text_name, model_name)] = (encoded_loss, translated_loss, enc_pre_logits) + # + # print(text_name, model_name, original_loss, encoded_loss, translated_loss) + # + # torch.save(encodings, encodings_cache_file) + # encodings = torch.load(encodings_cache_file) + + # === Run test on saved encoded logits === + for text_name, model_name, max_length in test_pairs: + _encoded_loss, _translated_loss, _enc_pre_logits = encodings[(text_name, model_name)] + tokenizer_topk_phrases(sample_text[text_name], model_name, max_length, _enc_pre_logits, topk=128) + + +def topk_phrases_crossentropy(text_batch: List[str], model_name: str, max_length: int, + last_indices: List[int], + enc_pre_logits: torch.FloatTensor = None, + device: str = 'cpu', topk: int = 128): + r""" + Tests the phrase cross entropy calculation to support loss calculation not just for next token + but also for next phrase consisting of standard tokenized token sequence that should be matched. + Emulates validator -> server -> validator interaction where the server-side logits phrases are + standard tokenized to token sequences / phrase with associated probabilities. + This allows the validator to receive full server continuation possibilities consisting of multiple tokens + per phrase, and not just a single token, without having to know any server tokenizer/model/decoder particulars. + Topk logit encoding is only used to save the server model response to avoid CUDA-device requirement + when routinely running the unit test. + Args: + text_batch (:obj:`List[str]`, `required`): + Input text_batch to test tokenizer translation with. + model_name (:obj:`str`, `required`): + Name of transformer model to use as template server. + max_length (:obj:`int`, `required`): + Specific tokenization max length, small enough to prevent padding, + since GPT2 tokenization doesn't have padding. + last_indices (:obj:`int`, `required`): + Sequence indices to use as last token indicator, with continuation forming target phrase. + enc_pre_logits (:obj:`torch.FloatTensor`, `optional`): + [batch_size, sequence_len, vocab_size] Encoded pre_logits from saved source, to + bypass server model forward call. + device (:obj:`str`, `optional`): + CUDA device for server model forward call. + topk (:obj:`int`, `optional`): + Amount of top logits to encode the server model pre_logits with (for saving purposes). + + Returns: + + """ + # ============================================= + # ==== Validator-side: CausalLM task setup ==== + # ============================================= + + std_tokenizer = AutoTokenizer.from_pretrained('gpt2') + std_tokenizer.pad_token = std_tokenizer.eos_token # Define PAD Token = EOS Token = 50256. https://github.com/huggingface/transformers/blob/49c8c67fb815a277405f84dea4a66353e19fb347/tests/models/gpt2/test_modeling_gpt2.py#L532 + std_tokenizer.padding_side = "left" # Generative default expects most recent token on right-hand side with padding on left. https://github.com/huggingface/transformers/pull/10552 + + input_batch = std_tokenizer(text_batch, return_offsets_mapping=True, add_special_tokens=False, + max_length=max_length, truncation=True, return_tensors='pt') + + token_batch = input_batch['input_ids'] + + # ============================ + # ==== Server-side: Setup ==== + # ============================ + + tokenizer = AutoTokenizer.from_pretrained(model_name) + prep_tokenizer(tokenizer, std_tokenizer) + + # ================================================ + # ==== Server-side: CausalLM task translation ==== + # ================================================ + + text_batch = std_tokenizer.batch_decode(token_batch) # decode tokens to original text + result = translate_special_token_text(text_batch, std_tokenizer, tokenizer) # translate special tokens + to_text_batch, from_offsets_batch, to_offsets_batch, pad_offsets_batch = result + + tokens = tokenizer(to_text_batch, padding=True, truncation=True, return_tensors='pt', + add_special_tokens=False) # assume tokenizer.padding_side = 'left' + + # get offsets_mapping in tokenization to delineate token segment positions + server_tokens = tokenizer(to_text_batch, return_offsets_mapping=True, add_special_tokens=False) + std_tokens = std_tokenizer(text_batch, return_offsets_mapping=True) # encode again to get offsets mapping + + # pad offsets so that special token offset widths match for continued correct alignment + tokens['offset_mapping'] = pad_offsets(server_tokens['offset_mapping'], to_offsets_batch, pad_offsets_batch) + tokens['offset_mapping_std'] = pad_offsets(std_tokens['offset_mapping'], from_offsets_batch, pad_offsets_batch) + + # ============================================== + # ==== Server-side: CausalLM task execution ==== + # ============================================== + + if enc_pre_logits is None: + server_model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + if server_model.config.pad_token_id is None and server_model.config.eos_token_id is not None: + server_model.config.pad_token_id = server_model.config.eos_token_id + + with torch.no_grad(): + # transformer models like gerpt2 typically perform worse with left-side attention mask, so turning it off + pre_model_output = server_model(input_ids=tokens['input_ids'].to(device), + # attention_mask=tokens['attention_mask'].to(device), + output_hidden_states=True) + pre_logits = pre_model_output.logits + + enc_pre_logits = encode_forward_response_tensor(pre_logits, topk=topk).cpu() + + dec_pre_logits = decode_forward_response_tensor(enc_pre_logits, len(tokenizer.vocab), topk=topk) + # dec_pre_logits.shape = [batch_size, sequence_len, vocab_size] + + recorded_losses = [] + for last_idx in last_indices: + last_logits = dec_pre_logits[:, last_idx, :] # last token predictions: [batch_size] + target_phrases = tokenizer.batch_decode(tokens['input_ids'][:, last_idx+1:]) + target_phrases = std_tokenizer(target_phrases)['input_ids'] + + _topk_tensor = topk_token_phrases(last_logits, tokenizer, topk=topk) # [batch_size, (topk + 1), max_len] + compact_topk = compact_topk_token_phrases(_topk_tensor) + # compact_topk: [sum_b(sum_k(len(phrase_k) + 1)_b)] Compacted 1-D tensor >= batch_size * (2 * topk + 1) + + topk_tensor = unravel_topk_token_phrases(compact_topk, topk=topk) + + assert (_topk_tensor - topk_tensor).abs().sum() < 1e-9 + + loss_val, loss = phrase_cross_entropy(target_phrases, topk_tensor) + recorded_losses += [loss.item()] + + return recorded_losses + + +def test_topk_phrases_crossentropy(): + r""" + Unit test for calculating topk token phrases cross entropy with target phrases. + + Returns: + Asserts that phrase cross entropy calculation yields previously observed value. + """ + test_pairs = [('German-1', 'benjamin/gerpt2-large', 172, list(range(50, 110, 5)), + [1.33, 4.07, 6.99, 5.11, 5.60, 2.30, 1.50, 1.51, 4.67, 9.75, 4.83, 3.28])] + + try: + encodings = torch.load(encodings_cache_file) + + except FileNotFoundError as e: + print('FileNotFoundError: Server model results not yet saved to', encodings_cache_file) + raise + + # # === Run server models to obtain encoded logits === + # print('Will first run server models (requires CUDA)...') + # + # encodings = {} + # for text_name, model_name, max_length in test_pairs: + # result = tokenizer_translation(sample_text[text_name], model_name, max_length, topk=128) + # original_loss, encoded_loss, translated_loss, enc_pre_logits = result + # encodings[(text_name, model_name)] = (encoded_loss, translated_loss, enc_pre_logits) + # + # print(text_name, model_name, original_loss, encoded_loss, translated_loss) + # + # torch.save(encodings, encodings_cache_file) + # encodings = torch.load(encodings_cache_file) + + # === Run test on saved encoded logits === + for text_name, model_name, max_length, last_indices, _recorded_losses in test_pairs: + _encoded_loss, _translated_loss, _enc_pre_logits = encodings[(text_name, model_name)] + recorded_losses = topk_phrases_crossentropy(sample_text[text_name], model_name, max_length, + last_indices, _enc_pre_logits, topk=128) + + recorded_losses = [round(r, 2) for r in recorded_losses] + # print(', '.join([f'{loss:.2f}' for loss in recorded_losses])) + assert _recorded_losses == recorded_losses + + +if __name__ == '__main__': + test_tokenizer_equivalence() + test_tokenizer_translation() + test_topk_token_phrases() + test_topk_phrases_crossentropy() diff --git a/tests/unit_tests/config_tests/template_server_sample_config.txt b/tests/unit_tests/config_tests/core_server_sample_config.txt similarity index 100% rename from tests/unit_tests/config_tests/template_server_sample_config.txt rename to tests/unit_tests/config_tests/core_server_sample_config.txt diff --git a/tests/unit_tests/config_tests/test_sample_config.py b/tests/unit_tests/config_tests/test_sample_config.py index b4aef5d2ae..9d87f59c84 100644 --- a/tests/unit_tests/config_tests/test_sample_config.py +++ b/tests/unit_tests/config_tests/test_sample_config.py @@ -1,21 +1,11 @@ import sys -from bittensor._neuron.text.template_miner import neuron as template_miner from bittensor._neuron.text.core_validator import neuron as core_validator -from bittensor._neuron.text.template_server import server as template_server -from bittensor._neuron.text.advanced_server import server as advanced_server +from bittensor._neuron.text.core_server import server as core_server -def test_run_template_miner_config(): +# TODO: Fix pathing issues in this file so it actually does something. +# These tests were not running on github actions, and most of them just work without reading the config files. - PATH = 'sample_configs/template_miner_sample_config.txt' - sys.argv = [sys.argv[0], '--config', PATH] - config = template_miner.config() - - assert config['axon']['ip'] == '[::]' - assert config['dataset']['data_dir'] == '~/.bittensor/data/' - assert config['dendrite']['requires_grad'] == True - - assert config['nucleus']['punishment'] == 0.001 def test_run_core_validator_config(): @@ -28,31 +18,16 @@ def test_run_core_validator_config(): assert config['logging']['logging_dir'] == '~/.bittensor/miners' assert config['neuron']['clip_gradients'] == 1.0 -def test_run_template_server_config(): +def test_run_core_server_config(): - PATH = 'sample_configs/template_server_sample_config.txt' + PATH = 'tests/unit_tests/config_tests/core_server_sample_config.txt' sys.argv = [sys.argv[0], '--config', PATH] - config = template_server.config() - + config = core_server.config() + assert config['axon']['backward_timeout'] == 20 assert config['dataset']['data_dir'] == '~/.bittensor/data/' assert config['logging']['debug'] == False assert config['wandb']['api_key'] == 'default' - -def test_run_advanced_server_config(): - - PATH = 'sample_configs/advanced_server_sample_config.txt' - sys.argv = [sys.argv[0], '--config', PATH] - config = advanced_server.config() - - assert config['axon']['backward_timeout'] == 20 - assert config['dataset']['data_dir'] == '~/.bittensor/data/' - assert config['logging']['debug'] == False - assert config['neuron']['blacklist']['stake']['backward'] == 100 - - if __name__ == "__main__": - test_run_template_miner_config() - test_run_template_server_config() - test_run_advanced_server_config() \ No newline at end of file + test_run_core_server_config() \ No newline at end of file