Train the selector model on Google Colab (more recent approach)

Unlike `train_selector.ipynb`, this trains a model whose inputs are layer activations from the generator model, rather than a model whose inputs are the text which the generator has created.  Thus it needs access to the generator model, and is coupled to it.  (If you re-train the generator, you must re-train this model on the new one.)

Similar assumptions about your environment to those in `generator.ipynb`

In [None]:
# make sure you have a GPU with ~16GB memory, for 1558M

!nvidia-smi

In [None]:
# cell for Colab-specific stuff, if you're on Colab

# if you're using Google Colab, need to tell it not to use tf 2.x
%tensorflow_version 1.x

# if you're using Google Colab + Google Drive, mount and change dir
from google.colab import drive
drive.mount('/content/drive')

%cd "/content/drive/My Drive/"

# i assume you're now somewhere with a directory called "gpt-2" holding my gpt-2 fork
# and your fine-tuned model
%cd "gpt-2"

In [None]:
import numpy as np, pandas as pd
import pickle

import matplotlib.pyplot as plt
%matplotlib inline

load/prep training corpus -- same as in `train_selector.ipynb` (TODO: DRY)

In [None]:
# you should have created this file by scraping posts/note counts, and put it on google drive
data_path = "reward/reward.pkl.gz"
with open(data_path, "rb") as f:
    ids_to_reward_data = pickle.load(f)["ids_to_reward_data"]

In [None]:
import re
def inverse_format_post_for_api(post):
    if post.startswith("<p>"):
        post = post[len("<p>"):]
    if post.endswith("</p>"):
        post = post[:-len("</p>")]
    # post = post.lstrip("<p>").rstrip("</p>")
    post = re.sub(r"</p><p>", "\n", post)
    post = re.sub(r"<br>", "\n", post)
    return post

def make_train_data(ids_to_reward_data, continuation_only=True, prompt_as_col=True):
    train_data = []
    for k, v in ids_to_reward_data.items():
      if v.get("note_count") is None or v.get('continuation') is None:
        continue
      if continuation_only:
        train_data.append([k, v["continuation"], v["note_count"]])
      else:
        train_data.append([k, " ".join(v["prompt"].split(" ")[-64:]) + v["continuation"], v["note_count"]])
      if prompt_as_col:
        train_data[-1].append(v["prompt"])
        
    if prompt_as_col:
      train_data = pd.DataFrame(train_data, columns=["id", "text", "note_count", "prompt"])
    else:
      train_data = pd.DataFrame(train_data, columns=["id", "text", "note_count"])

    train_data.text = train_data.text.apply(inverse_format_post_for_api)
    train_data.text = train_data.text.apply(lambda s: s.lstrip("\n"))

    return train_data

In [None]:
train_data = make_train_data(ids_to_reward_data, continuation_only=True)
train_data.note_count.describe()

In [None]:
temporally_ordered_train_data = train_data.sort_values(by="id").reset_index()

In [None]:
def non_overlapping_ma(array, width=31):
  return pd.Series([np.average(array[ix:ix+width], )
   for ix in range(0, len(array), width)])

window_width = 140
window_halfw = window_width//2

skip_n_most_recent = 40

allow_partial_windows = False
window_frac_left = 0.8 # None

rolling_quantiles = {}
rolling_advantages = {}

if window_frac_left is not None:
  window_shift_left = -1*int(window_frac_left*window_width)
  window_shift_right = window_width + window_shift_left
else:
  window_shift_left = -window_halfw
  window_shift_right = window_halfw

last_ix_allowed = len(temporally_ordered_train_data) - skip_n_most_recent

if allow_partial_windows:
  ixs = temporally_ordered_train_data.index[:last_ix_allowed]
else:
  ixs = temporally_ordered_train_data.index[(0-window_shift_left):(last_ix_allowed-window_shift_right)]

print(f"using ({ixs.min()} to {ixs.max()}) of (0 to {len(temporally_ordered_train_data)-1})")

for ix in ixs:
  point = temporally_ordered_train_data.loc[ix, 'note_count']
  window = temporally_ordered_train_data.loc[ix+window_shift_left:ix+window_shift_right, 'note_count']
  rolling_quantiles[ix] = (point>=window).mean()
  rolling_advantages[ix] = (point-window).mean()

