# Transformers and Attention 

**Goal 1: Implement a self-attention class that uses bi-directional attention**

In [1]:
# import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
class Attention (nn.Module):
    """
    Implementation of self/cross bi-directional/uni-directional attention. To use this for self attention, 
    pass the same input and context size on initialization. To use bi-directional/uni-directional 
    attention, pass different masks to the forward method
    
    Args: 
        - X: sequence to find contextual representation of, \in \mathcal{R}^{d_{x}, l}
        - Z: sequence to attent to, \in \mathcal{R}^{d_{z}, l}
    
    Returns: 
        - X_contextual: Updated embeddings of tokens in X with information from tokens in Z incorporated 
    """
    
    def __init__ (self, inp_embd_size, context_embd_size, attn_embd_size, out_embd_size, **kwargs):
        super(Attention, self).__init__()
        self.attn_embd_size = attn_embd_size
        self.out_embd_size = out_embd_size
        
        self.W_q = nn.Linear(in_features=inp_embd_size, out_features=self.attn_embd_size) # query embedding 
        self.W_k = nn.Linear(in_features=context_embd_size, out_features=self.attn_embd_size) # key embedding
        self.W_v = nn.Linear(in_features=context_embd_size, out_features=self.out_embd_size)
        
    def forward (self, primary_seq, context_seq, attn_mask, **kwargs):
        """
        Args: 
            - primary_seq: Sequence to find contextual embeddings of. Shape (inp_embd_size, context_maxlen)
            - context_seq: Contextual sequence of shape (context_embd_size, context_maxlen)
            - attn_mask: Determines which of the tokens in the context to be masked
        Q = W_q @ X
        K = W_k @ Z
        V = W_v @ Z
        X_contextual = softmax(QK^T / sqrt(attn_embd_size)) @ V
        """
        Q = self.W_q(X.T)
        K = self.W_k(Z.T)
        V = self.W_v(Z.T)
        
        unnormalized_weights = Q @ K.T / self.attn_embd_size**0.5
        unnormalized_weights[:, attn_mask == 0] = -1e10
        
        attn_weights = F.softmax(unnormalized_weights, dim=0)
        contextual_embd = attn_weights @ V
        
        return contextual_embd.T

In [7]:
inp_embd_size = 10
context_embd_size = 13
max_len = 5

sa = Attention(inp_embd_size=inp_embd_size,
               context_embd_size=context_embd_size,
               attn_embd_size=15,
               out_embd_size=25)

X = torch.rand((inp_embd_size, max_len))
Z = torch.rand((context_embd_size, 7))

In [8]:
sa(X, Z, attn_mask = torch.tensor([1, 1, 1, 1,1,0,0]))

