
# FinBERT Example Notebook

This notebooks shows how to train and use the FinBERT pre-trained language model for financial sentiment analysis. The FinBERT model has been fine-tuned further with r/wallstreetbets reddit comments for the purpose of performing sentiment analysis in this domain.



## Modules 

In [None]:
%%time
!git init
!git remote add origin https://github.com/ProsusAI/finBERT
!git pull origin master
!pip install pandas numpy matplotlib transformers textblob tqdm joblib scikit-learn spacy nltk torch==1.1.0

Initialized empty Git repository in /content/.git/
remote: Enumerating objects: 123, done.[K
remote: Total 123 (delta 0), reused 0 (delta 0), pack-reused 123[K
Receiving objects: 100% (123/123), 60.64 KiB | 4.33 MiB/s, done.
Resolving deltas: 100% (53/53), done.
From https://github.com/ProsusAI/finBERT
 * branch            master     -> FETCH_HEAD
 * [new branch]      master     -> origin/master
Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 15.2MB/s 
Collecting torch==1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/ac/23/a4b5c189dd624411ec84613b717594a00480282b949e3448d189c4aa4e47/torch-1.1.0-cp37-cp37m-manylinux1_x86_64.whl (676.9MB)
[K     |████████████████████████████████| 676.9MB 20kB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhost

In [None]:
%%time
%load_ext autoreload
%autoreload 2
from pathlib import Path
import shutil
import os
import logging
import sys
sys.path.append('..')

from textblob import TextBlob
from pprint import pprint
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from pathlib import Path

from transformers import AutoModelForSequenceClassification
import re

from finbert.finbert import *
import finbert.utils as tools
from functools import partial
from tqdm.auto import tqdm

run_debug=False
use_sampling =True

project_dir = Path.cwd()
pd.set_option('max_colwidth', -1)

CPU times: user 2.15 s, sys: 390 ms, total: 2.54 s
Wall time: 3.19 s




In [None]:
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.ERROR)
print(project_dir)

/content


## Prepare the model

### Setting path variables:
1. `lm_path`: the path for the pre-trained language model (If vanilla Bert is used then no need to set this one).
2. `cl_path`: the path where the classification model is saved.
3. `cl_data_path`: the path of the directory that contains the data files of `train.csv`, `validation.csv`, `test.csv`.
---

In the initialization of `bertmodel`, we can either use the original pre-trained weights from Google by giving `bm = 'bert-base-uncased`, or our further pre-trained language model by `bm = lm_path`


---
All of the configurations with the model is controlled with the `config` variable. 

In [None]:
lm_path = project_dir/'models'/'language_model'/'finbertTRC2' #not using this
cl_path = project_dir/'models'/'classifier_model_1'/'finbert-sentiment'
cl_data_path = project_dir/'data'/'sentiment_data'

###  Configuring training parameters

In [None]:
%%time
# Clean the cl_path
try:
    shutil.rmtree(cl_path) 
except:
    pass

bertmodel = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert",cache_dir=None, num_labels=3)


config = Config(   data_dir=cl_data_path,
                   bert_model=bertmodel,
                   num_train_epochs=5,
                   model_dir=cl_path,
                   max_seq_length = 48,
                   train_batch_size = 64,
                   learning_rate = 2e-5,
                   output_mode='classification',
                   warm_up_proportion=0.2,
                   local_rank=-1,
                   discriminate=True,
                   gradual_unfreeze=False)

05/09/2021 11:58:47 - INFO - filelock -   Lock 139870393923984 acquired on /root/.cache/huggingface/transformers/2120f4f96b5830e5a91fe94d242471b0133b0976c8d6e081594ab837ac5f17bc.ef97278c578016c8bb785f15296476b12eae86423097fed78719d1c8197a3430.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=758.0, style=ProgressStyle(description_…

05/09/2021 11:58:47 - INFO - filelock -   Lock 139870393923984 released on /root/.cache/huggingface/transformers/2120f4f96b5830e5a91fe94d242471b0133b0976c8d6e081594ab837ac5f17bc.ef97278c578016c8bb785f15296476b12eae86423097fed78719d1c8197a3430.lock





05/09/2021 11:58:47 - INFO - filelock -   Lock 139869602332816 acquired on /root/.cache/huggingface/transformers/b3ba5be9f12905cef8d1d18af435dfd568d75466fae4a117a4f20ed5faadd3e3.8764ec40d33a40810fe5d2c1e864945dcf7affafd797ed8ef1b71392bfcf8562.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=437992753.0, style=ProgressStyle(descri…

05/09/2021 11:59:03 - INFO - filelock -   Lock 139869602332816 released on /root/.cache/huggingface/transformers/b3ba5be9f12905cef8d1d18af435dfd568d75466fae4a117a4f20ed5faadd3e3.8764ec40d33a40810fe5d2c1e864945dcf7affafd797ed8ef1b71392bfcf8562.lock



CPU times: user 10.1 s, sys: 1.48 s, total: 11.6 s
Wall time: 20.3 s


`finbert` is our main class that encapsulates all the functionality. The list of class labels should be given in the prepare_model method call with label_list parameter.

In [None]:
finbert = FinBert(config)
finbert.base_model = "ProsusAI/finbert"
finbert.config.discriminate=True
finbert.config.output_mode='classification'

In [None]:
finbert.prepare_model(label_list=['positive','negative','neutral'])

05/09/2021 11:59:07 - INFO - finbert.finbert -   device: cuda n_gpu: 1, distributed training: False, 16-bits training: False
05/09/2021 11:59:07 - INFO - filelock -   Lock 139869518343312 acquired on /root/.cache/huggingface/transformers/a5b1a5451c9cf1702eec1072ac325d4af10e675a654628eab453b8cba2c6b111.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…

05/09/2021 11:59:08 - INFO - filelock -   Lock 139869518343312 released on /root/.cache/huggingface/transformers/a5b1a5451c9cf1702eec1072ac325d4af10e675a654628eab453b8cba2c6b111.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock





05/09/2021 11:59:09 - INFO - filelock -   Lock 139869518637328 acquired on /root/.cache/huggingface/transformers/4c21e8896b03f68c2e028133cf579267c62aba9de03a704a0845704e58eefe9e.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…

05/09/2021 11:59:09 - INFO - filelock -   Lock 139869518637328 released on /root/.cache/huggingface/transformers/4c21e8896b03f68c2e028133cf579267c62aba9de03a704a0845704e58eefe9e.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d.lock





05/09/2021 11:59:09 - INFO - filelock -   Lock 139869512698000 acquired on /root/.cache/huggingface/transformers/e3709a60694f45adca209a405cc69ce2b5d47b1cae60696ed9a901426be8c43d.8b6dccc90d16201c6d7ab0f3c6cc38e74b5f2fe587f6efadc9fa71fc0a00c606.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=252.0, style=ProgressStyle(description_…

05/09/2021 11:59:10 - INFO - filelock -   Lock 139869512698000 released on /root/.cache/huggingface/transformers/e3709a60694f45adca209a405cc69ce2b5d47b1cae60696ed9a901426be8c43d.8b6dccc90d16201c6d7ab0f3c6cc38e74b5f2fe587f6efadc9fa71fc0a00c606.lock





## Loading data

In [None]:
data_loc = "https://raw.githubusercontent.com/santi-buch/NLP/main/train%20(1)%20(2).csv"
test = "https://raw.githubusercontent.com/santi-buch/NLP/main/test%20(1).csv"
df=pd.read_csv(data_loc)
testdf=pd.read_csv(test)
if run_debug: 
    df=df.head(2000)
    testdf=df.head(200)

def process_dataframe(df):
    # df ['label'] = df['compound'].apply(lambda x: "positive" if x>0 else "negative" if x<0 else "neutral" )
    df ['label'] = df['comp_score'].apply(lambda x: "positive" if x=='pos' else "negative" if x=="neg" else "neutral")

    def deEmojify(text):
        regrex_pattern = re.compile(pattern = "["
            u"\U0001F600-\U0001F64F"  # emoticons
            u"\U0001F300-\U0001F5FF"  # symbols & pictographs
            u"\U0001F680-\U0001F6FF"  # transport & map symbols
            u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                            "]+", flags = re.UNICODE)
        return regrex_pattern.sub(r'',text).lower()

    df['text'] = df['Title'].apply(deEmojify)
    df = df[['text', 'label']]
    return df
df = process_dataframe(df)
test = process_dataframe(testdf)
n_classes=df['label'].nunique()
print(df['label'].value_counts()*100/df.shape[0])
print(test['label'].value_counts()*100/test.shape[0])
df.head()

05/09/2021 11:59:11 - INFO - numexpr.utils -   NumExpr defaulting to 2 threads.


positive    53.995376
neutral     31.672588
negative    14.332036
Name: label, dtype: float64
positive    52.566894
neutral     33.560749
negative    13.872357
Name: label, dtype: float64


Unnamed: 0,text,label
0,$sndl to the moon,positive
1,money $gme,neutral
2,"$gtt - shady, solid business with a potential squeeze. take a look if interested!",positive
3,we will get back up soon hodl $gme,positive
4,i'm gonna be the devils advocate. convince me why $gme is still a thing.,negative


In [None]:
#sampling

train, valid = train_test_split(df, test_size=0.1, stratify=df['label'])
print(train.shape, valid.shape, test.shape)
print(train['label'].value_counts()*100/train.shape[0], valid['label'].value_counts()*100/valid.shape[0], test['label'].value_counts()*100/test.shape[0])
if use_sampling:
    print(">>Label ditribution before:", df['label'].value_counts()*100/df.shape[0])
    max_sample_val = max(train['label'].value_counts())
    for l in train['label'].unique():
        sampled = train[train['label']==l].sample(max_sample_val-train['label'].value_counts()[l], replace=True)
        train = pd.concat([train, sampled], axis=0)
    print(">>Label ditribution after:", train['label'].value_counts()*100/train.shape[0],train.shape)

(31919, 2) (3547, 2) (20667, 2)
positive    53.996053
neutral     31.670792
negative    14.333156
Name: label, dtype: float64 positive    53.989287
neutral     31.688751
negative    14.321962
Name: label, dtype: float64 positive    52.566894
neutral     33.560749
negative    13.872357
Name: label, dtype: float64
>>Label ditribution before: positive    53.995376
neutral     31.672588
negative    14.332036
Name: label, dtype: float64
>>Label ditribution after: neutral     33.333333
negative    33.333333
positive    33.333333
Name: label, dtype: float64 (51705, 2)


In [None]:
folder = str(cl_data_path)+'/'
print(">> ", folder)
Path(folder).mkdir(parents=True, exist_ok=True)
train.to_csv(f'{folder}train.csv', index=True, sep='\t')
valid.to_csv(f'{folder}validation.csv', index=True, sep='\t')
test.to_csv(f'{folder}test.csv', index=True, sep='\t')

>>  /content/data/sentiment_data/


## Fine-tune the model

In [None]:
# Get the training example
train_data = finbert.get_data('train')

In [None]:
model = finbert.create_the_model()

### Training

In [None]:
%%time
trained_model = finbert.train(train_examples = train_data, model = model)

05/09/2021 11:59:14 - INFO - finbert.utils -   *** Example ***
05/09/2021 11:59:14 - INFO - finbert.utils -   guid: train-1
05/09/2021 11:59:14 - INFO - finbert.utils -   tokens: [CLS] $ bb strong company with great potential [SEP]
05/09/2021 11:59:14 - INFO - finbert.utils -   input_ids: 101 1002 22861 2844 2194 2007 2307 4022 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 11:59:14 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 11:59:14 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 11:59:14 - INFO - finbert.utils -   label: positive (id = 0)
05/09/2021 11:59:22 - INFO - finbert.finbert -   ***** Loading data *****
05/09/2021 11:59:22 - INFO - finbert.finbert -     Num examples = 51705
05/09/2021 11:59:22 - INFO - finbert.finbert -     B

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=808.0, style=ProgressStyle(description_wi…

05/09/2021 12:03:47 - INFO - finbert.utils -   *** Example ***
05/09/2021 12:03:47 - INFO - finbert.utils -   guid: validation-1
05/09/2021 12:03:47 - INFO - finbert.utils -   tokens: [CLS] can someone explain why $ ad ##be had a monster run today ? i didn ’ t see any news about it [SEP]
05/09/2021 12:03:47 - INFO - finbert.utils -   input_ids: 101 2064 2619 4863 2339 1002 4748 4783 2018 1037 6071 2448 2651 1029 1045 2134 1521 1056 2156 2151 2739 2055 2009 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:03:47 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:03:47 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:03:47 - INFO - finbert.utils -   label: positive (id = 0)





05/09/2021 12:03:47 - INFO - finbert.finbert -   ***** Loading data *****
05/09/2021 12:03:47 - INFO - finbert.finbert -     Num examples = 3547
05/09/2021 12:03:47 - INFO - finbert.finbert -     Batch size = 64
05/09/2021 12:03:47 - INFO - finbert.finbert -     Num steps = 275


HBox(children=(FloatProgress(value=0.0, description='Validating', max=56.0, style=ProgressStyle(description_wi…


Validation losses: [0.5659088077289718]
No best model found


Epoch:  20%|██        | 1/5 [04:32<18:08, 272.22s/it]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=808.0, style=ProgressStyle(description_wi…

05/09/2021 12:08:19 - INFO - finbert.utils -   *** Example ***
05/09/2021 12:08:19 - INFO - finbert.utils -   guid: validation-1
05/09/2021 12:08:19 - INFO - finbert.utils -   tokens: [CLS] can someone explain why $ ad ##be had a monster run today ? i didn ’ t see any news about it [SEP]
05/09/2021 12:08:19 - INFO - finbert.utils -   input_ids: 101 2064 2619 4863 2339 1002 4748 4783 2018 1037 6071 2448 2651 1029 1045 2134 1521 1056 2156 2151 2739 2055 2009 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:08:19 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:08:19 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:08:19 - INFO - finbert.utils -   label: positive (id = 0)





05/09/2021 12:08:19 - INFO - finbert.finbert -   ***** Loading data *****
05/09/2021 12:08:19 - INFO - finbert.finbert -     Num examples = 3547
05/09/2021 12:08:19 - INFO - finbert.finbert -     Batch size = 64
05/09/2021 12:08:19 - INFO - finbert.finbert -     Num steps = 275


HBox(children=(FloatProgress(value=0.0, description='Validating', max=56.0, style=ProgressStyle(description_wi…


Validation losses: [0.5659088077289718, 0.44579126353242565]


Epoch:  40%|████      | 2/5 [09:04<13:36, 272.17s/it]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=808.0, style=ProgressStyle(description_wi…

05/09/2021 12:12:51 - INFO - finbert.utils -   *** Example ***
05/09/2021 12:12:51 - INFO - finbert.utils -   guid: validation-1
05/09/2021 12:12:51 - INFO - finbert.utils -   tokens: [CLS] can someone explain why $ ad ##be had a monster run today ? i didn ’ t see any news about it [SEP]
05/09/2021 12:12:51 - INFO - finbert.utils -   input_ids: 101 2064 2619 4863 2339 1002 4748 4783 2018 1037 6071 2448 2651 1029 1045 2134 1521 1056 2156 2151 2739 2055 2009 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:12:51 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:12:51 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:12:51 - INFO - finbert.utils -   label: positive (id = 0)





05/09/2021 12:12:51 - INFO - finbert.finbert -   ***** Loading data *****
05/09/2021 12:12:51 - INFO - finbert.finbert -     Num examples = 3547
05/09/2021 12:12:51 - INFO - finbert.finbert -     Batch size = 64
05/09/2021 12:12:51 - INFO - finbert.finbert -     Num steps = 275


HBox(children=(FloatProgress(value=0.0, description='Validating', max=56.0, style=ProgressStyle(description_wi…


Validation losses: [0.5659088077289718, 0.44579126353242565, 0.44325055793992113]


Epoch:  60%|██████    | 3/5 [13:36<09:04, 272.20s/it]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=808.0, style=ProgressStyle(description_wi…

05/09/2021 12:17:23 - INFO - finbert.utils -   *** Example ***
05/09/2021 12:17:23 - INFO - finbert.utils -   guid: validation-1
05/09/2021 12:17:23 - INFO - finbert.utils -   tokens: [CLS] can someone explain why $ ad ##be had a monster run today ? i didn ’ t see any news about it [SEP]
05/09/2021 12:17:23 - INFO - finbert.utils -   input_ids: 101 2064 2619 4863 2339 1002 4748 4783 2018 1037 6071 2448 2651 1029 1045 2134 1521 1056 2156 2151 2739 2055 2009 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:17:23 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:17:23 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:17:23 - INFO - finbert.utils -   label: positive (id = 0)





05/09/2021 12:17:23 - INFO - finbert.finbert -   ***** Loading data *****
05/09/2021 12:17:23 - INFO - finbert.finbert -     Num examples = 3547
05/09/2021 12:17:23 - INFO - finbert.finbert -     Batch size = 64
05/09/2021 12:17:23 - INFO - finbert.finbert -     Num steps = 275


HBox(children=(FloatProgress(value=0.0, description='Validating', max=56.0, style=ProgressStyle(description_wi…

Epoch:  80%|████████  | 4/5 [18:07<04:31, 271.83s/it]


Validation losses: [0.5659088077289718, 0.44579126353242565, 0.44325055793992113, 0.46883722354790996]


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=808.0, style=ProgressStyle(description_wi…

05/09/2021 12:21:54 - INFO - finbert.utils -   *** Example ***
05/09/2021 12:21:54 - INFO - finbert.utils -   guid: validation-1
05/09/2021 12:21:54 - INFO - finbert.utils -   tokens: [CLS] can someone explain why $ ad ##be had a monster run today ? i didn ’ t see any news about it [SEP]
05/09/2021 12:21:54 - INFO - finbert.utils -   input_ids: 101 2064 2619 4863 2339 1002 4748 4783 2018 1037 6071 2448 2651 1029 1045 2134 1521 1056 2156 2151 2739 2055 2009 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:21:54 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:21:54 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:21:54 - INFO - finbert.utils -   label: positive (id = 0)





05/09/2021 12:21:54 - INFO - finbert.finbert -   ***** Loading data *****
05/09/2021 12:21:54 - INFO - finbert.finbert -     Num examples = 3547
05/09/2021 12:21:54 - INFO - finbert.finbert -     Batch size = 64
05/09/2021 12:21:54 - INFO - finbert.finbert -     Num steps = 275


HBox(children=(FloatProgress(value=0.0, description='Validating', max=56.0, style=ProgressStyle(description_wi…

Epoch: 100%|██████████| 5/5 [22:38<00:00, 271.71s/it]


Validation losses: [0.5659088077289718, 0.44579126353242565, 0.44325055793992113, 0.46883722354790996, 0.47820951816226753]





CPU times: user 14min 16s, sys: 8min 26s, total: 22min 43s
Wall time: 22min 47s


## Test the model

## Test the model

`bert.evaluate` outputs the DataFrame, where true labels and logit values for each example is given

In [None]:
%%time
test_data = finbert.get_data('test')

CPU times: user 228 ms, sys: 4.48 ms, total: 232 ms
Wall time: 232 ms


In [None]:
%%time
results = finbert.evaluate(examples=test_data, model=trained_model)

05/09/2021 12:22:02 - INFO - finbert.utils -   *** Example ***
05/09/2021 12:22:02 - INFO - finbert.utils -   guid: test-1
05/09/2021 12:22:02 - INFO - finbert.utils -   tokens: [CLS] energy oil & amp ; ex ##p was the sector to made today on paper lo ##l . . . $ cp ##g $ cv ##e $ wc ##p $ v ##lo . . . live long and pro ##sper my fellow red ##dit ##ors ! [SEP]
05/09/2021 12:22:02 - INFO - finbert.utils -   input_ids: 101 2943 3514 1004 23713 1025 4654 2361 2001 1996 4753 2000 2081 2651 2006 3259 8840 2140 1012 1012 1012 1002 18133 2290 1002 26226 2063 1002 15868 2361 1002 1058 4135 1012 1012 1012 2444 2146 1998 4013 17668 2026 3507 2417 23194 5668 999 102
05/09/2021 12:22:02 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
05/09/2021 12:22:02 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
05/09/2021 12:22:02

HBox(children=(FloatProgress(value=0.0, description='Testing', max=323.0, style=ProgressStyle(description_widt…


CPU times: user 24.9 s, sys: 12.6 s, total: 37.5 s
Wall time: 37.4 s


### Prepare the classification report

In [None]:
def report(df, cols=['label','prediction','logits']):
    print('Validation loss:{0:.2f}'.format(metrics['best_validation_loss']))
    cs = CrossEntropyLoss(weight=finbert.class_weights)
    loss = cs(torch.tensor(list(df[cols[2]])),torch.tensor(list(df[cols[0]])))
    print("Loss:{0:.2f}".format(loss))
    print("Accuracy:{0:.2f}".format((df[cols[0]] == df[cols[1]]).sum() / df.shape[0]) )
    print("\nClassification Report:")
    print(classification_report(df[cols[0]], df[cols[1]]))

In [None]:
results['prediction'] = results.predictions.apply(lambda x: np.argmax(x,axis=0))

In [None]:
#report(results,cols=['labels','prediction','predictions'])

In [None]:
%%time
def predict_from_model(text, model, batch_size=5):
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    if type(text) is str:
        text=[text]
    sentences = text
    print(f">> Using batch size {batch_size}")
    label_list = ['positive', 'negative', 'neutral']
    label_dict = {0: 'positive', 1: 'negative', 2: 'neutral'}
    result = pd.DataFrame(columns=['sentence', 'logit', 'prediction', 'sentiment_score'])
    for batch in tqdm(list(chunks(sentences, batch_size))):
        examples = [InputExample(str(i), sentence) for i, sentence in enumerate(batch)]

        features = convert_examples_to_features(examples, label_list, 64, tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
        
        with torch.no_grad():
            logits = model(all_input_ids.to('cuda'), all_attention_mask.to('cuda'), all_token_type_ids.to('cuda'))[0]
            # logging.info(logits)
            logits = softmax(np.array(logits.cpu()))
            sentiment_score = pd.Series(logits[:, 0] - logits[:, 1])
            predictions = np.squeeze(np.argmax(logits, axis=1))

            batch_result = {'sentence': batch,
                            'logit': list(logits),
                            'prediction': predictions,
                            'sentiment_score': sentiment_score}

            batch_result = pd.DataFrame(batch_result)
            result = pd.concat([result, batch_result], ignore_index=True)

    result['prediction'] = result.prediction.apply(lambda x: label_dict[x])

    return result

def get_predictions(input_file="", model_dir="", output_file="output.csv", text_column_name="Title", batch_size=32, debug=False):
    def deEmojify(text):
        if type(text) is list:
            return [deEmojify(t) for t in text]
        regrex_pattern = re.compile(pattern = "["
            u"\U0001F600-\U0001F64F"  # emoticons
            u"\U0001F300-\U0001F5FF"  # symbols & pictographs
            u"\U0001F680-\U0001F6FF"  # transport & map symbols
            u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                            "]+", flags = re.UNICODE)
        return regrex_pattern.sub(r'',text).lower()

    def run_model(text, model, batch_size=32):
        result = predict_from_model(deEmojify(text),model, batch_size=batch_size)
        return result

    model = AutoModelForSequenceClassification.from_pretrained(model_dir, cache_dir=None, num_labels=3)
    model.to('cuda')
    run_model = partial(run_model, model=model, batch_size=batch_size)

    tqdm.pandas(desc=">> Get predictions")
    df = pd.read_csv(input_file)
    if debug:
        df = df.head(100)
    res=run_model(deEmojify(df[text_column_name].tolist()))
    df = pd.concat([df, res], axis=1)
    df.to_csv(output_file)

CPU times: user 2 µs, sys: 2 µs, total: 4 µs
Wall time: 7.39 µs


In [None]:
%%time
torch.cuda.empty_cache()
get_predictions(
    input_file="https://raw.githubusercontent.com/santi-buch/NLP/main/test%20(1).csv", 
    model_dir='/content/models/classifier_model_1/finbert-sentiment', 
    output_file="output.csv", 
    text_column_name="Title",
    batch_size=256
    )

  from pandas import Panel
05/09/2021 12:26:12 - INFO - filelock -   Lock 139869074511888 acquired on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…

05/09/2021 12:26:12 - INFO - filelock -   Lock 139869074511888 released on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e.lock





05/09/2021 12:26:12 - INFO - filelock -   Lock 139869512358672 acquired on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…

05/09/2021 12:26:13 - INFO - filelock -   Lock 139869512358672 released on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock





05/09/2021 12:26:13 - INFO - filelock -   Lock 139869074552976 acquired on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…

05/09/2021 12:26:14 - INFO - filelock -   Lock 139869074552976 released on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock





05/09/2021 12:26:15 - INFO - filelock -   Lock 139869074510480 acquired on /root/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…

05/09/2021 12:26:15 - INFO - filelock -   Lock 139869074510480 released on /root/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79.lock



>> Using batch size 256


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

05/09/2021 12:26:15 - INFO - finbert.utils -   *** Example ***
05/09/2021 12:26:15 - INFO - finbert.utils -   guid: 0
05/09/2021 12:26:15 - INFO - finbert.utils -   tokens: [CLS] energy oil & amp ; ex ##p was the sector to be in today ! used to work all year and couldn ' t make what i made today on paper lo ##l . . . $ cp ##g $ cv ##e $ wc ##p $ v ##lo . . . live long and pro ##sper my fellow red ##dit ##ors ! [SEP]
05/09/2021 12:26:15 - INFO - finbert.utils -   input_ids: 101 2943 3514 1004 23713 1025 4654 2361 2001 1996 4753 2000 2022 1999 2651 999 2109 2000 2147 2035 2095 1998 2481 1005 1056 2191 2054 1045 2081 2651 2006 3259 8840 2140 1012 1012 1012 1002 18133 2290 1002 26226 2063 1002 15868 2361 1002 1058 4135 1012 1012 1012 2444 2146 1998 4013 17668 2026 3507 2417 23194 5668 999 102
05/09/2021 12:26:15 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
05/09/202


CPU times: user 33.7 s, sys: 17.9 s, total: 51.6 s
Wall time: 55.7 s


In [None]:
%%time
import shutil
shutil.make_archive('fin-bert-models', 'zip', 'models')

CPU times: user 20.5 s, sys: 483 ms, total: 21 s
Wall time: 20.9 s
