In [1]:
import torch

from models.transformer import GPT, BERT, T5

%load_ext autoreload
%autoreload 2

device = 'cpu'
#device = 'cuda'

# GPT Model
* Decoder only transformer

In [2]:
gpt_model = GPT(
    vocab_size=10_000,
    features_dim=384,
    num_heads=6,
    ff_dim=384,
    num_decoder_layers=6,
    emb_dropout_prob=0.1,
    attn_dropout_prob=0.1,
    ff_dropout_prob=0.1,
    attn_use_bias=False,
    ff_use_bias=False,
    vocab_projection_bias=False,
).to(device)

### Dummy Prediction

In [3]:
batch_size = 32
seq_len = 20

# combine half 1.0 and half 0.0 pad masks
_pad_mask = torch.cat((torch.ones(batch_size, seq_len//2), torch.zeros(batch_size, seq_len//2)), dim=1)

pred = gpt_model(
    x_input=torch.randint(low=0, high=10_000, size=(batch_size, seq_len)).to(device), 
    pad_mask=_pad_mask.to(device)
)

pred.shape

torch.Size([32, 20, 10000])

In [4]:
# first batch
_pad_mask[0]

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])

# BERT Model
* Encoder only transformer

In [5]:
bert_model = BERT(
    vocab_size=10_000,
    features_dim=384,
    num_heads=6,
    ff_dim=384,
    num_encoder_layers=6,
    emb_dropout_prob=0.1,
    attn_dropout_prob=0.1,
    ff_dropout_prob=0.1,
    attn_use_bias=False,
    ff_use_bias=False,
    vocab_projection_bias=False,
).to(device)

### Dummy Prediction

In [6]:
batch_size = 32
seq_len = 512

# combine half 1.0 and half 0.0 pad masks
_pad_mask = torch.cat((torch.ones(batch_size, seq_len//2), torch.zeros(batch_size, seq_len//2)), dim=1)

pred = bert_model(
    x_input=torch.randint(low=0, high=10_000, size=(batch_size, seq_len)).to(device), 
    pad_mask=_pad_mask.to(device)
)

pred.shape

torch.Size([32, 512, 10000])

# T5 
* Full encoder-decoder transformer

In [7]:
t5_model = T5(
    vocab_size_enc=10_000,
    vocab_size_dec=10_000,
    features_dim=384, 
    num_heads=6, 
    ff_dim=384, 
    num_encoder_layers=1,
    num_decoder_layers=6,
    emb_dropout_prob=0.1,
    attn_dropout_prob=0.1,
    ff_dropout_prob=0.1,
    attn_use_bias=False,
    ff_use_bias=False,
    vocab_projection_bias=False,
).to(device)

### Dummy Prediction

In [8]:
batch_size = 32
seq_len_enc = 512
seq_len_dec = 10


# combine half 1.0 and half 0.0 pad masks
_pad_mask_enc = torch.cat((torch.ones(batch_size, seq_len_enc//2), torch.zeros(batch_size, seq_len_enc//2)), dim=1)
_pad_mask_dec = torch.cat((torch.ones(batch_size, seq_len_dec//2), torch.zeros(batch_size, seq_len_dec//2)), dim=1)

# x_input, x_cross, pad_mask=None, pad_mask_cross=None
pred_t5 = t5_model(
    x_input=torch.randint(low=0, high=10_000, size=(batch_size, seq_len_enc)).to(device), 
    x_cross=torch.randint(low=0, high=10_000, size=(batch_size, seq_len_dec)).to(device),
    pad_mask=_pad_mask_enc.to(device),
    pad_mask_cross=_pad_mask_dec.to(device)
)

pred_t5.shape

torch.Size([32, 10, 10000])