In [1]:
import gradio as gr
import torch
import torch.nn as nn
from transformers import AdamW, BertModel, BertTokenizer

# Load the BERT model and tokenizer
model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

# Define the FakeNewsClassifier model
class FakeNewsClassifier(nn.Module):
    def __init__(self, bert_model):
        super(FakeNewsClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(768, 2)  # 768 is the BERT hidden size

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

# Load the model checkpoint
load_path = "model_checkpoint.pth"
checkpoint = torch.load(load_path)

# Create an instance of the FakeNewsClassifier
classifier_model = FakeNewsClassifier(model)

# Load the model state dict
classifier_model.load_state_dict(checkpoint["model_state_dict"])
classifier_model.eval()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier_model = classifier_model.to(device)

# Create a Gradio interface
def classify_fake_news(text):
    encoded_input = tokenizer.encode_plus(
        text,
        max_length=512,
        truncation=True,
        padding="max_length",
        add_special_tokens=True,
        return_attention_mask=True,
        return_token_type_ids=True,
        return_tensors="pt"
    )
    input_ids = encoded_input["input_ids"].to(device)
    attention_mask = encoded_input["attention_mask"].to(device)
    token_type_ids = encoded_input["token_type_ids"].to(device)

    with torch.no_grad():
        logits = classifier_model(input_ids, attention_mask, token_type_ids)
        probabilities = nn.functional.softmax(logits, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()

    label_map = {1: "Real News", 0: "Fake News"}
    predicted_label = label_map[predicted_class]
    return predicted_label

inputs = gr.inputs.Textbox(label="Enter the text to classify")
output = gr.outputs.Textbox(label="Predicted Label")




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  super().__init__(
  super().__init__(


In [2]:
gr.Interface(fn=classify_fake_news, inputs=inputs, outputs=output, title="Fake News Classifier").launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://1451fbc67844a8497c.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


