In [1]:
import pickle

In [2]:
import zstandard

In [4]:
w = 'asdfasdf'

In [5]:
def z_write(data, fname):
    cctx = zstandard.ZstdCompressor(level=10)
    with open(fname, 'wb') as f, cctx.stream_writer(f) as z:
        pickle.dump(data, z)

In [6]:
def z_read(fname):
    dctx = zstandard.ZstdDecompressor()
    with open(fname, 'rb') as f, dctx.stream_reader(f) as z:
        return pickle.load(z)

In [7]:
z_write(w, 'data.pickle.zst')

In [8]:
z_read('data.pickle.zst')

'asdfasdf'

In [10]:
cctx = zstandard.ZstdCompressor(level=10)
with open('data.pickle.zst', 'wb') as f, cctx.stream_writer(f) as z:
    pickle.dump(w, z)

In [11]:
dctx = zstandard.ZstdDecompressor()
with open('data.pickle.zst', 'rb') as f, dctx.stream_reader(f) as z:
    w = pickle.load(z)

In [12]:
w

'some d'

In [1]:
import os
#os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

os.environ["XLA_FLAGS--xla_gpu_cuda_data_dir"] = "/etc/alternatives/cuda"

In [419]:
import numpy as np
from PIL import Image
import jax
import clip_jax
import jax.numpy as jnp
from einops import rearrange, repeat
from functools import lru_cache


In [3]:
_, text_fn, jax_params, _ = clip_jax.load('ViT-B/32', "cpu")



In [4]:
v,_ = jax.tree_util.tree_flatten({k:v for k,v in jax_params.items() if 'visual' not in k})

In [5]:
sum([a.size for a in v])

63428097

In [6]:
text = clip_jax.tokenize(["a diagram", "a dog", "a cat"])
text_embed = text_fn(text)

In [7]:
text.shape

(3, 77)

In [8]:
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}

In [9]:
from clip_jax.simple_tokenizer import SimpleTokenizer as Tokenizer
tokenizer = Tokenizer()

In [10]:
sot_token = jnp.array([tokenizer.encoder["<|startoftext|>"]])
eot_token = jnp.array([tokenizer.encoder["<|endoftext|>"]])

In [11]:
sot_token, eot_token

(DeviceArray([49406], dtype=int32), DeviceArray([49407], dtype=int32))

In [12]:
texts = [v for k,v in descriptions.items()]

In [13]:
tokens = [tokenizer.encode(t)[:7] for t in texts]

In [14]:
tokens

[[320, 2504, 539, 4160, 781, 10551, 9512],
 [320, 11909, 1125, 539, 320, 36145, 2368],
 [320, 5352, 539, 550, 18376, 593, 518],
 [320, 8383, 2862, 525, 320, 31168, 7601],
 [320, 736, 10297, 2862, 530, 320, 8474],
 [320, 2533, 1312, 536, 320, 3934, 525],
 [320, 1449, 268, 537, 268, 1579, 26149],
 [320, 1937, 539, 2453, 525, 320, 42272]]

In [15]:
d_model = 2048
clip_text_context_lenght = 77
bs = 64

In [16]:
x = np.random.randint(eot_token, size=(bs,d_model)) 

In [17]:
x

array([[24913, 44619,  3472, ..., 26234, 45531, 39767],
       [26249,  5637, 44718, ..., 27296,  8147, 10318],
       [17890, 33893, 14524, ..., 15978, 42917, 35486],
       ...,
       [21018, 41703, 29179, ..., 10023, 25975, 42948],
       [40082, 30374, 31740, ..., 19863, 22775, 13138],
       [14923, 24015, 42782, ..., 12991,  3664, 22301]])

In [18]:
def clip_row(row):
    row = np.trim_zeros(row)
    # pad with the start and end tokens, the clip network was trained to use them
    row = np.concatenate((sot_token, row, eot_token))
    row = np.pad(row, (0, clip_text_context_lenght-len(row)))
    return row

