Generator component for nostalgebraist-autoresponder.

Needs a fine-tuned GPT-2 model appropriate for use with the rest of the nostalgebraist-autoresponder codebase, and a running instance of the bridge service to talk to.

In [None]:
# make sure you have a GPU with ~16GB memory, for 1558M

!nvidia-smi

In [None]:
BRIDGE_SERVICE_URL = ""  # fill yours in -- make sure it's acccessible from wherever this is running
generator_url = BRIDGE_SERVICE_URL + "/pollgenerator"

In [None]:
# cell for Colab-specific stuff, if you're on Colab

# if you're using Google Colab, need to tell it not to use tf 2.x
%tensorflow_version 1.x

# if you're using Google Colab + Google Drive, mount and change dir
from google.colab import drive
drive.mount('/content/drive')

%cd "/content/drive/My Drive/"

In [None]:
# i assume you're now somewhere with a directory called "gpt-2" holding my gpt-2 fork
# and your fine-tuned model
%cd "gpt-2"

import os, sys
sys.path.append("src")

%pip install -r "requirements.txt"

In [None]:
import fire
import json
import os
import numpy as np
import tensorflow as tf

import model, sample, encoder
from load_dataset import load_dataset, Sampler

model_name = ""  # fill in -- should be a directory under /models
dataset = ""  # fill in -- should be a directory under /data

better_length = True

# sets max context size, for long prompts we want to cut off to allow bot to write at least this many tokens
required_continuation_room = 100 # 385 #500 

batch_size = 4
nsamples = batch_size

seed = None

if better_length:
  length=825
else:
  length=625

EXPERIMENTAL_TOP_P = False
EXPERIMENTAL_MIDDLE_P_TWEAK = False

temperature=0.95
top_k=0
top_p=0
middle_p=0.85

batch_size = 4
nsamples = batch_size
if better_length:
    length=800
else:
    length=700

In [None]:
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
    hparams.override_from_dict(json.load(f))

hparams.set_hparam("attn_dropout", 0)
hparams.set_hparam("res_dropout", 0)

if dataset is not None:
    chunks = load_dataset(enc, dataset, 50000)
    data_sampler = Sampler(chunks)
    start_token = None
else:
    context_tokens = None
    start_token = enc.encoder['<|endoftext|>']
if length is None:
    length = hparams.n_ctx
elif length > hparams.n_ctx:
    raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

In [None]:
import tflex

load_done = False

while not load_done:
  try:
    tf.reset_default_graph()
    if CPU:
      sess = tf.Session()
    else:
      sess = tflex.Session()

    with sess.as_default():
        np.random.seed(seed)
        tf.set_random_seed(seed)

        if start_token is None:
            context = tf.placeholder(tf.int32, [batch_size, None])
        else:
            context = None

        start_ix = 1 if start_token is not None else 0
        output = sample.sample_sequence(stop_at_EOT=True, better_length=better_length,
                                        enc=enc, 
            hparams=hparams, length=length,
            start_token=start_token,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p, 
            middle_p=middle_p
        )[:, start_ix:]

        saver = tflex.Saver()
        ckpt = tflex.latest_checkpoint(os.path.join('models', model_name))
        print(f"restoring checkpoint: {ckpt}")
        saver.restore(sess, ckpt)
    load_done = True
  except Exception as e:
    print(f"encountered {e}, retrying...")

Business logic -- really this should be in a .py file somewhere

TODO: clean up this stuff. Change to "V5" was long ago, blocks that check if we're in "V5" can be changed to assume it

In [None]:
import re
from textwrap import fill, wrap

Q_CHAR = "会"
A_CHAR = "域"
T_CHAR = "职"
ORIG_POST_CHAR = "翰"
UNAME_CHAR = "友"