rolling_quantiles = pd.Series(rolling_quantiles)
rolling_advantages = pd.Series(rolling_advantages)

non_overlapping_ma(rolling_quantiles, width=21).plot(lw=1, ls='--', marker='.', markersize=5, figsize=(10, 6));

In [None]:
use_mov_avg = True
notes_key = "rolling_quantile" if use_mov_avg else "note_count"

if use_mov_avg:
  train_data_ = temporally_ordered_train_data.loc[rolling_quantiles.index]
  train_data_["rolling_quantile"] = rolling_quantiles
  train_data_["rolling_advantage"] = rolling_advantages
else:
  train_data_ = temporally_ordered_train_data

regression = False
drop_midrange = True
smaller_midrange_dropped = False

reg_log = False
reg_cutoff = 30

continuation_only = True

#train_data = make_train_data(ids_to_reward_data, continuation_only=continuation_only)

if regression:
  notes_key = "rolling_advantage" if use_mov_avg else "note_count"
  if reg_log:
    logstart = 1-train_data_[notes_key].min()
    train_data_["target"] = train_data_[notes_key].apply(lambda x: np.log(x+logstart))
  elif reg_cutoff:
    train_data_["target"] = train_data_[notes_key].apply(lambda x: min(x, reg_cutoff))
  else:
    train_data_["target"] = train_data_[notes_key]

  stratify = None
elif drop_midrange and not use_mov_avg:
  train_data_["target"] = (train_data_[notes_key]>=4).astype(int)
  train_data_ = train_data_[(train_data_[notes_key] <= 1) | (train_data_[notes_key] >=4)]
  stratify = train_data_["target"]
elif drop_midrange and use_mov_avg:
  if smaller_midrange_dropped:
    MIDRANGE_BOTTOM = np.percentile(train_data_[notes_key], 30)
    MIDRANGE_TOP = np.percentile(train_data_[notes_key], 70)
  else:
    MIDRANGE_BOTTOM = np.percentile(train_data_[notes_key], 24)
    MIDRANGE_TOP = np.percentile(train_data_[notes_key], 76)

  train_data_["target"] = (train_data_[notes_key] >= MIDRANGE_TOP).astype(int)
  train_data_ = train_data_[(train_data_[notes_key] <= MIDRANGE_BOTTOM) | (train_data_[notes_key] >= MIDRANGE_TOP)]
  stratify = train_data_["target"]
else:
# split at middle
  train_data_["target"] = (train_data_[notes_key] > 2).astype(int)
  train_data_ = train_data_
  stratify = train_data_["target"]


model_inputs = train_data_[["text", "target"]]

In [None]:
model_inputs.target.describe()

In [None]:
def baserate_loss(target):
  baserate = np.mean(target)

  return -1 * (baserate*np.log(baserate) + (1-baserate)*np.log(1-baserate))

print(f"baserate_loss (all): {baserate_loss(model_inputs.target):.3f}")

generic tensorflow utility code

In [None]:
def initialize_uninitialized(sess, print_names=True):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    if print_names:
      for i in not_initialized_vars:
        print(i.name)

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))

def re_initialize(sess, var_names):
  sess.run(tf.variables_initializer(var_names))

selector/generator model setup

In [None]:
%pip install -r "requirements.txt"

In [None]:
import os, sys

sys.path.append("src")

In [None]:
import fire
import json
import os
import numpy as np
import tensorflow as tf

import model, sample, encoder
from load_dataset import load_dataset, Sampler

import json

In [None]:
model_name = ""  # fill in -- should be a directory under /models

enc = encoder.get_encoder(model_name, eot_workaround=True)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
    hparams.override_from_dict(json.load(f))

hparams.set_hparam("attn_dropout", 0)
hparams.set_hparam("res_dropout", 0)

In [None]:
length=825
required_continuation_room = 100
max_context_size = length - required_continuation_room

# "for_h" doesn't mean anything here ("it's historical"), this is just the batch size
batch_size_for_h = 8

