# GPT Model Inference

Welcome! This notebook is a tutorial on how to use the model you've just trained on the Bittensor network.

In [1]:
import torch
import bittensor
from synapses.gpt2 import GPT2Synapse
from torch.nn import functional as F

## Load the trained model
You can find the model under `~/.bittensor/miners/gpt2-genesis/<miner_trial_id>/model.torch`. This is the default place that miners will store models, and your trial ID is auto-generated each time you run your miner, so if you don't know your trial ID you can always simply find the latest trial ID directory and use the `model.torch` there, as that will be your latest run. 

In [2]:
model_path = '../../miners/gpt2-genesis/gpugpt04'

# Check which device this machine is on, just in case we're not loading the model on the same machine that we trained it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("{}/model.torch".format(model_path), map_location=device)


# Let's load up a Bittensor config
config = bittensor.neuron.Neuron.default_config()
config = GPT2Synapse.default_config()

# Let's load up the same synapse config we trained our model with
config.synapse.n_head = 32
config.synapse.n_layer = 12
config.synapse.block_size = 20
config.synapse.device = device

# Load up the model
model = GPT2Synapse(config)
model.load_state_dict(checkpoint['model_state_dict'])
print("Combined loss (local, remote, and distilled) of preloaded model: {}:".format(checkpoint['loss']))
# Load up the huggingface tokenizer
tokenizer = bittensor.__tokenizer__()

Combined loss (local, remote, and distilled) of preloaded model: 8.745189072843083:


## Inference function
In essence, the output of the current GPT model is simply encoded using the HuggingFace tokenizer that Bittensor uses. We need to simply decode that information out using the same tokenizer and turn it into text. 

In [23]:
def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time. Clearly the sampling
    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
    of block_size, unlike an RNN that has an infinite context window.
    """
    block_size = model.get_block_size()-1
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
        logits = model.local_forward(x, training=False)
        logits = model.target_layer(logits.local_hidden)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        x = torch.cat((x, ix), dim=1)

    return x

In [31]:
context = "The cat "

# Tokenize the input
x = tokenizer(context, padding=True, truncation=True)['input_ids']
# Turn it into a tensor
x = torch.tensor(x, dtype=torch.long)
# Give it an extra dimension for the network's sake (expects a 2D tensor input)
x = x.unsqueeze(0)

# Let's sample the network for some output
y = sample(model, x, 15, temperature=1.0, sample=True, top_k=10)

# Decode the output
completion = ''.join([tokenizer.decode(i, skip_special_tokens=True) for i in y])

# Print what the model has predicted
print(completion)

The cat  up broke the window.The mouse was out for the this cheese.Did