def get_prompted_continuation(prompt: str, 
                              continue_if_cut_off=False, 
                              max_continue_steps=12):
    raw_text = prompt
    raw_text = re.sub(r"\\n", "\n", raw_text)
    context_tokens = enc.encode(raw_text)

    if better_length:
      max_context_size = length - required_continuation_room
    else:
      max_context_size = hparams.n_ctx - length - 10
    if len(context_tokens) > max_context_size:
      orig_len = len(context_tokens)
      context_tokens = context_tokens[-(max_context_size):]
      print(f"truncated {orig_len} to {len(context_tokens)}, max_context_size={max_context_size}")
    else:
      print(f"{len(context_tokens)} tokens can fit in max_context_size {max_context_size}")

    token_start_ix = len(context_tokens)

    batch_context_tokens = [context_tokens for _ in range(batch_size)]
    continuations = [[prompt] for _ in batch_context_tokens]
    generated = 0
    this_batch_continue_steps = 0

    done = False
    while not done:
      with sess.as_default():
          out = sess.run(output, feed_dict={
              context: batch_context_tokens
          })[:, token_start_ix:]
      for i in range(batch_size):
          generated += 1
          text = enc.decode(out[i])

          continuations[i].append(text)

      if continue_if_cut_off:
        next_prompts = ["".join(subtexts) for subtexts in continuations]
        batch_context_tokens = [enc.encode(text)[-(max_context_size):] for text in next_prompts]
        token_start_ix = len(batch_context_tokens[0])

        next_prompts_contonly = ["".join(subtexts[1:]) for subtexts in continuations]
        not_finished = [c for c in next_prompts_contonly 
                          if ("<|" not in c) and
                          (not any([control_char in c for control_char in {Q_CHAR, A_CHAR, ORIG_POST_CHAR, UNAME_CHAR}]))
                       ]
        n_not_finished = len(not_finished)
        more_needed =  n_not_finished > 0
        more_permitted = this_batch_continue_steps < max_continue_steps

        done = (not more_needed) or (not more_permitted)
        if not done:
          print("continuing within batch:")
          print(f"\t{n_not_finished}/{len(next_prompts)} unfinished")
          print(f"\t{this_batch_continue_steps}/{max_continue_steps} continue steps used")

          print("Using prompts:")
          for np in not_finished:
            print("\t" + "\n\t".join(wrap(np, width=90)) + "\n")

          this_batch_continue_steps += 1
      else:
        done = True
    
    # cleanup
    continuations_ = []
    for subtexts in continuations:
      text = "".join(subtexts[1:])  # don't return prompt as part of these
      if not text.endswith("<|") and "<|endoftext|>" in text:
        continuations_.append(text.split("<|endoftext|>")[0] + "<|")
      else:
        continuations_.append(text)

    return continuations_

