In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
import json

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

from pathlib import Path
from os import listdir
import re

In [3]:
config = 'configs/crypto_subset_8.json'
params = json.load(open(config))

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('my_tokenizer')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 6053381344


In [5]:
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
    network.state = read_ckpt(network.state, "model_checkpoints/step_26001/step_26001/", devices.shape[1], load_opt=False)

    network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

read from disk/gcs in 4871.7s


AssertionError: Incompatible checkpoints (8,) vs (8, 4096)

In [6]:
initial_ckpt_path = f"gs://gpt-j_records_eur/models_cp_crypto_subset"
meta_path = f"{initial_ckpt_path}/meta.json"

with open(meta_path, "r") as f:
    meta = json.load(f)
ckpt_step = meta["checkpoints"][-1]
initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/"
    
network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=False)

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

FileNotFoundError: [Errno 2] No such file or directory: 'gs://gpt-j_records_eur/models_cp_crypto_subset/meta.json'

In [7]:
if __name__ == "__main__":
    config = 'configs/crypto_subset_8.json'
    fine_tuning = False
    params = json.load(open(config))

    gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
    per_replica_batch = params["per_replica_batch"]
    cores_per_replica = params["cores_per_replica"]

    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    val_batches = params["val_batches"]
    val_every = params["val_every"]
    ckpt_every = params["ckpt_every"]
    keep_every = params["keep_every"]
    eval_tasks = params["eval_harness_tasks"]
    total_steps = params["total_steps"]

    pe = params["pe"]
    assert pe in ["fixed", "rotary", "t5"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]

    opt = None

    params["optimizer"] = opt

    start = time.time()
    tpu_size = jax.device_count()
    if tpu_size < cores_per_replica:
        msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})"
        raise ValueError(msg)
    print(f"jax devices: {tpu_size}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (tpu_size // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    step = 0

    print('`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)')
    initial_ckpt_path = f"gs://gpt-j_records_eur/models_cp_crypto_subset"
    meta_path = f"{initial_ckpt_path}/meta.json"

    with open(meta_path, "r") as f:
        meta = json.load(f)
    ckpt_step = meta["checkpoints"][-1]
    initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/"
    print(f"state will be restored from checkpoint {ckpt_step}")

    step = ckpt_step

    if initial_ckpt_state_path:
        print(f"path to load checkpoint from: {initial_ckpt_state_path}")
    else:
        print("not loading from a checkpoint")

    # load + run
    with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
        print("initializing network")
        print(params)
        network = CausalTransformer(params)

        if initial_ckpt_state_path:
            print("loading network")

            start = time.time()
            network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=False)

            print(f"network loaded in {time.time() - start:.06}s")

jax devices: 8
jax runtime initialized in 0.000291586s
`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)


FileNotFoundError: [Errno 2] No such file or directory: 'gs://gpt-j_records_eur/models_cp_crypto_subset/meta.json'

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    length = np.ones(context.shape[0], dtype=np.uint32) * context.shape[1]

    start = time.time()
    output = network.generate(context, length, gen_len, 
                              {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp}, return_logits=True)


    return output

In [None]:
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

end_token = 26

val_files = [f for f in listdir('strings/val/')]

results = {}
for f in val_files:
    if f != 'BTC_1min_string.txt':
        continue
    print(f)
        
    words = Path(f'strings/val/{f}').read_text()
    tokens = tokenizer.encode(words)
    i = 2048
    contexts = []
    outs = []
    while i < 3000:
        print(i)
        context = tokens[:i]
        while context[-1] != end_token:
            i += 1
            context = tokens[i-2048:i]
        context = np.array([context])
        contexts.append(context)
        contexts = np.array(contexts)
        i += 1
        outs.append(infer(top_p=top_p, temp=temp, gen_len=100, context=contexts))
    results[f] = {'outs': outs, 'contexts': contexts}

In [None]:
trues = []
preds = []
for f, v in results.items():
    print(f)
    contexts = v['contexts']
    targets = []
    for context in contexts:
        context = list(context)
        target = context[:context.index(end_token)+1]
        targets.append(target)
    
    for context, out, target in zip(contexts[:-1], outs[:-1], targets[1:]):
        out = out[:context.index(end_token)]
        out = tokenizer.decode(out[1][0])
        target = tokenizer.decode(target)
        
        target_values = re.findall('\d+\_\d+', target) + re.findall('\d+\:\d+', target) + re.findall('\d+\.\d+', target)
        out_values = re.findall('\d+\_\d+', out) + re.findall('\d+\:\d+', out) + re.findall('\d+\.\d+', out)
        if target_values[0] == out_values[0] and target_values[1] == out_values[1]:
            trues.append(target_values[2:])
            preds.append(out_values[2:])
        else:
            print('skipped')

In [23]:
re.findall('crypt(\d+)_(\d+)Thu(\d+):(\d+)O([+-]?([0-9]*[.])?[0-9]+)H(\d+)L(\d+)C(\d+)V(\d+);', 'crypt7_1Thu00:00O1.385H1.387L1.365C1.367V2564233')

[]

In [37]:
print(re.findall('\d+\_\d+', 'crypt7_1Thu00:00O1.385H1.387L1.365C1.367V2564233')+
      re.findall('\d+\:\d+', 'crypt7_1Thu00:00O1.385H1.387L1.365C1.367V2564233')+ 
      re.findall('\d+\.\d+', 'crypt7_1Thu00:00O1.385H1.387L1.365C1.367V2564233'))

['7_1', '00:00', '1.385', '1.387', '1.365', '1.367']


In [18]:
l = [1, 2, 3, 4, 3, 5, 4, 5]
l.reverse()
l.index(3)

3