In [1]:
%reload_ext autoreload
%autoreload 2

In [3]:
import torch
from htm_pytorch import HTMAttention

attn = HTMAttention(
    dim = 512,
    heads = 8,               # number of heads for within-memory attention
    dim_head = 64,           # dimension per head for within-memory attention
    topk_mems = 8,           # how many memory chunks to select for
    mem_chunk_size = 16,     # number of tokens in each memory chunk
    add_pos_enc = True       # whether to add positional encoding to the memories
)

queries = torch.randn(1, 1, 512)     # queries
memories = torch.randn(1, 128, 512)  # memories, of any size
mask = torch.ones(1, 128).bool()     # memory mask

attended = attn(queries, memories, mask = mask) # (1, 128, 512)

In [4]:
from htm_pytorch import HTMBlock

block = HTMBlock(
    dim = 512,
    topk_mems = 8,
    mem_chunk_size = 32,
    heads = 8
)

queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 20000, 512)
mask = torch.ones(1, 20000).bool()

out = block(queries, memories, mask = mask) # (1, 128, 512)

In [8]:
from htm_pytorch import HTMBlock

block = HTMBlock(
    dim = 384,
    topk_mems = 8,
    mem_chunk_size = 16,
    heads = 8
)

queries = torch.randn(16, 1, 384)
memories = torch.randn(16, 128, 384)
mask = torch.ones(16, 128).bool()

out = block(queries, memories, mask = mask) # (16, 1, 384)
print(out.shape)

torch.Size([16, 1, 384])


In [None]:
from htm_pytorch import HTMBlockReLU

blockrelu = HTMBlockReLU(
    dim = 512,
    topk_mems = 8,
    mem_chunk_size = 32,
    heads = 8
)

queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 20000, 512)
mask = torch.ones(1, 20000).bool()

out = blockrelu(queries, memories, mask = mask) # (1, 128, 512)

In [None]:
from htm_pytorch import HTMBlockReLU

blockrelu = HTMBlockReLU(
    dim = 384,
    topk_mems = 8,
    mem_chunk_size = 16,
    heads = 8
)

queries = torch.randn(16, 1, 384)
memories = torch.randn(16, 128, 384)
mask = torch.ones(16, 128).bool()

#out = blockrelu(queries, memories, mask = mask) # (1, 128, 512)
out = blockrelu(queries, memories) # (1, 128, 512)

In [None]:
print(out.shape)

In [6]:
import torch
from transformer_htm import HTMTransformerBlock, HTMTransformer

config = {
    "num_blocks":1, 
    "embed_dim": 384, 
    "num_heads":8, 
    "layer_norm":"pre", 
    "identity_map_reordering":True, 
    "topk_mems":8, 
    "mem_chunk_size":16
    }

htmtransformerblock = HTMTransformerBlock(
    embed_dim=config["embed_dim"],
    num_heads=config["num_heads"],  
    config=config
)

queries = torch.randn(1, 1, 384)
memories = torch.randn(1, 128, 384)
mask = torch.ones(1, 128).bool()

out, attn_weights = htmtransformerblock(queries, memories, mask)

In [11]:
import torch
from transformer_htm import HTMTransformerBlock, HTMTransformer

input_dim = 384
num_blocks = 1
num_workers = 1
mem_length = 128

# topk_mems always equal or smaller than max episode length // chunk size
config = {
    "num_blocks":num_blocks, 
    "embed_dim": input_dim, 
    "num_heads":8, 
    "layer_norm":"pre",
    "positional_encoding":"",
    "identity_map_reordering":True, 
    "topk_mems":8, 
    "mem_chunk_size":16
    }

htmtransformer = HTMTransformer(
    input_dim = config["embed_dim"],
    max_episode_steps = 256,
    config = config
)

queries = torch.randn(num_workers, input_dim) # flattened input
memories = torch.randn(num_workers, mem_length, num_blocks, input_dim)
mask = torch.ones(num_workers, mem_length).bool()
memory_indices = torch.randperm(mem_length, dtype=torch.long).repeat(num_workers, 1)

h, out_memories = htmtransformer(queries, memories, mask, memory_indices)
#h2, out_memories2 = htmtransformer(h, out_memories, mask, memory_indices)