In [None]:
SELECTION_CHAR = "<|endoftext|>"
print(enc.encode(SELECTION_CHAR))
SELECTION_TOK = enc.encode(SELECTION_CHAR)[-1]
print(SELECTION_TOK)

define model architecture

much of this either reuses gpt2 code (imported in `from model import *`), or defines slightly modified equivalents of it

In [None]:
from model import *

def extract_selection_ix(tokens, extract_from):
  mask = tf.equal(tf.dtypes.cast(tokens, tf.int32), SELECTION_TOK)
  extracted_ragged = tf.ragged.boolean_mask(extract_from, mask)
  
  row_lengths = extracted_ragged.row_lengths()
  row_ixs = row_lengths-1
  selection_ix = tf.stack([tf.range(0, batch_size_for_h, dtype=tf.int64), row_ixs], axis=1,)

  extracted = tf.gather_nd(extracted_ragged.to_tensor(), selection_ix,  )

  return {"extracted": extracted, "selection_ix": selection_ix}

def model_activations(hparams, X, hparams_select, 
                      layer_nums: list, 
                      norm_layers_after: bool=False,
                      past=None, past_select=None,
                      scope='model', reuse=tf.AUTO_REUSE):
    activations = []
    h_names = []
    
    dtype = hparams.dtype if hparams else tf.float32
    with tf.variable_scope(scope, reuse=reuse, dtype=dtype):
        results = {}
        batch, sequence = shape_list(X)

        wpe = get_variable('wpe') or tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                             initializer=tf.random_normal_initializer(stddev=0.01, dtype=dtype))
        wte = get_variable('wte') or tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                             initializer=tf.random_normal_initializer(stddev=0.02, dtype=dtype))
        past_length = 0 if past is None else tf.shape(past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

        # Transformer
        presents = []
        pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
            if layer in layer_nums:
              h_name = f'h{layer}'
              print(f'{h_name} found')
              h_names.append(h_name)
              activations.append(h)

        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f', hparams=hparams)

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits

        # activations
        if norm_layers_after:
          activations = [norm(act, f'ln_after_{act_name}', hparams=hparams_select)
                         for act_name, act in zip(h_names, activations)]

        results['activations'] = list(zip(h_names, activations))

        return results

In [None]:
def get_initializer(hparams, scope):
  initializer = tf.random_normal_initializer
  if hparams.get("orth_init"):
    print(f"orth init in scope {scope}")
    initializer = tf.compat.v1.orthogonal_initializer
  return initializer

def conv1d(x, scope, nf, *, w_init_stdev=0.02, hparams=None):
    dtype = hparams.dtype if hparams else tf.float32

    initializer = get_initializer(hparams, scope)
    with tf.variable_scope(scope, dtype=dtype):
        *start, nx = shape_list(x)
        w = get_variable('w') or tf.get_variable('w', [1, nx, nf], initializer=initializer(w_init_stdev, dtype=dtype))
        b = get_variable('b') or tf.get_variable('b', [nf], initializer=tf.constant_initializer(0, dtype=dtype))
        c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
        return c

# this is a copy/paste -- we need to redefine "attn" so the "conv1d" defn'd above
# is used

def attn(x, scope, n_state, *, past, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0
    if past is not None:
        assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads, sequence, features]
        return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(x, [0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w*b - tf.cast(65500 if w.dtype != tf.float32 else 1e10, w.dtype)*(1-b)
        return w

    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))

        w = mask_attn_weights(w)
        w = softmax(w)
        w = dropout(w, hparams.attn_dropout)
        a = tf.matmul(w, v)
        return a

    dtype = hparams.dtype if hparams else tf.float32
    with tf.variable_scope(scope, dtype=dtype):
        c = conv1d(x, 'c_attn', n_state*3, hparams=hparams)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)
        if past is not None:
            pk, pv = tf.unstack(past, axis=1)
            k = tf.concat([pk, k], axis=-2)
            v = tf.concat([pv, v], axis=-2)
        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state, hparams=hparams)
        a = dropout(a, hparams.res_dropout)
        return a, present

