Generator component for nostalgebraist-autoresponder.

Needs a fine-tuned GPT-2 model appropriate for use with the rest of the nostalgebraist-autoresponder codebase, and a running instance of the bridge service to talk to.

In [None]:
# make sure you have a GPU with ~16GB memory, for 1558M

!nvidia-smi

In [None]:
BRIDGE_SERVICE_URL = ""  # fill yours in -- make sure it's acccessible from wherever this is running
generator_url = BRIDGE_SERVICE_URL + "/pollgenerator"

In [None]:
# cell for Colab-specific stuff, if you're on Colab

# if you're using Google Colab, need to tell it not to use tf 2.x
%tensorflow_version 1.x

# if you're using Google Colab + Google Drive, mount and change dir
from google.colab import drive
drive.mount('/content/drive')

%cd "/content/drive/My Drive/"

In [None]:
# i assume you're now somewhere with a directory called "gpt-2" holding my gpt-2 fork
# and your fine-tuned model
%cd "gpt-2"

import os, sys
sys.path.append("src")

%pip install -r "requirements.txt"

In [None]:
import fire
import json
import os
import numpy as np
import tensorflow as tf

import model, sample, encoder
from load_dataset import load_dataset, Sampler

model_name = ""  # fill in -- should be a directory under /models
dataset = ""  # fill in -- should be a directory under /data

EOT_WORKAROUND = True
EOT_PREPEND = True

eot_end_segment = "<|endoftext|>" if EOT_WORKAROUND else "<|"

better_length = True

# sets max context size, for long prompts we want to cut off to allow bot to write at least this many tokens
required_continuation_room = 100 # 385 #500 

batch_size = 4
nsamples = batch_size

seed = None

if better_length:
  length=825
else:
  length=625

EXPERIMENTAL_TOP_P = False
EXPERIMENTAL_MIDDLE_P_TWEAK = False

temperature=0.95
top_k=0
top_p=0
middle_p=0.85

batch_size = 4
nsamples = batch_size
if better_length:
    length=800
else:
    length=700

In [None]:
# is the selector is a model using activations from the generator?
SELECT_VIA_GENERATOR = True

# if so, it will run in this notebook -- need a checkpoint for it, and some hparams

ckpt_select = "" # fill this in, should be a .hdf5 file you saved using `train_generator_to_select.ipynb`

# fill these in -- should match hparams you used in `train_generator_to_select.ipynb`
#
# TODO: DRY (really these should go in a json file with the checkpoint or something)
layer_nums = [24-1, 36-1]
do_resid = False
norm_layers_after = False
use_mlp = True
resid_mlp = True

In [None]:
enc = encoder.get_encoder(model_name, eot_workaround=EOT_WORKAROUND)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
    hparams.override_from_dict(json.load(f))

hparams.set_hparam("attn_dropout", 0)
hparams.set_hparam("res_dropout", 0)

if dataset is not None:
    chunks = load_dataset(enc, dataset, 50000)
    data_sampler = Sampler(chunks)
    start_token = None
else:
    context_tokens = None
    start_token = enc.encoder['<|endoftext|>']
if length is None:
    length = hparams.n_ctx
elif length > hparams.n_ctx:
    raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

In [None]:
import tflex

load_done = False

while not load_done:
  try:
    tf.reset_default_graph()
    if CPU:
      sess = tf.Session()
    else:
      sess = tflex.Session()

    with sess.as_default():
        np.random.seed(seed)
        tf.set_random_seed(seed)

        if start_token is None:
            context = tf.placeholder(tf.int32, [batch_size, None])
        else:
            context = None

        start_ix = 1 if start_token is not None else 0
        output = sample.sample_sequence(stop_at_EOT=True, better_length=better_length,
                                        eot_workaround=EOT_WORKAROUND,
                                        enc=enc, 
            hparams=hparams, length=length,
            start_token=start_token,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p, 
            middle_p=middle_p
        )[:, start_ix:]

        saver = tflex.Saver()
        ckpt = tflex.latest_checkpoint(os.path.join('models', model_name))

        print(f"restoring checkpoint: {ckpt}")
        saver.restore(sess, ckpt)
    load_done = True
  except Exception as e:
    print(f"encountered {e}, retrying...")

