# This notebook contains some examples for how to use the finetune API.

In [None]:
"""In this example, we load and print the Hugging Face URL for the best-model in the SN9_MODEL competition."""

import finetune as ft
from competitions.data import CompetitionId

# Each model competes in a single competition. Find the best top performing miner UID for the 
# competition we care about (SN9_MODEL, in this example).
top_model_uid = ft.graph.best_uid(competition_id=CompetitionId.SN9_MODEL)

# Get the HuggingFace URL for this model.
repo_url = await ft.mining.get_repo(top_model_uid)
print(f"The best-performing model for SN9_MODEL competition is {repo_url}")

In [None]:
"""In this example, we load and print the actual metadata on the chain for the current best-model in the SN9_MODEL competition."""

# NOTE: Run the first cell to get the top model uid first.
import bittensor as bt

import constants
from model.storage.chain.chain_model_metadata_store import \
    ChainModelMetadataStore

# Create a subtensor to communicate to the main chain.
subtensor = bt.subtensor(network="finney")

# Find the hotkey of the current best uid from the metagraph.
metagraph = subtensor.metagraph(constants.SUBNET_UID, lite=True)
hotkey = metagraph.hotkeys[top_model_uid]

# Instantiate the store that handles reading/writing metadata to the chain.
metadata_store = ChainModelMetadataStore(subtensor)

# Fetch the metadata, parsing from the chain payload into the ModelMetadata class.
# NOTE: This may need to be retried due to transient chain failures.
metadata = await metadata_store.retrieve_model_metadata(hotkey)

# Breaking that down we have the following components.
print(f"HuggingFace namespace and repo name:  {metadata.id.namespace} and {metadata.id.name}")
print(f"Exact commit for that HuggingFace repo: {metadata.id.commit}")
print(f"ID of the competition this model is competing in: {metadata.id.competition_id.name}")
print(f"Secure Hash of the model directory and the hotkey of the miner: {metadata.id.secure_hash}")
print(f"block number that the metadata was committed to the chain: {metadata.block}")

In [None]:
"""In this example, we load the top model for the SN9_MODEL competition and converse with it."""

import bittensor as bt
import torch
from taoverse.model.competition import utils as competition_utils
from transformers import GenerationConfig

import constants
import finetune as ft
from competitions import utils as competition_utils
from competitions.data import CompetitionId

# The device to run the model on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Download the top model to the specified directory.
download_dir = "./finetune-example"
model = await ft.mining.load_best_model(
        download_dir=download_dir, competition_id=CompetitionId.SN9_MODEL
    )

# Load the competition so we can load the right tokenizer.
metagraph = bt.metagraph(constants.SUBNET_UID)
competition = competition_utils.get_competition_for_block(CompetitionId.SN9_MODEL, metagraph.block)
tokenizer = ft.model.load_tokenizer(competition.constraints)

# Decide on a prompt.
prompt = "How much wood could a woodchuck chuck if a woodchuck could chuck wood?"

# Tokenize it.
conversation = [{"role": "user", "content": prompt}]
input_ids = tokenizer.apply_chat_template(
    conversation,
    truncation=True,
    return_tensors="pt",
    max_length=competition.constraints.sequence_length,
    add_generation_prompt=True,
)

# Generate the output.
# You may wish to customize the generation config.
generation_config = GenerationConfig(
    max_length=competition.constraints.sequence_length,
    do_sample=True,
    temperature=0.8,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
)
with torch.inference_mode():
    model.to(device)
    model.eval()
    input_ids = input_ids.to(device)
    output = model.generate(
        input_ids=input_ids, generation_config=generation_config
    )
    response = tokenizer.decode(
        output[0][len(input_ids[0]) :], skip_special_tokens=True
    )
    
    print(response)