In [None]:
import numpy as np
import cv2
from matplotlib import cm
from PIL import Image

ZOOM = 1

def convert_to_colormap(arr: np.ndarray):
    T, T = arr.shape
    arr_min, arr_max = np.min(arr), np.max(arr)
    normalized = (arr - arr_min) / (arr_max - arr_min + 1e-12)
    colormapped = cm.gist_earth(normalized)
    gamma = 0.2
    colormapped = (colormapped / np.max(colormapped)) ** gamma
    im = Image.fromarray((colormapped*255).astype(np.uint8))
    arr = np.asarray(im)[:, :, :3]
    arr = cv2.resize(arr, None, fx=ZOOM, fy=ZOOM, interpolation=cv2.INTER_NEAREST)
    border = np.ones((arr.shape[0]+2, arr.shape[1]+2, arr.shape[2]), dtype=np.uint8)
    border = border * 255
    border[1:-1, 1:-1, :] = arr
    return border

: 

In [None]:
# visualize one tensor
import os
os.makedirs(f"./plots/poc/grid_sampling/", exist_ok=True)

def visualize_one_tensor(tensor, name):
    img = convert_to_colormap(tensor.numpy())
    path = f"./plots/poc/grid_sampling/{name}.png"
    cv2.imwrite(path, img)
    print('processed', path)

: 

In [None]:
# import torch

# a = (torch.rand((1, 1, 203, 128)) ).float() # > 0.5

# visualize_one_tensor(a[0][0],'a')

: 

In [None]:
import torch.nn.functional as F

def grid_sample_bf16(input, grid, mode='nearest', align_corners=False, padding_mode='zeros'):
    input_dtype = input.dtype
    op_dtype = torch.float32 if torch.get_autocast_gpu_dtype() == torch.bfloat16 else input_dtype
    if op_dtype != input_dtype:
        input = input.to(op_dtype)
        grid = grid.to(op_dtype)
    y = F.grid_sample(
        input=input,
        grid=grid,
        mode=mode,
        align_corners=align_corners,
        padding_mode='zeros',
    )
    if y.dtype != input_dtype:
        y = y.to(input_dtype)
    return y

: 

In [None]:
attention_mask = torch.ones(1, 1, 1, 203)
attention_mask[10:]= 0
zero_one_attention_mask = (attention_mask > -1).float()


def resize_from_m_to_t(x, masked_fill_value, target_width=None):
    N, H, T1, T_M = x.shape
    if target_width is not None:
        T2 = target_width
    else:
        T2 = T1
# with timer("resize"):
    # with timer("resize.grid"):
    if not False:
        token_index_x = zero_one_attention_mask.view(N, 1, T2)
        if masked_fill_value is not None:
            # token_index_x = torch.roll(token_index_x, shifts=(1,), dims=(-1)).cumsum(-1) + ((1.0 - zero_one_attention_mask) * 2).view(N, 1, T2)
            # token_index_x = (token_index_x / ((zero_one_attention_mask.sum(-1) + 2).view(N, 1, 1) + 1e-8) * 2 - 1).expand(N, T1, T2)
            mask = token_index_x
            mask_cs = mask.cumsum(-1)
            token_length = (mask_cs[:, :, -1].unsqueeze(-1) - 1) + 3 * (mask_cs[:, :, -1].unsqueeze(-1)/T_M)
            token_index_x = torch.clamp(((((mask_cs - 1) + (1 - mask) * 5000)) / (token_length + 1e-8)) * 2 - 1, -1, 1)
            token_index_x = token_index_x.expand(N, T1, T2)
        else:
            token_index_x = token_index_x.cumsum(-1)
            token_index_x = (token_index_x / ((zero_one_attention_mask.sum(-1) - 1).view(N, 1, 1) + 1e-8) * 2 - 1).expand(N, T1, T2)
    else:
        assert masked_fill_value is not None
        mask = (causal_attention_mask > -1).float()
        _N, _H, _TQ, _TK = mask.shape
        mask_cs = mask.cumsum(-1)
        token_length = (mask_cs[:, :, :, -1].unsqueeze(-1) - 1) + 3 * (_TK/T_M)
        token_index_x = torch.clamp((((mask_cs - 1) + (1 - mask) * (5000  * (_TK/T_M))) / (token_length + 1e-8)) * 2 - 1, -1, 1)
        assert _H == 1
        token_index_x = token_index_x[:,0,:,:]
    token_index_y = (
        torch.arange(T1, dtype=torch.long, device=token_index_x.device)\
            .view(1, T1, 1) / T1 * 2 - 1)\
            .expand(N, T1, T2) #type: torch.Tensor
    token_index = torch.cat([
        token_index_x.unsqueeze(-1),
        token_index_y.unsqueeze(-1)
    ], dim=-1)

