In [1]:
import bittensor
import torch
from tqdm import tqdm
from nuclei.gpt2 import GPT2Nucleus
from types import SimpleNamespace
from loguru import logger
bittensor.__debug_on__ = True
%load_ext autoreload

In [11]:
%autoreload 2

In [12]:

class Miner:
    def __init__( self, dataset: bittensor.Dataloader, endpoint: bittensor.Endpoint, child: bittensor.Endpoint ):
        
        # Local info.
        self.endpoint = endpoint
        
        # Child to call forward on.
        self.child = child
        
        # Text dataloader.
        self.dataset = dataset
        
        # Axon RPC server.
        # We attach the forward and backward passes to this miner class.
        # When this miner recieves a Forward/Backward request it calls these functions
        self.axon = bittensor.axon( local_port = self.endpoint.port )
        self.axon.attach_forward_function( self.forward )
        self.axon.attach_backward_function( self.backward )
        
        # Dendrite RPC Client.
        # Differentiable RPC function which calls Forward and Backward 
        # on passes endpoints.
        self.dendrite = bittensor.dendrite()
        
        # Torch NN Module with remote_forward and local_forwadd functions.
        # plus a routing function.
        self.nucleus = GPT2Nucleus()
        self.nucleus.attach_routing_function( self ) 
        
        # Base Torch optimizer.
        self.optimizer = torch.optim.AdamW(self.nucleus.parameters(), lr = 0.01, betas = (0.9, 0.95) )
                
    # Function is called by the nucleus to query child and get responses.
    def route( self, inputs: torch.int64, query: torch.float32 ) -> torch.float32:
        
        # Is this a leaf node.
        if self.child == None:
            response = [torch.zeros( [inputs.shape[0], inputs.shape[1], bittensor.__network_dim__ ])]
        
        # Otherwise, makes differentiable calls.
        else:
            # Takes a list of endpoints and a list of inputs
            # Sends inputs to endpoints.
            responses, return_codes = self.dendrite.forward_text (
                endpoints = [self.child], 
                x = [inputs] 
            )
            
        return responses[0]
    
    # Function which is called when this miner recieves a forward request from a dendrite.
    def forward ( self, pubkey:str, inputs: torch.float32, modality:int ) -> torch.FloatTensor:
        # Call nucleus (locally, i.e. using the distillation model instead of calling the child)
        # return the last hidden layer.  
        output = self.nucleus.local_forward (
            inputs = inputs        
        )
        return output.local_hidden

    # Function which is called when this miner recieves a backward request. (Off for now.)
    def backward ( self, pubkey:str, inputs_x:torch.float32, grads_dy:torch.float32, modality:int ) -> torch.FloatTensor:
        return None
    
    # Start the axon serving endpoint.
    def start(self):
        self.axon.start()
        
    # Tear down the axon serving endpoint.
    def __del__(self):
        self.axon.stop()

    # Run a single epoch.
    def epoch(self):
        # ---- Next Batch ----
        for iteration, inputs in enumerate(self.dataset.dataloader( 100 )):     
            
            # ---- Forward pass ----
            output = self.nucleus.remote_forward(
                inputs = inputs,
                training = True,
            )

            # ---- Backward pass ----
            output.loss = output.local_target_loss + output.distillation_loss + output.remote_target_loss
            output.loss.backward() # Accumulates gradients on the nucleus.
            self.optimizer.step() # Applies accumulated gradients.
            self.optimizer.zero_grad() # Zeros out gradients for next accummulation

        


In [13]:
# Dataset pulled from IPFS
dataset = bittensor.dataloader( max_corpus_size = 1000000 )

