In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

In [2]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, einsum
from einops import parse_shape, rearrange, repeat, reduce

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
from brainle.models.architectures.quantizer import QuantizerBase, MQBlock

base = QuantizerBase()
q = torch.rand(2, 8)
v = torch.rand(1000, 8)
z = base(q,v)
z['onehot'].shape

torch.Size([2, 1000])

In [4]:
block = MQBlock(
    features = 64,
    memory_size = 1000
)
block.train()
x = torch.rand(1, 100, 64) # [b, n, c]
z = block(x)

print(z.keys())
print(z['embedding'].shape, z['indices'].shape, z['onehot'].shape, z['perplexity'])

dict_keys(['embedding', 'indices', 'onehot', 'perplexity'])
torch.Size([1, 100, 64]) torch.Size([1, 100]) torch.Size([1, 100, 1000]) tensor(5.4572)


In [7]:
from brainle.models.architectures.quantizer import MultiMQ

# Image encoder output 
z = torch.rand([1, 128, 2, 2])
# Rearrange for quantizer  
z = rearrange(z, 'b c h w -> b (h w) c')

quantizer = MultiMQ(
    channels_list = [32, 32, 32, 32],
    memory_size = 1024
)

print(z.shape)
z = quantizer(z)
print(z['embedding'].shape, z['indices'].shape, z['onehot'].shape, z['perplexity'])

torch.Size([1, 4, 128])
torch.Size([1, 4, 128]) torch.Size([1, 16]) torch.Size([1, 16, 1024]) tensor([2.0000, 4.0000, 1.0000, 1.7548])


In [14]:
x =  { 'a': 1 }
x.update({ i: 2 for i in range(5)})

print(x)

{'a': 1, 0: 2, 1: 2, 2: 2, 3: 2, 4: 2}