def attn_only_block(x, scope, *, past, hparams, do_input_norm=True):
    dtype = hparams.dtype if hparams else tf.float32
    do_resid = hparams.do_resid if hparams else True
    print(f"do_resid: {do_resid}")
    print(f"do_input_norm: {do_input_norm}")
    with tf.variable_scope(scope, dtype=dtype):
        nx = x.shape[-1].value

        if do_input_norm:
          x_attn_in = norm(x, 'ln_1', hparams=hparams)
        else:
          x_attn_in = x
        a, present = attn(x_attn_in, 'attn', nx, past=past, hparams=hparams)
        if do_resid:
          x = x + a
        else:
          x = a

        return x, present

In [None]:
def mlp_no_proj(x, scope, n_state, *, hparams, is_expansion=False):
    dtype = hparams.dtype if hparams else tf.float32
    with tf.variable_scope(scope, dtype=dtype):
        nx = x.shape[-1].value
        h = gelu(conv1d(x, 'c_fc', n_state,
                        w_init_stdev=0.02,
                        hparams=hparams))
        h = dropout(h, hparams.res_dropout)
        return h


In [None]:
def selector(hparams, X, hparams_select, 
             layer_nums: list, 
             scope="model", 
             reuse=tf.AUTO_REUSE,
             norm_layers_after: bool=False,
             use_mlp: bool=True,
             resid_mlp: bool=True,
             ):
  results = {}

  activations = model_activations(
      hparams=hparams, hparams_select=hparams_select,
       X=X, layer_nums=layer_nums,
        norm_layers_after=norm_layers_after,
        scope=scope, reuse=reuse,
      )['activations']
  
  hs_select = []
  for act_name, act in activations:
    h_select, _ = attn_only_block(act, f'h_select_{act_name}', 
                                  hparams=hparams_select,
                                  past=None,
                                  do_input_norm=(not norm_layers_after))
    h_select = norm(h_select, f'ln_2_select_{act_name}', hparams=hparams_select,)
    hs_select.append(h_select)

    h_select_in = tf.concat(hs_select, axis=-1)
      
    h_select_in_at_selection_ix = extract_selection_ix(X, h_select_in)['extracted']
    
  with tf.variable_scope(scope, reuse=reuse, dtype=hparams_select.dtype):
    if use_mlp:
      m = mlp_no_proj(h_select_in_at_selection_ix, "select_mlp__", len(layer_nums)*hparams.n_embd, hparams=hparams_select)
      if resid_mlp:
        h_select_in_at_selection_ix = m + h_select_in_at_selection_ix
      else:
        h_select_in_at_selection_ix = m
      
    
    w_select = get_variable('w_select_')
    if w_select is None:
      initializer = get_initializer(hparams_select, scope)
      w_select = tf.get_variable('w_select_', [len(layer_nums)*hparams.n_embd, 2],
                                initializer=initializer(0.02, dtype=hparams.dtype))
    
    b_select = get_variable('b_select')
    if b_select is None:
      b_select = tf.get_variable('b_select', [2],
                                initializer=tf.constant_initializer(0, dtype=hparams.dtype))
      
    select_logits = tf.matmul(h_select_in_at_selection_ix, w_select) + b_select

  results['logits_select'] = select_logits

  return results

define selector model hparams, add selector model to tf graph

In [None]:
from tensorflow.contrib.training import HParams

hparams_select_train = HParams(
        n_vocab=hparams.n_vocab,
        n_ctx=hparams.n_ctx,
        n_embd=hparams.n_embd,
        n_head=hparams.n_head,
        n_layer=hparams.n_layer,
        res_dropout=0.,
        attn_dropout=0.,
        dtype=tf.float32,
        do_resid=False,
        orth_init=True,
    )

hparams_select_eval = HParams(
        n_vocab=hparams.n_vocab,
        n_ctx=hparams.n_ctx,
        n_embd=hparams.n_embd,
        n_head=hparams.n_head,
        n_layer=hparams.n_layer,
        res_dropout=0,
        attn_dropout=0,
        dtype=tf.float32,
        do_resid=False,
        orth_init=True,
    )

In [None]:
layer_nums = [24-1, 36-1]
norm_layers_after = False
use_mlp = True
resid_mlp = True