In [19]:
def clip_tokens(tokens):
    # N -> NxN
    matrix = np.repeat(tokens[None,], tokens.shape[0], axis=0)
    # use triag matrix, so we dont look forward during the training
    matrix = np.tril(matrix).transpose()
    # zero all tokens that not fit into the clip context window. -2 so we can pad with 2 special tokens after.
    matrix = matrix*((matrix!=0).cumsum(1)<=clip_text_context_lenght-2)
    matrix = matrix.transpose()
    return np.apply_along_axis(clip_row, 1, matrix)

In [20]:
x.shape

(64, 2048)

In [21]:
from multiprocessing import Pool

%%time
with Pool(64) as p:
    clip_input = jnp.array(list(p.map(clip_tokens, x)))

In [22]:
def clip_row(row):
    row = np.trim_zeros(row)
    # pad with the start and end tokens, the clip network was trained to use them
    row = np.concatenate((sot_token, row, eot_token))
    row = np.pad(row, (0, clip_text_context_lenght-len(row)))
    return row

In [23]:
def clip_tokens(tokens):
    max_len = clip_text_context_lenght-2
    matrix = [tokens[max(0, i+1 - max_len):i+1] for i in range(tokens.size)]
    matrix = [clip_row(row) for row in matrix]
    return np.array(matrix)

In [24]:
clip_tokens(x[0])

array([[49406, 24913, 49407, ...,     0,     0,     0],
       [49406, 24913, 44619, ...,     0,     0,     0],
       [49406, 24913, 44619, ...,     0,     0,     0],
       ...,
       [49406, 24276,  4039, ..., 28252, 26234, 49407],
       [49406,  4039, 43296, ..., 26234, 45531, 49407],
       [49406, 43296, 12280, ..., 45531, 39767, 49407]])

In [25]:
%%time
with Pool(64) as p:
    clip_input = np.array(list(p.map(clip_tokens, x)))

CPU times: user 80.2 ms, sys: 2.01 s, total: 2.09 s
Wall time: 2.28 s


In [26]:
%%time
clip_input = np.array([clip_tokens(tokens) for tokens in x])

CPU times: user 4.84 s, sys: 40.5 ms, total: 4.88 s
Wall time: 4.87 s


In [27]:
clip_input.shape

(64, 2048, 77)

In [28]:
b, l, c = clip_input.shape
res = clip_input.reshape(b*l, c)

In [29]:
%%time
text_fn(res)

2021-08-21 16:49:18.661780: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:462] Allocator (GPU_0_bfc) ran out of memory trying to allocate 19.25GiB (rounded to 20669530112)requested by op 
2021-08-21 16:49:18.662498: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:473] *************************************************************************************************___
2021-08-21 16:49:18.662555: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: Resource exhausted: Out of memory while trying to allocate 20669530112 bytes.


RuntimeError: Resource exhausted: Out of memory while trying to allocate 20669530112 bytes.

In [39]:
from jax.experimental.maps import xmap, mesh

In [40]:
devices = jax.devices()

In [41]:
devices

[GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0),
 GpuDevice(id=3, process_index=0)]

In [42]:
d = xmap(text_fn,
         in_axes=(['batch', ...]),
         out_axes=(['batch', ...]),
         axis_resources={'batch': 'b'}) 

In [43]:
%%time
with mesh(jax.devices(), 'b'):
    clip_embedding = d(clip_input)
clip_embedding.block_until_ready()
pass

2021-08-21 17:29:57.475219: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:275] Allocator (GPU_0_bfc) ran out of memory trying to allocate 14.45GiB with freed_by_count=0. The caller indicates that this is not a failure, but may mean that there could be performance gains if more memory were available.


RuntimeError: Resource exhausted: Out of memory while trying to allocate 15518924800 bytes.

In [44]:
clip_embedding.shape

NameError: name 'clip_embedding' is not defined

In [45]:
def scanned(_, x):
    return _, text_fn(x)

In [46]:
%%time
_, clip_embedding = jax.lax.scan(scanned, None, xs=clip_input)
clip_embedding.block_until_ready()
pass

