# Instructions

Go through each cell in this notebook one by one, take a look at the options and descriptions and then press the play button to the left of it. You can skip the optional one. Don't skip any of the others. After running the "Play" cell, a small form will appear underneath, which you can use to actually play.

To reset the state of your game, run the "Setup" cell again. Closing the notebook will lose your progress, so if you want to keep your story, use the "history" action, copy out your story to a text editor. You can also copy out your author's note and memory from the output of the "info" action.

The most reliable way of loading the models is to download them and then store them in your google drive. If your colab instance happens to be able to download from mega, it can automatically download the models in the model setup step. If this succeeds, you can then copy it into your drive in the optional following step. Otherwise, download the files yourself and upload them to your drive yourself.

* [gpt-neo-2.7B-horni](https://mega.nz/file/6BNykLJb#B6gxK3TnCKBpeOF1DJMXwaLc_gcTcqMS0Lhzr1SeJmc) 5GB, for NSFW styled output
* [gpt-neo-2.7B-horni-ln](https://mega.nz/file/rQcWCTZR#tCx3Ztf_PMe6OtfgI95KweFT5fFTcMm7Nx9Jly_0wpg) 5GB, for light novel styled output

In [None]:
#@title Setup
#@markdown Run this for setting up dependencies or resetting actions
!pip install git+https://github.com/finetuneanon/transformers@gpt-neo-dungeon
!wget -c http://ftp.us.debian.org/debian/pool/main/m/megatools/megatools_1.11.0~git20200404-1_amd64.deb -O megatools.deb
!dpkg -i megatools.deb
!nvidia-smi

import os

from transformers import GPTNeoForCausalLM, AutoTokenizer
import tarfile
import codecs
import torch
import subprocess

from IPython.display import HTML, display
import ipywidgets as widgets

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))

try:
  initialized += 1
except:
  get_ipython().events.register('pre_run_cell', set_css)
  initialized = 0

actions = []
memory = ("", torch.zeros((1, 0)).long())
lmi = ["", torch.zeros((1, 0)).long()]
an = ("", torch.zeros((1, 0)).long())
an_depth = 3

In [None]:
#@title Model setup
#@markdown horni was finetuned for one epoch on about 800MB worth of random blocks of text from the one dataset distributed by EleutherAI that is excluded from the pile dataset. Do not use the horni model if you dislike NSFW outputs. horni-ln uses horni as a base and was finetuned for one epoch on 579MB of text from a light novel dataset.

model_name = "gpt-neo-2.7B-horni" #@param ["gpt-neo-2.7B-horni", "gpt-neo-2.7B-horni-ln", "EleutherAI/gpt-neo-2.7B"]
model_gdrive = "/content/drive/MyDrive/gpt-neo-2.7B-horni.tar" #@param {type:"string"}
use_gdrive = False #@param {type:"boolean"}
#@markdown If you get "chunk download failed" errors, the IP of your colab instance is over the mega quota. In that case, right-click, select "interrupt execution", download the checkpoint from mega yourself, upload to your google drive, tick use_gdrive and put the correct filename, e.g. `gpt-neo-2.7B-horni-ln.tar` and restart the cell.
#@markdown
#@markdown If you get a crash here, you may need to upgrade to Colab Pro and switch the runtime to high memory

custom_models = ["gpt-neo-2.7B-horni", "gpt-neo-2.7B-horni-ln"]

if use_gdrive:
  from google.colab import drive
  drive.mount('/content/drive')

model_types = {"gpt-neo-2.7B-horni": "https://mega.nz/file/6BNykLJb#B6gxK3TnCKBpeOF1DJMXwaLc_gcTcqMS0Lhzr1SeJmc",
               "gpt-neo-2.7B-horni-ln": "https://mega.nz/file/rQcWCTZR#tCx3Ztf_PMe6OtfgI95KweFT5fFTcMm7Nx9Jly_0wpg"}

model = None
tokenizer = None
pipeline = None
checkpoint = None

if not os.path.isdir(model_name) and model_name in custom_models:
  if use_gdrive:
    tar = tarfile.open(model_gdrive, "r")
  else:
    model_url = model_types[model_name]
    print("Downloading:", model_url)
    !megadl $model_url --no-ask-password
    tar = tarfile.open(model_name + ".tar", "r")
  tar.extractall()
  tar.close()