with sess.as_default():
  selection_step_train = selector(
      hparams=hparams, hparams_select=hparams_select_train,
       X=context_for_h, layer_nums=layer_nums,
        norm_layers_after=norm_layers_after,
        use_mlp=use_mlp, resid_mlp=resid_mlp
      )
  selection_step_eval = selector(
    hparams=hparams, hparams_select=hparams_select_eval,
      X=context_for_h, layer_nums=layer_nums,
      norm_layers_after=norm_layers_after,
      use_mlp=use_mlp, resid_mlp=resid_mlp
    )

load generator part of model into session from checkpoint

In [None]:
import tflex

seed = None

load_done = False

while not load_done:
  try:
    tf.reset_default_graph()
    sess = tflex.Session()

    with sess.as_default():
        np.random.seed(seed)
        tf.set_random_seed(seed)

        context_for_h = tf.placeholder(tf.int32, [batch_size_for_h, None])
        do_step_with_h = step_with_h(hparams, tokens=context_for_h, batch_size_=batch_size_for_h)

        saver = tflex.Saver()
        ckpt = tflex.latest_checkpoint(os.path.join('models', model_name))
        print(f"restoring checkpoint: {ckpt}")
        saver.restore(sess, ckpt)
    load_done = True
  except Exception as e:
    print(f"encountered {e}, retrying...")

initialize weights of selector part

In [None]:
initialize_uninitialized(sess)

more data munging -- this is unique to this modeling approach, not shared w/ the BERT one

doing it at this point because we need to know how much data we have to set up AdamW correctly

In [None]:
CONTINUATION_ONLY = True
NO_TAGS = False

UNAME_CHAR = "友"
T_CHAR = "职"

def strip_uname_tags(s, prompt, verbose=False):
  post, optional_tchar, tagbody = s.partition(T_CHAR)
  tags = [t for t in tagbody.split("#") if len(t)>0]

  stripped_tags = []
  for t in tags:
    uname_substr = UNAME_CHAR + t.rstrip(" ")
    if uname_substr not in prompt:
        stripped_tags.append("#" + t)
    else:
      if verbose:
        print(f"found {t} as {uname_substr}")
  
  return post + optional_tchar + "".join(stripped_tags)

train_data_for_selection = train_data_.copy()
selector_input_continuation = train_data_for_selection.text.apply(lambda s: (s[:-2] if s.endswith("<|") else s))

if CONTINUATION_ONLY:
  selector_input = selector_input_continuation.copy()
else:
  selector_input_prompt = train_data_for_selection.prompt.apply(lambda s: enc.decode(enc.encode(s)[-max_context_size:]))
  selector_input = selector_input_prompt + selector_input_continuation

if NO_TAGS:
  selector_input = selector_input.apply(lambda s: s.partition(T_CHAR)[0].rstrip("\n\ufffa\ufffb ") + T_CHAR)
else:
  selector_input.iloc[:] = [strip_uname_tags(s, prompt) for s, prompt in zip(selector_input, train_data_for_selection.prompt)]
  selector_input = selector_input.apply(lambda s: s.rstrip("\n\ufffa\ufffb "))
  
selector_input = selector_input.apply(lambda s: enc.decode(enc.encode(s)[-(length-1):]))
train_data_for_selection["selector_input"] = selector_input

In [None]:
TEST_SIZE = 0.175

In [None]:
from sklearn.model_selection import train_test_split

train_data_for_selection, test_data_for_selection = train_test_split(train_data_for_selection, 
                                                                     test_size=TEST_SIZE, 
                                                                     stratify=train_data_for_selection.target)

In [None]:
train_data_for_selection["n_tokens"] = train_data_for_selection["selector_input"].apply(lambda s: len(enc.encode(s)))
train_data_for_selection = train_data_for_selection.sort_values(by="n_tokens")

batches = [train_data_for_selection.iloc[row_ix:row_ix + batch_size_for_h, :]
           for row_ix in range(0, len(train_data_for_selection), batch_size_for_h)]

np.random.shuffle(batches)

train_data_for_selection_final = pd.concat(batches, ignore_index=True)

