# How to use this notebook
This notebook allows you to generate text with a language model on Google's
spare resources, including older GPUs. GPU availability is not guaranteed, but
this notebook works without it; it will just be slow. Please use these compute resources responsibly, as people who have better things to do than make fun of Pigroach also rely on them.

To get started...
1.   Click the little down arrow next to "RAM" and "Disk" in the top right,
and click "Change runtime type" if electing to use a GPU. Select T4.
2.   Click the same arrow and "Connect to a hosted runtime"
3.   Run all the setup by clicking "Runtime > Run All" or the shortcut ctrl+F9
4.   You can repeatedly generate new text by changing the settings in the last cell, and then rerunning it by clicking the Play arrow that appears when you mouse-over, or with ctrl+Enter.
5.   When you're done, just close the tab. You'll need redo this process when you come back.

The text used to train this model includes...
*   Old SF google group
*   Subset of SRK forums archive
*   KoH forum up to ~mid 2023
*   DSP's top-haters personal website
*   Subset of discord leaks, including both WWE Champions and Mod discord.


The training set does not include...
*   Stream chat messages (too short)
*   Tweets (usually too short)
*   Any transcriptions of spoken word (stylistically different from written text)




In [1]:
!pip install transformers tqdm ipywidgets

from IPython.display import HTML, display
import ipywidgets as widgets

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)



# First do some setup behind the scenes...
This cell downloads a pretrained model (~800MB for 410M, ~2.8GB for 1.4B) and configures it for text generation. 

In [2]:
import transformers as tfs
import torch

class StopAtTok(tfs.StoppingCriteria):
  def __init__(self, stoptok):
    self.stoptok = stoptok["input_ids"]

  def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
    return input_ids.flatten()[-1] == self.stoptok.flatten()

class Inference():
  def __init__(self, modelname, blocksize, device):
    print("Implementing direct capture...")
    self.device = device
    self.blocksize = blocksize
    self.modelname = modelname
    self.tokenizer = tfs.AutoTokenizer.from_pretrained(
      modelname,
    )

    self.tokenizer.pad_token = self.tokenizer.eos_token
    print("Checking that the camera's not on...")
    self.config = tfs.AutoConfig.from_pretrained(modelname)
    self.config.use_cache=True

    print("Setting up Green Screen...")
    self.model = tfs.AutoModelForCausalLM.from_pretrained(
      modelname,
      config=self.config
    ).to(self.device)
    self.model.eval()
    print("Loaded model.")


  def infer(self, opt):
    prompt = opt['prompt']

    inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True).to(self.device)
    inputs = inputs.to(self.device)
    prompt_len_tok = inputs['input_ids'].shape[-1]

    if opt.get('length') is not None:
        length = opt['length']
    else:
        length = self.blocksize

    print("Prompt has size {}, leaving {} tokens for generation".format(
      prompt_len_tok,
      length - prompt_len_tok
    ))

    generation_cfg = tfs.GenerationConfig(
      do_sample=True,
      eos_token_id=self.model.config.eos_token_id,
      bos_token_id=self.model.config.bos_token_id,
      pad_token_id=self.model.config.eos_token_id,
      use_cache=True,
      max_new_tokens=(length - prompt_len_tok),
      temperature=opt['temperature'],
      top_k=opt['top_k'],
      top_p=opt['top_p'],
      repetition_penalty=opt['repetition_penalty'],
      length_penalty=1.0,
      num_return_sequences=1
    )

    stopper = StopAtTok(self.tokenizer("\n", return_tensors='pt').to(self.device))

    output = []
    with torch.no_grad():
      logits = self.model.generate(
        **inputs,
        generation_config=generation_cfg,
        stopping_criteria=[stopper]
      )
      output = self.tokenizer.batch_decode(logits)

    return "\n".join(output)


In [4]:
model_obj = Inference(
  modelname = 'oddlyshapedfn/YouCompleteRe',
  # modelname = 'ycr-chat',
  blocksize = 512,
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
)

Implementing direct capture...
Checking that the camera's not on...
Setting up Green Screen...
[2023-09-02 02:11:57,060] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Loaded model.


# Use the cell below to generate text!
Here are some prompt ideas. Try to capture the tone of the response you want but don't complete the sentence. You can change the prompt by replacing the text after "prompt": in the next cell. Include the `<YCR>:` tag in your prompt, it helps the model recognize that it should respond in the desired way.
```
<YCR>: I don't have time to explain how wrong you are, but
<YCR>: This weekend's tournament was an utter disappointment because
<YCR>: The thing my detractors don't get, and never will, is
```

In [5]:
# Here are your knobs to influence the model output. Start with just the
# prompt. The initial settings allow for a lot of randomness, so turn down
# temperature and top_k if results become incoherent.
#
# After making your changes, rerun this cell to regenerate.
opts = {
    # Generate this many tokens.
    "length": 200,
    # Set higher for more randomness. Don't go too far over 1.0
    "temperature": 0.9,
    # Set higher for less repetitive responses. Stay < 1.5 to avoid nonsensical generations.
    "repetition_penalty": 1.15,
    # Set higher for more random responses.
    "top_k": 75,
    # Set higher for more random responses. Value should be between 0 and 1.
    "top_p": 0.65,
    # The beginning of the text to complete. Start with "<YCR>:" for best results
    "prompt": "<YCR>"
}
@widgets.interact(prompt="<YCR>:")
def f(prompt):
    newopts = opts.copy()
    newopts['prompt'] = prompt
    print("INPUT: {}".format(prompt))
    print("=========================")
    response = model_obj.infer(newopts)
    print("RESPONSE: {}".format(response))

interactive(children=(Text(value='<YCR>:', description='prompt'), Output()), _dom_classes=('widget-interact',)…