In [3]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import TensorDataset, DataLoader
torch.set_default_device('cpu')

import numpy as np

from transformers_simple.transformer import MultiHeadAttention, TransformerBlock, GPT

In [4]:
N = 32
input_size = 10
input_vec_size = 7
hidden_size = 5
output_size = 4
num_heads = 6
vocab_size = 8

In [5]:
net = MultiHeadAttention(input_size=input_vec_size, hidden_size=hidden_size,
                        output_size=output_size, num_heads=num_heads, block_size=input_size)

In [6]:
X = np.random.randn(N, input_size, input_vec_size)

In [7]:
o = net(torch.Tensor(X))

In [8]:
o.size()

torch.Size([32, 10, 4])

In [9]:
net = TransformerBlock(block_size=input_size, vec_size=input_vec_size, hidden_size=20,
                      attn_hidden_size=hidden_size, output_size=input_vec_size, num_heads=num_heads)

In [10]:
net(torch.Tensor(X)).size()

torch.Size([32, 10, 7])

In [11]:
net = GPT(vocab_size=vocab_size, block_size=input_size, embed_size=15, hidden_size=10,
          attn_hidden_size=hidden_size,
         output_size=output_size, num_transformer_blocks=3, num_heads=num_heads)

In [12]:
X = np.random.randint(vocab_size, size=(N,input_size)).astype(int)

In [13]:
net(torch.IntTensor(X))

tensor([[[-0.8119, -0.1798, -0.6446,  0.0348],
         [-0.8119, -0.1798, -0.6446,  0.0348],
         [ 0.2825, -0.4206,  0.0581,  0.2651],
         ...,
         [-0.7709, -0.1832, -0.6729,  0.0308],
         [-0.1002, -0.0596,  0.1096,  0.4100],
         [-0.0668,  0.0986, -0.7228,  0.0859]],

        [[ 0.2277, -0.4700,  0.1812,  0.2434],
         [-0.1163, -0.1401,  0.2037,  0.3811],
         [-0.0579,  0.0825, -0.7218,  0.0654],
         ...,
         [ 0.7911, -0.7016, -1.0886, -0.5459],
         [ 0.1840, -0.0432, -0.2201, -0.1755],
         [ 0.7872, -0.7054, -1.0912, -0.5477]],

        [[ 0.6461,  0.5759,  0.2482, -0.3920],
         [-0.0420,  0.0965, -0.7498,  0.0734],
         [-0.0550,  0.1011, -0.7449,  0.0722],
         ...,
         [ 0.5793,  0.6237,  0.2578, -0.3776],
         [ 0.2993, -0.4274, -0.0021,  0.2669],
         [ 0.5899,  0.6082,  0.2649, -0.3636]],

        ...,

        [[ 0.2277, -0.4700,  0.1812,  0.2434],
         [-0.5850, -0.3053,  0.5086, -0.1199]