In [None]:
import bittensor
import pandas as pd
import json
import substrateinterface as pysub

from typing import List, Tuple, Optional, Dict


In [None]:
WS_URL = "wss://archive.chain.opentensor.ai:443"
sub = bittensor.subtensor(WS_URL)
BLOCK_TIME = 12

BLOCKS_PER_HOUR = 3600 / BLOCK_TIME

In [None]:
START_BLOCK = 3_791_351 # First upgrade block
END_BLOCK = 3_811_908 # NEW Fix block

In [None]:
# TODO: Fill in with ss58 addresses
# Note: this script assumes this *coldkey* was not swapped in the period between START_BLOCK and END_BLOCK
OWNER_KEY = "OWNER KEY"

In [None]:
# Assume only ever ONE hotkey swap happened
owned_hotkeys_start = sub.query_subtensor("OwnedHotkeys", START_BLOCK, params=[OWNER_KEY])
owned_hotkeys_end = sub.query_subtensor("OwnedHotkeys", END_BLOCK, params=[OWNER_KEY])

In [None]:
all_w_pending = []
for hk in owned_hotkeys_start:
    pending_em = sub.query_subtensor("PendingdHotkeyEmission", START_BLOCK + 10, params=[hk.value])
    if pending_em:
        all_w_pending.append((hk.value, pending_em.value))

all_w_pending = sorted(all_w_pending, key=lambda x: x[1], reverse=True)
START_HOTKEY = all_w_pending[0][0]
print(f"Start hotkey: {START_HOTKEY}")

In [None]:
all_w_pending = []
for hk in owned_hotkeys_end:
    pending_em = sub.query_subtensor("PendingdHotkeyEmission", END_BLOCK - 120, params=[hk.value])
    if pending_em:
        all_w_pending.append((hk.value, pending_em.value))

all_w_pending = sorted(all_w_pending, key=lambda x: x[1], reverse=True)
END_HOTKEY = all_w_pending[0][0]
print(f"End hotkey: {END_HOTKEY}")

In [None]:
ZERO_KEY = bittensor.u8_key_to_ss58([0]*32)

In [None]:
# Find where swap happened
def find_swap(curr_hk: str) -> Optional[int]:
    curr_block = END_BLOCK
    while curr_block > START_BLOCK:
        # Get last emission drain
        last_emission_drain = sub.query_subtensor("LastHotkeyEmissionDrain", curr_block, params=[curr_hk])
        if not last_emission_drain:
            print(f"No last emission drain found at block {curr_block}")
            break

        # If the start hotkey, check for swap happened
        zero_key_stake = sub.query_subtensor("Stake", last_emission_drain.value - 1, params=[curr_hk, ZERO_KEY])
        if zero_key_stake.value == 0: # Swap happened before this tempo ran
            print(f"Hotkey {curr_hk} swapped AFTER block {curr_block}")
            return last_emission_drain.value # This is first tempo after swap

        curr_block = last_emission_drain.value - 1
    
    return None

In [None]:
target_hk = START_HOTKEY

swap_block = find_swap(target_hk)
if swap_block is None:
    print("No swap found")
    exit(1)

print(f"Swap happened before block {swap_block}")

# Verify the swap had not yet occurred the block *before* the swap_block
null_stake_now = sub.query_subtensor("Stake", swap_block - 1, params=[target_hk, ZERO_KEY])
if null_stake_now.value > 0:
    print(f"Swap already happened at block {swap_block - 1}")
    raise ValueError("Swap already happened")

print(f"Swap not yet happened at block {swap_block - 1}")

target_block = swap_block


In [None]:
 
TO_EMIT = {}
EMISSIONS = {}

# Get the pending emission
pending_emission = sub.query_subtensor("PendingdHotkeyEmission", target_block - 1, params=[target_hk])
EMISSIONS[target_block] = pending_emission.value
if not pending_emission:
    print(f"No pending emission found at block {target_block}")
    raise Exception("No pending emission found")

