In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
from htm_pytorch import HTMAttention

attn = HTMAttention(
    dim = 512,
    heads = 8,               # number of heads for within-memory attention
    dim_head = 64,           # dimension per head for within-memory attention
    topk_mems = 8,           # how many memory chunks to select for
    mem_chunk_size = 32,     # number of tokens in each memory chunk
    add_pos_enc = True       # whether to add positional encoding to the memories
)

queries = torch.randn(1, 128, 512)     # queries
memories = torch.randn(1, 20000, 512)  # memories, of any size
mask = torch.ones(1, 20000).bool()     # memory mask

attended = attn(queries, memories, mask = mask) # (1, 128, 512)

In [None]:
from htm_pytorch import HTMBlock

block = HTMBlock(
    dim = 512,
    topk_mems = 8,
    mem_chunk_size = 32,
    heads = 8
)

queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 20000, 512)
mask = torch.ones(1, 20000).bool()

out = block(queries, memories, mask = mask) # (1, 128, 512)

In [None]:
from htm_pytorch import HTMBlockReLU

blockrelu = HTMBlockReLU(
    dim = 512,
    topk_mems = 8,
    mem_chunk_size = 32,
    heads = 8
)

queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 20000, 512)
mask = torch.ones(1, 20000).bool()

out = blockrelu(queries, memories, mask = mask) # (1, 128, 512)

In [None]:
from htm_pytorch import HTMBlockReLU

blockrelu = HTMBlockReLU(
    dim = 512,
    topk_mems = 8,
    mem_chunk_size = 16,
    heads = 8
)

queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 128, 512)
mask = torch.ones(1, 128).bool()

out = blockrelu(queries, memories, mask = mask) # (1, 128, 512)

In [None]:
import torch
from transformer_htm import HTMTransformerBlock, HTMTransformer

config = {
    "num_blocks":4, 
    "embed_dim": 512, 
    "num_heads":8, 
    "layer_norm":"pre", 
    "identity_map_reordering":True, 
    "topk_mems":8, 
    "mem_chunk_size":16
    }

htmtransformerblock = HTMTransformerBlock(
    embed_dim=config["embed_dim"],
    num_heads=config["num_heads"],  
    config=config
)

config = {
    "num_blocks":4, 
    "embed_dim": 512, 
    "num_heads":8, 
    "layer_norm":"pre", 
    "identity_map_reordering":False, 
    "topk_mems":8, 
    "mem_chunk_size":16
    }

htmtransformerblock_norelu = HTMTransformerBlock(
    embed_dim=config["embed_dim"],
    num_heads=config["num_heads"],  
    config=config
)

queries = torch.randn(1, 128, 512)
values = queries
keys = queries
mask = torch.ones(1, 128).bool()

out, attn_weights = htmtransformerblock(values, keys, queries, mask)
out2, attn_weights2 = htmtransformerblock_norelu(values, keys, queries, mask)

In [62]:
import torch
from transformer_htm import HTMTransformerBlock, HTMTransformer

config = {
    "num_blocks":4, 
    "embed_dim": 512, 
    "num_heads":8, 
    "layer_norm":"pre", 
    "identity_map_reordering":True, 
    "topk_mems":8, 
    "mem_chunk_size":16,
    "positional_encoding":"relative"
    }

htmtransformer = HTMTransformer(
    input_dim = 512,
    max_episode_steps = 128,
    config = config
)

queries = torch.randn(1, 512) # flattened input
memories = torch.randn(1, 4, 512)
mask = torch.ones(1, 128).bool()
memory_indices = torch.randperm(128, dtype=torch.long).unsqueeze(0)

h, out_memories = htmtransformer(queries, memories, mask, memory_indices)

In [63]:
print(h.shape)
print(out_memories.shape)

torch.Size([1, 512])
torch.Size([1, 4, 512])


In [3]:
print(h)
print(out_memories)

tensor([[-7.6566e-01,  6.4929e-02, -5.2239e-01, -6.3507e-01, -9.6712e-02,
         -1.9467e-01, -7.2196e-03,  2.4872e+00, -9.9464e-01,  4.6380e-01,
         -7.4534e-01, -6.4544e-01, -4.4616e-01, -2.8897e-01, -7.3916e-01,
          1.1044e+00,  1.5525e+00, -3.1247e-01, -9.2064e-01,  1.0516e+00,
         -2.3508e-01,  4.5104e-01, -1.0103e+00, -2.1380e-04, -3.5806e-01,
          3.0700e-01, -8.2619e-01, -7.1348e-01,  2.1025e+00,  2.9316e-01,
         -1.1990e-01,  7.6287e-01, -6.8626e-01,  7.0067e-01, -9.9076e-01,
         -1.3339e+00, -5.4188e-01,  5.0785e-01, -7.9272e-01,  1.7046e+00,
         -4.2457e-01, -3.9765e-01, -7.8544e-01,  5.2377e-01, -1.1061e+00,
         -3.8341e-01, -3.4125e-01, -3.9336e-01,  1.0108e+00,  1.4276e-01,
          6.9876e-01,  4.7560e-01, -3.0044e-01,  1.7225e+00,  5.7128e-01,
         -2.0701e-01,  3.3593e-01, -4.8509e-01, -2.6782e-01,  4.6120e-01,
          1.2577e+00, -3.3717e-01, -7.5854e-01, -6.9199e-02,  1.1319e+00,
          1.7431e-01, -6.6581e-01,  2.

In [56]:
import torch
from torch import nn
from torch.nn import functional as F

in_features_next_layer = 576
memory_layer_size = 256
obs = torch.randn(1, 3, 56, 56) 

conv1 = nn.Conv2d(obs.shape[1], 32, 8, 4,)
conv2 = nn.Conv2d(32, 64, 4, 2, 0)
conv3 = nn.Conv2d(64, 64, 3, 1, 0)
lin_hidden = nn.Linear(in_features_next_layer, memory_layer_size)

h = obs
batch_size = h.shape[0]
# Propagate input through the visual encoder
h = F.relu(conv1(h))
h = F.relu(conv2(h))
h = F.relu(conv3(h))
# Flatten the output of the convolutional layers
h = h.reshape((batch_size, -1))

h2 = F.relu(lin_hidden(h))


In [57]:
print(h.shape)
print(h2.shape)

torch.Size([1, 576])
torch.Size([1, 256])


In [17]:
obs = torch.randn(1, 384) 
print(obs.shape[0])

1
