# Transformers XL, Masked Word Completion

In [1]:
import torch
from pytorch_transformers import TransfoXLTokenizer, TransfoXLLMHeadModel

## Set Device

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load Pre-trained Transformer XL Model Tokenizer (Vocabulary)

In [3]:
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')

## Encode Text Inputs

In [4]:
text = 'Hello, my dog is cute'
encoded_text = tokenizer.encode(text)
tensor_text = torch.tensor([encoded_text])

In [5]:
tensor_text = tensor_text.to(device) 

## Load Pre-trained [Transformer XL](https://huggingface.co/transformers/model_doc/transformerxl.html) Model Weights

In [6]:
transformerxl = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
transformerxl.to(device)

TransfoXLLMHeadModel(
  (transformer): TransfoXLModel(
    (word_emb): AdaptiveEmbedding(
      (emb_layers): ModuleList(
        (0): Embedding(20000, 1024)
        (1): Embedding(20000, 256)
        (2): Embedding(160000, 64)
        (3): Embedding(67735, 16)
      )
      (emb_projs): ParameterList(
          (0): Parameter containing: [torch.FloatTensor of size 1024x1024]
          (1): Parameter containing: [torch.FloatTensor of size 1024x256]
          (2): Parameter containing: [torch.FloatTensor of size 1024x64]
          (3): Parameter containing: [torch.FloatTensor of size 1024x16]
      )
    )
    (drop): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0): RelPartialLearnableDecoderLayer(
        (dec_attn): RelPartialLearnableMultiHeadAttn(
          (qkv_net): Linear(in_features=1024, out_features=3072, bias=False)
          (drop): Dropout(p=0.1, inplace=False)
          (dropatt): Dropout(p=0.0, inplace=False)
          (o_net): Linear(in_features=1024, o

## Evaluate Transformer XL Model

In [7]:
transformerxl.eval()

with torch.no_grad():
    outputs = transformerxl(tensor_text)
    predictions, mems = outputs[:2]

In [8]:
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(encoded_text + [predicted_index])

In [9]:
print('Prediction is:', predicted_text)

Prediction is: <unk> my dog is cute.


---