In [None]:
def reshuffle_batches(train_data_for_selection):
  batches = [train_data_for_selection.iloc[row_ix:row_ix + batch_size_for_h, :]
           for row_ix in range(0, len(train_data_for_selection), batch_size_for_h)]

  np.random.shuffle(batches)

  return pd.concat(batches, ignore_index=True)

In [None]:
test_data_for_selection["n_tokens"] = test_data_for_selection["selector_input"].apply(lambda s: len(enc.encode(s)))
test_data_for_selection = test_data_for_selection.sort_values(by="n_tokens")

batches = [test_data_for_selection.iloc[row_ix:row_ix + batch_size_for_h, :]
           for row_ix in range(0, len(test_data_for_selection), batch_size_for_h)]

np.random.shuffle(batches)

test_data_for_selection_final = pd.concat(batches, ignore_index=True)

In [None]:
print(train_data_for_selection_final.shape)
print(test_data_for_selection_final.shape)

set up optimizer

In [None]:
import tflex_sgdr

n_batches_per_epoch = len(train_data_for_selection_final)//batch_size_for_h

base_lr = 0.0001 * (1-0.25)/(1-0.175)
min_lr = base_lr/20
initial_period_steps=int(n_batches_per_epoch)

try:
  global_step = tf.get_variable('global_step_', shape=(), dtype=tf.int32, trainable=False, )
except:
  pass
global_step.load(0, session=sess)

lr = tflex_sgdr.sgdr_decay(base_lr, global_step, 
    initial_period_steps=initial_period_steps, 
    t_mul=1., m_mul=0.5
    )
lr = tf.maximum(lr, tf.constant(min_lr, dtype=lr.dtype))

In [None]:
with sess.as_default():
  select_logits_train = selection_step_train['logits_select']
  select_logits_eval = selection_step_eval['logits_select']
  select_target = tf.placeholder(tf.int32, [batch_size_for_h], )

  select_loss = tf.reduce_mean(
      tf.nn.sparse_softmax_cross_entropy_with_logits(labels=select_target, logits=select_logits_train)
      )

In [None]:
from tensorflow.train import AdamOptimizer
from tensorflow.contrib.opt import AdamWOptimizer

weight_decay = 0.025
opt = AdamWOptimizer(weight_decay=weight_decay*lr, learning_rate=lr)

train_vars = [var for var in tf.trainable_variables() if "select" in var.name and "ln_2_" not in var.name]
decay_vars = [var for var in train_vars if "_scalars" not in var.name and "b_select" not in var.name and "ln_" not in var.name]

opt_gradients, opt_variables = zip(*opt.compute_gradients(select_loss, train_vars))
opt_gradients, _ = tf.clip_by_global_norm(opt_gradients, 1.0)
opt_apply = opt.apply_gradients(zip(opt_gradients, opt_variables), decay_var_list=decay_vars)

initialize_uninitialized(sess)

define model train and eval loops

In [None]:
import time
import scipy.special

