In [2]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel

class SentimentAnalysisModel(nn.Module):
    def __init__(self):
        super(SentimentAnalysisModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)  # Output size is 1

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        outputs = self.classifier(pooled_output)
        return outputs

# Load the saved model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SentimentAnalysisModel()
model.load_state_dict(torch.load('bert_stock_sentiment_model.pth', map_location=device))
model.to(device)
model.eval()

# Create a BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define a function to prepare the input data and make predictions
def prepare_input(text):
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=512,
        return_attention_mask=True,
        return_tensors='pt'
    )
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # Make predictions
    outputs = model(input_ids, attention_mask=attention_mask)

    # Add thresholding to output 1 for good sentiment and -1 for bad sentiment
    threshold = 0.5
    sentiment = torch.sigmoid(outputs).item()
    if sentiment > threshold:
        print("Good sentiment!")
        sentiment_output = 1
    else:
        print("Bad sentiment!")
        sentiment_output = -1

    print("Sentiment output:", sentiment_output)

# Prepare the input data and make predictions
text = "Moneycontrol had reported in April that KKR is in talks with Torrent Pharma to sell its majority stake in JB Chemicals."
prepare_input(text)

ModuleNotFoundError: No module named 'transformers'