In [None]:
# Load the dataset
df = pd.read_feather("Data/Data1.feather")
print(df.head(10))

# Initialize BERT tokenizer
tokenizer_path = 'bert-base-uncased'
if os.path.exists(tokenizer_path):
    tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
else:
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.save_pretrained(tokenizer_path)

# Add additional stop words for financial context
stop_words_appended = stopwords.words('english')
stop_words_appended.extend(['rt', 'ep'])

# Prepare datasets
X = df[['text', 'base_url', 'company_names']].apply(lambda x: ' '.join(x), axis=1)
Y = df['sentiment']
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

train_dataset = FinancialSentimentDataset(X_train.tolist(), y_train.tolist(), tokenizer)
test_dataset = FinancialSentimentDataset(X_test.tolist(), y_test.tolist(), tokenizer)

In [None]:
# Load or train the model
model_path = './7_BERT/Results/1'
if os.path.exists(model_path):
    model = BertForSequenceClassification.from_pretrained(model_path)
else:
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3)
    training_args = TrainingArguments(
        output_dir=model_path,
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=16,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset
    )
    trainer.train()

In [None]:
# Evaluate the model
trainer = Trainer(model=model)
results = trainer.evaluate(test_dataset)
print(results)

In [None]:
# Plot the results
sns.set(style='whitegrid')
plt.figure(figsize=(10, 5))
plt.title('Model Accuracy and Loss')
plt.plot([x['epoch'] for x in trainer.state.log_history], [x['loss'] for x in trainer.state.log_history], label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Value')
plt.legend()
plt.show()