# Easily export jupyter cells to python module
https://github.com/fastai/course-v3/blob/master/nbs/dl2/notebook2script.py

In [5]:
! python /tf/src/scripts/notebook2script.py conditional_model.ipynb

Converted conditional_model.ipynb to exp/nb_conditional.py


In [None]:
!pip3 install -r requirements.txt

In [None]:
#export
import json
import os
import numpy as np
import tensorflow as tf
import time
import tqdm
from encoder import *

tf.__version__

In [None]:
cd /tf/src/data/gpt-2/

# Model

In [None]:
#export
class HParams():
  n_vocab=50257
  n_ctx=1024
  n_embd=768
  n_head=12
  n_layer=12
  
  def __init__(self, n_vocab, n_ctx, n_embd, n_head, n_layer):
    self.n_vocab = n_vocab
    self.n_ctx = n_ctx
    self.n_embd = n_embd
    self.n_head = n_head
    self.n_layer = n_layer

In [None]:
#export
def default_hparams():
    return HParams(
        n_vocab=50257,
        n_ctx=1024,
        n_embd=768,
        n_head=12,
        n_layer=12,
    )

def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(input=x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def gelu(x):
    return 0.5 * x * (1 + tf.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))

def norm(x, scope, *, axis=-1, epsilon=1e-5):
    """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
    with tf.compat.v1.variable_scope(scope):
        n_state = x.shape[-1]
        g = tf.compat.v1.get_variable('g', [n_state], initializer=tf.compat.v1.constant_initializer(1), use_resource=False)
        b = tf.compat.v1.get_variable('b', [n_state], initializer=tf.compat.v1.constant_initializer(0), use_resource=False)
        u = tf.reduce_mean(input_tensor=x, axis=axis, keepdims=True)
        s = tf.reduce_mean(input_tensor=tf.square(x-u), axis=axis, keepdims=True)
        x = (x - u) * tf.math.rsqrt(s + epsilon)
        x = x*g + b
        return x

def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m//n])

def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a*b])

def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.compat.v1.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.compat.v1.get_variable('w', [1, nx, nf], initializer=tf.compat.v1.random_normal_initializer(stddev=w_init_stdev), use_resource=False)
        b = tf.compat.v1.get_variable('b', [nf], initializer=tf.compat.v1.constant_initializer(0), use_resource=False)
        c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
        return c

def attention_mask(nd, ns, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)


def attn(x, scope, n_state, *, past, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0
    if past is not None:
        assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads, sequence, features]
        return tf.transpose(a=split_states(x, hparams.n_head), perm=[0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(a=x, perm=[0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w*b - tf.cast(1e10, w.dtype)*(1-b)
        return w

    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w * tf.math.rsqrt(tf.cast(v.shape[-1], w.dtype))

        w = mask_attn_weights(w)
        w = tf.nn.softmax(w, axis=-1)
        a = tf.matmul(w, v)
        return a

    with tf.compat.v1.variable_scope(scope):
        c = conv1d(x, 'c_attn', n_state*3)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)
        if past is not None:
            pk, pv = tf.unstack(past, axis=1)
            k = tf.concat([pk, k], axis=-2)
            v = tf.concat([pv, v], axis=-2)
        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state)
        return a, present


def mlp(x, scope, n_state, *, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]
        h = gelu(conv1d(x, 'c_fc', n_state))
        h2 = conv1d(h, 'c_proj', nx)
        return h2

def block(x, scope, *, past, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]
        a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
        x = x + a
        m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
        x = x + m
        return x, present

def past_shape(*, hparams, batch_size=None, sequence=None):
    return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]

def expand_tile(value, size):
    """Add a new axis of given size."""
    value = tf.convert_to_tensor(value=value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)

def positions_for(tokens, past_length):
    batch_size = tf.shape(input=tokens)[0]
    nsteps = tf.shape(input=tokens)[1]
    return expand_tile(past_length + tf.range(nsteps), batch_size)


def model(hparams, X, past=None, scope='model', reuse=tf.compat.v1.AUTO_REUSE):
    with tf.compat.v1.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)

        wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                             initializer=tf.compat.v1.random_normal_initializer(stddev=0.01), use_resource=False)
        wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                             initializer=tf.compat.v1.random_normal_initializer(stddev=0.02), use_resource=False)
        past_length = 0 if past is None else tf.shape(input=past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

        # Transformer
        presents = []
        pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f')

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results

# Sample from Model

In [None]:
#export
def sample_sequence(*, hparams, start_token=None, batch_size=None, context=None, past=None, temperature=1):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents
        }

    def body(past, prev, output, embedding):
        next_outputs = step(hparams, prev, past=past)
        logits = next_outputs['logits'][:, -1, :]  / tf.cast(temperature, dtype=tf.float32)
        samples = tf.random.categorical(logits=logits, num_samples=1, dtype=tf.int32)
        return [
            tf.concat([output, samples], axis=1),
            logits,
        ]

    output, logprobs = body(past, context, context, context)

    probs = tf.math.softmax(logprobs)
    return output, probs

In [None]:
from pathlib import Path
def load_dataset(enc, path, combine):
    paths = []
    if os.path.isfile(path):
        # Simple file
        paths.append(path)
    elif os.path.isdir(path):
        # Directory
        for i, (dirpath, _, fnames) in enumerate(os.walk(path)):
            for fname in fnames:
                paths.append(os.path.join(dirpath, fname))
    else:
        # Assume glob
        paths = glob.glob(path)

    token_chunks = []
    raw_text = ''
    for i, path in enumerate(tqdm.tqdm(paths)):

        try:
            with open(path, 'r') as fp:
                raw_text += fp.read()
            tokens = raw_text
            token_chunks.append(tokens)
            raw_text = ''
        except Exception as e:
            print(e)
        if i >= 1000:
            break
    return token_chunks

In [None]:
#export
class Args():
    def __init__(self, dataset, model_name, combine, batch_size, learning_rate, optimizer, noise, top_k, top_p, run_name, sample_every, sample_length, sample_num, save_every, val_dataset, val_batch_size, val_batch_count, val_every, pretrained, iterations):
        self.dataset = dataset
        self.model_name = model_name
        self.combine = combine
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.noise = noise
        self.top_k = top_k
        self.top_p = top_p
        self.run_name = run_name
        self.sample_every = sample_every
        self.sample_length = sample_length
        self.sample_num = sample_num
        self.save_every = save_every
        self.val_dataset = val_dataset
        self.val_batch_size = val_batch_size
        self.val_batch_count = val_batch_count
        self.val_every = val_every
        self.pretrained = pretrained
        self.iterations = iterations

In [None]:
#export
def get_default_args():
    args = Args(
                dataset="/tf/src/data/methods/DATA00M_[god-r]/train",
                model_name="117M",
                combine=50000,
                batch_size=1, # DO NOT TOUCH. INCREASING THIS WILL RAIN DOWN HELL FIRE ONTO YOUR COMPUTER.
                learning_rate=0.00002,
                optimizer="sgd",
                noise=0.0,
                top_k=1,
                top_p=0.0,
                run_name="run1",
                sample_every=100,
                sample_length=1023,
                sample_num=1,
                save_every=1000,
                val_dataset=None,
                val_batch_size=1,
                val_batch_count=40,
                val_every=100,
                pretrained=True,
                iterations=200000
    )
    
    return args

In [None]:
args = get_default_args()
enc = get_encoder(args.model_name, "models")
data_set = load_dataset(enc, args.dataset, args.combine)
len(data_set)

In [None]:
def calc_probs(chkpt_path, ):
    pass

# Self Supervised Pre-Experiment

In [None]:
import math
MAX_CHUNK = 1024

def interact_model(
    model_name='117M',
    seed=None,
    nsamples=1,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=0,
    models_dir='models',
    ds=[]
):
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = get_encoder("117M", "models")
    hparams = default_hparams()

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
        context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.compat.v1.set_random_seed(seed)
        output, logits = sample_sequence(
            hparams=hparams,
            context=context,
            past=None,
            batch_size=batch_size,
            temperature=temperature
        )
        

        saver = tf.compat.v1.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)
        
        mean_probs = {}
        print(mean_probs)   
        for method in tqdm.tqdm(ds[:2]):
            enc_meth = enc.encode(method)
            rshft = enc_meth[:-1]
            lshft = enc_meth[1:]
            context_tokens = []
            for i, tok in enumerate(rshft):
                context_tokens.append(tok)
                if len(context_tokens) == MAX_CHUNK + 1:
                    context_tokens.pop(0)
               
                out, probs = sess.run([output, logits], feed_dict={
                    context: [context_tokens for _ in range(batch_size)]
                }, options = tf.compat.v1.RunOptions(report_tensor_allocations_upon_oom = True))
                out = out[:, -1]
                
                print(repr(enc.decode([out[0]])), repr(enc.decode([lshft[i]])), probs[0][out[0]])
                if out[0] == lshft[i]:
                    if enc.decode([out[0]]) in mean_probs:
                        prob, count = mean_probs[enc.decode([out[0]])]
                        mean_probs[enc.decode([out[0]])] = (prob + probs[0][out[0]], count + 1)
                    else:
                        mean_probs[enc.decode([out[0]])] = (probs[0][out[0]], 1)
        print(mean_probs.keys())

interact_model(model_name='run1',
    seed=None,
    nsamples=1,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=40,
    models_dir='checkpoint', ds=data_set)