Business logic

TODO: DRY (really this should be in a .py file somewhere)

TODO: clean up this stuff. Change to "V5" was long ago, blocks that check if we're in "V5" can be changed to assume it

In [None]:
# code related to generation, prompting, etc

import re
from textwrap import fill, wrap

Q_CHAR = "会"
A_CHAR = "域"
T_CHAR = "职"
ORIG_POST_CHAR = "翰"
UNAME_CHAR = "友"

Q_CHAR = "会"
A_CHAR = "域"
T_CHAR = "职"
ORIG_POST_CHAR = "翰"
UNAME_CHAR = "友"

def get_prompted_continuation(prompt: str, 
                              continue_if_cut_off=False, 
                              max_continue_steps=12,
                              verbose=False):
    raw_text = prompt
    raw_text = re.sub(r"\\n", "\n", raw_text)
    context_tokens = enc.encode(raw_text)

    if better_length:
      max_context_size = length - required_continuation_room
    else:
      max_context_size = hparams.n_ctx - length - 10
    if len(context_tokens) > max_context_size:
      orig_len = len(context_tokens)
      context_tokens = context_tokens[-(max_context_size):]
      print(f"truncated {orig_len} to {len(context_tokens)}, max_context_size={max_context_size}")
    else:
      print(f"{len(context_tokens)} tokens can fit in max_context_size {max_context_size}")

    token_start_ix = len(context_tokens)

    batch_context_tokens = [context_tokens for _ in range(batch_size)]
    continuations = [[prompt] for _ in batch_context_tokens]
    is_repeating = [False for _ in batch_context_tokens]
    generated = 0
    this_batch_continue_steps = 0

    done = False
    while not done:
      with sess.as_default():
          out = sess.run(output, feed_dict={
              context: batch_context_tokens
          })[:, token_start_ix:]
      for i in range(batch_size):
          generated += 1
          text = enc.decode(out[i])

          if len(set(out[i])) >= 0.2*len(out[i]):
            continuations[i].append(text)
            is_repeating[i] = False
          else:
            continuations[i].append("")
            is_repeating[i] = True

      if continue_if_cut_off:
        next_prompts = ["".join(subtexts) for subtexts in continuations]
        batch_context_tokens = [enc.encode(text)[-(max_context_size):] for text in next_prompts]
        token_start_ix = len(batch_context_tokens[0])

        next_prompts_contonly = ["".join(subtexts[1:]) for subtexts in continuations]
        not_finished = [c for c, rep in zip(next_prompts_contonly, is_repeating)
                          if (eot_end_segment not in c) and
                          (not any([control_char in c for control_char in {Q_CHAR, A_CHAR, ORIG_POST_CHAR, UNAME_CHAR}])) and
                          (len([char for char in c if char == T_CHAR]) < 2) and
                          not rep
                       ]
        n_not_finished = len(not_finished)
        more_needed =  n_not_finished > 0
        more_permitted = this_batch_continue_steps < max_continue_steps

        done = (not more_needed) or (not more_permitted)
        if not done:
          print("continuing within batch:")
          print(f"\t{n_not_finished}/{len(next_prompts)} unfinished")
          print(f"\t{this_batch_continue_steps}/{max_continue_steps} continue steps used")

          if verbose:
            print("Using prompts:")
            for np in not_finished:
              print("\t" + "\n\t".join(wrap(np, width=90)) + "\n")

          this_batch_continue_steps += 1
      else:
        done = True
    
    # cleanup
    continuations_ = []
    for subtexts in continuations:
      text = "".join(subtexts[1:])  # don't return prompt as part of these
      if not text.endswith(eot_end_segment) and eot_end_segment in text:
        continuations_.append(text.split(eot_end_segment)[0] + eot_end_segment)
      else:
        continuations_.append(text)

    return continuations_

