In [1]:
import warnings
import string
import joblib
import multiprocessing
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from collections import defaultdict
from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer
from transformers import BertModel
from torch.nn import functional as F

warnings.filterwarnings("ignore")

I0708 14:37:59.328766  4416 file_utils.py:39] PyTorch version 1.2.0 available.


In [2]:
class SentenceBert():
    """
    A common approach to zero shot learning using Sentence-BERT.
    Reference from https://joeddav.github.io/blog/2020/05/29/ZSL.html
    """
    def __init__(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained('deepset/sentence_bert')
        self.model = AutoModel.from_pretrained('deepset/sentence_bert')
        self.model = self.model.to(self.device)
        
    def get_similarity(self, sentence, labels):
        """
        Parameters:
            sentence: str
            label: list
        """
        # Run inputs through model and mean-pool over the sequence dimension to get sequence-level representations
        inputs = self.tokenizer.batch_encode_plus(
            [sentence] + labels,
            return_tensors='pt',
            pad_to_max_length=True)
        input_ids = inputs['input_ids'].to(self.device)
        attention_mask = inputs['attention_mask'].to(self.device)
        with torch.no_grad():
            output = self.model(input_ids, attention_mask=attention_mask)[0]
        sentence_rep = output[:1].mean(dim=1)
        label_reps = output[1:].mean(dim=1)
    
        # Now find the labels with the highest cosine similarities to the sentence
        similarities = F.cosine_similarity(sentence_rep, label_reps)
        closest = similarities.argsort(descending=True)
        
        sim_dict = defaultdict()
        for ind in closest:
            sim_dict[labels[ind]] = (similarities[ind].item())
            
        return sim_dict

In [3]:
df = joblib.load("reuters_news.joblib")

In [4]:
labels = ['forex', 'finance', 'stocks']
SB = SentenceBert()
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    sim_dict = SB.get_similarity(row["title"], labels)
    for i in range(len(labels)):   
        df.loc[index, labels[i]] = sim_dict[labels[i]]

I0708 14:38:02.867355  4416 configuration_utils.py:285] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/deepset/sentence_bert/config.json from cache at C:\Users\YangWang/.cache\torch\transformers\1a89cc2d2dfc0e9beb7d49442e942849b45bbbf49d0b004f6be414b44c4e01fa.f9ed8a8332fc340fb779f9e83f1745369bbe51a3ac3e1c0d0dc5c3cf72ef4626
I0708 14:38:02.871345  4416 configuration_utils.py:321] Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

I0708 14:38:02.873347  4416 tokenization_utils.py:929] Model name 'deepset/sentence_bert' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




In [5]:
df = df.sort_values(by="finance", axis=0, ascending=False)
df = df.reset_index(drop=True)
df.head()

Unnamed: 0,title,date,query,forex,finance,stocks
0,"Beware of debt costs and an inflationary bite,...",2020-07-05 19:00:00-08:00,google,0.064136,0.498809,0.287072
1,"Beware of debt costs and an inflationary bite,...",2020-07-05 19:17:00-08:00,google,0.064136,0.498809,0.287072
2,Deals of the day-Mergers and acquisitions,2020-07-07 16:00:00-08:00,google,0.157465,0.358413,0.32864
3,Zoom rolls out hardware subscription service,2020-07-07 10:48:00-08:00,google,0.101178,0.283011,0.171321
4,Zoom rolls out hardware subscription service,2020-07-07 09:58:00-08:00,google,0.101178,0.283011,0.171321
