In [4]:
import argparse
import json
import time

import jax
import numpy as np
import optax

import wandb
from tqdm import tqdm
import transformers


from mesh_transformer import util
from mesh_transformer.checkpoint import read_ckpt, write_ckpt
from mesh_transformer.transformer_shard import CausalTransformer
from tfrecord_loader import TFRecordNewInputs
from smart_open import open
from google.cloud import storage
from google.cloud.exceptions import NotFound

from mesh_transformer.util import clip_by_global_norm, additive_weight_decay

In [5]:
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

In [6]:
def save(network, step, bucket, path, mp, aux=None, keep_n=3, delete_old=True):
    assert path
    client = storage.Client()

    if aux is None:
        aux = {}

    try:
        with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
            meta = json.load(f)
    except:
        # create metadata file
        with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
            json.dump({
                "step": 0,
                "checkpoints": [],
                "aux": {}
            }, f)

    # do sharded checkpoint writing
    start = time.time()
    res = []
    for shard_id in range(mp):
        write_ckpt(network.state, f"gs://{bucket}/{path}/step_{step}/", shard_id)

    print(f"Wrote checkpoint in {time.time() - start:.06}s")

    with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
        meta = json.load(f)

    meta["step"] = step
    meta["checkpoints"].append(step)
    all_aux = meta.get("aux", {})

    while len(meta["checkpoints"]) > keep_n:
        ckpt_to_delete = meta["checkpoints"].pop(0)

        try:
            del all_aux[str(ckpt_to_delete)]
        except:
            print(f"failed to delete the aux state for {step}")

        if delete_old:
            print(f"deleting checkpoint {ckpt_to_delete}")
            for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"):
                # print(f"deleting {blob.name}")
                assert path in blob.name
                blob.delete()
        else:
            print(f"keeping checkpoint {ckpt_to_delete}")

    all_aux[step] = aux
    meta["aux"] = all_aux

    with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
        json.dump(meta, f)

def eval_step(network, data):
    inputs = {
        "obs": data[:, :-1],
        "target": data[:, 1:],
    }

    out = network.eval(inputs)

    return out

In [6]:
if __name__ == "__main__":
    config = 'configs/crypto_subset_8.json'
    fine_tuning = False
    params = json.load(open(config))

    gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
    per_replica_batch = params["per_replica_batch"]
    cores_per_replica = params["cores_per_replica"]

    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    val_batches = params["val_batches"]
    val_every = params["val_every"]
    ckpt_every = params["ckpt_every"]
    keep_every = params["keep_every"]
    eval_tasks = params["eval_harness_tasks"]
    total_steps = params["total_steps"]

    pe = params["pe"]
    assert pe in ["fixed", "rotary", "t5"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]

    opt = optax.scale(0)

    params["optimizer"] = opt

    start = time.time()
    tpu_size = jax.device_count()
    if tpu_size < cores_per_replica:
        msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})"
        raise ValueError(msg)
    print(f"jax devices: {tpu_size}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (tpu_size // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    step = 0

    print('`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)')
#     initial_ckpt_path = f"gs://gpt-j_records_eur/models_cp_crypto_subset"
    initial_ckpt_path = f"gs://gpt-j_records_eur/records"
#     meta_path = f"{initial_ckpt_path}/meta.json"

#     with open(meta_path, "r") as f:
#         meta = json.load(f)
#     ckpt_step = meta["checkpoints"][-1]
    ckpt_step = 383500
    initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/"
    print(f"state will be restored from checkpoint {ckpt_step}")

    step = ckpt_step

    if initial_ckpt_state_path:
        print(f"path to load checkpoint from: {initial_ckpt_state_path}")
    else:
        print("not loading from a checkpoint")

    # load + run
    with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
        print("initializing network")
        print(params)
        network = CausalTransformer(params)

        if initial_ckpt_state_path:
            print("loading network")

            start = time.time()
            network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=False)

            print(f"network loaded in {time.time() - start:.06}s")

jax devices: 8
jax runtime initialized in 3.42035s
`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)
state will be restored from checkpoint 383500
path to load checkpoint from: gs://gpt-j_records_eur/records/step_383500/
initializing network
{'layers': 28, 'd_model': 4096, 'n_heads': 16, 'n_vocab': 50400, 'norm': 'layernorm', 'pe': 'rotary', 'pe_rotary_dims': 64, 'seq': 2048, 'cores_per_replica': 8, 'per_replica_batch': 1, 'gradient_accumulation_steps': 32, 'warmup_steps': 3000, 'anneal_steps': 27221, 'lr': 5e-05, 'end_lr': 1e-05, 'weight_decay': 0.1, 'total_steps': 30221, 'tpu_size': 8, 'bucket': 'gpt-j_records_eur', 'model_dir': 'models_cp_crypto_subset', 'train_set': 'crypto_subset.train.index', 'val_set': {'crypto_scratch': 'crypto_subset.val.index'}, 'eval_harness_tasks': [], 'val_batches': 100, 'val_every': 100, 'ckpt_every': 1000, 'keep_every': 5000, 'name': 'GPT3_crypto_subset', 'comment': '', 'optimizer': Gradient

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 6053381344
loading network


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 851, in next
    item = self._items.popleft()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/Mcian/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_202421/1881186042.py", line 83, in <module>
    network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=False)
  File "/home/Mcian/mesh-transformer-jax/mesh_transformer/checkpoint.py", line 135, in read_ckpt
    shards = list((p.imap(read_shard, [f"{dir}shard_{i}/" for i in range(shards_in)])))
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 856, in next
    self._cond.wait(timeout)
  File "/usr/lib/python3.8/threading.py", line 302, in wait
    waiter.acquire()
KeyboardInterrup

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    length = np.ones(context.shape[0], dtype=np.uint32) * context.shape[1]
    output = network.generate(context, length, gen_len, 
                              {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp}, return_logits=True)

    return output

In [None]:
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

end_token = 26

val_files = [f for f in listdir('strings/val/')]

results = {}
for f in val_files:
    if f != 'BTC_1min_string.txt':
        continue
    print(f)
        
    words = Path(f'strings/val/{f}').read_text()
    tokens = tokenizer.encode(words)
    i = 2048
    all_contexts = np.array([])
    outs = np.array([])
    while i < len(tokens):
        contexts = []
        while len(contexts) < 100:
            print(i)
            context = tokens[:i]
            while context[-1] != end_token:
                i += 1
                context = tokens[i-2048:i]
            contexts.append(context)
            
        contexts = np.array(contexts)
        all_contexts = np.concat([all_contexts, contexts])
        i += 1
        outs = np.concat([outs, infer(top_p=top_p, temp=temp, gen_len=100, context=contexts)])
    results[f] = {'outs': outs, 'contexts': all_contexts}

In [None]:
trues = []
preds = []
for f, v in results.items():
    print(f)
    contexts = v['contexts']
    targets = []
    for context in contexts:
        context = list(context)
        target = context[:context.index(end_token)+1]
        targets.append(target)
    
    for context, out, target in zip(contexts[:-1], outs[:-1], targets[1:]):
        out = out[:context.index(end_token)]
        out = tokenizer.decode(out[1][0])
        target = tokenizer.decode(target)
        
        target_values = re.findall('\d+\_\d+', target) + re.findall('\d+\:\d+', target) + re.findall('\d+\.\d+', target)
        out_values = re.findall('\d+\_\d+', out) + re.findall('\d+\:\d+', out) + re.findall('\d+\.\d+', out)
        if target_values[0] == out_values[0] and target_values[1] == out_values[1]:
            trues.append(target_values[2:])
            preds.append(out_values[2:])
        else:
            print('skipped')