In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
model_path = 'bi_lstm_model.pt'
label_map = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}

In [4]:
max_length = 256
embed_dim = 300
hidden_dim = 512
num_classes = 3

In [5]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [6]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first = True, bidirectional = True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        
        x = self.embedding(x)
        lstm_out, (h_n, c_n) = self.lstm(x)
        h_final = torch.cat((h_n[0], h_n[1]), dim = 1)
        x = self.dropout(h_final)
        
        return self.fc(x)

In [7]:
model = LSTM(
    vocab_size = tokenizer.vocab_size,
    embed_dim = embed_dim,
    hidden_dim = hidden_dim,
    output_dim = num_classes
).to(device)

In [8]:
model.load_state_dict(torch.load(model_path, map_location = device, weights_only = False))
model.eval()

LSTM(
  (embedding): Embedding(30522, 300)
  (lstm): LSTM(300, 512, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=1024, out_features=3, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [9]:
def predict(premise, hypothesis):
    encoded = tokenizer(
        premise,
        hypothesis,
        padding = 'max_length',
        truncation = True,
        max_length = max_length,
        return_tensors = 'pt'
    )
    input_ids = encoded['input_ids'].to(device)

    with torch.no_grad():
        logits = model(input_ids)
        pred = torch.argmax(logits, dim = 1).item()

    return label_map[pred]

In [10]:
while True:
    premise = input("\nPremise (or type 'exit' to quit): ")
    if premise.lower() == 'exit':
        break
    hypothesis = input('Hypothesis : ')
    result = predict(premise, hypothesis)
    print('Prediction :', result)


Premise (or type 'exit' to quit):  The woman is drinking coffee at the cafe.
Hypothesis :  The woman is consuming a beverage.


Prediction : entailment



Premise (or type 'exit' to quit):  The dog is sleeping on the couch.
Hypothesis :  The dog is resting indoors.


Prediction : entailment



Premise (or type 'exit' to quit):  A woman is painting a portrait.
Hypothesis :  The portrait was painted by an artist.


Prediction : entailment



Premise (or type 'exit' to quit):  The child is reading a book in the library.
Hypothesis :  The child enjoys watching movies at home.


Prediction : neutral



Premise (or type 'exit' to quit):  exit
