Welcome to the Language-Branch GPT-2 notebook! Here you can set a prompt and GPT-2 will generate many, many variations of how the sentence can go.

In [None]:
!pip install transformers

In [2]:
import json
import math
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [24]:
#@title Select your GPT model, you can find more gpt-2 models for other langs and stuff at huggingface.co
model_name = "distilgpt2" #@param ["gpt2", "distilgpt2", "gpt2-large", "gpt2-xl", "sshleifer/tiny-gpt2"] {allow-input: true}

In [25]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Load pre-trained model (weights)
model = GPT2LMHeadModel.from_pretrained(model_name)

In [26]:
# The prompt that the AI generates from 
starting_prompt = "The"

# Set this to a higher number if you want more leaves and less branches, vice versa.
# If you are using the tiny-gpt2 model, set this really close to 1
multiplier = 5

In [27]:
#@title Define the functions
parsed_prompt = starting_prompt.split(" ")
print(parsed_prompt)

forbidden_tokens = ["<|endoftext|>", "\n", "!", ".", "?"]

class Tree:
    def __init__(self, token, probability, depth=1, total_nodes=1):
        self.val = token
        self.depth = depth
        self.total_nodes = total_nodes
        self.probability = probability
        self.nodes = []
    def increase_node_count(self):
      self.total_nodes+=1

    def add_node(self, token, probability):
        self.nodes.append(Tree(token, probability))

    def __repr__(self):
        return f"Tree({self.val}, {self.probability}): {self.nodes}"


a = Tree(parsed_prompt[0], 100)

def create_tree_prompt(node, index):
  if index < len(parsed_prompt):
    node.add_node(" " + parsed_prompt[index], 100)
    a.increase_node_count()
    a.depth+=1
    index+=1
    create_tree_prompt(node.nodes[0], index)

create_tree_prompt(a, 1)

# Display the tree

def display_tree(node, level=0):
    print('\t' * level + (node.val + " - " + str(round(node.probability, 2)) + "%"))
    for child in node.nodes:
        display_tree(child, level + 1)


def return_token_and_probability(sentence):
    # Encode a text inputs
    text = "<|endoftext|>" + sentence
    indexed_tokens = tokenizer.encode(text)

    # Convert indexed tokens in a PyTorch tensor
    tokens_tensor = torch.tensor([indexed_tokens])

    # Set the model in evaluation mode to deactivate the DropOut modules
    model.eval()

    # If you have a GPU, put everything on cuda
    tokens_tensor = tokens_tensor.to('cuda')
    model.to('cuda')

    # Predict all tokens
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    # Get the predicted next sub-word

    predictions = torch.sort(predictions[0, -1, :], descending=True)
    # You can change the range but typically only the first 20ish values matter

    predicted_ids = predictions.indices[:100]
    predicted_probabilities = predictions.values[:100]
    #predicted_index = torch.argmax(predictions[0, -1, :]).item()
    predicted_texts = []
    for token, probability in zip(predicted_ids, predicted_probabilities):
      predicted_texts.append( (tokenizer.decode([token]), math.e ** probability.item()))

    total_prob = 0
    for log_prob in predicted_texts:
      total_prob += log_prob[1]
    for index, log_prob in enumerate(predicted_texts):
      predicted_texts[index] = (predicted_texts[index][0], (log_prob[1] / total_prob) * 100)

    # Print the predicted word
    return predicted_texts

multiplier = 5

def tree_last_layer_add(node, prompt = ''):
    prompt = f"{prompt}{node.val}"
    depth = len(prompt.split())
    if node.nodes == [] and depth == a.depth:
        to_add = return_token_and_probability(prompt)
        highest_probability = max([element[1] for element in to_add]) if to_add else 0
        for tup in to_add:
            if highest_probability < tup[1] * multiplier and tup[0] not in forbidden_tokens: #Change the 5 to lower numbers for less leaves
                node.add_node(tup[0], tup[1])
                a.increase_node_count()
    else:
        for child in node.nodes:
            tree_last_layer_add(child, prompt)

def tree_generate_n_layers(max_layers):
    for i in range(max_layers):
        if(i == 0):
          multiplier = 5
        else:
          multiplier = 2
        tree_last_layer_add(a)
        a.depth+=1
        print("Total nodes:", a.total_nodes)
        tree_save_as_json(a, ('sentence_tree' + str(a.depth) + ".json"))
def tree_save_as_json(tree, filename):
    with open(filename, 'w') as f:
        json.dump(to_dict(tree), f)

def to_dict(tree):
    if tree.nodes:
        return {"name": tree.val, "probability": round(tree.probability, 2), "size": 1, "children": [to_dict(n) for n in tree.nodes]}
    else:
        return {"name": tree.val, "probability": round(tree.probability, 2), "size": 1}

['The']


Run the generation! (it will save as sentence_tree(depth layers).json)

---

The higher the number for the input, the longer it will take, but it will have more levels.


In [28]:
tree_generate_n_layers(5)

Total nodes: 38
Total nodes: 319
Total nodes: 1385
Total nodes: 6045
Total nodes: 29528


Congratulations! It finished generating!
You now have several JSON files named "sentence_tree2.json" or "sentence_tree3.json" etc. If you would like to display the JSON in a cool way, check out https://codesandbox.io/s/charming-rubin-unfnm4 then change the "flare.json" file to your json (keep the filename)