In [1]:
import torch
import torch.nn as nn
from reformer_pytorch import ReformerLM
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [2]:
NUM_CLASSES = 2
NUM_ENCODER_LAYERS = 3
NUM_HEADS = 4
EMEBED_DIM = 256
TRAIN_EPOCHS = 6
seq_length = 256
tokenizer = get_tokenizer("basic_english")


def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)


# Reload the training iterator for building vocab
train_iter, _ = IMDB(split=("train", "test"))
vocab = build_vocab_from_iterator(yield_tokens(
    train_iter), specials=["<unk>", "<pad>"])
vocab.set_default_index(
    vocab["<unk>"]
)  # Set default index for out-of-vocabulary tokens
NUM_TOKENS = len(vocab)

In [3]:
def text_pipeline(x):
    return vocab(tokenizer(x))

## Transformer


In [4]:
class EncoderOnlyTransformer(nn.Module):
    def __init__(self, d_model=256, nhead=4, num_encoder_layers=1, dim_feedforward=256):
        super(EncoderOnlyTransformer, self).__init__()
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                batch_first=True,
            ),
            num_layers=num_encoder_layers,
        )
        self.dropout = nn.Dropout(0.2)
        self.embedding = nn.Embedding(NUM_TOKENS, d_model)
        self.classifier = nn.Linear(d_model, 1)

    def forward(self, src):
        src = self.embedding(src)
        encoded_output = self.transformer_encoder(src)
        # encoded_output = encoded_output[:, -1, :]
        encoded_output = self.dropout(encoded_output)
        encoded_output = encoded_output.max(dim=1)[0]
        output = self.classifier(encoded_output)
        return output

In [11]:
model_path = "transformer_model_weights.pth"
transformer = EncoderOnlyTransformer(
    d_model=EMEBED_DIM,
    nhead=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    dim_feedforward=EMEBED_DIM,
).to(device)
transformer.load_state_dict(torch.load(model_path))
transformer.eval()

EncoderOnlyTransformer(
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (embedding): Embedding(100684, 256)
  (classifier): Linear(in_features=256, out_features=1, bias=True)
)

## Reformer


In [6]:
class ReformerForClassification(nn.Module):
    def __init__(self, num_tokens, emb_dim, dim, depth, heads, max_seq_len):
        super(ReformerForClassification, self).__init__()
        self.encoder = ReformerLM(
            num_tokens=num_tokens,
            emb_dim=emb_dim,
            dim=dim,
            depth=depth,
            heads=heads,
            max_seq_len=max_seq_len,
            fixed_position_emb=True,
            return_embeddings=True,
        )

        # Classification head
        self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Linear(dim, 1)

    def forward(self, x):
        # Encoder processes the input
        encoded_output = self.encoder(x)
        encoded_output = self.dropout(encoded_output)
        encoded_output = encoded_output.max(dim=1)[0]
        output = self.classifier(encoded_output)

        return output

In [9]:
model_path = "reformer_model_weights.pth"
reformer = ReformerForClassification(
    num_tokens=NUM_TOKENS,
    emb_dim=EMEBED_DIM,
    dim=EMEBED_DIM,
    depth=NUM_ENCODER_LAYERS,
    heads=NUM_HEADS,
    max_seq_len=seq_length,
).to(device)
reformer.load_state_dict(torch.load(model_path))
reformer.eval()

ReformerForClassification(
  (encoder): ReformerLM(
    (token_emb): Embedding(100684, 256)
    (to_model_dim): Identity()
    (pos_emb): FixedPositionalEmbedding()
    (layer_pos_emb): Always()
    (reformer): Reformer(
      (layers): ReversibleSequence(
        (blocks): ModuleList(
          (0-2): 3 x ReversibleBlock(
            (f): Deterministic(
              (net): PreNorm(
                (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
                (fn): LSHSelfAttention(
                  (toqk): Linear(in_features=256, out_features=256, bias=False)
                  (tov): Linear(in_features=256, out_features=256, bias=False)
                  (to_out): Linear(in_features=256, out_features=256, bias=True)
                  (lsh_attn): LSHAttention(
                    (dropout): Dropout(p=0.0, inplace=False)
                    (dropout_for_hash): Dropout(p=0.0, inplace=False)
                  )
                  (full_attn): FullQKAttention(
            

## Gradio Inference


In [12]:
import gradio as gr

In [15]:
def inference(text, model):
    if model == "Reformer":
        model = reformer
    else:
        model = transformer
    tokens = torch.tensor(text_pipeline(text), dtype=torch.int64).to(device)
    output = model(tokens.unsqueeze(0))
    output = torch.sigmoid(output).item()
    y_pred = "Positive 🥳 🤩 👍" if output > 0.5 else "Negative 🤮 🙅 👎"
    return y_pred

In [16]:
models = ["Reformer", "Transformer"]
demo = gr.Interface(
    fn=inference,
    inputs=[
        "text",
        gr.CheckboxGroup(
            models,
            label="Model Selection",
            info="Which model would you like to use?",
        ),
    ],
    outputs=gr.Textbox(
        label="Results",
        lines=1,
        show_copy_button=True,
        show_label=True,
        placeholder="📢❗🚨Results will be displayed here.",
    ),
    allow_flagging=False,
)

demo.launch()



Running on local URL:  http://127.0.0.1:7861
IMPORTANT: You are using gradio version 3.48.0, however version 4.29.0 is available, please upgrade.
--------

To create a public link, set `share=True` in `launch()`.