TAKE = sub.query_subtensor("Delegates", target_block - 1, params=[target_hk])
if not TAKE:
    print(f"No delegate found at block {target_block - 1}")
    raise Exception("No delegate found")

TAKE = TAKE.value / (2**16-1) # normalize to 1.0

hk_take = pending_emission.value * TAKE
to_emit_from_pending = pending_emission.value - hk_take

# Find start of this tempo
## Get the last emission drain
last_emission_drain = sub.query_subtensor("LastHotkeyEmissionDrain", target_block - 1, params=[target_hk]).value
if last_emission_drain == 0:
    last_emission_drain = START_BLOCK

## Set the start of this tempo
start_of_tempo = last_emission_drain


# Get the stake map
stake_map = sub.query_map_subtensor("Stake", start_of_tempo, params=[target_hk])
if not stake_map:
    print(f"No stake map found at block {start_of_tempo}")
    raise Exception("No stake map found")

stake_dict = {ck.value: st.value for ck, st in stake_map}

stake_sum = sum(stake_dict.values())
TO_EMIT_THIS_TEMPO = {}
for ck, st in stake_dict.items():
    TO_EMIT_THIS_TEMPO[ck] = 0
    proportion = st / stake_sum
    owed_emission = proportion * to_emit_from_pending
    # Track for below assertion
    TO_EMIT_THIS_TEMPO[ck] += owed_emission
    
    if ck not in TO_EMIT:
        TO_EMIT[ck] = 0
    TO_EMIT[ck] += owed_emission

# Verify that the sum of the emissions is less than the total to emit
# Using an epsilon of 1_000 RAO
assert sum(TO_EMIT_THIS_TEMPO.values()) <= to_emit_from_pending + 1_000


In [None]:

stake_to_distribute = sum(TO_EMIT.values())
emissions_since = sum(EMISSIONS.values())
owner_key_em = TO_EMIT.get(OWNER_KEY)
vali_take = emissions_since - stake_to_distribute

stake_to_distribute_no_owner = stake_to_distribute - owner_key_em

null_stake_after = sub.query_subtensor("Stake", target_block, params=[target_hk, ZERO_KEY])
if not null_stake_after:
    print(f"No null stake found at block {target_block}")
    raise Exception("No null stake found")

assert null_stake_after.value == emissions_since, f"Expected {emissions_since}, got {null_stake_after.value}"


print(f"Total emissions earned: {emissions_since/1e9} TAO")
print(f"Stake on the null account: {null_stake_after.value/1e9} TAO")
print(f"Validator take: {vali_take/1e9} TAO -> {vali_take/emissions_since*100:.2f}%")
print(f"Earned by owner key: {owner_key_em/1e9} TAO")
print(f"Stake to distribute (minus owner key): {stake_to_distribute_no_owner/1e9} TAO")

In [None]:
# Add the vali take to the owner key's emission
TO_EMIT[OWNER_KEY] += vali_take


In [None]:
to_emit_formatted = {
    "address": [],
    "amount": []
}
total_in_csv = 0
for coldkey, amount in TO_EMIT.items():
    in_tao = round(amount/1e9, 9) # make into TAO, keep 9 decimal places
    total_in_csv += in_tao
    
    if in_tao == 0:
        continue # Skip coldkey if owed nothing
    to_emit_formatted["address"].append(coldkey)
    to_emit_formatted["amount"].append(in_tao) 
print(f"Total to emit recorded in CSV: {total_in_csv} TAO")

In [None]:
def to_csv(emission_map: Dict[str, int], filename: str):
    df = pd.DataFrame.from_dict(emission_map)
    df.to_csv(filename, index=False)

In [None]:
to_csv(to_emit_formatted, "emit_map_migration.csv")

In [None]:
def to_json(emission_map: Dict[str, int], filename: str):
    with open(filename, "w") as f:
        json.dump(emission_map, f)

In [None]:
TO_EMIT_COPY = {k: int(round(v/1e9, 9)*1e9) for k, v in TO_EMIT.items() if v > 0}

to_json(TO_EMIT_COPY, "emit_map_migration.json")
print("Done")