In [12]:
print(h.shape)
print(out_memories.shape)

torch.Size([1, 384])
torch.Size([1, 1, 384])


In [None]:
import torch
from transformer import Transformer, Transformer

# topk_mems always equal or smaller than max episode length // chunk size
config = {
    "num_blocks":4, 
    "embed_dim": 512, 
    "num_heads":8, 
    "layer_norm":"pre", 
    "positional_encoding":"relative"
    }

transformer = Transformer(
    input_dim = 512,
    max_episode_steps = 256,
    config = config
)

queries = torch.randn(1, 512) # flattened input
memories = torch.randn(1, 4, 512)
mask = torch.ones(1, 256).bool()
memory_indices = torch.randperm(256, dtype=torch.long).unsqueeze(0)

h, out_memories = transformer(queries, memories, mask, memory_indices)

In [None]:
print(h.shape)
print(out_memories.shape)

In [None]:
import torch
memories = torch.randn(16, 32, 2, 384)
print(memories.shape)
print(memories[:,:,1].shape)


In [52]:
import torch
from transformer_htm import HTMTransformerBlock, HTMTransformer
from stable_baselines3.common.utils import set_random_seed
set_random_seed(1)
# topk_mems always equal or smaller than max episode length // chunk size
config = {
    "num_blocks":1, 
    "embed_dim": 384, 
    "num_heads":8, 
    "layer_norm":"pre", 
    "identity_map_reordering":True, 
    "topk_mems":8, 
    "mem_chunk_size":16,
    "positional_encoding":"learned"
    }

htmtransformer = HTMTransformer(
    input_dim = 384,
    max_episode_steps = 256,
    config = config
)

queries = torch.randn(16, 384) # flattened input
memories = torch.randn(16, 128, 1, 384)
mask = torch.ones(16, 128).bool()
memory_indices = torch.randperm(128, dtype=torch.long).repeat(16,1)

In [53]:
set_random_seed(1)
htmtransformer = HTMTransformer(
    input_dim = 384,
    max_episode_steps = 256,
    config = config
)

h, out_memories = htmtransformer(queries, memories, mask, memory_indices)

print(h.shape)
print(out_memories.shape)
print(h)

torch.Size([16, 384])
torch.Size([16, 1, 384])
tensor([[-0.6750,  0.4172,  0.0656,  ..., -0.4660,  1.6029, -0.4685],
        [ 1.7512,  0.1096, -0.2575,  ..., -0.6927,  0.2866, -0.5759],
        [ 0.5399,  0.5540, -0.5765,  ..., -0.4698,  0.1769, -0.5949],
        ...,
        [-0.6666, -0.1259,  1.2321,  ..., -0.6991,  1.1065, -0.5994],
        [-0.6462, -0.2878,  0.9402,  ...,  1.9376, -0.6890, -0.5007],
        [-0.6855,  2.9800, -0.5774,  ..., -0.6631, -0.6974, -0.5417]],
       grad_fn=<SqueezeBackward0>)


In [51]:
set_random_seed(1)
htmtransformer = HTMTransformer(
    input_dim = 384,
    max_episode_steps = 256,
    config = config
)

h, out_memories = htmtransformer(queries, memories, mask, memory_indices)

print(h.shape)
print(out_memories.shape)
print(h)

torch.Size([16, 384])
torch.Size([16, 1, 384])
tensor([[-0.6750,  0.4172,  0.0656,  ..., -0.4660,  1.6029, -0.4685],
        [ 1.7512,  0.1096, -0.2575,  ..., -0.6927,  0.2866, -0.5759],
        [ 0.5399,  0.5540, -0.5765,  ..., -0.4698,  0.1769, -0.5949],
        ...,
        [-0.6666, -0.1259,  1.2321,  ..., -0.6991,  1.1065, -0.5994],
        [-0.6462, -0.2878,  0.9402,  ...,  1.9376, -0.6890, -0.5007],
        [-0.6855,  2.9800, -0.5774,  ..., -0.6631, -0.6974, -0.5417]],
       grad_fn=<SqueezeBackward0>)
