This notebook is mostly useless, but has some EncoderDecoder usage examples

In [1]:
from pprint import pprint

import torch
import torch.nn as nn
import transformers

In [2]:
ENCODER_NAME = 'distilbert-base-uncased'
OUT_VOCAB_SIZE = 100
HIDDEN = 768

In [3]:
tokenizer = transformers.AutoTokenizer.from_pretrained(ENCODER_NAME, use_fast=True)
encoder = transformers.AutoModel.from_pretrained(ENCODER_NAME)

In [4]:
encoder.embeddings

Embeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [5]:
x = tokenizer.encode_plus('this is base text', return_tensors='pt')
pprint(x)

y = encoder(**x)[0]  # transformer returns a tuple of something (depends on the model), the first element is always hidden states of the last layer
y.shape

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1]]),
 'input_ids': tensor([[ 101, 2023, 2003, 2918, 3793,  102]])}


torch.Size([1, 6, 768])

In [6]:
# BERTConfig is a generic transformer and is only decoder Transformers support by now
decoder_config = transformers.BertConfig(
    vocab_size=OUT_VOCAB_SIZE,
    hidden_size=HIDDEN,
    is_decoder=True,  # adds cross-attention modules and enables causal masking
)

# 'MaskedLM' only means an additional projection from hidden to vocab, but does not affect causal masking
decoder = transformers.BertForMaskedLM(decoder_config)

In [7]:
x = torch.randint(0, OUT_VOCAB_SIZE, (1, 5))
pprint(x)

y = decoder(x)

pprint(y[0].shape)

tensor([[10, 38, 78, 14, 42]])
torch.Size([1, 5, 100])


In [8]:
seq2seq = transformers.EncoderDecoderModel(
    encoder=encoder,
    decoder=decoder
)

In [12]:
x_enc = torch.randint(0, tokenizer.vocab_size, size=(3, 7))
x_dec = torch.randint(0, OUT_VOCAB_SIZE, size=(3, 5))

y_ids, enc_hidden = seq2seq(input_ids=x_enc, decoder_input_ids=x_dec)

assert x_dec.shape == y_ids.shape[:2]
assert x_enc.shape == enc_hidden.shape[:2]
assert y_ids.shape[2] == OUT_VOCAB_SIZE
assert enc_hidden.shape[2] == encoder.config.hidden_size