CPU times: user 48.9 s, sys: 20 s, total: 1min 8s
Wall time: 49.7 s


In [47]:
b, l, c = clip_input.shape
dc = jax.local_device_count()
res = clip_input.reshape(dc, b//dc*64, l//64, c)
res.shape

(4, 1024, 32, 77)

In [48]:
def scan(x):
    _, clip_embedding = jax.lax.scan(scanned, None, xs=x)
    return clip_embedding

In [49]:
%%time
clip_embedding = jax.pmap(scan)(res)
clip_embedding = clip_embedding.reshape(b, l, 512)
clip_embedding.block_until_ready()
pass

CPU times: user 52.3 s, sys: 18.8 s, total: 1min 11s
Wall time: 15.5 s


In [50]:
clip_embedding.shape

(64, 2048, 512)

In [51]:
from functools import partial

In [52]:
mapped = partial(jax.lax.map, text_fn)

In [53]:
%%time
clip_embedding = jax.pmap(mapped)(res)
clip_embedding = clip_embedding.reshape(b, l, 512)
clip_embedding.block_until_ready()
pass

CPU times: user 52.9 s, sys: 17.9 s, total: 1min 10s
Wall time: 13.6 s


In [54]:
max_seq_len = 2048

In [55]:
context_length = 75

In [56]:
transformer_width = 512

In [57]:
clip_embedding = clip_embedding[:,:75,:]

In [58]:
clip_embedding.shape

(64, 75, 512)

In [59]:
k = clip_embedding
q = clip_embedding
v = clip_embedding

In [60]:
clip_embedding

DeviceArray([[[-1.27166688e-01,  1.97072387e-01, -1.00960121e-01, ...,
               -4.07567397e-02,  8.89264867e-02, -5.05541377e-02],
              [-3.04048061e-01,  1.42297626e-01, -2.12778747e-01, ...,
                6.25961661e-01,  2.60466561e-02,  2.03563780e-01],
              [-2.54633695e-01,  1.61029398e-01, -2.61954725e-01, ...,
                5.58265090e-01, -1.81519408e-02,  4.72949669e-02],
              ...,
              [-3.66377890e-01,  1.48722440e-01, -7.17692636e-03, ...,
                2.59172529e-01,  3.92375477e-02,  3.31961364e-02],
              [-3.61875117e-01,  1.25733733e-01,  1.78049318e-03, ...,
                2.38077253e-01,  5.48923276e-02,  2.24524215e-02],
              [-2.90850163e-01,  1.95316732e-01, -4.30019833e-02, ...,
                3.00663948e-01, -3.73106450e-03, -1.06333271e-01]],

             [[ 3.99326608e-02, -6.54429048e-02, -5.32164238e-02, ...,
                1.52062416e-01, -2.04652101e-01, -4.92503792e-02],
             

In [None]:
Last non zero token - output
NOT resudial vs resudial

In [435]:
rotary_dims = 32 # a half of head output dim, not something like RoPE in GPT-J

In [607]:
MAX_SEQ_LEN = 8192

@lru_cache()
def fixed_pos_embedding(rotary_dims):
    inv_freq = 1. / (10000 ** (np.arange(0, rotary_dims, 2) / rotary_dims))
    sinusoid_inp = np.einsum('i , j -> i j', np.arange(MAX_SEQ_LEN), inv_freq)
    return np.sin(sinusoid_inp), np.cos(sinusoid_inp)

def rotate_every_two(x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = jnp.stack((-x2, x1), axis=-1)
    return rearrange(x, '... d j -> ... (d j)')

#@partial(jax.jit, static_argnums=(2,))
def apply_rotary_pos_emb(x, sincos, seq_dim):
    sincos = map(lambda t: repeat(t, '... b n -> ... b (n j)', j=2)[-x.shape[seq_dim]:], sincos)
    
    # (n_seq, dim_per_head) -> (n_seq, 1, 1, dim_per_head), so we can do mult
    # in case "x" is something like (n_seq, bs, n_heads, dim_per_head)
    add_dims = set(np.arange(x.ndim-1)) - set([np.arange(x.ndim)[seq_dim]])
    sin, cos = map(lambda t: jnp.expand_dims(t, tuple(add_dims)), sincos)
    
    return (x * cos) + (rotate_every_two(x) * sin)

In [615]:
@partial(jax.jit, static_argnums=(1,2))
def apply_rope(x, rotary_dims, seq_dim=1):
    x_rot = x[..., :rotary_dims]
    x_pass = x[..., rotary_dims:]
    sincos = fixed_pos_embedding(rotary_dims)
    x_rot = apply_rotary_pos_emb(x_rot, sincos, seq_dim)
    return jnp.concatenate([x_rot, x_pass], axis=-1)

In [616]:
def rope_tests():
    rotary_dims = 32
    vectors = np.random.random(size=(2,75))
    def test_pos(pos1, pos2, f):
        q = np.zeros(shape=(1,64,75))
        v = np.zeros(shape=(1,64,75))
        q[0,pos1] = vectors[0]
        v[0,pos2] = vectors[1]
        res = f(q,rotary_dims)@f(v,rotary_dims).transpose(0, 2, 1)
        return res[0,pos1,pos2]
    
    pos0 = test_pos(3,17, lambda x,y: x)
    pos1 = test_pos(3,17, apply_rope)
    pos2 = test_pos(5,19, apply_rope)
    pos3 = test_pos(5,20, apply_rope)
    assert not jnp.isclose(pos0, pos1)
    assert jnp.isclose(pos1, pos2)
    assert not jnp.isclose(pos2, pos3)

rope_tests()

In [617]:
def rope_tests2():
    rotary_dims = 32
    vectors = np.random.random(size=(2,75))
    def test_pos(pos1, pos2, f):
        q = np.zeros(shape=(64,75))
        v = np.zeros(shape=(64,75))
        q[pos1] = vectors[0]
        v[pos2] = vectors[1]
        q = q[:,None,None,:]
        v = v[:,None,None,:]
        q = f(q,rotary_dims,0).transpose(1,2,0,3)
        v = f(v,rotary_dims,0).transpose(1,2,0,3)
        q = jnp.squeeze(q)
        q = jnp.squeeze(q)
        v = jnp.squeeze(v)
        v = jnp.squeeze(v)
        res = q@v.transpose()
        return res[pos1,pos2]
    
    pos0 = test_pos(3,17, lambda x,y,z: x)
    pos1 = test_pos(3,17, apply_rope)
    pos2 = test_pos(5,19, apply_rope)
    pos3 = test_pos(5,20, apply_rope)
    assert not jnp.isclose(pos0, pos1)
    assert jnp.isclose(pos1, pos2)
    assert not jnp.isclose(pos2, pos3)

rope_tests2()

In [418]:
jnp.expand_dims(jnp.array([1,2]), ())

DeviceArray([1, 2], dtype=int32)

In [542]:
pos1 = 3

In [543]:
pos2 = pos1+17
q = np.zeros(shape=(1,64,75))
v = np.zeros(shape=(1,64,75))
q[0,pos1] = vectors[0]
v[0,pos2] = vectors[1]

In [544]:
res0 = q@v.transpose(0, 2, 1)
res0[0,pos1,pos2]

17.220902311676884

In [545]:
res1 = apply_rope(q,rotary_dims)@apply_rope(v,rotary_dims).transpose(0, 2, 1)
res1[0,pos1,pos2]

DeviceArray(9.897591, dtype=float32)

In [546]:
pos1 = 5

In [547]:
pos2 = pos1+17
q = np.zeros(shape=(1,64,75))
v = np.zeros(shape=(1,64,75))
q[0,pos1] = vectors[0]
v[0,pos2] = vectors[1]

In [548]:
res2 = apply_rope(q,rotary_dims)@apply_rope(v,rotary_dims).transpose(0, 2, 1)
res2=res2[0,pos1,pos2]
res2

DeviceArray(9.897591, dtype=float32)

In [549]:
pos2 = pos1+16
q = np.zeros(shape=(1,64,75))
v = np.zeros(shape=(1,64,75))
q[0,pos1] = vectors[0]
v[0,pos2] = vectors[1]

In [550]:
res3 = apply_rope(q,rotary_dims)@apply_rope(v,rotary_dims).transpose(0, 2, 1)
res3[0,pos1,pos2]

DeviceArray(9.807338, dtype=float32)

In [551]:
res3[0,pos1,pos2]-res2

DeviceArray(-0.09025288, dtype=float32)

In [538]:
res3[0,pos1,pos2]-res2

DeviceArray(-0.09025383, dtype=float32)

In [78]:
embed_dim: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int

SyntaxError: invalid syntax (4034806244.py, line 1)

In [652]:
context_length = 5
seq_lenght = 5

In [653]:
mask = jnp.zeros((seq_lenght, seq_lenght))
mask -= 10e10
# make zeroes in place of context windows
mask = jnp.triu(mask, context_length).transpose() + jnp.triu(mask, 1)  


In [654]:
mask

DeviceArray([[ 0.e+00, -1.e+11, -1.e+11, -1.e+11, -1.e+11],
             [ 0.e+00,  0.e+00, -1.e+11, -1.e+11, -1.e+11],
             [ 0.e+00,  0.e+00,  0.e+00, -1.e+11, -1.e+11],
             [ 0.e+00,  0.e+00,  0.e+00,  0.e+00, -1.e+11],
             [ 0.e+00,  0.e+00,  0.e+00,  0.e+00,  0.e+00]],            dtype=float32)

DeviceArray([[ 0.e+00, -1.e+11, -1.e+11, -1.e+11, -1.e+11],
             [ 0.e+00,  0.e+00, -1.e+11, -1.e+11, -1.e+11],
             [ 0.e+00,  0.e+00,  0.e+00, -1.e+11, -1.e+11],
             [-1.e+11,  0.e+00,  0.e+00,  0.e+00, -1.e+11],
             [-1.e+11, -1.e+11,  0.e+00,  0.e+00,  0.e+00]],            dtype=float32)

In [None]:
    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = jnp.zeros((seq_lenght, seq_lenght))
        mask -= 10e10
        mask = jnp.triu(mask, 1)  # zero out the lower diagonal
        return mask


In [645]:
import torch
import torch.nn as nn

In [646]:
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
output = cos(input1, input2)

In [648]:
output.shape

torch.Size([100])

In [649]:
output

tensor([ 0.0155,  0.0529,  0.1901, -0.0201,  0.0419,  0.0010,  0.0252, -0.0221,
         0.1344,  0.2112,  0.0273,  0.0225,  0.1488, -0.1109,  0.0742, -0.0657,
         0.0241,  0.1762, -0.0318, -0.0286,  0.1075, -0.0912, -0.0371,  0.1421,
        -0.0579,  0.0132,  0.0162, -0.0640, -0.0344, -0.1271,  0.0717,  0.0575,
        -0.0399,  0.1130,  0.0802,  0.0140,  0.0429, -0.1490, -0.0763, -0.0017,
         0.0865,  0.0190,  0.0124, -0.0139, -0.0402, -0.0498,  0.1039,  0.0720,
        -0.0385, -0.1752,  0.0201,  0.1232, -0.0860,  0.0839, -0.1284,  0.0702,
         0.0830, -0.1514,  0.0038,  0.0643, -0.0735,  0.1572,  0.2567,  0.1676,
         0.0304,  0.0659,  0.1055,  0.0475, -0.1839, -0.0283, -0.0106,  0.0377,
         0.1506, -0.0467, -0.2100, -0.1163,  0.1028, -0.0598, -0.0339, -0.0810,
         0.0007, -0.0107, -0.1102,  0.1285, -0.0298, -0.0173, -0.0733,  0.0044,
        -0.0856, -0.0936,  0.1001,  0.0335, -0.0664,  0.0400,  0.2534, -0.2347,
         0.0064, -0.0910, -0.0844, -0.00