In [10]:
from model_with_feature_classes_1 import GPTModel
import torch
from torch import nn
import tiktoken

In [11]:
tokenizer = tiktoken.get_encoding("gpt2")

In [12]:
device = "cpu"

In [13]:
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"

BASE_CONFIG = {
    "vocab_size": 50257,
    "context_length": 1024,
    "drop_rate": 0.1,
    "qkv_bias": True
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}
}

BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

In [14]:
model = GPTModel(BASE_CONFIG)

num_classes = 2
model.out_head = torch.nn.Linear(BASE_CONFIG["emb_dim"], num_classes)

model.load_state_dict(torch.load("review_classifier.pt", map_location="cpu"))
model.eval();

In [15]:
def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
  model.eval()

  input_ids = tokenizer.encode(text)
  supported_context_length = model.pos_emb.weight.shape[0]   
  input_ids = input_ids[:min(max_length, supported_context_length)]

  input_ids += [pad_token_id] * (max_length - len(input_ids))
  input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
  with torch.no_grad():
    logits = model(input_tensor)[:, -1, :]
  predicted_label = torch.argmax(logits, dim=-1).item()

  return "spam" if predicted_label == 1 else "not spam"

In [16]:
text1 = (
    "You are a winner, you have been specialy selected" 
    " to recieve $1000 cash or a $2000 award."
)

print(classify_review(text1, model, tokenizer, device, 120))

spam


In [17]:
text2 = (
    "hey mate how are you ? I am coming over later, "
    "are you there already ?"
)

print(classify_review(text2, model, tokenizer, device, 120))

not spam


In [19]:
text3 = (
    "Answer in the next hour and you can win a great price of $1000, "
    "do not miss this chance or you will get another one tomorrow"
)

print(classify_review(text3, model, tokenizer, device, 120))

spam


In [20]:
text4 = (
    "hey there just want to remind you about the $500 for the ticket, "
    "bring it tomorrow when we meet at the arena."
)

print(classify_review(text4, model, tokenizer, device, 120))

not spam