if model_name in custom_models:
  checkpoint = torch.load(model_name + "/pytorch_model.bin", map_location="cuda:0")
  model = GPTNeoForCausalLM.from_pretrained(model_name, state_dict=checkpoint).half().to("cuda").eval()
  for k in list(checkpoint.keys()):
    del checkpoint[k]
  del checkpoint
else:
  from transformers.file_utils import cached_path, WEIGHTS_NAME, hf_bucket_url
  archive_file = hf_bucket_url(model_name, filename=WEIGHTS_NAME)
  resolved_archive_file = cached_path(archive_file)
  checkpoint = torch.load(resolved_archive_file, map_location="cuda:0")
  for k in checkpoint.keys():
    checkpoint[k] = checkpoint[k].half()
  model = GPTNeoForCausalLM.from_pretrained(model_name, state_dict=checkpoint).half().to("cuda").eval()
  for k in list(checkpoint.keys()):
    del checkpoint[k]
  del checkpoint
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
#@title Copy downloaded model to google drive (optional)
#@markdown If the model checkpoint was downloaded from mega in the previous step, you can copy it to your google drive here for more reliable access in the future
gdrive_target = "/content/drive/MyDrive/gpt-neo-2.7B-horni.tar" #@param {type:"string"}
copy_model_file = False #@param {type:"boolean"}

if copy_model_file:
  from google.colab import drive
  drive.mount('/content/drive')
  model_tar = '/content/' + model_name + ".tar"
  !cp -v $model_tar $gdrive_target

In [None]:
#@title Sampling settings
#@markdown You can modify sampling settings here. Don't forget to run the cell again after changing. The number of generated tokens is subtracted from the context window size, don't set it high.
top_k = 60 #@param {type:"number"}
top_p = 0.9 #@param {type:"number"}
temperature =  0.6#@param {type:"number"}
number_generated_tokens =  40#@param {type:"integer"}
repetition_penalty = 1.25 #@param {type:"number"}
repetition_penalty_range = 300 #@param {type:"number"}
repetition_penalty_slope = 3.33 #@param {type:"number"}
number_show_last_actions = 15 #@param {type:"integer"}

#@markdown Temperatures seem to give results different from those in AIDG, so play around with it. Even 0.5 can give good results.

# Using gpt-neo dungeon's play function

If your prompt starts with a letter, try putting a space or newline in front.

* **generate** adds your prompt as an action and generates more output
* **continue** generates more output
* **replace** replaces the last output with the prompt and generates more, use this to edit
* **info** outputs LMI and memory
* **history** outputs all actions so far
* **memory** sets memory to the text in the prompt field
* **authorsnote** sets author's note to the text in the prompt field
* **andepth** sets the depth of the author's note to the number in the prompt
* **tokenize** tokenizes the text in the prompt field and outputs the number of tokens

In [None]:
#@title Play

action_type = "generate"
prompt = ""
need_refresh = True

action_types = ["generate", "continue", "replace", "undo", "retry", "memory", "authorsnote", "andepth", "info", "history", "tokenize"]

def assemble():
  remaining = (1850 - number_generated_tokens + 1) - memory[1].shape[1] - an[1].shape[1] # 2048
  n_actions = len(actions)
  n_ctx = 0
  back_i = n_actions
  for i in range(n_actions):
      i_action = n_actions - i - 1
      n_tok = actions[i_action][1].shape[1]
      if remaining > n_ctx + n_tok:
        n_ctx += n_tok
        back_i = i_action
      else:
        break
  lmi[0], lmi[1] = memory[0], memory[1]
  start = False
  if n_actions - back_i - 1 < an_depth:
    start = True
  while back_i < n_actions:
    if start or n_actions - back_i - 1 == an_depth:
      lmi[0] += an[0]
      lmi[1] = torch.cat([lmi[1].cpu(), an[1].cpu()], 1).long()
      start = False
    lmi[0] += actions[back_i][0]
    lmi[1] = torch.cat([lmi[1].cpu(), actions[back_i][1].cpu()], 1).long()
    back_i += 1

def clear_output():
  with out:
    IPython.display.clear_output()

def set_action(change):
  global action_type
  action_type = change.new

