In [4]:
import pandas as pd
import torch
from tqdm import tqdm
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from collections import Counter

tokenizer = AutoTokenizer.from_pretrained("mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis")
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis")

# Check if a GPU is available and if not, use a CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the GPU
model = model.to(device)

df = pd.read_csv('pq_metadata.csv')

for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    text = row['Full text']
    chunk_size = 512
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]

    sentiments = []
    confidences = []
    for chunk in chunks:
        inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512).to(device)
        outputs = model(**inputs)
        probabilities = F.softmax(outputs.logits, dim=-1)
        _, predicted = torch.max(probabilities, 1)
        max_prob = torch.max(probabilities)
        sentiments.append(predicted.item())
        confidences.append(max_prob.item())

    # Majority vote for sentiment
    counter = Counter(sentiments)
    majority_vote_sentiment = counter.most_common(1)[0][0]

    # Average confidence for the majority vote sentiment
    average_confidence = sum(confidences) / len(confidences)

    if majority_vote_sentiment == 0:
        df.at[index, 'Sentiment'] = 'Negative'
    elif majority_vote_sentiment == 1:
        df.at[index, 'Sentiment'] = 'Neutral'
    else:
        df.at[index, 'Sentiment'] = 'Positive'
    df.at[index, 'Probability'] = average_confidence

100%|██████████| 7070/7070 [09:29<00:00, 12.41it/s]


In [7]:
# df.to_csv('pq_metadata_sentiment.csv', index=False, encoding='utf-8')

# real_estate_df = df[df['Section'] == 'Real Estate']
# real_estate_df.to_csv('real_estate_sentiment.csv', index=False, encoding='utf-8')