# Test of log one hot positional encoding

## Common premable

In [2]:
import os, sys
sys.path.append(os.path.join(os.path.abspath(''), '../'))

import peewee as pw
from toyDb.databases import ExperimentDb, ShaderDb
from toyDb.utils.Directory import getToyDbRootDir

import matplotlib.pyplot as plt
import numpy as np
import json
from tqdm import tqdm
import torch

ExperimentDb.init_from_default_db()

In [3]:
def log_one_hot_pytorch(tensor: 'torch.Tensor', d_model: int, base=10):
    bsz, seq_len = tensor.size()
    tensor = tensor.view(bsz, seq_len, 1)

    # sequence [base^(d_model-1), base^(d_model-2), ..., base^0]
    # (d_model,)
    div = base ** torch.arange(d_model-1, -1, -1).to(tensor.device)
    
    # (bsz, seq_len, 1)
    # div broadcasts to (1, 1, d_model) and tensor to (bsz, seqlen, d_model)
    # https://pytorch.org/docs/stable/notes/broadcasting.html
    temp = tensor // div

    # mod with base results in each digit alone in range(0, base) in each of d_model dimension
    # shape (bsz, seq_len, d_model)
    tensor_d_model = temp % base

    return tensor_d_model.flip(2)

# Usage
def test(bsz, seq_len, d_model, base):
  tensor = torch.randint(0, 10000000000, (bsz, seq_len))  # random tensor between 0 and 9999 (fits in d_model=4 digits)
  print(tensor)
  tensor = tensor.float()
  print(tensor)
  # tensor = tensor.to(torch.int64)
  # print(tensor)
  result = log_one_hot_pytorch(tensor, d_model, base)
  print(result)

test(2, 5, 16, 2)
test(2, 5, 16, 10)

tensor([[ 350310311, 4240118520, 9333777395, 3387237762, 5541553840],
        [4940163684, 2745820344, 3627624808, 1872476548, 2258461142]])
tensor([[3.5031e+08, 4.2401e+09, 9.3338e+09, 3.3872e+09, 5.5416e+09],
        [4.9402e+09, 2.7458e+09, 3.6276e+09, 1.8725e+09, 2.2585e+09]])
