In [1]:
import torch
from torch.utils.data import DataLoader, random_split, Subset
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from sklearn.metrics import accuracy_score, classification_report

from model import MultimodalModel, StockDataset

import pandas as pd

[nltk_data] Downloading package words to
[nltk_data]     C:\Users\bbala_n314ugx\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\bbala_n314ugx\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\bbala_n314ugx\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\bbala_n314ugx\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\bbala_n314ugx\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [2]:
# returns a tag based on % delta between close prices
def label(curr_close, next_close):
    pct_delta = (next_close - curr_close) / curr_close;
    if pct_delta <= -0.01:
        return 'bearish'
    elif pct_delta > -0.01 and pct_delta < 0.01:
        return 'neutral'
    else:
        return 'bullish'
df = pd.read_csv('data/stock_news_data_AAPL.csv', on_bad_lines='skip', low_memory=False)
df = df.sort_values(by=['date'])
df.set_index('date',inplace=True)

# labels dataset
labels = []
for i in range(len(df)-1):
    curr_close = df.iloc[i]['Stock Close']
    next_close = df.iloc[i+1]['Stock Close']
    l = label(curr_close, next_close)
    labels.append(l)
labels.append(None)
df['label'] = labels
df = df.dropna()


In [3]:
df.head()

Unnamed: 0_level_0,Article Headline,Article URL,Article Text,overall_sentiment_score,overall_sentiment_label,Stock Open,Stock Close,Stock High,Stock Low,volume,num_trades,adj_close,label
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
2022-03-01,"US stocks fall, oil tops $105 as Ukraine crisi...",https://www.aljazeera.com/economy/2022/3/1/us-...,A surge in oil sent shivers through risky asse...,-0.27746,Somewhat-Bearish,164.695,163.2,166.6,161.97,79455454.0,701957.0,164.167482,bullish
2022-03-02,"Rich Russians turn to luxury jewellery, watche...",https://www.aljazeera.com/economy/2022/3/2/ric...,With sanctions on Russia sending the ruble plu...,-0.118323,Neutral,164.39,166.56,167.36,162.95,76135254.0,631927.0,165.810466,neutral
2022-03-03,Are You an Investor Needing Some Calm Guidance?,https://www.fool.com/investing/2022/03/03/are-...,Read this.,-0.04104,Neutral,168.47,166.23,168.91,165.55,73779442.0,622341.0,166.927454,bearish
2022-03-04,Marvell's ( MRVL ) Q4 Earnings and Revenues ...,https://www.zacks.com/stock/news/1877623/marve...,Marvell's (MRVL) Q4 top and bottom lines refle...,0.136708,Neutral,164.49,163.17,165.55,162.1,80761684.0,710586.0,163.402599,bearish
2022-03-07,EPAM Shares Continue to Fall on Ukraine Crisis...,https://www.zacks.com/stock/news/1878525/epam-...,EPAM Systems' (EPAM) share price has plunged s...,-0.046687,Neutral,163.36,159.3,165.02,159.04,92893526.0,803961.0,161.40379,bearish


In [4]:
# Instantiate the dataset
dataset = StockDataset(df, extraction_type='bert', lookback=3)

In [5]:
dataset.data