tensor([[-0.2675, -0.2833, -0.2504, -0.2697, -0.2771],
        [ 0.8851,  0.9592,  0.8076,  0.8960,  0.9251],
        [-0.1002, -0.1057, -0.0918, -0.1008, -0.1032],
        [ 0.2821,  0.3092,  0.2548,  0.2873,  0.2978],
        [-0.1855, -0.1995, -0.1699, -0.1888, -0.1943],
        [-0.0884, -0.0966, -0.0830, -0.0892, -0.0929],
        [ 0.8463,  0.9158,  0.7716,  0.8572,  0.8848],
        [-0.0792, -0.0911, -0.0689, -0.0811, -0.0864],
        [ 0.5071,  0.5524,  0.4594,  0.5125,  0.5315],
        [-0.4519, -0.4914, -0.4097, -0.4590, -0.4738],
        [-0.1769, -0.1870, -0.1658, -0.1780, -0.1817],
        [-0.7086, -0.7620, -0.6555, -0.7165, -0.7361],
        [ 0.3986,  0.4333,  0.3629,  0.4051,  0.4170],
        [-0.0143, -0.0163, -0.0122, -0.0135, -0.0138],
        [-0.1992, -0.2188, -0.1764, -0.2036, -0.2107],
        [ 0.4346,  0.4655,  0.4013,  0.4380,  0.4525],
        [-0.6479, -0.7024, -0.5956, -0.6568, -0.6778],
        [-0.1916, -0.2033, -0.1793, -0.1907, -0.1982],
        [-

# Adding Context to Short-Text Inputs

### 1: Sentiment Analysis with Contextual Information

Traditional sentiment analysis starts with labeled sentiment text data and follows by fitting a model to predict sentiment. Current SoTA models in NLP use local contextual information by attending to different parts of the input. In a dynamic setting, the same input could change from having positive sentiment to negative sentiment depending on public perception. The goal of this post is to build a model that can capture and incorporate global contextual information into sentiment prediction.

- Input: BERT encoded vector of a tweet
- Context: $n$ most recent tweets before the input
- Cross-attention between input and context to find a contextual representation of the input sequence
- Classifier head on-top of the contextual representation

**Goal 1:** Using cross-attention, create a contextual representation of a tweet with context being the previous $n$ tweets.


### 2: Contextual Company Embeddings using News

Can cross-attention be used to create contextual embeddings for company names using the news? 

**Goal 1:** Get contextual representation of one company where the context is news in a topic. 


### Questions

- Can I train a model with an attention head on-top of an LLM to make contextual classifications?
- How can an attention head be used to supplement short-inputs? 
- How is context related to simply adding an additional input? Is there a difference?
- Should self-attention or cross-attention be used here? Intuitively, what would be the difference in the embeddings? I am assuming cross-attention gives more weight to the context sequence and less weight to the primary sequence/tokens. My ultimate goal is to add context to short-text inputs. 
- Compare self-attention VS cross-attention for adding context to short-text inputs. How to evaluate the embeddings? 

In [3]:
import pandas as pd
from datasets import load_dataset
import re
import numpy as np

pd.set_option("display.max_rows", 5000)
pd.set_option("display.max_columns", 5000)
pd.set_option("max_colwidth", 5000)

In [33]:
news = pd.read_csv("/Users/jramkissoon/Documents/data/blog/financial-tweets/stockerbot-export.csv", 
                   error_bad_lines=False)
news = news.drop(columns=["id"])
print(news.shape)
news.head()

(28264, 7)




  news = pd.read_csv("/Users/jramkissoon/Documents/data/blog/financial-tweets/stockerbot-export.csv",
Skipping line 731: expected 8 fields, saw 13
Skipping line 2836: expected 8 fields, saw 15
Skipping line 3058: expected 8 fields, saw 12
Skipping line 3113: expected 8 fields, saw 12
Skipping line 3194: expected 8 fields, saw 17
Skipping line 3205: expected 8 fields, saw 17
Skipping line 3255: expected 8 fields, saw 17
Skipping line 3520: expected 8 fields, saw 17
Skipping line 4078: expected 8 fields, saw 17
Skipping line 4087: expected 8 fields, saw 17
Skipping line 4088: expected 8 fields, saw 17
Skipping line 4499: expected 8 fields, saw 12



Unnamed: 0,text,timestamp,source,symbols,company_names,url,verified
0,VIDEO: ‚ÄúI was in my office. I was minding my own business...‚Äù ‚ÄìDavid Solomon tells $GS interns how he learned he wa‚Ä¶ https://t.co/QClAITywXV,Wed Jul 18 21:33:26 +0000 2018,GoldmanSachs,GS,The Goldman Sachs,https://twitter.com/i/web/status/1019696670777503745,True
1,The price of lumber $LB_F is down 22% since hitting its YTD highs. The Macy's $M turnaround is still happening.‚Ä¶ https://t.co/XnKsV4De39,Wed Jul 18 22:22:47 +0000 2018,StockTwits,M,Macy's,https://twitter.com/i/web/status/1019709091038547968,True
2,Who says the American Dream is dead? https://t.co/CRgx19x7sA,Wed Jul 18 22:32:01 +0000 2018,TheStreet,AIG,American,https://buff.ly/2L3kmc4,True
3,Barry Silbert is extremely optimistic on bitcoin -- but predicts that 99% of new crypto entrants are ‚Äúgoing to zero‚Ä¶ https://t.co/mGMVo2cZgY,Wed Jul 18 22:52:52 +0000 2018,MarketWatch,BTC,Bitcoin,https://twitter.com/i/web/status/1019716662587740160,True
4,How satellites avoid attacks and space junk while circling the Earth https://t.co/aHzIV3Lqp5 #paid @Oracle https://t.co/kacpqZWiDJ,Wed Jul 18 23:00:01 +0000 2018,Forbes,ORCL,Oracle,http://on.forbes.com/6013DqDDU,True


In [45]:
news["text"] = news.text.apply(lambda x: re.sub(r"https://\S+", "", x))
news["RT"] = news.text.apply(lambda x: re.search("^RT @.*:", x) != None)
print(news.RT.sum())
news.sample(10)

4500


Unnamed: 0,text,timestamp,source,symbols,company_names,url,verified,RT
10541,RT @Optionsonar1: $260000.00 of bearish unusual option activity detected for $SYF,Mon Jul 16 16:18:34 +0000 2018,teebizy,SYF,Synchrony Financial,http://optionsonar.com/unusual-option-activity/syf/latest-trades,False,True
287,$WLTW Earnings volatility concerns drive reinsurance buying Willis Towers Watson survey finds Nasdaq:WLTW,Mon Jul 09 13:11:54 +0000 2018,StockTexts,WLTW,Willis Towers Watson Public Limited Company,http://www.globenewswire.com/news-release/2018/07/09/1534686/0/en/Earnings-volatility-concerns-drive-reinsurance-buying-Willis-Towers-Watson-survey-finds.html,False,False
5443,New UOA detected for $MPC: 1356 $75.0 PUT options expiring 2018-08-17 traded for $4.9 bought on the ask,Fri Jul 13 19:44:01 +0000 2018,liveoptiondata,MPC,Marathon Petroleum Corporation,https://www.optionsonar.com/unusual-option-activity/MPC/latest-trades,False,False
2003,7/11 50D MA Watch List: $ADSK $MU $OIH $OKTA $BG $KEY $BA $JD $LNG $XME $FXE $GPS $GM $CME $C $EMR $LBTYA $SOGO‚Ä¶,Wed Jul 11 13:00:20 +0000 2018,TradeAcademyCo,JCI,Johnson Controls International plc,https://twitter.com/i/web/status/1017030833352241153,False,False
26406,$DWCH $EBAY $AXP $AA 5 Stocks Moving In Wednesday's After-Hours Session -,Wed Jul 18 21:21:04 +0000 2018,BenzingaMedia,AXP,American Express Company,http://tinyurl.com/ydaqpgxp,False,False
15161,$GG Goldcorp Inc. Option Order Flow Sentiment is 81.9% Bullish.,Tue Jul 17 17:45:42 +0000 2018,MC_OptionTrades,GG,Goldcorp Inc.,https://marketchameleon.com/Overview/GG/OptionOrderSentiment/,False,False
10450,Fluor Co. $NEW $FLR Expected to Announce Earnings of $0.69 Per Share,Mon Jul 16 15:57:55 +0000 2018,WatchlistN,FLR,Fluor Corporation,http://zpr.io/6XyTv,False,False
13769,$CERN $HSIC $PAYX Bearish MACD crossover,Tue Jul 17 13:01:26 +0000 2018,themaxpain,HSIC,Henry Schein,,False,False
1029,Stericycle Inc $SRCL Given Consensus Recommendation of ‚ÄúHold‚Äù by Brokerages,Tue Jul 10 12:33:59 +0000 2018,TickerReport,SRCL,Stericycle,http://tickerreport.com/?p=3626453,False,False
6695,Credit Suisse Group Boosts Nordstrom $JWN Price Target to $52.00,Sat Jul 14 23:26:49 +0000 2018,MareaInformativ,JWN,Nordstrom,http://www.mareainformativa.com/?p=442513,False,False


In [53]:
verified = news[~news.RT].reset_index(drop=True)
verified[verified.symbols =="NFLX"].sort_values("timestamp")[["text"]].head(10)

Unnamed: 0,text
22253,Netflix $NFLX just released quarterly 10-Q. Quarterly net income increased from 127M to 462 Million!!! 400% ‚¨ÜÔ∏è
22269,#NASDAQ MOST VOLUME $HMNY -9.09% [Volume: +82.41% ~ 81012800] $AMD -0.53% [Volume: -1.8% ~ 40776700] $MU +0.35‚Ä¶
22273,New article on $NFLX at Today I went over my bullish and bearish trade ideas from 6/18/18 a‚Ä¶
22346,NOW OFFERING 7 Day FREE Trial to options day trading team Room or $TWTR feed $FB $AAPL $NFLX $TSLA $AMZN $GOOGL
22415,When I try to get in on some hot $NFLX calls
22429,$MS expecting a $BAC like move when it reported earnings. $CPAH HOD didn't come have to exit before the rugpull.‚Ä¶
22432,Netflix To Bring Comedy To New Sirius XM Channel $NFLX $SIRI $SPOT $T $AMZN $DIS
22434,With past performance like this how can you not sign up for a Free 7-day trial to Winning‚Ä¶
22460,{VIDEO} Stock Analysis + Trade Ideas: $TSLA $NFLX $BIDU $BABA $GOOGL $DIS - click link to watch &gt;&gt;‚Ä¶
22502,Tech Sector Buzz: Netflix Earnings Amazon and the App Store $NFLX


In [55]:
input_ = "Netflix releases earnings numbers of $1"
context_ = '. '.join([
    "Netflix earnings expected to be $2",
    "Netflix not having a good quarter",
    "Netflix sees quarterly earnings at $1.2"
])

print(context_)

Netflix earnings expected to be $2. Netflix not having a good quarter. Netflix sees quarterly earnings at $1.2


#### Model

In [13]:
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel, BertConfig, AdamW, get_linear_schedule_with_warmup
import pytorch_lightning as pl
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [14]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_base = BertModel.from_pretrained("bert-base-uncased")

Downloading (‚Ä¶)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (‚Ä¶)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (‚Ä¶)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class ContextualSentimentModel(nn.Module):
    def __init__(self, bert_base, num_labels):
        super(ContextualSentimentModel, self).__init__()
        self.bert_base = bert_base
        self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=12)
        self.classifier = nn.Linear(768, num_labels)
        
    def forward(self, inputs, context):
        """
        Args: 
            - inputs: dictionary with keys corresponding to outputs of tokenizer(...)
            - context: dictionary with keys corresponding to outputs of tokenizer(...)
        """
        # input embedding:
        input_embd = self.bert_base(input_ids=inputs["input_ids"], 
                                    attention_mask=inputs["attention_mask"]).last_hidden_state
        
        # context embedding: 
        context_embd = self.bert_base(input_ids=context["input_ids"], 
                                      attention_mask=context["attention_mask"]).last_hidden_state
        
        # Apply cross-attention
        query = input_embd[:, -1, :].unsqueeze(1)  # Use the last token as the query
        key = context_embd
        attn_output, _ = self.cross_attention(query, key, key, key_padding_mask=(1 - inputs["attention_mask"]))
        
        # Pass the attended output through the classification head
        logits = self.classifier(attn_output.squeeze(1))
        return logits

In [68]:
inputs = tokenizer(input_, return_tensors="pt")
context = tokenizer(context_, return_tensors="pt")

input_encoding = bert_base(**tokenizer(input_, return_tensors="pt"))
context_encoding = bert_base(**tokenizer(context_, return_tensors="pt"))

input_embd = input_encoding.last_hidden_state
context_embd = context_encoding.last_hidden_state

cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=12)

In [73]:
print(context_embd.shape)
print(query.shape)

torch.Size([1, 26, 768])
torch.Size([1, 1, 768])


In [71]:
query = input_embd[:, -1, :].unsqueeze(1)  # Use the last token as the query
key = context_embd

# attn_output, _ = self.
cross_attention(query, key, key) # key_padding_mask=(1 - inputs["attention_mask"]))

RuntimeError: shape '[1, 12, 64]' is invalid for input of size 19968

In [None]:
class SentimentModelLightning(pl.LightningModule):
    def __init__(self, model, learning_rate=2e-5, num_training_steps=None, num_warmup_steps=0):
        super(SentimentModelLightning, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.num_training_steps = num_training_steps
        self.num_warmup_steps = num_warmup_steps
        
    def forward(self, input_ids, attention_mask):
        return self.model(input_ids, attention_mask)
    
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        logits = self(input_ids, attention_mask)
        loss = F.cross_entropy(logits, labels)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        logits = self(input_ids, attention_mask)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        accuracy = accuracy_score(labels.cpu(), preds.cpu())
        return {"val_loss": loss, "val_accuracy": accuracy}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_accuracy = sum([x["val_accuracy"] for x in outputs]) / len(outputs)
        self.log("val_loss", avg_loss)
        self.log("val_accuracy", avg_accuracy)
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = get_linear

In [86]:
# pd.Series(dataset["train"]["label"]).value_counts()
data = pd.DataFrame(dataset["train"][::])
data["topic"] = data.label.apply(lambda x: topics[x])

# clean text: 
data["text"] = data.text.apply(lambda x: x.replace("$AAPL", ""))
data["text"] = data.text.apply(lambda x: re.sub(r"https://\S+", "", x))
data["text"] = data.text.apply(lambda x: re.sub(r"\#\w+", "", x))
data["text"] = data.text.apply(lambda x: re.sub(r"\@\w+", "", x))

data.head()

Unnamed: 0,text,label,topic
0,"Here are Thursday's biggest analyst calls: Apple, Amazon, Tesla, Palantir, DocuSign, Exxon &amp; more",0,Analyst Update
1,"Buy Las Vegas Sands as travel to Singapore builds, Wells Fargo says",0,Analyst Update
2,"Piper Sandler downgrades DocuSign to sell, citing elevated risks amid CEO transition",0,Analyst Update
3,"Analysts react to Tesla's latest earnings, break down what's next for electric car maker",0,Analyst Update
4,"Netflix and its peers are set for a ‚Äòreturn to growth,‚Äô analysts say, giving one stock 120% upside",0,Analyst Update


In [87]:
raw_apple = data[data.text.str.contains("apple", flags=re.IGNORECASE)]
raw_apple = raw_apple[~raw_apple.topic.isin([
    "Fed | Central Banks", "Dividend", "Earnings", "Treasuries | Corporate Debt"
])]
raw_apple.head()

Unnamed: 0,text,label,topic
0,"Here are Thursday's biggest analyst calls: Apple, Amazon, Tesla, Palantir, DocuSign, Exxon &amp; more",0,Analyst Update
13,"Here are Tuesday's biggest analyst calls: Meta, Chipotle, Apple, Tesla, Exxon, Netflix, Sunrun &amp; more",0,Analyst Update
20,"Apple's near-term future looks murky as consumer spending slows, Bernstein says",0,Analyst Update
22,"Here are Monday's biggest analyst calls of the day: Tesla, Apple, Yum, Delta, Fox, Netflix &amp; more",0,Analyst Update
35,"Here are Thursday's biggest analyst calls: Tesla, Amazon, Twitter, Qualcomm, Costco, Apple &amp; more",0,Analyst Update


In [88]:
raw_apple.topic.value_counts()

Company | Product News    66
General News | Opinion    23
Stock Commentary          15
Analyst Update            12
Markets                   11
Macro                      9
Legal | Regulation         8
Stock Movement             6
Name: topic, dtype: int64

In [94]:
apple_train = pd.concat([
    raw_apple.groupby("topic")["text"].apply(lambda x: ". ".join(np.random.choice(x, size=2))).reset_index()
    for _ in range(100)
]).reset_index(drop=True)

apple_train["input"] = "Apple"
apple_train.head()

Unnamed: 0,topic,text,input
0,Analyst Update,"- Apple lowering trade-in prices, implies 'strong demand,' BofA says . More analysts covering Apple are cutting their share-price forecasts, signaling growing concerns about an economic slowdown that could hurt the sales of its products",Apple
1,Company | Product News,Apple plans to slow hiring and spending growth next year in some divisions to cope with a potential economic downturn . Apple isn't planning to backfill roles or add new staff on certain teams,Apple
2,General News | Opinion,"üéßCalorie counts are everywhere. But calories aren‚Äôt all that they appear to be. In the series premiere of Losing it, dives into how we got the calorie so wrong ‚ñ∂Ô∏è Apple: ‚ñ∂Ô∏è Spotify: . Good point but the Apple ecosystem is self reinforcing. I don't think that's the case with Tesla. Each EV is a stand alone product, and some are better than Tesla. Mercedes has better range, the EV Mustang is a much better value, etc. Buyers switch automakers all the time.",Apple
3,Legal | Regulation,": Apple sued by payment card issuers, alleging antitrust competition issues over Apple Pay policies . : Apple sued by payment card issuers, alleging antitrust competition issues over Apple Pay policies",Apple
4,Macro,The White House expects June‚Äôs consumer price index figures to be ‚Äúhighly elevated‚Äù as Americans grappled with substantial increases in the cost of gas and food . The White House expects June‚Äôs consumer price index figures to be ‚Äúhighly elevated‚Äù as Americans grappled with substantial increases in the cost of gas and food,Apple