def get_prompted_continuation_with_retries_for_length(prompt: str, retry_if_under: int=60, best_of: int=3, prompt_from_dataset=False, verbose=False):
    n_words = 0
    n_retries = 0
    best = {"post": "", "tags": []}

    while (n_words < retry_if_under) and (n_retries < (best_of // batch_size)):
        n_retries += 1
        if prompt_from_dataset:
                prompt = get_prompt_from_dataset(dataset)
                if verbose:
                    print(f"Using prompt: {prompt}")
        if verbose:
            print(f"try #{n_retries} of {best_of // batch_size}...")
        continuations = get_prompted_continuation(prompt)
        #continuation = continuations[0]
        for continuation in continuations:
          if prompt_from_dataset:
              continuation = prompt + continuation
              if continuation.startswith("endoftext|>"):
                  continuation = continuation[len("endoftext|>"):]

          parsed = parse_continuation(continuation, verbose=verbose)
          n_words_latest = len(parsed['post'].split(" "))
          if n_words_latest > n_words and continuation.endswith(eot_end_segment) and Q_CHAR not in continuation and A_CHAR not in continuation:
              if verbose:
                  print(f"n_words {n_words_latest} beats previous high of {n_words}")
              n_words = n_words_latest
              best = continuation

    return [best]

def get_prompted_continuation_with_length_proportional_sampling(prompt: str, avoid_if_under: int=20, best_of: int=3, verbose=False, prompt_from_dataset=False, return_all=False):
    n_words_by_try = []
    not_cut_off_by_try = []
    no_control_chars = []

    tries = []
    best = {"post": "", "tags": []}

    for n_retries in range(best_of // batch_size):
        if prompt_from_dataset:
                prompt = get_prompt_from_dataset(dataset)
                if verbose:
                    print(f"Using prompt: {prompt}")
        if verbose:
            print(f"try #{n_retries} of {best_of // batch_size}...")
        continuations = get_prompted_continuation(prompt)
        #continuation = continuations[0]
        for continuation in continuations:
          if prompt_from_dataset:
              continuation = prompt + continuation
              if continuation.startswith("endoftext|>"):
                  continuation = continuation[len("endoftext|>"):]
          parsed = parse_continuation(continuation, verbose=verbose)
          n_words_latest = len(parsed['post'].split(" "))
          if verbose:
              print(f"n_words {n_words_latest}")
          n_words_by_try.append(n_words_latest)
          not_cut_off_by_try.append(continuation.endswith(eot_end_segment))
          no_control_chars.append(Q_CHAR not in continuation and A_CHAR not in continuation)
          tries.append(continuation)

    keep = [(nw >= avoid_if_under) and (nco) and (ncc)
              for nw, nco, ncc in zip(n_words_by_try, not_cut_off_by_try, no_control_chars)]
    if any(keep):
        n_words_by_try = [nw for nw, k in zip(n_words_by_try, keep) if k]
        tries = [t for t, k in zip(tries, keep) if k]

    probs = np.asarray(n_words_by_try) / sum(n_words_by_try)

    if verbose:
        print(f"choosing between word counts {n_words_by_try}\nprobs {probs}")
    choice_ix = np.random.choice(list(range(len(probs))), p=probs)

    best = tries[choice_ix]
    if verbose:
        print(f"chose #{choice_ix} ({n_words_by_try[choice_ix]} words, prob {probs[choice_ix]})")

    if return_all:
        return tries
    return [best]

def parse_continuation(continuation: str, verbose=True):
    if verbose:
        print(f"parsing the following raw output:\n------------------\n{fill(continuation)}\n------------------\n")

    # split out tags, if present
    post, _ , tag_text = continuation.partition(T_CHAR)
    tags = []
    if len(tag_text) > 0:
        tags = [s.rstrip(" ") for s in tag_text.split("#")]

    post = post.lstrip(ORIG_POST_CHAR) # TODO: fix this in get_prompted_continuation_with_length_proportional_sampling
    parsed = {"post": post, "tags": tags}
    return parsed

def get_prompt_from_dataset(dataset):
    global data_sampler
    if data_sampler is None:
        print("getting data sampler...")
        chunks = load_dataset(enc, dataset, 50000)
        data_sampler = Sampler(chunks)

    segment = "会"
    segments = []
    #while segment[0]=="会": # V3
    while segment[0] != "翰": # V4
        while len(segments) == 0:
            segments = enc.decode(data_sampler.sample(1024)).split("<|endoftext|>")[1:]
        segment = segments.pop()
    if EOT_WORKAROUND:
      if EOT_PREPEND:
        segment = "<|endoftext|>" + segment
        context_tokens = enc.encode(segment)[:5]
      else:
        context_tokens = enc.encode(segment)[:4]
      print(f'using context_tokens {context_tokens} = {enc.decode(context_tokens)}')
    else:
      context_tokens = enc.encode("endoftext|>" + segment)[:len(enc.encode("endoftext|>"))+3]
    prompt = enc.decode(context_tokens)
    

    return prompt

def basic_n_continuations(prompt, N, 
                          avoid_if_under=20, 
                          avoid_if_cut_off=True, 
                          split_on_control_char=False,
                          prompt_from_dataset=False,
                          avoid_initial_blockquote=False,
                          continue_if_cut_off=False,
                          max_continue_steps=12,
                          verbose=False):
  if prompt_from_dataset:
      prompt = get_prompt_from_dataset(dataset)

  continuations = []
  while len(continuations) < N:
    print(f"\ncontinuing, have {len(continuations)} of {N}\n")      

    this_batch_continuations = get_prompted_continuation(prompt, 
                                                         continue_if_cut_off=continue_if_cut_off,
                                                         max_continue_steps=max_continue_steps,
                                                         verbose=verbose,
                                                         )

    for c in this_batch_continuations:
      if any([control_char in c for control_char in {Q_CHAR, A_CHAR, ORIG_POST_CHAR, UNAME_CHAR}]):
        if split_on_control_char:
          min_ix = min([i for i, char in enumerate(c) if char in {Q_CHAR, A_CHAR, ORIG_POST_CHAR, UNAME_CHAR}])
          csub = c[:min_ix]
          print(f"splitting on control char:")
          print(f"\t{len(c)} chars, {len(c.split(' '))} words-->\n\t{len(csub)} chars, {len(csub.split(' '))} words")
          c = csub
        else:
          print(f"rejecting because control char: \n{fill(c)}\n")
          continue

      if len(c.split(" ")) < avoid_if_under:
        print(f"rejecting because length under {avoid_if_under}: \n{fill(c)}\n")
      elif (not c.endswith(eot_end_segment)) and avoid_if_cut_off:
        print(f"rejecting because cut off: \n{fill(c)}\n")
      elif (c.startswith("<blockquote")) and avoid_initial_blockquote:
        print(f"rejecting because initial blockquote: \n{fill(c)}\n")
      elif (len([char for char in c if char == T_CHAR]) >= 2):
        print(f"rejecting because multiple T_CHAR: \n{fill(c)}\n")
      else:
        continuations.append(c)
        
  continuations_ = []
  for continuation in continuations:
    if prompt_from_dataset:
      continuation = prompt + continuation
      if EOT_PREPEND and continuation.startswith("<|endoftext|>"):
        continuation = continuation[len("<|endoftext|>"):].lstrip(ORIG_POST_CHAR)
    continuations_.append(continuation)
  continuations = continuations_

  return continuations

If the selector model uses generator activations (and thus runs in this notebook), we need to declare the code defining it.  

TODO: DRY (this, too, *really* should be in a .py source file)

In [None]:
# static declarations for the selector model -- this is copy/paste ie repeating myself :(

from textwrap import fill, wrap
import scipy.special

def show_note_probas(texts, probas, continuation_sentiments=None):
    if continuation_sentiments is None:
        for tpe, proba in zip(texts, probas):
            print(f"\tpredicted prob: {proba:.1%}\n")
            print("\n~_~_~_~_~_\n")
            print("\n".join(wrap(tpe, replace_whitespace=False)))
            print("\n~_~_~_~_~_\n")
    else:
        for tpe, proba, sent in zip(texts, probas, continuation_sentiments):
            print(f"\tpredicted prob: {proba:.1%}, pos_sent {pos_sent(sent):.1%}\n")
            print("\n~_~_~_~_~_\n")
            print("\n".join(wrap(tpe, replace_whitespace=False)))
            print("\n~_~_~_~_~_\n")

if SELECT_VIA_GENERATOR:
  from model import *

  SELECTION_CHAR = "<|endoftext|>"
  SELECTION_TOK = enc.encode(SELECTION_CHAR)[-1]

  def extract_selection_ix(tokens, extract_from):
    mask = tf.equal(tf.dtypes.cast(tokens, tf.int32), SELECTION_TOK)
    extracted_ragged = tf.ragged.boolean_mask(extract_from, mask)
    
    row_lengths = extracted_ragged.row_lengths()
    row_ixs = row_lengths-1
    selection_ix = tf.stack([tf.range(0, batch_size, dtype=tf.int64), row_ixs], axis=1,)

    extracted = tf.gather_nd(extracted_ragged.to_tensor(), selection_ix,  )

    return {"extracted": extracted, "selection_ix": selection_ix}

  def model_activations(hparams, X, hparams_select, 
                      layer_nums: list, 
                      norm_layers_after: bool=False,
                      past=None, past_select=None,
                      scope='model', reuse=tf.AUTO_REUSE):
    activations = []
    h_names = []
    
    dtype = hparams.dtype if hparams else tf.float32
    with tf.variable_scope(scope, reuse=reuse, dtype=dtype):
        results = {}
        batch, sequence = shape_list(X)

        wpe = get_variable('wpe') or tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                             initializer=tf.random_normal_initializer(stddev=0.01, dtype=dtype))
        wte = get_variable('wte') or tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                             initializer=tf.random_normal_initializer(stddev=0.02, dtype=dtype))
        past_length = 0 if past is None else tf.shape(past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

        # Transformer
        presents = []
        pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
            if layer in layer_nums:
              h_name = f'h{layer}'
              print(f'{h_name} found')
              h_names.append(h_name)
              activations.append(h)

        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f', hparams=hparams)

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits

        # activations
        if norm_layers_after:
          activations = [norm(act, f'ln_after_{act_name}', hparams=hparams_select)
                         for act_name, act in zip(h_names, activations)]

        results['activations'] = list(zip(h_names, activations))

        return results

  def get_initializer(hparams, scope):
    initializer = tf.random_normal_initializer
    if hparams.get("orth_init"):
      print(f"orth init in scope {scope}")
      initializer = tf.compat.v1.orthogonal_initializer
    return initializer

  def conv1d(x, scope, nf, *, w_init_stdev=0.02, hparams=None):
      dtype = hparams.dtype if hparams else tf.float32

      initializer = get_initializer(hparams, scope)
      with tf.variable_scope(scope, dtype=dtype):
          *start, nx = shape_list(x)
          w = get_variable('w') or tf.get_variable('w', [1, nx, nf], initializer=initializer(w_init_stdev, dtype=dtype))
          b = get_variable('b') or tf.get_variable('b', [nf], initializer=tf.constant_initializer(0, dtype=dtype))
          c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
          return c

  # this is a copy/paste -- we need to redefine "attn" so the "conv1d" defn'd above
  # is used

  def attn(x, scope, n_state, *, past, hparams):
      assert x.shape.ndims == 3  # Should be [batch, sequence, features]
      assert n_state % hparams.n_head == 0
      if past is not None:
          assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

      def split_heads(x):
          # From [batch, sequence, features] to [batch, heads, sequence, features]
          return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])

      def merge_heads(x):
          # Reverse of split_heads
          return merge_states(tf.transpose(x, [0, 2, 1, 3]))

      def mask_attn_weights(w):
          # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
          _, _, nd, ns = shape_list(w)
          b = attention_mask(nd, ns, dtype=w.dtype)
          b = tf.reshape(b, [1, 1, nd, ns])
          w = w*b - tf.cast(65500 if w.dtype != tf.float32 else 1e10, w.dtype)*(1-b)
          return w

      def multihead_attn(q, k, v):
          # q, k, v have shape [batch, heads, sequence, features]
          w = tf.matmul(q, k, transpose_b=True)
          w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))

          w = mask_attn_weights(w)
          w = softmax(w)
          w = dropout(w, hparams.attn_dropout)
          a = tf.matmul(w, v)
          return a

      dtype = hparams.dtype if hparams else tf.float32
      with tf.variable_scope(scope, dtype=dtype):
          c = conv1d(x, 'c_attn', n_state*3, hparams=hparams)
          q, k, v = map(split_heads, tf.split(c, 3, axis=2))
          present = tf.stack([k, v], axis=1)
          if past is not None:
              pk, pv = tf.unstack(past, axis=1)
              k = tf.concat([pk, k], axis=-2)
              v = tf.concat([pv, v], axis=-2)
          a = multihead_attn(q, k, v)
          a = merge_heads(a)
          a = conv1d(a, 'c_proj', n_state, hparams=hparams)
          a = dropout(a, hparams.res_dropout)
          return a, present

  def attn_only_block(x, scope, *, past, hparams, do_input_norm=True):
      dtype = hparams.dtype if hparams else tf.float32
      do_resid = hparams.do_resid if hparams else True
      print(f"do_resid: {do_resid}")
      print(f"do_input_norm: {do_input_norm}")
      with tf.variable_scope(scope, dtype=dtype):
          nx = x.shape[-1].value

          if do_input_norm:
            x_attn_in = norm(x, 'ln_1', hparams=hparams)
          else:
            x_attn_in = x
          a, present = attn(x_attn_in, 'attn', nx, past=past, hparams=hparams)
          if do_resid:
            x = x + a
          else:
            x = a

          return x, present

  def mlp_no_proj(x, scope, n_state, *, hparams, is_expansion=False):
    dtype = hparams.dtype if hparams else tf.float32
    with tf.variable_scope(scope, dtype=dtype):
        nx = x.shape[-1].value
        h = gelu(conv1d(x, 'c_fc', n_state,
                        w_init_stdev=0.02,
                        hparams=hparams))
        h = dropout(h, hparams.res_dropout)
        return h

  def selector(hparams, X, hparams_select, 
             layer_nums: list, 
             scope="model", 
             reuse=tf.AUTO_REUSE,
             norm_layers_after: bool=False,
             use_mlp: bool=True,
             resid_mlp: bool=True,
             ):
    results = {}

    activations = model_activations(
        hparams=hparams, hparams_select=hparams_select,
        X=X, layer_nums=layer_nums,
          norm_layers_after=norm_layers_after,
          scope=scope, reuse=reuse,
        )['activations']
    
    hs_select = []
    for act_name, act in activations:
      h_select, _ = attn_only_block(act, f'h_select_{act_name}', 
                                    hparams=hparams_select,
                                    past=None,
                                    do_input_norm=(not norm_layers_after))
      h_select = norm(h_select, f'ln_2_select_{act_name}', hparams=hparams_select,)
      hs_select.append(h_select)

      h_select_in = tf.concat(hs_select, axis=-1)
        
      h_select_in_at_selection_ix = extract_selection_ix(X, h_select_in)['extracted']
      
    with tf.variable_scope(scope, reuse=reuse, dtype=hparams_select.dtype):
      if use_mlp:
        m = mlp_no_proj(h_select_in_at_selection_ix, "select_mlp__", len(layer_nums)*hparams.n_embd, hparams=hparams_select)
        if resid_mlp:
          h_select_in_at_selection_ix = m + h_select_in_at_selection_ix
        else:
          h_select_in_at_selection_ix = m
        
      
      w_select = get_variable('w_select_')
      if w_select is None:
        initializer = get_initializer(hparams_select, scope)
        w_select = tf.get_variable('w_select_', [len(layer_nums)*hparams.n_embd, 2],
                                  initializer=initializer(0.02, dtype=hparams.dtype))
      
      b_select = get_variable('b_select')
      if b_select is None:
        b_select = tf.get_variable('b_select', [2],
                                  initializer=tf.constant_initializer(0, dtype=hparams.dtype))
        
      select_logits = tf.matmul(h_select_in_at_selection_ix, w_select) + b_select

    results['logits_select'] = select_logits

    return results

  def single_batch_predict_select(text_batch, threshold=0.5, debug=False):
    if len(text_batch) != batch_size:
      raise ValueError("badlength")
    batch_context = []
    for text in text_batch:
      for end_segment in {eot_end_segment, '<|'}:  # explicitly support old <| thing, for now
          if text.endswith(end_segment):
            text = text[:-len(end_segment)]
      batch_context.append(enc.encode(text)[-(length-1):] + [SELECTION_TOK])
      if debug:
        print(f"predicting on:\n{enc.decode(batch_context[-1])}\n")
    max_tokens = max([len(toks) for toks in batch_context])
    batch_context_ = [toks + [0 for _ in range(max_tokens-len(toks))] for toks in batch_context]
    batch_context = batch_context_

    with sess.as_default():
      logits = sess.run(select_logits, feed_dict={context_for_h.name: batch_context})

    probs = scipy.special.softmax(logits, axis=1)[:, 1]
    results = {"logits": logits, "probs": probs, "preds": probs>threshold}
    return results

  def predict_select(texts, threshold=0.5, debug=False):
    batches = []

    for i in range(0, len(texts), batch_size):
      batch = texts[i:i+batch_size]

      while len(batch) != batch_size:
        batch = batch + [batch[-1] for _ in range(batch_size - len(batch))]
      batches.append(batch)

    batch_results = [single_batch_predict_select(batch, threshold=threshold, debug=debug) for batch in batches]
    
    result_keys = batch_results[0].keys()
    results = {k: np.concatenate([br[k] for br in batch_results])[:len(texts)]
              for k in result_keys}
    
    return results