Unnamed: 0,Article Headline,Article URL,Article Text,overall_sentiment_score,overall_sentiment_label,Stock Open,Stock Close,Stock High,Stock Low,volume,num_trades,adj_close,label,SMA,EMA,RSI,MACD,ATR
0,EXCLUSIVE from Bitcoin 2022: TradeZing CEO Jor...,https://www.benzinga.com/markets/cryptocurrenc...,"detroit-based benzinga , a medium and data pro...",-0.055666,Neutral,163.92,165.07,166.5984,163.57,68843424.0,574580.0,164.928106,bullish,166.920000,166.177407,27.659200,2.103217,4.206650
1,"China Lockdowns, Factory Closures Seen Hurting...",https://www.investors.com/news/technology/appl...,"apple stock : china lockdown , factory closure...",-0.051895,Neutral,165.02,167.40,167.8200,163.91,67566065.0,546660.0,166.548430,neutral,165.920000,166.788704,51.523573,1.789685,4.107767
2,EXCLUSIVE From Bitcoin 2022: How To Increase L...,https://www.benzinga.com/markets/cryptocurrenc...,"detroit-based benzinga , a medium and data pro...",-0.053154,Neutral,168.76,167.23,168.8800,166.10,67515353.0,590666.0,167.349497,neutral,166.566667,167.009352,49.728204,1.510084,3.665178
3,Garmin ( GRMN ) Strengthens Fitness Segment ...,https://www.zacks.com/stock/news/1904502/garmi...,garmin ( grmn ) expand fitness offering with t...,0.151760,Somewhat-Bullish,168.91,166.42,171.5300,165.91,86895494.0,748146.0,168.808299,bearish,167.016667,166.714676,39.813029,1.209199,4.316785
4,"SNAP Q1 Earnings Miss Estimates, User Growth A...",https://www.zacks.com/stock/news/1905492/snap-...,snap 's ( snap ) first-quarter result reflect ...,-0.051587,Neutral,166.46,161.79,167.8699,161.50,84550744.0,725782.0,164.231066,neutral,165.146667,164.252338,14.693530,0.590338,5.001157
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
249,Democratic presidential longshot Marianne Will...,https://www.marketwatch.com/story/democratic-p...,democrat be close rank behind president biden ...,0.107952,Neutral,164.59,165.21,166.3200,163.83,48280881.0,476387.0,165.025855,neutral,163.623333,164.262721,69.202059,2.973583,3.293847
250,Apple Offers High-Yield Savings Accounts in La...,https://www.barrons.com/articles/apple-card-sa...,apple card saving account be the company 's la...,0.114923,Neutral,165.09,165.23,165.3900,164.03,41531918.0,484303.0,164.804537,neutral,165.333333,164.746361,69.363014,2.963389,2.649232
251,Apple Card Savings Account Launched! Check min...,https://www.financialexpress.com/business/inve...,apple card saving account launch ! check minim...,0.254145,Somewhat-Bullish,166.10,166.47,167.4100,165.65,49948656.0,495451.0,166.353505,neutral,165.636667,165.608180,79.383367,3.020549,2.492821
252,Committed to investing across the country: App...,https://www.business-standard.com/india-news/c...,commit to invest across the country : apple ce...,0.424460,Bullish,165.80,167.63,168.1600,165.54,47848162.0,475096.0,167.278824,neutral,166.443333,166.619090,85.868847,3.123446,2.535214


In [6]:
# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
# TODO: padding/truncate here
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataset, val_dataset = Subset(dataset, range(train_size)), Subset(dataset, range(train_size, train_size + val_size))

In [7]:
len(train_dataset) == train_size

True

In [8]:
len(val_dataset)==val_size

True

In [9]:
train_size

203

In [10]:
val_size

51

In [11]:
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [12]:
# Instantiate the model
model = MultimodalModel(hidden_size=64, num_classes=5, extraction_type='bert', lstm_or_gru='gru')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
device

device(type='cpu')

In [14]:
# Define loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4)

In [15]:
# Train the model
from sklearn.metrics import classification_report
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    # for target in train_loader:
        # print(type(target), target)
        # break
    for article_input, numerical_data, target in train_loader:
        # article_input = {k: v.to(device) for k, v in article_input.items()}
        # numerical_data = numerical_data.to(device)

        # Forward pass
        outputs = model(article_input, numerical_data)
        
        # Compute loss
        _, preds = torch.max(outputs, dim=1)
        # targets = target.to(device)
        # print(outputs)
        loss = criterion(outputs, target)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Validate the model
    model.eval()
    total_correct = 0
    with torch.no_grad():
        allpred, alltarg = [],[]
        for article_input, numerical_data, target in val_loader:
            # article_input = {k: v.to(device) for k, v in article_input.items()}
            # numerical_data = numerical_data.to(device)

            # Forward pass
            outputs = model(article_input, numerical_data)
            
            # Compute accuracy
            _, predicted = torch.max(outputs, dim=1)
            _, targets = torch.max(outputs, dim=1)
            allpred.extend(predicted.cpu().numpy())            
            alltarg.append(targets.cpu().numpy())
            total_correct += (predicted == target).sum().item()
    
    print(classification_report(allpred, alltarg))
    accuracy = total_correct / val_size
    print(f'Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy:.4f}')


KeyboardInterrupt: 