def set_prompt(change):
  global prompt
  prompt = change.new

@torch.no_grad()
def play():
  global memory, need_refresh, an, an_depth
  action = action_type
  with out:
    if prompt in action_types:
      action == prompt
    else:
      if action == "replace":
        if len(actions) > 0:
          actions.pop()
        need_refresh = True
        action = "generate"
      if action == "generate":
        text = prompt
        if len(text) > 0:
          for line in text.splitlines(True):
            tokens = tokenizer(line, return_tensors="pt").input_ids.to("cpu")
            actions.append((line, tokens))
        action = "continue"
      if action == "info":
        clear_output()
        print("LMI: " + lmi[0])
        print("LMI tokens: " + str(lmi[1].shape[1]))
        print("Memory: " + memory[0])
        print("Author's note: " + an[0])
        print("Author's note depth: " + str(an_depth))
        need_refresh = True
      if action == "history":
        clear_output()
        print("".join([action[0] for action in actions]), end="")
        need_refresh = False
      if action == "retry":
        if len(actions) > 0:
          actions.pop()
        need_refresh = True
        action = "continue"
      if action == "undo":
        if len(actions) > 0:
          actions.pop()
        assemble()
        clear_output()
        print("".join([action[0] for action in actions[-number_show_last_actions:]]), end="")
        need_refresh = False
      if action == "memory":
        text = codecs.decode(prompt + "\n", "unicode-escape")
        tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
        memory = (text, tokens)
        clear_output()
        print("Memory: " + text)
      if action == "authorsnote":
        text = "\n[Author's note: " + codecs.decode(prompt, "unicode-escape") + "]\n"
        tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
        an = (text, tokens)
        clear_output()
        print("Author's note: " + text)
      if action == "andepth":
        clear_output()
        try:
          an_depth = int(codecs.decode(prompt + "\n", "unicode-escape"))
        except:
          pass
        print("Author's note depth: " + str(an_depth))
      if action == "tokenize":
        text = codecs.decode(prompt, "unicode-escape")
        tokens = tokenizer(text, return_tensors="pt").input_ids.to("cpu")
        clear_output()
        print("Tokens: " + str(tokens.shape[1]))
        print(tokens[0])
        need_refresh = True
      if action == "continue":
        assemble()
        ids = lmi[1].cuda()
        n_ids = ids.shape[1]
        if n_ids < 1:
          n_ids = 1
          ids = torch.tensor([[tokenizer.eos_token_id]])
        max_length = number_generated_tokens + n_ids
        torch.cuda.empty_cache()
        gen_tokens = model.generate(
            ids.long().cuda(),
            do_sample=True,
            min_length=max_length,
            max_length=max_length,
            temperature=temperature,
            top_k = top_k,
            top_p = top_p,
            repetition_penalty = repetition_penalty,
            repetition_penalty_range = repetition_penalty_range,
            repetition_penalty_slope = repetition_penalty_slope,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id
        ).long()
        stop_tokens = [0, 13, 30, 526, 764, 1701, 2474, 5145, 5633]
        for i in reversed(range(len(gen_tokens[0]))):
          if i < n_ids:
            gen_tokens = gen_tokens[0]
            break
          if gen_tokens[0][i] in stop_tokens:
            gen_tokens = gen_tokens[0][:i+1]
            break
        gen_text = tokenizer.decode(gen_tokens[n_ids:])
        if len(gen_text) > 0:
          actions.append((gen_text, gen_tokens[n_ids:].unsqueeze(0).cpu()))
        clear_output()
        print("".join([action[0] for action in actions[-number_show_last_actions:]]), end="")
        torch.cuda.empty_cache()
        need_refresh = False

import ipywidgets as widgets
import IPython.display
out = widgets.Output(layout={'border': '1px solid black', "width": "1280px"})
dropdown = widgets.Dropdown(options=action_types, value=action_type, description='Action:', disabled=False)
dropdown.observe(set_action, 'value')
button = widgets.Button(description='>', disabled=False)
button.on_click(lambda _: play())
input = widgets.Textarea(value='', placeholder='', description='Input:', disabled=False, rows=4, layout={"width": "1280px"})
input.observe(set_prompt, 'value')

display(out, dropdown, button, input)