In [None]:
# selector: add tf ops to current session, load checkpoint

if SELECT_VIA_GENERATOR:
  hparams_select = HParams(
        n_vocab=hparams.n_vocab,
        n_ctx=hparams.n_ctx,
        n_embd=hparams.n_embd,
        n_head=hparams.n_head,
        n_layer=hparams.n_layer,
        res_dropout=0,
        attn_dropout=0,
        dtype=tf.float32,
        do_resid=do_resid,
        orth_init=True,
    )

  with sess.as_default():
    context_for_h = tf.placeholder(tf.int32, [batch_size, None])
    
    selection_step = selector(
    hparams=hparams, hparams_select=hparams_select,
      X=context_for_h, layer_nums=layer_nums,
      norm_layers_after=norm_layers_after,
      use_mlp=use_mlp, resid_mlp=resid_mlp
    )
    
    select_logits = selection_step['logits_select']

  var_list = [var for var in tf.trainable_variables() if "select" in var.name]

  done = False
  while not done:
    try:
      tflex.load_variables(ckpt_select, session=sess, var_list=var_list)
      done=True
    except Exception as e:
      print(f"encountered {e}, retrying...")

  display(var_list)

In [None]:
# code setting up this notebook's interactions with the bridge service

import requests, time
RESULT_STACK = {}

