In [15]:
from transformers import BertTokenizer, BertModel, AdamW

In [20]:
load_path = "model_checkpoint.pth"
checkpoint = torch.load(load_path)

class FakeNewsClassifier(nn.Module):
    def __init__(self, bert_model):
        super(FakeNewsClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(768, 2)  # 768 is the BERT hidden size

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

classifier_model = FakeNewsClassifier(model)
optimizer = AdamW(classifier_model.parameters(), lr=2e-5)
total_steps = 10000  # 10 epochs
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-5, total_steps=total_steps)

# Load model state dict
classifier_model.load_state_dict(checkpoint["model_state_dict"])
classifier_model.eval()

# Load optimizer and scheduler state dicts
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

# Load metrics
validation_accuracy = checkpoint["metrics"]["validation_accuracy"]
test_accuracy = checkpoint["metrics"]["test_accuracy"]

#set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier_model = classifier_model.to(device)

print("Model and metrics loaded successfully!")

Model and metrics loaded successfully!


In [21]:
# User input
user_text = input("Enter the text to classify: ")

# Tokenize and preprocess the user input
encoded_input = tokenizer.encode_plus(
    user_text,
    max_length=512,
    truncation=True,
    padding="max_length",
    add_special_tokens=True,
    return_attention_mask=True,
    return_token_type_ids=True,
    return_tensors="pt"
)
input_ids = encoded_input["input_ids"].to(device)
attention_mask = encoded_input["attention_mask"].to(device)
token_type_ids = encoded_input["token_type_ids"].to(device)

# Make predictions
with torch.no_grad():
    logits = classifier_model(input_ids, attention_mask, token_type_ids)
    probabilities = nn.functional.softmax(logits, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()

# Map predicted class to label
label_map = {1: "Real News", 0: "Fake News"}
predicted_label = label_map[predicted_class]

# Print the predicted label
print("Predicted Label:", predicted_label)

Enter the text to classify: North Korea blew up the moon.
Predicted Label: Fake News