# with timer("resize.sample"):
    grid_input = F.pad(F.pad(x, pad=(0, 2), value=0), pad=(0, 1), value=masked_fill_value) if masked_fill_value is not None else x
    if grid_input.dtype != x.dtype:
        grid_input = grid_input.to(x.dtype)
    if token_index.dtype != x.dtype:
        token_index = token_index.to(x.dtype)
    
    return grid_sample_bf16(
        input=grid_input,
        grid=token_index,
        mode='nearest',
        align_corners=True,
        padding_mode='border'
    )

: 

In [None]:
# b = resize_from_m_to_t(a, masked_fill_value=0)

# visualize_one_tensor(b[0][0],'b')

: 

In [None]:
203*128

: 

In [None]:
import torch

a = (torch.rand((1, 1, 203, 128)) ).float().view(1,1,203*128) # > 0.5
a_inx = torch.topk(input=a, k =25000, dim=-1)
print(a.shape)
print(a_inx[1].shape)
a.scatter_(dim=-1, index = a_inx[1], value=0)
a = a.view(1,1,203,128)

: 

In [None]:
# 203 * 128

inx = torch.tensor([3, 6, 30, 47, 67, 125]).view(1,1,1,6).expand(1, 1, 203, 6)
a_l = torch.scatter(a, dim=-1, index=inx, value=1)

visualize_one_tensor(a[0][0],'a')
visualize_one_tensor(a_l[0][0],'a_l')

: 

In [None]:
b = resize_from_m_to_t(a , masked_fill_value=0) # * (attention_mask.transpose(-1,-2)>0)
b_l = resize_from_m_to_t(a_l, masked_fill_value=0)

visualize_one_tensor(b[0][0],'b')
visualize_one_tensor(b_l[0][0],'b_l')

: 

In [None]:
print(a.shape)
print(a_l.shape)

print(b.shape)
print(b_l.shape)

: 

In [None]:
sum_a_l = a_l.sum(dim=-2)
(sum_a_l==203).sum()

: 

In [None]:
sum_b_l = b_l.sum(dim=-2)
(sum_b_l==203).sum()

: 

In [None]:
6*(203/128)

: 

In [None]:
'''
partial attention mask
실제로는 주로 -inf으로 채워지고 일부만 0 값으로 채워진 n^2 matrix에
특정 column들만 다 0으로 채워줌
그 후에 resize_from_m_to_t

resize_from_m_to_t 결과 다 0으로 채워진 column이 사라지지는 않는가? 그 두께는?
>> 지금 어떤 것은 thick하게 변했고, 어떤 것은 안 변함
>> 이거는 grid sample의 mechanism을 이해해야 넘어갈 수 있을듯...
>> 또 attention mask까지 고려된 
'''

: 

In [None]:
import torch

a = (torch.rand((1, 1, 203, 128)) ).float().view(1,1,203*128) # > 0.5
a_inx = torch.topk(input=a, k =900, dim=-1)
print(a.shape)
print(a_inx[1].shape)
a.fill_(-1000)
a.scatter_(dim=-1, index = a_inx[1], value=0)

a = a.view(1,1,203,128)

: 

In [None]:
print((a==0).sum())
a

: 

In [None]:
# 203 * 128

inx = torch.tensor([3, 6, 30, 47, 67, 125]).view(1,1,1,6).expand(1, 1, 203, 6)
a_l = torch.scatter(a, dim=-1, index=inx, value=0)

visualize_one_tensor(a[0][0],'a')
visualize_one_tensor(a_l[0][0],'a_l')

: 

In [None]:
b = resize_from_m_to_t(a , masked_fill_value=0) # * (attention_mask.transpose(-1,-2)>0)
b_l = resize_from_m_to_t(a_l, masked_fill_value=0)

visualize_one_tensor(b[0][0],'b')
visualize_one_tensor(b_l[0][0],'b_l')

: 

In [None]:
print((a_l.sum(dim=-2)==0).sum())
print((b_l.sum(dim=-2)==0).sum())

: 

In [None]:
(6*min(max(math.floor(203/128),1), 128))

: 

: 

In [None]:
'''
thickness
>>> (min(max(math.ceil(6*203/128),1), 128))
10
>>> 6*203/128
9.515625
>>> (min(6*(max(round(203/128),1)), 128))
12
>>> (min(max(round(6*203/128),1), 128))
10
'''

: 