def serve_answer(data):
  prompt = data["prompt"]
  kwargs = data["kwargs"]
  avoid_if_under = kwargs.get("avoid_if_under", 20)
  avoid_if_cut_off = kwargs.get("avoid_if_cut_off", True)
  split_on_control_char = kwargs.get("split_on_control_char", False)
  avoid_initial_blockquote = kwargs.get("avoid_initial_blockquote", True)

  continue_if_cut_off = kwargs.get("continue_if_cut_off", True)
  if continue_if_cut_off:
    avoid_if_cut_off = False

  if kwargs.get("V5"):
    try:
      continuations = basic_n_continuations(prompt, N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under, 
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            split_on_control_char=split_on_control_char,
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=continue_if_cut_off)
    except Exception as e:
      print(f"got {e}, trying without continue_if_cut_off")
      continuations = basic_n_continuations(prompt, N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under, 
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            split_on_control_char=split_on_control_char,
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=False)
    parsed = data.copy()
    parsed["continuations"] = continuations
  else:
    kwargs = {k: v for k, v in kwargs.items() if k != "V5"}
    continuations = get_prompted_continuation_with_length_proportional_sampling(prompt, **kwargs)
    continuation = continuations[0]
  
    parsed = parse_continuation(continuation)
    
  return parsed