def train_selection(data, steps=None, start_ix=0, avg_loss_beta=0.98, show_mix=False, show_lr=False):
  all_losses = []
  running_loss = None
  latest_mix = None

  if show_mix:
    with tf.variable_scope("model", reuse=tf.AUTO_REUSE, dtype=hparams.dtype):
      latest_mix_var = tf.unstack(tf.get_variable("w_select_scalars"), axis=-1)
      latest_mix_var_g = tf.get_variable("g_select_scalars")

  if steps is None:
    steps = len(data)//batch_size_for_h

  row_ix = start_ix*batch_size_for_h
  for step_ix in range(start_ix, steps):
    data_batch = data.iloc[row_ix:row_ix + batch_size_for_h, :]

    batch_context = [enc.encode(text) + [SELECTION_TOK] for text in data_batch.selector_input.values]
    max_tokens = max([len(toks) for toks in batch_context])
    batch_context_ = [toks + [0 for _ in range(max_tokens-len(toks))] for toks in batch_context]
    batch_context = batch_context_

    batch_target = data_batch.target.values

    batch_context_display = [(text[:20] + ('...' if len(text)>20 else ''), t) for text, t in zip(data_batch.selector_input.values,
                                                                                              batch_target)]
    print(f"{step_ix}/{steps} | {sum(batch_target)}/{batch_size_for_h} pos | {batch_size_for_h} rows of {max_tokens} tokens")

    t1 = time.time()
    with sess.as_default():
      try:
        batch_loss, _ = sess.run([select_loss, opt_apply], feed_dict={context_for_h.name: batch_context, select_target.name: batch_target})
      except tf.errors.InvalidArgumentError:
        print("skipping")
        continue
    t2 = time.time()
    tdiff = t2 - t1

    all_losses.append(batch_loss)

    if running_loss is None:
      if step_ix > 3:
        running_loss = batch_loss
        avg_display = f"{running_loss:.4f}"
      else:
        avg_display = "None"
    else:
      running_loss = (avg_loss_beta * running_loss) + ((1-avg_loss_beta) * batch_loss)
      avg_display = f"{running_loss:.4f}"

    print(f"{step_ix}/{steps} | {tdiff:.2f}s  | loss={batch_loss:.4f}, avg={avg_display}")
    if show_mix and step_ix % 10 == 0:
      latest_mixes_raw, latest_g = sess.run([latest_mix_var, latest_mix_var_g])
      nmix=len(latest_mixes_raw)
      nlayer = len(latest_mixes_raw[0])
      
      latest_g_formatted = [f"{x: .4f}" for x in latest_g]
      latest_mixes_raw_formatted = [np.asarray([f"{x: .4f}" for x in mix]) for mix in latest_mixes_raw]

      latest_mixes_normed = [scipy.special.softmax(mix, axis=0) for mix in latest_mixes_raw]
      latest_mixes_normed_formatted = [np.asarray([f"{x: .4f}" for x in mix]) for mix in latest_mixes_normed]

      for mix_ix in range(nmix):
        print(f"mix {mix_ix}:")
        print(f"latest mix gain: {latest_g_formatted}")
        print(f'latest_mix_raw:\n{latest_mixes_raw_formatted[mix_ix]}')
        print(f'latest_mix_normed:\n{latest_mixes_normed_formatted[mix_ix]}')

    if show_lr:
      print(f"latest lr: {sess.run(lr):.4e}")
    print("\n--------------\n")

    row_ix += batch_size_for_h

    current_step = global_step.eval(sess)
    global_step.load(current_step+1, session=sess)

In [None]:
import scipy.special

def predict_select(text_batch, threshold=0.5):
  if len(text_batch) != batch_size_for_h:
    raise ValueError("badlength")
  batch_context = []
  for text in text_batch:
    batch_context.append(enc.encode(text)[-(length-1):] + [SELECTION_TOK])
  max_tokens = max([len(toks) for toks in batch_context])
  batch_context_ = [toks + [0 for _ in range(max_tokens-len(toks))] for toks in batch_context]
  batch_context = batch_context_

  with sess.as_default():
    logits = sess.run(select_logits_eval, feed_dict={context_for_h.name: batch_context})

  probs = scipy.special.softmax(logits, axis=1)
  results = {"logits": logits, "probs": probs, "preds": probs[:, 1]>threshold}
  return results

In [None]:
import time 

def eval_selection(data, steps=None, start_ix=0):
  all_preds = []
  all_probs = []
  all_targets = []
  all_row_ix = []

  if steps is None:
    steps = len(data)//batch_size_for_h

  row_ix = start_ix*batch_size_for_h
  for step_ix in range(start_ix, steps):
    data_batch = data.iloc[row_ix:row_ix + batch_size_for_h, :]
    
    t1 = time.time()
    try:
      results_batch = predict_select(data_batch.selector_input.values)
    except Exception as e:
      print(f"skipping batch ({e})")
      continue
    t2 = time.time()
    tdiff = t2 - t1

    all_targets.extend(data_batch.target.values)
    all_preds.extend(results_batch["preds"])
    all_probs.extend(results_batch["probs"])
    all_row_ix.extend(list(range(row_ix, row_ix + batch_size_for_h)))

    accs = np.array(all_preds) == np.array(all_targets)
    avg_acc = accs.mean()

    tp = (np.array(all_targets)>0).sum()
    pp = (np.array(all_preds)>0).sum()

    assert len(all_targets) == len(all_preds)

    print(f"{step_ix}/{steps} | {tdiff:.2f}s | acc={avg_acc:.4f} | {tp}/{len(all_targets)} true pos | {pp}/{len(all_targets)} pred pos")
    print("\n--------------\n")

    row_ix += batch_size_for_h

  all_probs = np.stack(all_probs)[:, 1]
  return all_preds, all_probs, all_targets, all_row_ix

