In [4]:
import torch 
import torch.nn as nn 


In [5]:
coordinate_cache = {}

In [6]:
def _add_coordinate_encoding( x):
    b, c, t = x.shape 
    cache_key = f"{t}_{x.device}"
    if cache_key in coordinate_cache:
        coords_vec = coordinate_cache[cache_key]
    else:
        coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1)
        coordinate_cache[cache_key] = coords_vec

    expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1)
    x_with_coords = torch.cat((x, expanded_coords), dim=1)  
    return x_with_coords
    

In [7]:
x = torch.randn(256, 3, 10)
x_with_coords = _add_coordinate_encoding(x)
print(f"Input shape: {x.shape}")
print(f"With Coordinate Encoding shape: {x_with_coords.shape}")

print(f"Coordinate Cache Size: {len(coordinate_cache)}")
print(f"Coordinate Cache Keys: {list(coordinate_cache.keys())}")

print(f"coordinate cache: {coordinate_cache['10_cpu'][0, ]}")

print(f"x_with_coords: {x_with_coords[0, 3, :]}")  # Print the first channel of the first sample

Input shape: torch.Size([256, 3, 10])
With Coordinate Encoding shape: torch.Size([256, 4, 10])
Coordinate Cache Size: 1
Coordinate Cache Keys: ['10_cpu']
coordinate cache: tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000])
x_with_coords: tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000])


In [11]:
coordinate_cache = {} 
def _add_coordinate_encoding( x):
    b, c, t = x.shape 
    cache_key = f"{b}_{t}_{x.device}"
    if cache_key in coordinate_cache:
        expanded_coords = coordinate_cache[cache_key]
    else:
        coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1)
        expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1)
        coordinate_cache[cache_key] = expanded_coords


    x_with_coords = torch.cat((x, expanded_coords), dim=1)  
    return x_with_coords
    

In [14]:
x = torch.randn(64, 3, 10)
x_with_coords = _add_coordinate_encoding(x)
print(f"Input shape: {x.shape}")
print(f"With Coordinate Encoding shape: {x_with_coords.shape}")

print(f"Coordinate Cache Size: {len(coordinate_cache)}")
print(f"Coordinate Cache Keys: {list(coordinate_cache.keys())}")

print(f"coordinate cache: {coordinate_cache['64_10_cpu'][0, ]}")

print(f"x_with_coords: {x_with_coords[0, 3, :]}")  # Print the first channel of the first sample

Input shape: torch.Size([64, 3, 10])
With Coordinate Encoding shape: torch.Size([64, 4, 10])
Coordinate Cache Size: 2
Coordinate Cache Keys: ['256_10_cpu', '64_10_cpu']
coordinate cache: tensor([[-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
          0.7778,  1.0000]])
x_with_coords: tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000])


## ConvNN 

In [None]:
"""
ConvNN
Total parameters: 185,150
Trainable parameters: 185,150

Input shape: torch.Size([64, 197, 64])
After Permute: torch.Size([64, 64, 197])
After Split_head: torch.Size([64, 4, 197, 16])
After Batch_Combine: torch.Size([256, 16, 197]) ### [256, 17, 197] added coordinate encoding
After Conv1d: torch.Size([256, 16, 197]) ###    [256, 17, 197] added coordinate encoding
After batch_split: torch.Size([64, 4, 197, 16]) ### [64, 4, 197, 17] added coordinate encoding
After Combine_Heads: torch.Size([64, 64, 197]) ### [64, 68, 197] added coordinate encoding

Output shape: torch.Size([64, 100])
"""

''

## ConvNNAttention

In [None]:
"""
ConvNNAttention
Total parameters: 83,716
Trainable parameters: 83,716

Input shape: torch.Size([64, 197, 64])
After Split_head: torch.Size([64, 4, 197, 16])
After Batch_Combine: torch.Size([256, 16, 197])
After Conv1d: torch.Size([256, 16, 197])
After permute: torch.Size([256, 197, 16])
After Batch_Split: torch.Size([64, 4, 197, 16])
After Combine_Heads: torch.Size([64, 197, 64])

Output shape: torch.Size([64, 100])

"""

╰─$ python -u "/Users/mingikang/Developer/Convolutional-Nearest-Neighbor/vit.py"
Regular Attention
Total parameters: 70,228
Trainable parameters: 70,228
Output shape: torch.Size([64, 100])

ConvNN
Total parameters: 218,546
Trainable parameters: 218,546
Output shape: torch.Size([64, 100])

ConvNNAttention
Total parameters: 71,930
Trainable parameters: 71,930
Output shape: torch.Size([64, 100])

╰─$ python -u "/Users/mingikang/Developer/Convolutional-Nearest-Neighbor/vit.py"
Regular Attention
Total parameters: 70,228
Trainable parameters: 70,228
Output shape: torch.Size([64, 100])

ConvNN
Total parameters: 218,295
Trainable parameters: 218,295
Output shape: torch.Size([64, 100])

ConvNNAttention
Total parameters: 71,679
Trainable parameters: 71,679
Output shape: torch.Size([64, 100])