def serve_textpost(data):
  prompt = ""
  kwargs = data["kwargs"]
  avoid_if_under = kwargs.get("avoid_if_under", 30)
  avoid_if_cut_off = kwargs.get("avoid_if_cut_off", True)
  split_on_control_char = kwargs.get("split_on_control_char", True)
  avoid_initial_blockquote = kwargs.get("avoid_initial_blockquote", False)

  continue_if_cut_off = kwargs.get("continue_if_cut_off", True)
  if continue_if_cut_off:
    avoid_if_cut_off = False

  if kwargs.get("V5"):
    try:
      continuations = basic_n_continuations(prompt, 
                                            N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under,
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            split_on_control_char=split_on_control_char,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=continue_if_cut_off)
    except Exception as e:
      print(f"got {e}, trying without continue_if_cut_off")
      continuations = basic_n_continuations(prompt, 
                                            N=kwargs['best_of'], 
                                            avoid_if_under=avoid_if_under,
                                            avoid_if_cut_off=avoid_if_cut_off,
                                            split_on_control_char=split_on_control_char,
                                            prompt_from_dataset=kwargs.get("prompt_from_dataset"),
                                            avoid_initial_blockquote=avoid_initial_blockquote,
                                            continue_if_cut_off=False)
    parsed = data.copy()
    parsed["continuations"] = continuations
  else:
    kwargs = {k: v for k, v in kwargs.items() if k != "V5"}
    continuations = get_prompted_continuation_with_retries_for_length(prompt, **kwargs)
    continuation = continuations[0]
  
    parsed = parse_continuation(continuation)
    
  return parsed


def poll():
  global RESULT_STACK

  r = requests.post(generator_url, json={"results": RESULT_STACK})

  PROMPT_STACK = r.json()
  
  RESULT_STACK = {k: v for k, v in RESULT_STACK.items() if k in PROMPT_STACK}  # clean out already used results

  for prompt_id, data in PROMPT_STACK.items():
    print("generating...")
    if data["type"] == "answer":
      RESULT_STACK[prompt_id] = serve_answer(data)
    elif data["type"] == "textpost":
      RESULT_STACK[prompt_id] = serve_textpost(data)

  print("done generating for this poll")

  if len(PROMPT_STACK) > 0:
    r = requests.post(generator_url, json={"results": RESULT_STACK})
    time.sleep(1)

import time

def loop_poll(period=60):
  global RESULT_STACK
  while True:
    try:
      poll()
    except Exception as e:
      print(f"{type(e)}: {e}")
      time.sleep(period*10)
    if len(RESULT_STACK) == 0:
      time.sleep(period)

Main loop of the generator

In [None]:
RESULT_STACK = {}

loop_poll(period=10)