def get_prompted_continuation_with_retries_for_length(prompt: str, retry_if_under: int=60, best_of: int=3, prompt_from_dataset=False, verbose=False):
    n_words = 0
    n_retries = 0
    best = {"post": "", "tags": []}

    while (n_words < retry_if_under) and (n_retries < (best_of // batch_size)):
        n_retries += 1
        if prompt_from_dataset:
                prompt = get_prompt_from_dataset(dataset)
                if verbose:
                    print(f"Using prompt: {prompt}")
        if verbose:
            print(f"try #{n_retries} of {best_of // batch_size}...")
        continuations = get_prompted_continuation(prompt)
        for continuation in continuations:
          if prompt_from_dataset:
              continuation = prompt + continuation
              if continuation.startswith("endoftext|>"):
                  continuation = continuation[len("endoftext|>"):]

          parsed = parse_continuation(continuation, verbose=verbose)
          n_words_latest = len(parsed['post'].split(" "))
          if n_words_latest > n_words and continuation.endswith("<|") and Q_CHAR not in continuation and A_CHAR not in continuation:
              if verbose:
                  print(f"n_words {n_words_latest} beats previous high of {n_words}")
              n_words = n_words_latest
              best = continuation

    return [best]

def get_prompted_continuation_with_length_proportional_sampling(prompt: str, avoid_if_under: int=20, best_of: int=3, verbose=False, prompt_from_dataset=False, return_all=False):
    n_words_by_try = []
    not_cut_off_by_try = []
    no_control_chars = []

    tries = []
    best = {"post": "", "tags": []}

    for n_retries in range(best_of // batch_size):
        if prompt_from_dataset:
                prompt = get_prompt_from_dataset(dataset)
                if verbose:
                    print(f"Using prompt: {prompt}")
        if verbose:
            print(f"try #{n_retries} of {best_of // batch_size}...")
        continuations = get_prompted_continuation(prompt)
        #continuation = continuations[0]
        for continuation in continuations:
          if prompt_from_dataset:
              continuation = prompt + continuation
              if continuation.startswith("endoftext|>"):
                  continuation = continuation[len("endoftext|>"):]
          parsed = parse_continuation(continuation, verbose=verbose)
          n_words_latest = len(parsed['post'].split(" "))
          if verbose:
              print(f"n_words {n_words_latest}")
          n_words_by_try.append(n_words_latest)
          not_cut_off_by_try.append(continuation.endswith("<|"))
          no_control_chars.append(Q_CHAR not in continuation and A_CHAR not in continuation)
          tries.append(continuation)

    keep = [(nw >= avoid_if_under) and (nco) and (ncc)
              for nw, nco, ncc in zip(n_words_by_try, not_cut_off_by_try, no_control_chars)]
    if any(keep):
        n_words_by_try = [nw for nw, k in zip(n_words_by_try, keep) if k]
        tries = [t for t, k in zip(tries, keep) if k]

    probs = np.asarray(n_words_by_try) / sum(n_words_by_try)

    if verbose:
        print(f"choosing between word counts {n_words_by_try}\nprobs {probs}")
    choice_ix = np.random.choice(list(range(len(probs))), p=probs)

    best = tries[choice_ix]
    if verbose:
        print(f"chose #{choice_ix} ({n_words_by_try[choice_ix]} words, prob {probs[choice_ix]})")

    if return_all:
        return tries
    return [best]

def parse_continuation(continuation: str, verbose=True):
    if verbose:
        print(f"parsing the following raw output:\n------------------\n{fill(continuation)}\n------------------\n")

    # split out tags, if present
    post, _ , tag_text = continuation.partition(T_CHAR)
    tags = []
    if len(tag_text) > 0:
        tags = [s.rstrip(" ") for s in tag_text.split("#")]

    post = post.lstrip(ORIG_POST_CHAR) # TODO: fix this in get_prompted_continuation_with_length_proportional_sampling
    parsed = {"post": post, "tags": tags}
    return parsed

def get_prompt_from_dataset(dataset):
    global data_sampler
    if data_sampler is None:
        print("getting data sampler...")
        chunks = load_dataset(enc, dataset, 50000)
        data_sampler = Sampler(chunks)

    segment = "会"
    segments = []
    #while segment[0]=="会": # V3
    while segment[0] != "翰": # V4
        while len(segments) == 0:
            segments = enc.decode(data_sampler.sample(1024)).split("<|endoftext|>")[1:]
        segment = segments.pop()
    context_tokens = enc.encode("endoftext|>" + segment)[:len(enc.encode("endoftext|>"))+3]
    prompt = enc.decode(context_tokens)

    return prompt

def basic_n_continuations(prompt, N, 
                          avoid_if_under=20, 
                          avoid_if_cut_off=True, 
                          split_on_control_char=False,
                          prompt_from_dataset=False,
                          avoid_initial_blockquote=False,
                          continue_if_cut_off=False,
                          max_continue_steps=12):
  if prompt_from_dataset:
      prompt = get_prompt_from_dataset(dataset)

  continuations = []
  while len(continuations) < N:
    print(f"\ncontinuing, have {len(continuations)} of {N}\n")      

    this_batch_continuations = get_prompted_continuation(prompt, 
                                                         continue_if_cut_off=continue_if_cut_off,
                                                         max_continue_steps=max_continue_steps)

    for c in this_batch_continuations:
      if any([control_char in c for control_char in {Q_CHAR, A_CHAR, ORIG_POST_CHAR, UNAME_CHAR}]):
        if split_on_control_char:
          min_ix = min([i for i, char in enumerate(c) if char in {Q_CHAR, A_CHAR, ORIG_POST_CHAR, UNAME_CHAR}])
          csub = c[:min_ix]
          print(f"splitting on control char:")
          print(f"\t{len(c)} chars, {len(c.split(' '))} words-->\n\t{len(csub)} chars, {len(csub.split(' '))} words")
          c = csub
        else:
          print(f"rejecting because control char: \n{fill(c)}\n")
          continue

      if len(c.split(" ")) < avoid_if_under:
        print(f"rejecting because length under {avoid_if_under}: \n{fill(c)}\n")
      elif (not c.endswith("<|")) and avoid_if_cut_off:
        print(f"rejecting because cut off: \n{fill(c)}\n")
      elif (c.startswith("<blockquote")) and avoid_initial_blockquote:
        print(f"rejecting because initial blockquote: \n{fill(c)}\n")
      else:
        continuations.append(c)
        
  for continuation in continuations:
    if prompt_from_dataset:
      continuation = prompt + continuation
      if continuation.startswith("endoftext|>"):
        continuation = continuation[len("endoftext|>"):].lstrip(ORIG_POST_CHAR)

  return continuations

In [None]:
import requests, time
RESULT_STACK = {}

def serve_answer(data):
  prompt = data["prompt"]
  kwargs = data["kwargs"]
  avoid_if_under = kwargs.get("avoid_if_under", 20)
  avoid_if_cut_off = kwargs.get("avoid_if_cut_off", True)
  split_on_control_char = kwargs.get("split_on_control_char", False)
  avoid_initial_blockquote = kwargs.get("avoid_initial_blockquote", True)

  continue_if_cut_off = kwargs.get("continue_if_cut_off", True)
  if continue_if_cut_off:
    avoid_if_cut_off = False

  if kwargs.get("V5"):
    try:
      continuations = basic_n_continuations(prompt, N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under, 
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            split_on_control_char=split_on_control_char,
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=continue_if_cut_off)
    except Exception as e:
      print(f"got {e}, trying without continue_if_cut_off")
      continuations = basic_n_continuations(prompt, N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under, 
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            split_on_control_char=split_on_control_char,
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=False)
    parsed = data.copy()
    parsed["continuations"] = continuations
  else:
    kwargs = {k: v for k, v in kwargs.items() if k != "V5"}
    continuations = get_prompted_continuation_with_length_proportional_sampling(prompt, **kwargs)
    continuation = continuations[0]
  
    parsed = parse_continuation(continuation)
    
  return parsed

def serve_textpost(data):
  prompt = ""
  kwargs = data["kwargs"]
  avoid_if_under = kwargs.get("avoid_if_under", 30)
  avoid_if_cut_off = kwargs.get("avoid_if_cut_off", True)
  split_on_control_char = kwargs.get("split_on_control_char", True)
  avoid_initial_blockquote = kwargs.get("avoid_initial_blockquote", False)

  continue_if_cut_off = kwargs.get("continue_if_cut_off", True)
  if continue_if_cut_off:
    avoid_if_cut_off = False

  if kwargs.get("V5"):
    try:
      continuations = basic_n_continuations(prompt, 
                                            N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under,
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            split_on_control_char=split_on_control_char,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=continue_if_cut_off)
    except Exception as e:
      print(f"got {e}, trying without continue_if_cut_off")
      continuations = basic_n_continuations(prompt, 
                                            N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under,
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            split_on_control_char=split_on_control_char,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=False)
    parsed = data.copy()
    parsed["continuations"] = continuations
  else:
    kwargs = {k: v for k, v in kwargs.items() if k != "V5"}
    continuations = get_prompted_continuation_with_retries_for_length(prompt, **kwargs)
    continuation = continuations[0]
  
    parsed = parse_continuation(continuation)
    
  return parsed


def poll():
  global RESULT_STACK

  r = requests.post(generator_url, json={"results": RESULT_STACK})

  PROMPT_STACK = r.json()
  
  RESULT_STACK = {k: v for k, v in RESULT_STACK.items() if k in PROMPT_STACK}  # clean out already used results

  for prompt_id, data in PROMPT_STACK.items():
    print("generating...")
    if data["type"] == "answer":
      RESULT_STACK[prompt_id] = serve_answer(data)
    elif data["type"] == "textpost":
      RESULT_STACK[prompt_id] = serve_textpost(data)

  print("done generating for this poll")

  if len(PROMPT_STACK) > 0:
    r = requests.post(generator_url, json={"results": RESULT_STACK})
    time.sleep(1)

import time

def loop_poll(period=60):
  global RESULT_STACK
  while True:
    try:
      poll()
    except Exception as e:
      print(f"{type(e)}: {e}")
      time.sleep(period*10)
    if len(RESULT_STACK) == 0:
      time.sleep(period)

Main loop of the generator

In [None]:
RESULT_STACK = {}

loop_poll(period=10)