In [None]:
!pip install -r requirements.txt

Collecting vector-quantize-pytorch>=1.6.28 (from -r requirements.txt (line 2))
  Downloading vector_quantize_pytorch-1.18.5-py3-none-any.whl.metadata (28 kB)
Collecting einx>=0.3.0 (from vector-quantize-pytorch>=1.6.28->-r requirements.txt (line 2))
  Downloading einx-0.3.0-py3-none-any.whl.metadata (6.9 kB)
Downloading vector_quantize_pytorch-1.18.5-py3-none-any.whl (41 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.2/41.2 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading einx-0.3.0-py3-none-any.whl (102 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einx, vector-quantize-pytorch
Successfully installed einx-0.3.0 vector-quantize-pytorch-1.18.5


In [None]:
import torch
from discrete_kv_bottleneck import DiscreteKeyValueBottleneck, DiscreteKeyValueBottleneckConfig

In [None]:
config = DiscreteKeyValueBottleneckConfig(
    dim=128, # input dimension
    num_memories=256, # output dimension - or dimension of each memories for all heads (defaults to same as input)
    num_memory_codebooks=4,
    average_pool_memories=True
)

kvbottleneck = DiscreteKeyValueBottleneck(config)
input_tensor = torch.randn(2, 64, 128)  # Example input
memories = kvbottleneck(input_tensor)


In [None]:
!pip install vit-pytorch

Collecting vit-pytorch
  Downloading vit_pytorch-1.8.5-py3-none-any.whl.metadata (68 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/68.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.5/68.5 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Downloading vit_pytorch-1.8.5-py3-none-any.whl (133 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.9/133.9 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: vit-pytorch
Successfully installed vit-pytorch-1.8.5


In [None]:
from vit_pytorch import SimpleViT
from vit_pytorch.extractor import Extractor
import torch
from discrete_kv_bottleneck import DiscreteKeyValueBottleneck, DiscreteKeyValueBottleneckConfig

# Initialize SimpleViT
vit = SimpleViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=512,
    depth=6,
    heads=16,
    mlp_dim=2048
)

# Train vit, or load pretrained weights
# Assuming vit is pretrained, extract only embeddings
vit = Extractor(vit, return_embeddings_only=True)

# Configure the DiscreteKeyValueBottleneck
config = DiscreteKeyValueBottleneckConfig(
    encoder=vit,         # pass the frozen encoder into the bottleneck
    dim=512,             # input dimension
    num_memories=256,    # number of memories
    dim_memory=2048,     # dimension of the output memories
    average_pool_memories=True
)

# Initialize the bottleneck module
enc_with_bottleneck = DiscreteKeyValueBottleneck(config, decay=0.9)

# Example input images
images = torch.randn(1, 3, 256, 256)  # input to encoder

# Process the images through the encoder with bottleneck
memories = enc_with_bottleneck(images)  # Output: (1, 64, 2048)

print(memories.shape)  # Should print: torch.Size([1, 64, 2048])

torch.Size([1, 64, 2048])


Understanding the output:
- Input Images: we have passed an input tensor of shape (1, 3, 256, 256), representing a batch of one image with three color channels (RGB) and a size of 256x256 pixels.
- SimpleViT Embeddings: The SimpleViT processes the image into patches and returns embeddings of size 512. By using the Extractor, the ViT acts as a frozen encoder, providing a feature map of size (1, 64, 512), where 64 is the number of patches, and 512 is the embedding dimension for each patch.
- Discrete Key-Value Bottleneck Output: The embeddings from ViT are then fed into the bottleneck, which processes them and produces output memories with shape (1, 64, 2048). Here, 64 corresponds to the number of input patches, and 2048 represents the processed, compressed memory dimension.

- Batch Size: 1 — One input image processed.
- Number of Patches: 64 — The image was split into 64 patches by the Vision Transformer (ViT), based on the patch size.
- Memory Dimension: 2048 — Each patch is processed into a memory vector of dimension 2048 by the Discrete Key-Value Bottleneck.

### Interpretation: Analysis of Results
Objective:
The aim was to integrate a discrete key-value bottleneck mechanism with a Vision Transformer (ViT) model, leveraging the bottleneck to transform high-dimensional inputs into compressed memory representations. This can help achieve efficient information processing and storage.

Key Observations:
Effective Feature Compression:

The input image is processed by SimpleViT, which splits it into 64 patches, and each patch is represented as a 512-dimensional embedding.
After passing through the Discrete Key-Value Bottleneck, the embeddings are transformed into discrete memory representations of size 2048.
This increase in dimension might initially seem counterintuitive, but it allows the bottleneck to encode more complex information in a structured way, potentially enabling better feature representation and retrieval.
Average Pooling:

Average pooling is applied across multiple discrete memories, which helps aggregate information and reduce noise. This makes the final output more robust and reduces variability across different patches.
Memory and Efficiency:

The bottleneck mechanism compresses and quantizes input features, which can be beneficial for downstream tasks, such as classification or image generation, where discrete representations help with better generalization.

Potential Applications:
- Efficient Storage and Retrieval: By transforming continuous features into discrete memory tokens, this setup can be used in scenarios that require efficient storage and retrieval of information, such as language models or image generation tasks.
- Robust Representations: The process can also aid in creating more robust representations that are less prone to noise, improving model performance in tasks like classification.