In [1]:
import torch

In [2]:
t = torch.randn(30,30)
print(t[1].sum())
print(t[1:2,:].sum())
print(t[1].shape)
print(t[1:2,:].shape)

tensor(2.4750)
tensor(2.4750)
torch.Size([30])
torch.Size([1, 30])


In [3]:
import torch

# Inputs
x = torch.randn(12,12,12)
b = torch.randn(12,12)
p = torch.rand((12,12))
B,T,C = x.shape

# Check what indices to keep
keep_ix = b > 0.5 # (B,T), bool
counts = keep_ix.sum(dim=1) # (B,) => How many items to keep per row, max will be padding
Tds = int(counts.max().item()) # Padding index

# Bringing keep_ix==True to the front to discard all other ix
perm = torch.argsort(keep_ix.to(torch.int8), dim=1, descending=True, stable=True)  # (B, T), with index we want to keep in front
sel = perm[:, :Tds]                       # (B, Tds), padded

# Dechunk mask for position & gather chunks
mask_ds = torch.arange(Tds, device=x.device).unsqueeze(0) < counts.unsqueeze(1) # (B, Tds), masking range above count
x_chunks = torch.take_along_dim(x, sel.unsqueeze(-1).expand(-1, -1, C), dim=1) # (B, Tds, E)
P_chunks = torch.take_along_dim(p, sel, dim=1) # (B, Tds)

# Zero-Out padding
x_chunks = x_chunks.masked_fill(~mask_ds.unsqueeze(-1), 0)
P_chunks = P_chunks.masked_fill(~mask_ds, 0)
gather_idx = sel.masked_fill(~mask_ds, 0)     

# State for dechunking
state = {
    "p_full": p, # (B, T)
    "b_full": b, # (B, T)
    "gather_idx": gather_idx, # (B, Tds)
    "mask_ds": mask_ds, # (B, Tds)
}


In [4]:
att_mask=None
N = 3
B,T = p.shape
att_mask = att_mask if att_mask else torch.ones((B,T))
L = att_mask.sum(dim=1).clamp_min(1).float() # (B,)
F = ((b > 0.5).float() * att_mask.float()).sum(dim=1) / L # ""
G = (p * att_mask.float()).sum(dim=1) / L   # ""
ratio = (N/(N-1)) * (((N - 1.0) * F * G) + ((1.0 - F) * (1.0 - G))) # ""
print(ratio.mean())

tensor(0.9839)


In [5]:
import torch
from hnet import DynChunking, HNetConfig

In [6]:
d = DynChunking(512,0.5,HNetConfig)
x = torch.randn((512,512,512))
p,bt=d.route(x)
print(f"bt: {bt.shape}, p: {p.shape}")
x_chunks, P_chunks, mask_ds, state = d.downsample(x,p,bt)
print(f"xchunks: {x_chunks.shape}, pchunks: {P_chunks.shape}, mask_ds: {mask_ds.shape}")

bt: torch.Size([512, 512]), p: torch.Size([512, 512])
xchunks: torch.Size([512, 289, 512]), pchunks: torch.Size([512, 289]), mask_ds: torch.Size([512, 289])


In [7]:
# EMA
z_t = x_chunks
pt = P_chunks 
zt_hat = z_t @ pt
zt_hat_inv = torch.roll(z_t, shifts=1, dims=1) @ (pt.new_ones(pt.shape) - pt) 
(zt_hat+zt_hat_inv).shape

torch.Size([512, 289, 289])

In [8]:
a = torch.arange(10)
b = torch.tensor([1])
torch.concat((b,a))

tensor([1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [9]:
# Upsampling (6)
ct = (p ** bt) * (p.new_zeros(p.shape) - p) ** bt
ct.shape

torch.Size([512, 512])

In [13]:
B,T,C = x.shape
pt, bt, gather_idx, mask_ds = state["p_full"], state["b_full"], state["gather_idx"], state["mask_ds"]
bt_int = (bt > 0.5).long() # (B,T)
idx = torch.cumsum(bt_int, dim=1) - 1 # (B,T), determine what index to expand
idx = idx.clamp(min=0, max=Tds - 1) # (B,T), cut to padding length
i_arange = torch.arange(B, device=x_chunks.device).unsqueeze(1).expand(B, T) # 
z_up = x_chunks[i_arange, idx] 

In [17]:
print(i_arange.shape)

torch.Size([512, 512])


In [None]:
a = torch.randn(32,32,32)
b = torch.randn(32,32,32)
c = torch.bmm(a,b)
d = a @ b
torch.all(c == d)

tensor(True)

In [None]:
S[0]

tensor([ -1.2715,  -3.2769,  -4.8626,  -8.4185,  -9.8799, -10.1357, -10.3183,
        -11.7566])

In [None]:
(-1.2715+3.2769)

2.0054

In [None]:
delta[0]

tensor([[  0.0000,   2.0053,   3.5911,   7.1469,   8.6083,   8.8642,   9.0468,
          10.4851],
        [ -2.0053,   0.0000,   1.5857,   5.1416,   6.6030,   6.8589,   7.0414,
           8.4798],
        [ -3.5911,  -1.5857,   0.0000,   3.5559,   5.0173,   5.2732,   5.4557,
           6.8940],
        [ -7.1469,  -5.1416,  -3.5559,   0.0000,   1.4614,   1.7173,   1.8998,
           3.3382],
        [ -8.6083,  -6.6030,  -5.0173,  -1.4614,   0.0000,   0.2559,   0.4384,
           1.8768],
        [ -8.8642,  -6.8589,  -5.2732,  -1.7173,  -0.2559,   0.0000,   0.1826,
           1.6209],
        [ -9.0468,  -7.0414,  -5.4557,  -1.8998,  -0.4384,  -0.1826,   0.0000,
           1.4383],
        [-10.4851,  -8.4798,  -6.8940,  -3.3382,  -1.8768,  -1.6209,  -1.4383,
           0.0000]])

In [23]:
zt = torch.randn(12,64,128)
ct = torch.randn(12,64)
(ct.unsqueeze(-1)*zt).shape

torch.Size([12, 64, 128])