tensor([[[0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0.,

In [4]:
a_l = torch.asarray([7575728212])
print(a_l)
a_f = a_l.float()
a_f_l = a_f.long()
print(a_f)
print(a_f_l)

log_one_hot_pytorch(a_l.view(1, 1), 10, 10)

# There are minor inaccuracy with this method, but overall good
# and it can handle float input
log_one_hot_pytorch(a_f.view(1, 1), 10, 10)

tensor([7575728212])
tensor([7.5757e+09])
tensor([7575728128])


tensor([[[8., 0., 0., 8., 2., 7., 5., 7., 5., 7.]]])

Application tip: use `-base` to get a balanced code?

The above method have serious bug for large d_model. Use bit ops instead (for base2 only)

In [5]:
bsz, seq_len, d_model = 2, 5, 768
shifted = True

# (bsz, seqlen)
trace_labels = torch.randint(0, 2**63-1, (bsz, seq_len), dtype=torch.int64)
print(trace_labels)

ones = torch.ones((d_model,), dtype=torch.int64, device=trace_labels.device)
# (d_model)
masks = torch.bitwise_left_shift(ones, torch.arange(d_model, device=trace_labels.device))

# (bsz, seq_len, d_model)
trace_embed = (torch.bitwise_and(trace_labels.view(bsz, seq_len, 1), masks.view(1, 1, d_model)) > 0).float()
if shifted:
    trace_embed -= 0.5
print(trace_embed)

# test for recovery
recovered = torch.sum((trace_embed > 0).int() * masks.view(1, 1, d_model), dim=-1)
print(recovered)

assert(torch.equal(recovered, trace_labels))

tensor([[1397778103705324293, 1593878440054493545, 7747133411190494865,
         6651259306569629801, 2456639253156878187],
        [ 608122149875274741, 3540975803657400709, 2898410917480899212,
         7727577357813226273,  272090792944917241]])
tensor([[[ 0.5000, -0.5000,  0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [ 0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [ 0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [ 0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [ 0.5000,  0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]],

        [[ 0.5000, -0.5000,  0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [ 0.5000, -0.5000,  0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [-0.5000, -0.5000,  0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [ 0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
         [ 0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]]])
tensor([[1397778103705324293, 1593878

> Test for 豆豆's improvement thoughts on this

In [44]:
trace_label_binary_embedding_max_length = 8
hidden_size = 11

trace_label_binary_embedding = torch.zeros((2 * trace_label_binary_embedding_max_length, hidden_size), dtype=torch.float32)
for idx in range(0, trace_label_binary_embedding_max_length):
    trace_label_binary_embedding[2 * idx, idx] = 0
    trace_label_binary_embedding[2 * idx + 1, idx] = int(2 ** (idx))

print(trace_label_binary_embedding)

def _calculate_trace_label_embeddings_binary_learnable(trace_labels: 'torch.Tensor') -> 'torch.Tensor':
    d_model = hidden_size

    # this assumes all input number < 2**bits_max-1
    bits_max = 64
    
    bsz, seq_len = trace_labels.size()
    trace_labels = trace_labels.to(dtype=torch.int64, device=trace_labels.device)
    ones = torch.ones((bits_max,), dtype=torch.int64, device=trace_labels.device)

    # (bits_max, )
    masks = torch.bitwise_left_shift(ones, torch.arange(bits_max, device=trace_labels.device))

    # shape (bsz, seq_len, bits_max)
    # bits_max dim is from LSB first
    binary_form = (torch.bitwise_and(trace_labels.view(bsz, seq_len, 1), masks.view(1, 1, bits_max)) > 0).int()

    # check if some trace_label have gone beyond the limit
    d_emb_max_len = trace_label_binary_embedding_max_length

    # (bsz, seq_len)
    is_beyond_limit_mask = (
        torch.sum(binary_form[:, :, d_emb_max_len:], dim=2) > 0
    )

    # shape (bsz, seq_len, d_emb_max_len)
    binary_form_masked = torch.where(
        # (bsz, seq_len, 1)
        is_beyond_limit_mask.unsqueeze(-1),
        # (1, 1, 1)
        torch.ones((1, 1, 1), dtype=torch.int64, device=trace_labels.device),
        # (bsz, seq_len, d_emb_max_len)
        binary_form[:, :, :d_emb_max_len]
    )
    
    # Manually unwraps the first loop
    assert(trace_label_binary_embedding_max_length >= 1)
    # (bsz, seq_len, d_model = d_model)
    trace_embeds = torch.where(
        # (bsz, seq_len) -> (bsz, seq_len, 1)
        (binary_form_masked[:, :, 0] == 1).view(bsz, seq_len, 1),
        # The embedding of 1, (d_model,) -> (1, 1, d_model)
        trace_label_binary_embedding[1].view(1, 1, d_model),
        # The embedding of 0, (d_model,) -> (1, 1, d_model)
        trace_label_binary_embedding[0].view(1, 1, d_model)
    )

    for idx in range(1, d_emb_max_len):
        trace_embeds += (
            torch.where(
                (binary_form_masked[:, :, idx] == 1).view(bsz, seq_len, 1),
                trace_label_binary_embedding[2 * idx + 1].view(1, 1, d_model),
                trace_label_binary_embedding[2 * idx].view(1, 1, d_model)
            )
        )
    return trace_embeds

# test clamping
trace_labels = torch.as_tensor([[0, 1, 2, 3, 255, 256, 512, 100000, 1 << 62]], dtype=torch.int64)
_calculate_trace_label_embeddings_binary_learnable(trace_labels)

tensor([[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   2.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   4.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   8.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,  16.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,  32.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0

tensor([[[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [  1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [  0.,   2.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [  1.,   2.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [  1.,   2.,   4.,   8.,  16.,  32.,  64., 128.,   0.,   0.,   0.],
         [  1.,   2.,   4.,   8.,  16.,  32.,  64., 128.,   0.,   0.,   0.],
         [  1.,   2.,   4.,   8.,  16.,  32.,  64., 128.,   0.,   0.,   0.],
         [  1.,   2.,   4.,   8.,  16.,  32.,  64., 128.,   0.,   0.,   0.],
         [  1.,   2.,   4.,   8.,  16.,  32.,  64., 128.,   0.,   0.,   0.]]])