train

In [None]:
# train for several epochs

n_epochs = 3

for epoch_ix in range(n_epochs):
  train_data_for_selection_final = reshuffle_batches(train_data_for_selection)
  train_selection(train_data_for_selection_final, start_ix=0, show_mix=False, show_lr=True)

evaluate

In [None]:
from sklearn.metrics import brier_score_loss, average_precision_score, accuracy_score

all_preds, all_probs, all_targets, all_row_ix = eval_selection(test_data_for_selection_final, start_ix=0)

In [None]:
baserate_acc = max(test_data_for_selection_final.target.mean(), 1.-test_data_for_selection_final.target.mean())
baserate_brier = brier_score_loss(test_data_for_selection_final.target, [test_data_for_selection_final.target.mean() for _ in range(len(test_data_for_selection_final.target))])
baserate_AP = average_precision_score(test_data_for_selection_final.target, [test_data_for_selection_final.target.mean() for _ in range(len(test_data_for_selection_final.target))])

print(f"acc:   {accuracy_score(all_targets, all_preds):.3f} (vs {baserate_acc:.3f})")
print(f"brier: {brier_score_loss(all_targets, all_probs):.3f} (vs {baserate_brier:.3f})")
print(f"AP:    {average_precision_score(all_targets, all_probs):.3f} (vs {baserate_AP:.3f})")

In [None]:
from sklearn.metrics import precision_recall_curve

ps, rs, ts = precision_recall_curve(all_targets, all_probs)

plt.scatter(ps, rs, s=1, c=[0] + ts.tolist())
plt.xlim([0, 1])
plt.colorbar();

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 6))

bins=np.linspace(0, 1, 9)

ax[0].hist(all_probs, bins=bins, alpha=0.5, density=False, edgecolor='k')
ax[0].set_xticks(bins)

_, bins, _ = ax[1].hist(all_probs[pd.Series(all_targets) < 0.5], bins=bins, alpha=0.5, density=False, edgecolor='k')
ax[1].hist(all_probs[pd.Series(all_targets) > 0.5], bins=bins, alpha=0.5, density=False, edgecolor='k');
ax[1].set_xticks(bins);

In [None]:
calib_real = []
calib_goal = []
calib_width = []

for e1, e2 in zip(bins[:-1], bins[1:]):
  bin_probs = all_probs[(all_probs>=e1) & (all_probs < e2)]
  bin_targs = pd.Series(all_targets)[(all_probs>=e1) & (all_probs < e2)].values

  calib_real.append(bin_targs.mean())
  calib_goal.append((e2+e1)/2)
  calib_width.append(e2-e1)

  print(f"{e1:.0%} - {e2:.0%}:\t{bin_targs.mean():.1%} (vs {(e2+e1)/2:.1%},  \t{len(bin_targs)} examples)")

In [None]:
plt.figure(figsize=(6, 6))
plt.bar(calib_goal, calib_real, width=calib_width, edgecolor='k', alpha=0.5)
plt.plot(calib_goal, calib_goal, marker='o', c='r', ls='--');
plt.axis([0, 1, 0, 1]);
plt.xticks(calib_goal)
plt.yticks(calib_goal)
plt.gca().set_aspect(1);

In [None]:
print(f"base rate: {pd.Series(all_targets).mean():.2%}")

pd.Series(all_probs).describe(percentiles=np.linspace(0, 1, 11))

save

In [None]:
save_path = "" # wherever you want to put the .hdf5 checkpoint -- make sure directory exists already

var_list = [var for var in tf.trainable_variables() if "select" in var.name]

saver_tflex = tflex.Saver(
            var_list=var_list,
            max_to_keep=20,
            keep_checkpoint_every_n_hours=20,
            reshape=False)

display(var_list)
saver_tflex.save(sess, save_path)