In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from haiku import LayerNorm
from einops import rearrange, repeat
from functools import lru_cache, partial

# there is no limit for lenght of text, because our countext windows is always 75 tokens, 
# but we have to choose for RoPE
MAX_SEQ_LEN = 8192

@lru_cache()
def fixed_pos_embedding(rotary_dims):
    inv_freq = 1. / (10000 ** (np.arange(0, rotary_dims, 2) / rotary_dims))
    sinusoid_inp = np.einsum('i , j -> i j', np.arange(MAX_SEQ_LEN), inv_freq)
    return np.sin(sinusoid_inp), np.cos(sinusoid_inp)

def rotate_every_two(x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = jnp.stack((-x2, x1), axis=-1)
    return rearrange(x, '... d j -> ... (d j)')

def apply_rotary_pos_emb(x, sincos, seq_dim):
    sincos = map(lambda t: repeat(t, '... b n -> ... b (n j)', j=2)[-x.shape[seq_dim]:], sincos)
    
    # (n_seq, dim_per_head) -> (n_seq, 1, 1, dim_per_head), so we can do mult
    # in case "x" is something like (n_seq, bs, n_heads, dim_per_head)
    add_dims = set(np.arange(x.ndim-1)) - set([np.arange(x.ndim)[seq_dim]])
    sin, cos = map(lambda t: jnp.expand_dims(t, tuple(add_dims)), sincos)
    
    return (x * cos) + (rotate_every_two(x) * sin)

@partial(jax.jit, static_argnums=(1,2))
def apply_rope(x, rotary_dims, seq_dim):
    x_rot = x[..., :rotary_dims]
    x_pass = x[..., rotary_dims:]
    sincos = fixed_pos_embedding(rotary_dims)
    x_rot = apply_rotary_pos_emb(x_rot, sincos, seq_dim)
    return jnp.concatenate([x_rot, x_pass], axis=-1)

def rope_tests():
    rotary_dims = 32
    vectors = np.random.random(size=(2,75))
    def test_pos(pos1, pos2, f):
        q = np.zeros(shape=(1,64,75))
        v = np.zeros(shape=(1,64,75))
        q[0,pos1] = vectors[0]
        v[0,pos2] = vectors[1]
        res = f(q,rotary_dims)@f(v,rotary_dims).transpose(0, 2, 1)
        return res[0,pos1,pos2]
    
    pos0 = test_pos(3,17, lambda x,y: x)
    pos1 = test_pos(3,17, apply_rope)
    pos2 = test_pos(5,19, apply_rope)
    pos3 = test_pos(5,20, apply_rope)
    assert not jnp.isclose(pos0, pos1)
    assert jnp.isclose(pos1, pos2)
    assert not jnp.isclose(pos2, pos3)

#rope_tests()

def rope_tests2():
    rotary_dims = 32
    vectors = np.random.random(size=(2,75))
    def test_pos(pos1, pos2, f):
        q = np.zeros(shape=(64,75))
        v = np.zeros(shape=(64,75))
        q[pos1] = vectors[0]
        v[pos2] = vectors[1]
        q = q[:,None,None,:]
        v = v[:,None,None,:]
        q = f(q,rotary_dims,0).transpose(1,2,0,3)
        v = f(v,rotary_dims,0).transpose(1,2,0,3)
        q = jnp.squeeze(q)
        q = jnp.squeeze(q)
        v = jnp.squeeze(v)
        v = jnp.squeeze(v)
        res = q@v.transpose()
        return res[pos1,pos2]
    
    pos0 = test_pos(3,17, lambda x,y,z: x)
    pos1 = test_pos(3,17, apply_rope)
    pos2 = test_pos(5,19, apply_rope)
    pos3 = test_pos(5,20, apply_rope)
    assert not jnp.isclose(pos0, pos1)
    assert jnp.isclose(pos1, pos2)
    assert not jnp.isclose(pos2, pos3)

#rope_tests2()

In [3]:
class MultiHeadAttention(hk.Module):
    def __init__(
            self,
            num_heads: int,
            head_size: int,
            rotary_dims: int, 
            w_init_scale: float,
            attn_mask: jnp.ndarray = None,
            name: str = "mha",
    ):
        super().__init__(name=name)
        self.num_heads = num_heads
        self.model_size = head_size * num_heads
        self.w_init = hk.initializers.VarianceScaling(w_init_scale)
        self.attn_mask = attn_mask
        self.rotary_dims = rotary_dims

        self.in_proj_weight = hk.get_parameter("in_proj_weight", shape=[self.model_size * 3, self.model_size], init=self.w_init)
        self.in_proj_bias = hk.get_parameter("in_proj_bias", shape=[self.model_size * 3], init=self.w_init)
        self.out_proj = hk.Linear(self.model_size, name="out_proj")

    def __call__(
            self,
            x: jnp.ndarray,
    ) -> jnp.ndarray:
        """Compute (optionally masked) MHA with queries, keys & values."""
        all_out = jnp.dot(x, self.in_proj_weight.transpose())
        all_out += self.in_proj_bias

        q, k, v = jnp.array_split(all_out, 3, axis=-1)

        query_heads = self._split(q)
        key_heads = self._split(k)
        value_heads = self._split(v)
        # RoPE
        query_heads = apply_rope(query_heads, self.rotary_dims, seq_dim=0)
        key_heads = apply_rope(key_heads, self.rotary_dims, seq_dim=0)
        
        attention_logits = jnp.einsum("tbhd,Tbhd->bhtT", query_heads, key_heads)
        sqrt_key_size = np.sqrt(self.model_size//self.num_heads).astype(k.dtype)
        attention_logits = attention_logits / sqrt_key_size

        if self.attn_mask is not None:
            attention_logits += self.attn_mask

        attention_weights = jax.nn.softmax(attention_logits)
        attention = jnp.einsum("bhtT,Tbhd->tbhd", attention_weights, value_heads)
        # Concatenate attention matrix of all heads into a single vector.
        attention_vec = jnp.reshape(attention, (*q.shape[:2], -1))

        return self.out_proj(attention_vec)

    def _split(
            self,
            x: jnp.ndarray,
    ) -> jnp.ndarray:
        return x.reshape((*x.shape[:2], self.num_heads, self.model_size//self.num_heads))


class QuickGELU(hk.Module):
    def __call__(self, x: jnp.ndarray):
        return x * jax.nn.sigmoid(1.702 * x)


class ResidualAttentionBlock(hk.Module):
    def __init__(self, d_model: int, n_head: int, rotary_dims:int, attn_mask: jnp.ndarray, name: str):
        super().__init__(name=name)
        self.attn = MultiHeadAttention(n_head, d_model // n_head, rotary_dims, 1, attn_mask, name="attn")
        self.ln_1 = LayerNorm(-1, create_scale=True, create_offset=True, name="ln_1")
        with hk.experimental.name_scope("mlp"):
            self.mlp = [hk.Linear(d_model * 4, name="c_fc"),
                        QuickGELU(),
                        hk.Linear(d_model, name="c_proj")]

        self.ln_2 = LayerNorm(-1, create_scale=True, create_offset=True, name="ln_2")

    def run_mlp(self, x: jnp.ndarray):
        for f in self.mlp:
            x = f(x)
        return x

    def __call__(self, x: jnp.ndarray):
        x = x + self.attn(self.ln_1(x))
        x = x + self.run_mlp(self.ln_2(x))
        return x


class Transformer(hk.Module):
    def __init__(self, width: int, layers: int, heads: int, rotary_dims: int, name: str, attn_mask=None):
        super().__init__(name=name)
        self.width = width
        self.layers = layers
        self.resblocks = [ResidualAttentionBlock(width, heads, rotary_dims, attn_mask, name=f"resblocks{i}") for i in range(layers)]
        self.attn_mask = attn_mask

    def __call__(self, x: jnp.ndarray):
        for b in self.resblocks:
            x = b(x)
        return x

In [4]:
import clip_jax

_, text_fn, jax_params, _ = clip_jax.load('ViT-B/32', "cpu")



In [5]:
class TextCLIP(hk.Module):
    @hk.transparent
    def __init__(self,
                 embed_dim: int,
                 context_length: int,
                 vocab_size: int,
                 rotary_dims: int, 
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int,
                 seq_length:int = None
                 ):
        super().__init__()

        self.context_length = context_length
        if seq_length is None:
            seq_length = context_length
        self.seq_length = seq_length
            
        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            rotary_dims=rotary_dims,
            attn_mask=self.build_attention_mask(),
            name="transformer"
        )

        self.vocab_size = vocab_size
        self.token_embedding = hk.Embed(vocab_size, transformer_width, name="token_embedding")

        scale = transformer_width ** -0.5
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(scale))
        self.ln_final = LayerNorm(-1, create_scale=True, create_offset=True, name="ln_final")

        self.text_projection = hk.get_parameter("text_projection", shape=[transformer_width, embed_dim], init=w_init)
        self.logit_scale = hk.get_parameter("logit_scale", shape=[], init=hk.initializers.Constant(1))

    def build_attention_mask(self):
        # we use additive attention mask; fill with -inf
        mask = jnp.zeros((self.seq_length, self.seq_length))
        mask -= 10e10
        # make zeroes in place of context windows, -inf otherwise
        mask = jnp.triu(mask, self.context_length).transpose() + jnp.triu(mask, 1)
        return mask

    def encode(self, text):
        x = self.token_embedding(text)  # [batch_size, d_input, d_model]

        x = x.transpose((1, 0, 2))  # NLD -> LND

        x = self.transformer(x)
        x = x.transpose((1, 0, 2))  # LND -> NLD
        x = self.ln_final(x) @ self.text_projection
        return x
    
    def encode_text(self, text):
        x = self.encode(text)
        # x.shape == [batch_size, n_ctx, transformer.width]
        # take features from the last non-zero token 
        pos = jnp.cumsum(text, axis=-1).argmax(axis=-1)
        x = x[jnp.arange(x.shape[0]), pos] 
        return x

In [6]:
def encode_text(config, tokens):
    clip = TextCLIP( # the same as orig, except context_length = 77 - 2 (special tokens)
                     embed_dim = 512, 
                     context_length = 75, 
                     vocab_size = 49408,
                     # can possibly vary
                     rotary_dims = config["rotary_dims"], 
                     transformer_width = config["d_model"],
                     transformer_heads = config["n_heads"],
                     transformer_layers = config["layers"])
    return clip.encode_text(tokens)


In [7]:
def cfg_encode_text(config, tokens):
    clip = TextCLIP( # the same as orig, except context_length = 77 - 2 (special tokens)
                     embed_dim = 512, 
                     context_length = 75, 
                     vocab_size = 49408,
                     # can possibly vary
                     rotary_dims = config["rotary_dims"], 
                     transformer_width = config["d_model"],
                     transformer_heads = config["n_heads"],
                     transformer_layers = config["layers"])
    return clip.encode_text(tokens)

In [8]:
default_config = {    "layers": 12,
    "d_model": 512,
    "n_heads": 8,
    "rotary_dims": 32,}

In [9]:
config = default_config

In [10]:
        encode_text = partial(cfg_encode_text,config)


In [11]:
clip_init_fn = hk.transform(hk.experimental.optimize_rng_use(encode_text)).init

In [12]:
key = jax.random.PRNGKey(42)


In [13]:
input_shape = (1,75)

In [14]:
        key = hk.PRNGSequence(42)
        x = jax.random.randint(next(key), (jax.local_device_count(),75), 0, 49408)
        
        clip_apply_fn = hk.without_apply_rng(hk.transform(encode_text)).apply


In [15]:
params = clip_init_fn(next(key), x)

In [16]:
xs = jax.random.randint(next(key), (64,2048), 0, 49408)

In [17]:
context_len = 75

def reshape_data(data):
    # trim data to be multiple of context_len
    chunks = data.shape[-1] // context_len
    data = data[...,:chunks*context_len]
    # bs, chunks*context_len -> bs*chunks, context_len
    data = data.flatten()
    data = data.reshape(-1, context_len)

    # trim data to be multiple of (context_len, context_len)
    chunks = data.shape[0] // context_len
    data = data[:chunks*context_len]
    # chunks*context_len, context_len -> chunks, context_len, context_len
    data = data.reshape((-1, context_len, context_len))
    # make all kind of input length - from 1 to 75 by zeroing the rest
    return np.tril(data)

def add_sot_eot(data):
    sot_token, eot_token = 49406, 49407
    # add sot_token to the start of each sample
    sots = np.full((data.shape[0],context_len,1), sot_token)
    data = np.concatenate((sots, data),axis=-1) 
    # add zeros column at end, so we can add eot_token for the last sample
    zeroes = np.full((data.shape[0],context_len,1), 0)
    data = np.concatenate((data, zeroes),axis=-1) 
    # make diag eot_token matrix
    eots = np.diagflat(np.full(context_len, eot_token))
    # move eot_token two position from diagonal
    zeroes = np.full((context_len,2), 0)
    eots = np.concatenate((zeroes, eots),axis=-1) 
    # place eot_token
    return data + eots

def align_to_devices(data):
    data = data.reshape(-1, data.shape[-1])
    chunks = data.shape[0] // jax.device_count()
#    return data[:chunks*jax.device_count()]
    return data[:jax.device_count()]



In [18]:
    data = reshape_data(xs)    
    orig_data = add_sot_eot(data.copy())
    
    data, orig_data = map(align_to_devices, (data, orig_data))

In [19]:
data.shape

(4, 75)

In [20]:
x = data
xs = x

In [21]:
clip_apply_fn(params, xs)

DeviceArray([[  62.17662  , -261.82703  ,  -53.46753  , ...,
                 3.7723002, -105.044365 ,   15.762376 ],
             [  84.91298  ,  -52.361774 ,  -51.79109  , ...,
                36.110058 ,   87.18306  ,  -28.854849 ],
             [ 106.10378  ,    5.4214125,   -1.3591211, ...,
               -64.286476 , -159.97821  ,  -91.17958  ],
             [-138.82695  ,  147.98439  ,  -17.922964 , ...,
               -32.472694 , -248.58313  ,  -43.378807 ]], dtype=float32)

In [22]:
from jax.experimental.maps import mesh
from jax.experimental.pjit import PartitionSpec, pjit, with_sharding_constraint

In [23]:
PS=PartitionSpec('devices')

In [25]:
import optax

In [26]:
    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1),
    )



In [27]:
 def init(key, xs):
            params = clip_init_fn(key, tokens = xs)
            opt_state = optimizer.init(params)
            
            return {
                "params": params,
                "step": np.array(0),
                "opt_state": opt_state
            }


In [28]:
x.shape

(4, 75)

In [29]:
init_pjit = pjit(init,
                               in_axis_resources=(None, PS),
                               out_axis_resources=(None))

  warn("pjit is an experimental feature and probably has bugs!")


In [30]:
with mesh(jax.devices(), ('devices',)):
    state = init_pjit(next(key), x)

In [31]:
        _, clip_target, _, _ = clip_jax.load('ViT-B/32', "cpu")


In [32]:

        def train_loss(params, x, y):
            return jnp.mean(jnp.square(clip_apply_fn(params, x) - clip_target(y)))
            
        def train(state, x, y):
            val_grad_fn = jax.value_and_grad(train_loss)
            loss, grad = val_grad_fn(state["params"], x, y)
            updates, new_opt_state = optimizer.update(grad, state["opt_state"], state["params"])
            prms = optax.apply_updates(state["params"], updates)
            #stp = state["step"] + 1
            return loss, prms
            '''
            {
                "params": optax.apply_updates(state["params"], updates),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }
            ''' 
        PS=PartitionSpec('devices')
        train_pjit = pjit(train,
                               in_axis_resources=(None, PS, PS),
                               out_axis_resources=(None))
    

In [34]:
with mesh(jax.devices(), ('devices',)):
        loss, state = train(state, data, orig_data)


In [35]:
loss

DeviceArray(9353.357, dtype=float32)

In [42]:
        loss = np.array(loss)

In [43]:
x,y = data, orig_data

In [44]:
train_loss_pjit = pjit(train_loss,
                               in_axis_resources=(None, PS, PS),
                               out_axis_resources=(None))

In [45]:
            val_grad_fn = jax.value_and_grad(train_loss)

In [40]:
with mesh(jax.devices(), ('devices',)):
    res = train_loss_pjit(state["params"], x, y)

KeyError: 'params'

In [46]:
res

NameError: name 'res' is not defined

In [33]:
train_pjit = pjit(train,
                   in_axis_resources=(None, PS, PS),
                   out_axis_resources=(None),
                   donate_argnums=(0,)
                 )

In [1]:
with mesh(jax.devices(), ('devices',)):
        loss, state = train_pjit(state, data, orig_data)


NameError: name 'mesh' is not defined

In [None]:
train_loss = pjit(clip_target,
                               in_axis_resources=(None, PS, None),
                               out_axis_resources=(None))

In [None]:
train_loss = clip_target(y)

In [38]:
with mesh(jax.devices(), ('devices',)):
    loss, grad = val_grad_fn(state["params"], x, y)


AssertionError: 

In [None]:
        train_pjit = pjit(train,
                               in_axis_resources=(None, PS, None),
                               out_axis_resources=(None))


In [None]:
with mesh(jax.devices(), ('devices',)):
        #print(f"shapes {obs.shape} {target.shape}" )
        loss, state = train_pjit(state, data, orig_data)

In [66]:
loss

ShardedDeviceArray(311684.38, dtype=float32)

In [40]:
loss

ShardedDeviceArray(615782.3, dtype=float32)

In [138]:
clip_apply_pjit = pjit(clip_apply_fn, in_axis_resources=(None, PS), out_axis_resources=PS)

In [139]:
with mesh(jax.devices(), ('devices',)):
    result = clip_apply_pjit(params, xs)
print(result)

[[ -45.10905   -11.829655  -10.228348 ...  -30.686451   45.81359
    28.876823]
 [  -9.172699   21.597015   76.27464  ... -178.04343    64.722885
   -40.204914]
 [  97.46905   -39.03131    66.35701  ...   97.62146  -145.99475
   114.75131 ]
 [ -71.22487    17.125637   84.63946  ...  -99.49795   -37.36209
    51.03698 ]]


In [63]:
        PS=PartitionSpec('devices')
        train_pjit = pjit(train,
                               in_axis_resources=(None, PS, PS),
                               out_axis_resources=(None))



In [None]:
loss, self.state = self.train_pjit(self.state, obs, target)

In [None]:
with mesh(jax.devices(), ('devices',)):
    result = train_pjit(result, key, xs)
print(result)

In [56]:
class ClipTrainer:
    def __init__(self, config):
        self.config = config
        optimizer = config["optimizer"]
        
        _, clip_target, _, _ = clip_jax.load('ViT-B/32', "cpu")
        
        clip_init_fn = hk.transform(hk.experimental.optimize_rng_use(partial(cfg_encode_text,config))).init
            
        def init(key, xs):
            params = clip_init_fn(key, xs = xs)
            opt_state = optimizer.init(params)
            
            return {
                "params": params,
                "step": np.array(0),
                "opt_state": opt_state
            }

        key = hk.PRNGSequence(42)
        x = jax.random.randint(key, (jax.local_device_count(),75), 0, 49408)
        
        clip_apply_fn = hk.without_apply_rng(hk.transform(encode_text)).apply

        def train_loss(params, x, y):
            return jnp.mean(jnp.square(clip_apply_fn(params, x) - clip_target(y)))
            
        def train(state, x, y):
            val_grad_fn = jax.value_and_grad(train_loss)
            loss, grad = val_grad_fn(state["params"], x, y)
            updates, new_opt_state = optimizer.update(grad, state["opt_state"], state["params"])
            
            return loss, {
                "params": optax.apply_updates(state["params"], updates),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }
        
        PS=PartitionSpec('devices')
        self.train_pjit 
        self.train_pjit = pjit(train,
                               in_axis_resources=(None, PS, PS),
                               out_axis_resources=(None))

        self.eval_pjit = pjit(train_loss,
                              in_axis_resources=(None, PS, PS),
                              out_axis_resources=(None))

        self.state = self.init(next(key), x)
        self.eval_weights = None

        param_count = hk.data_structures.tree_size(self.state['params'])
        head_print(f"Total parameters: {param_count * dp}")

    def write_ckpt(self, path, _):
        write_ckpt_v2(self.state, path)

    def load_ckpt(self, path):
        self.state = load_ckpt_v2(self.state, path)

    def train(self, sample):
        obs = sample["obs"]
        target = sample["target"]
        
        loss, self.state = self.train_pjit(self.state, obs, target)
        loss = np.array(loss)

        return loss.mean()

    def eval(self, sample):
        out = self.eval_pjit(self.state["params"], sample["obs"], sample["target"])
        return out

In [57]:
import pickle
from smart_open import open

def write_ckpt(x, ckpt_dir):
    pickle.dump(x['params'], open(ckpt_dir + '/params.pickle', "wb"))
    pickle.dump(x['opt_state'], open(ckpt_dir + '/opt_state.pickle', "wb"))

def read_ckpt(ckpt_dir, load_opt=True):
    result = {'params': pickle.load(open(ckpt_dir + '/params.pickle', 'rb'))}
    if load_opt:
        result['opt_state'] = pickle.load(open(ckpt_dir + '/opt_state.pickle', 'rb'))
    return result

In [None]:
import argparse
import json
import time

import jax
import numpy as np
import optax

import wandb
from tqdm import tqdm

from text_clip import ClipTrainer
from tfrecord_loader import TFRecordNewInputs
from smart_open import open
from google.cloud import storage
from google.cloud.exceptions import NotFound


def parse_args():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="""
    To use, download the full checkpoint archive, extract and upload to a GCS bucket, and set that as --tune-model-path
    Modify the config file:
        - set `model_dir` to where the checkpoints should be written during training
        - set `train_set`, `val_set` to index files for your data
        - set `warmup_steps`, `anneal_steps`, `lr`, `end_lr` to the lr schedule for your finetuning run
        - the global step will reset to 0, keep that in mind when writing your lr schedule
        - set `name` to specify the name of the Weights & Biases run
        - set `wandb_project` to specify the Weights & Biases project to log to
    To prepare data in the expected data format:
        - use the script `create_finetune_tfrecords.py` in this repo to create data in the expected format
        - upload the .tfrecords files to GCS
        - save their GCS paths to a index file under `data/`, see existing files for examples
    """,
    formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--config", type=str, default=None, help="Config file location")
    parser.add_argument("--tune-model-path", type=str, default=None, help="Base model to finetune")
    parser.add_argument("--fresh-opt", default=False, action="store_true", help="Use a newly initialized optimizer, ignoring any optimizer state saved in the base checkpoint")

    args = parser.parse_args()
    return args

import pickle
from smart_open import open

def write_ckpt(x, ckpt_dir):
    pickle.dump(x['params'], open(ckpt_dir + '/params.pickle', "wb"))
    pickle.dump(x['opt_state'], open(ckpt_dir + '/opt_state.pickle', "wb"))

def read_ckpt(ckpt_dir, load_opt=True):
    result = {'params': pickle.load(open(ckpt_dir + '/params.pickle', 'rb'))}
    if load_opt:
        result['opt_state'] = pickle.load(open(ckpt_dir + '/opt_state.pickle', 'rb'))
    return result

def save(network, step, bucket, path, aux=None, keep_n=3, delete_old=True):
    assert path
    client = storage.Client()

    if aux is None:
        aux = {}

    try:
        with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
            meta = json.load(f)
    except:
        # create metadata file
        with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
            json.dump({
                "step": 0,
                "checkpoints": [],
                "aux": {}
            }, f)

    # do sharded checkpoint writing
    start = time.time()
    res = []
    write_ckpt(network.state, f"gs://{bucket}/{path}/step_{step}/")

    print(f"Wrote checkpoint in {time.time() - start:.06}s")

    with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
        meta = json.load(f)

    meta["step"] = step
    meta["checkpoints"].append(step)
    all_aux = meta.get("aux", {})

    while len(meta["checkpoints"]) > keep_n:
        ckpt_to_delete = meta["checkpoints"].pop(0)

        try:
            del all_aux[str(ckpt_to_delete)]
        except:
            print(f"failed to delete the aux state for {step}")

        if delete_old:
            print(f"deleting checkpoint {ckpt_to_delete}")
            for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"):
                # print(f"deleting {blob.name}")
                assert path in blob.name
                blob.delete()
        else:
            print(f"keeping checkpoint {ckpt_to_delete}")

    all_aux[step] = aux
    meta["aux"] = all_aux

    with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
        json.dump(meta, f)

def gpt3_schedule(warmup_steps,
                  total_steps,
                  peak_lr,
                  end_lr):
    def sch(step):
        warmup_pct = jnp.clip(step, 0, warmup_steps) / warmup_steps
        anneal_pct = jnp.clip(step - warmup_steps, 0, total_steps) / total_steps

        return warmup_pct * peak_lr - (peak_lr - end_lr) * (1 - jnp.cos(jnp.pi * anneal_pct)) / 2

    return sch

context_len = 75

def reshape_data(data):
    # trim data to be multiple of context_len
    chunks = data.shape[-1] // context_len
    data = data[...,:chunks*context_len]
    # bs, chunks*context_len -> bs*chunks, context_len
    data = data.flatten()
    data = data.reshape(-1, context_len)

    # trim data to be multiple of (context_len, context_len)
    chunks = data.shape[0] // context_len
    data = data[:chunks*context_len]
    # chunks*context_len, context_len -> chunks, context_len, context_len
    data = data.reshape((-1, context_len, context_len))
    # make all kind of input length - from 1 to 75 by zeroing the rest
    return np.tril(data)

def add_sot_eot(data):
    sot_token, eot_token = 49406, 49407
    # add sot_token to the start of each sample
    sots = np.full((data.shape[0],context_len,1), sot_token)
    data = np.concatenate((sots, data),axis=-1) 
    # add zeros column at end, so we can add eot_token for the last sample
    zeroes = np.full((data.shape[0],context_len,1), 0)
    data = np.concatenate((data, zeroes),axis=-1) 
    # make diag eot_token matrix
    eots = np.diagflat(np.full(context_len, eot_token))
    # move eot_token two position from diagonal
    zeroes = np.full((context_len,1), 0)
    eots = np.concatenate((zeroes, eots),axis=-1) 
    eots = np.concatenate((zeroes, eots),axis=-1) 
    # place eot_token
    return data + eots

def network_step(network, data):
    data = reshape_data(data)    
    orig_data = add_sot_eot(data.copy())
    
    data = data.reshape(-1, context_len)
    orig_data = orig_data.reshape(-1, context_len+2)
    
    inputs = {
        "obs": data,
        "target": orig_data,
    }

    return network.train(inputs)

if __name__ == "__main__":
    args = parse_args()
    params = json.load(open(args.config))

    gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
    batch_size = params["batch_size"]

    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]

    val_batches = params["val_batches"]
    val_every = params["val_every"]
    ckpt_every = params["ckpt_every"]
    keep_every = params["keep_every"]
    total_steps = params["total_steps"]

    warmup_steps = params["warmup_steps"]
    anneal_steps = params["anneal_steps"]
    lr = params["lr"]
    end_lr = params["end_lr"]
    weight_decay = params["weight_decay"]
   
    # alpha parameter for the exponential moving averages used to compute B_simple
    noise_scale_alpha = params.get("noise_scale_alpha", 0.01)

    scheduler = gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)
    
    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps),
        optax.clip_by_global_norm(1),
        optax.scale_by_adam(),
        optax.add_decayed_weights(weight_decay),
        optax.scale(-1),
        optax.scale_by_schedule(scheduler)
    )

    params["optimizer"] = opt

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    devices = jax.devices()

    # pick initial ckpt - based on tuning vs train from scratch

    step = 0
    initial_ckpt_state_path = None
    train_loader = None

    if args.tune_model_path:
        print('`--tune_model_path` passed: we are beginning a fine-tuning run')
        fine_tuning = True
        initial_ckpt_state_path = args.tune_model_path
    else:
        print('`--tune_model_path` not passed: we are continuing a fine-tuning run from a checkpoint (or we are not fine-tuning)')
        fine_tuning = False
        initial_ckpt_model_dir = model_dir
        initial_ckpt_path = f"gs://{bucket}/{initial_ckpt_model_dir}"
        meta_path = f"{initial_ckpt_path}/meta.json"

        try:
            with open(meta_path, "r") as f:
                meta = json.load(f)
            ckpt_step = meta["checkpoints"][-1]
            initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/"
            print(f"state will be restored from checkpoint {ckpt_step}")

            step = ckpt_step
            train_loader = meta['aux'][str(ckpt_step)].get("train_loader", None)
        except NotFound:
            # no checkpoint, start at zero
            print(f"No checkpoint to load at {initial_ckpt_path}. Training from scratch.")

    if initial_ckpt_state_path:
        print(f"path to load checkpoint from: {initial_ckpt_state_path}")
    else:
        print("not loading from a checkpoint")

    # set up datasets
    print("setting up datasets")

    train_dataset = TFRecordNewInputs(f"data/{params['train_set']}",
                                      batch_size=(
                                          gradient_accumulation_steps,
                                          batch_size),
                                      sample_size=2048,
                                      restore_state=train_loader)

    val_sets = {}

    for k, v in params["val_set"].items():
        val_sets[k] = TFRecordNewInputs(
            f"data/{v}", batch_size=(batch_size,), sample_size=seq
        )

    # tok/sec metrics
    sequences_per_step = gradient_accumulation_steps * batch_size * (2048//context_len)

    # load + run
    with jax.experimental.maps.mesh(devices, ('devices',)):
        print("initializing network")
        network = ClipTrainer(params)

        if initial_ckpt_state_path:
            print("loading network")
            if fine_tuning:
                # get the scheduler step stored in the just-initialized optimizer
                # should be zero
                init_sched_state = network.state["opt_state"][-1]

            start = time.time()
            network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1], load_opt=(not args.fresh_opt))

            if fine_tuning:
                # overwrite the loaded scheduler step with zeros
                # this makes fine-tuning use the lr schedule in
                network.state["opt_state"][-1] = init_sched_state

            print(f"network loaded in {time.time() - start:.06}s")

        print('compiling train fn')
        start = time.time()
        loss = network_step(network, train_dataset.get_samples())
        step += 1
        print(f"Train fn compiled in {time.time() - start:.06}s")

        print('compiling eval fn')
        start = time.time()
        for val_set in val_sets.values():
            network_step(network, val_set.get_samples())
            val_set.reset()
        print(f"Eval fn compiled in {time.time() - start:.06}s")
        '''
        project = params.get("wandb_project", "text-clip")
        wandb.init(project=project, name=params["name"], config=params)
        '''
        while True:
            if (step % ckpt_every == 1) or step == total_steps:
                print(f"saving a checkpoint for step {step}")
                save(network, step, bucket, model_dir,
                     aux={"train_loader": train_dataset.get_state()},
                     delete_old=True,
                     )

            if step % val_every == 1:  # 1 because we've already taken a step to compile train fn
                for name, val_set in val_sets.items():
                    val_loss = []
                    for i, _ in tqdm(zip(val_set.sample_once(), range(val_batches)),
                                     desc=f"validation for step {step}, set {name}",
                                     total=val_batches):
                        val_loss.append(network_step(network, i))
                    val_set.reset()

                    val_loss = np.array(val_loss).mean()
                    print(f"validation loss for step {step}, set {name}: {val_loss}")

                    #wandb.log({f'val/loss_{name}': float(val_loss)}, step)

            if step == total_steps:
                print("training completed!")
                exit()

            start = time.time()
            loss = network_step(network, train_dataset.get_samples())
            step += 1

            steps_per_sec = 1 / (time.time() - start)
            sequences_processed = sequences_per_step * step

            ### compute summary stats about the gradient

            # converts from grads-summed-over-microbatch (what `CasualTransformer.train` computes)
            # to grads-averaged-over-microbatch (what we want)
            #
            # (when taking gradient steps, the same conversion happens inside the optimizer
            #  via optax.scale(1 / gradient_accumulation_steps))
            
            '''
            wandb_stats = {
                "train/loss": loss,
                "train/steps_per_sec": steps_per_sec,
                "train/tokens_per_sec": tokens_per_sec,
                "train/learning_rate": float(scheduler(network.state["opt_state"][-1].count[0].item())),
                "sequences_processed": sequences_processed,
            }
            wandb_stats.update(noise_scale_stats)

            wandb.log(wandb_stats, step)
            '''