[1mINFO    [0m|[36mbittensor._dataloader.dataloader_impl[0m:[36mconstruct_text_corpus[0m:[36m149[0m - [1mRetrieving a dataset file from the IPFS gateway...[0m
[1mINFO    [0m|[36mbittensor._dataloader.dataloader_impl[0m:[36mconstruct_text_corpus[0m:[36m168[0m - [1m[32mAdded:[0m[1m [36mnaruda.txt[0m[1m[0m
[1mINFO    [0m|[36mbittensor._dataloader.dataloader_impl[0m:[36mconstruct_text_corpus[0m:[36m168[0m - [1m[32mAdded:[0m[1m [36mself_reliance.txt[0m[1m[0m
[1mINFO    [0m|[36mbittensor._dataloader.dataloader_impl[0m:[36mconstruct_text_corpus[0m:[36m168[0m - [1m[32mAdded:[0m[1m [36mpgp.txt[0m[1m[0m
[1mINFO    [0m|[36mbittensor._dataloader.dataloader_impl[0m:[36mconstruct_text_corpus[0m:[36m168[0m - [1m[32mAdded:[0m[1m [36manimal_farm.txt[0m[1m[0m
[1mINFO    [0m|[36mbittensor._dataloader.dataloader_impl[0m:[36mconstruct_text_corpus[0m:[36m168[0m - [1m[32mAdded:[0m[1m [36mprophet.txt[0m[1m[0m
[1mINFO   

In [14]:
# Two fake bittensor endpoints for the miner.
endpoint_A = bittensor.endpoint( uid = 0, hotkey = '0', ip = '0.0.0.0', ip_type = 4, port = 8080 , modality = 0, coldkey = 'N/A'  )
endpoint_B = bittensor.endpoint( uid = 1, hotkey = '1', ip = '0.0.0.0', ip_type = 4, port = 8081 , modality = 0, coldkey = 'N/A'  )

In [19]:
# Create miner A
if "miner_A" in locals():
    del miner_A
miner_A = Miner( dataset = dataset, endpoint = endpoint_A, child = endpoint_B )
miner_A.start()

[32m[1mSUCCESS [0m|[36mbittensor._axon.axon_impl[0m:[36m_serve[0m:[36m465[0m - [32m[1mAxon is serving on: 127.0.0.1:8080[0m


In [20]:
# Create miner B
if "miner_B" in locals():
    del miner_B
miner_B = Miner( dataset = dataset, endpoint = endpoint_B, child = None )
miner_B.start()

[32m[1mSUCCESS [0m|[36mbittensor._axon.axon_impl[0m:[36m_serve[0m:[36m465[0m - [32m[1mAxon is serving on: 127.0.0.1:8081[0m


In [21]:
# Start training miner A.
miner_A.epoch()

[34m[1mDEBUG   [0m|[36mbittensor._dendrite.dendrite_impl[0m:[36m_get_or_create_receptor_for_endpoint[0m:[36m266[0m - [34m[1m[37mCreate receptor for endpoint:[0m[34m[1m <endpoint uid: 1 hotkey: 1 ip: /ipv4/0.0.0.0:8081 modality: 0 coldkey: N/A>[0m
[32m[1mSUCCESS [0m|[36mbittensor._wallet.wallet_impl[0m:[36m_load_hotkey[0m:[36m242[0m - [32m[1mLoaded hotkey: [36m0x76a02b5371044ee1c7d031e5018e872acfcf6d3673f346a030e9bb1e167e621b[0m[32m[1m[0m
[34m[1mDEBUG   [0m|[36mbittensor._receptor.receptor_impl[0m:[36mforward[0m:[36m246[0m - [34m[1m[37mDendrite[0m[34m[1m [32mForward Request[0m[34m[1m ---> [37mto[0m[34m[1m:/ipv4/0.0.0.0:8081, [37minputs[0m[34m[1m:torch.Size([10, 20]), [37mmode[0m[34m[1m:0[0m
<bound method Miner.forward of <__main__.Miner object at 0x143279670>>
[34m[1mDEBUG   [0m|[36mbittensor._axon.axon_impl[0m:[36mForward[0m:[36m88[0m - [34m[1m-> Got Forward request: 0x76a02b5371044ee1c7d031e5018e872acfcf6d3673f

KeyboardInterrupt: 