To run this notebook SentenceTransformer needed to be installed

In [1]:
#!pip install SentenceTransformer

## Import Libraries

In [2]:
'''basics'''
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join('../..', 'src')))
sys.setrecursionlimit(20500)
import vectorize_embed as em
import pandas as pd
#import pickle5 as pickle
import pickle
import numpy as np

'''Plotting'''
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')

'''features'''
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import label_binarize

'''Classifiers'''
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn import svm
from sklearn.naive_bayes import MultinomialNB


'''Metrics/Evaluation'''
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix
from scipy import interp
from itertools import cycle
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report
from sklearn.pipeline import make_pipeline
from sklearn_hierarchical_classification.classifier import HierarchicalClassifier
from sklearn_hierarchical_classification.constants import ROOT
from sklearn_hierarchical_classification.metrics import h_fbeta_score, multi_labeled
from sklearn.pipeline import Pipeline


import warnings
warnings.filterwarnings('ignore')

import joblib



## Import data

In [3]:
df = pd.read_csv('../../data/processed/encoded_labels/_types_of_private_sector.csv')
df.columns

Index(['PIMS_ID', 'all_text_clean', 'all_text_clean_spacy',
       '_types_of_private_sector', 'capital_providers',
       'entrepreneurs_manufacturing_firms_investors',
       'financial_intermediaries_and_market_facilitators',
       'individuals_entrepeneurs', 'individuals_entrepreneurs',
       'iron_and_steel', 'large_corporations',
       'manufacturers_building_professionals', 'no tag', 'non_grant_pilot',
       'poultry_fisheries_dairy_horticulture_khadi_homespun_cloth_and_silk_weaving_bamboo_and_commercial_cooking',
       'retailers_manufacturers', 'small_and_medium_sized_enterprises',
       'sugarcane_mills'],
      dtype='object')

# Compare different embeddings performances

In [4]:
categories = ['_types_of_private_sector', 'capital_providers',
       'entrepreneurs_manufacturing_firms_investors',
       'financial_intermediaries_and_market_facilitators',
       'individuals_entrepeneurs', 'individuals_entrepreneurs',
       'iron_and_steel', 'large_corporations',
       'manufacturers_building_professionals', 'no tag', 'non_grant_pilot',
       'poultry_fisheries_dairy_horticulture_khadi_homespun_cloth_and_silk_weaving_bamboo_and_commercial_cooking',
       'retailers_manufacturers', 'small_and_medium_sized_enterprises',
       'sugarcane_mills']

#Turning the labels into numbers
y = pd.DataFrame(df, columns = categories)

X = df['all_text_clean'].astype('str').tolist()


## Train multiple Embeddings with SGD / OneVsRest multi-label strategy

In [5]:
#Creating a dict of the embeddings
embedding_dict = {'Glove' : 'average_word_embeddings_glove.6B.300d', 
                  'Distilbert':'distilbert-base-nli-mean-tokens', 
                  'Roberta' : 'roberta-base-nli-stsb-mean-tokens', 
                  'Bert' : 'bert-base-nli-stsb-mean-tokens'}
              
sgd_classifier = SGDClassifier(alpha=1e-06,
                               loss='log',
                               max_iter=1000,
                               penalty='l1',
                               random_state = 3,
                               tol=0.001)
model = OneVsRestClassifier(sgd_classifier)

#Train test split with stratified sampling for evaluation
X_train, X_test, y_train, y_test = train_test_split(X, 
                                                    y, 
                                                    test_size = .3, 
                                                    shuffle = True,  
                                                    random_state = 3)
y_train = y_train.dropna(axis=1)

#Function to get the scores for each model in a df
def model_score_df(embedding_dict, X_train, X_test, y_train, y_test, category):   
    embedding_name, ac_score_list, p_score_list, r_score_list, f1_score_list = [], [], [], [], []
    for k,v in embedding_dict.items():   
        embedding_name.append(k)   
        model.fit(em.get_embeddings(v, X_train), y_train)
        
        # save the model to disk
        filename = '../'+category+'_'+k+'model.sav'
        joblib.dump(model, filename)
        
        y_pred = model.predict(em.get_embeddings(v, X_test))
        ac_score_list.append(accuracy_score(y_test, y_pred))
        p_score_list.append(precision_score(y_test, y_pred, average='macro'))
        r_score_list.append(recall_score(y_test, y_pred, average='macro'))
        f1_score_list.append(f1_score(y_test, y_pred, average='macro'))
        model_comparison_df = pd.DataFrame([embedding_name, ac_score_list, p_score_list, r_score_list, f1_score_list]).T
        model_comparison_df.columns = ['embedding_name', 'accuracy_score', 'precision_score', 'recall_score', 'f1_score']
        model_comparison_df = model_comparison_df.sort_values(by='f1_score', ascending=False)
    return model_comparison_df

      
lis = []
for category in categories:
    dic = {}
    dff = model_score_df(embedding_dict, X_train, X_test, y_train[category], y_test[category], category)
    # Using DataFrame.insert() to add a column
    dic['Category'] = category
    dic['Classifiers'] = '    \n '.join(dff.embedding_name.apply(str).tolist())
    dic['accuracy_score'] = '    \n '.join(dff.accuracy_score.apply(str).tolist()) 
    dic['precision_score'] = '   \n '.join(dff.precision_score.apply(str).tolist())
    dic['recall_score'] = '    \n '.join(dff.recall_score.apply(str).tolist())
    dic['f1_score'] = '    \n '.join(dff.f1_score.apply(str).tolist())
    lis.append(dic)
    


In [7]:
from tabulate import tabulate  
header = lis[0].keys()
rows =  [x.values() for x in lis]
print(tabulate(rows, header, tablefmt='html'))

<table>
<thead>
<tr><th>Category                                                                                                </th><th>Classifiers  </th><th>accuracy_score  </th><th>precision_score  </th><th>recall_score  </th><th>f1_score  </th></tr>
</thead>
<tbody>
<tr><td>_types_of_private_sector                                                                                </td><td>Bert    
 Glove    
 Roberta    
 Distilbert              </td><td>0.56353591160221    
 0.5469613259668509    
 0.6022099447513812    
 0.6243093922651933                 </td><td>0.07673071861016056   
 0.05947347611659734   
 0.06498550724637682   
 0.08240119313944817                  </td><td>0.09089032164589975    
 0.09874369747899159    
 0.06461344537815125    
 0.05166900093370681               </td><td>0.07863837312113174    
 0.06475394875394874    
 0.06388500290972991    
 0.05220764862555907           </td></tr>
<tr><td>capital_providers                                                  

### copy-past the output of the last cell here


In [None]:
<table>
<thead>
<tr><th>Category                                                                                                </th><th>Classifiers  </th><th>accuracy_score  </th><th>precision_score  </th><th>recall_score  </th><th>f1_score  </th></tr>
</thead>
<tbody>
<tr><td>_types_of_private_sector                                                                                </td><td>Bert    
 Glove    
 Roberta    
 Distilbert              </td><td>0.56353591160221    
 0.5469613259668509    
 0.6022099447513812    
 0.6243093922651933                 </td><td>0.07673071861016056   
 0.05947347611659734   
 0.06498550724637682   
 0.08240119313944817                  </td><td>0.09089032164589975    
 0.09874369747899159    
 0.06461344537815125    
 0.05166900093370681               </td><td>0.07863837312113174    
 0.06475394875394874    
 0.06388500290972991    
 0.05220764862555907           </td></tr>
<tr><td>capital_providers                                                                                       </td><td>Glove    
 Roberta    
 Bert    
 Distilbert              </td><td>0.8839779005524862    
 0.8729281767955801    
 0.8729281767955801    
 0.7845303867403315                 </td><td>0.7381360777587193   
 0.7122641509433962   
 0.7070552147239264   
 0.6109929078014185                  </td><td>0.7292239955971381    
 0.7043203082003302    
 0.6671711612548157    
 0.6722619702806825               </td><td>0.7335436382754994    
 0.708166841920785    
 0.6836866499506117    
 0.6252587991718426           </td></tr>
<tr><td>entrepreneurs_manufacturing_firms_investors                                                             </td><td>Glove    
 Distilbert    
 Bert    
 Roberta              </td><td>1.0    
 1.0    
 1.0    
 0.988950276243094                 </td><td>1.0   
 1.0   
 1.0   
 0.5                  </td><td>1.0    
 1.0    
 1.0    
 0.494475138121547               </td><td>1.0    
 1.0    
 1.0    
 0.49722222222222223           </td></tr>
<tr><td>financial_intermediaries_and_market_facilitators                                                        </td><td>Distilbert    
 Glove    
 Roberta    
 Bert              </td><td>0.9337016574585635    
 0.9226519337016574    
 0.9502762430939227    
 0.9392265193370166                 </td><td>0.5633333333333334   
 0.5422687861271676   
 0.4777777777777778   
 0.47752808988764045                  </td><td>0.5480491329479769    
 0.5422687861271676    
 0.49710982658959535    
 0.4913294797687861               </td><td>0.5541871921182266    
 0.5422687861271676    
 0.48725212464589235    
 0.4843304843304843           </td></tr>
<tr><td>individuals_entrepeneurs                                                                                </td><td>Glove    
 Distilbert    
 Roberta    
 Bert              </td><td>0.994475138121547    
 0.994475138121547    
 0.994475138121547    
 0.994475138121547                 </td><td>0.4972375690607735   
 0.4972375690607735   
 0.4972375690607735   
 0.4972375690607735                  </td><td>0.5    
 0.5    
 0.5    
 0.5               </td><td>0.4986149584487534    
 0.4986149584487534    
 0.4986149584487534    
 0.4986149584487534           </td></tr>
<tr><td>individuals_entrepreneurs                                                                               </td><td>Roberta    
 Distilbert    
 Bert    
 Glove              </td><td>0.8011049723756906    
 0.8011049723756906    
 0.8121546961325967    
 0.8397790055248618                 </td><td>0.6003311258278146   
 0.5905695611577965   
 0.5301136363636363   
 0.5053366174055829                  </td><td>0.6206210191082803    
 0.6029723991507431    
 0.5210987261146497    
 0.5017250530785563               </td><td>0.6082251082251082    
 0.5957816377171217    
 0.5222049689440993    
 0.4884514179904492           </td></tr>
<tr><td>iron_and_steel                                                                                          </td><td>Glove    
 Distilbert    
 Roberta    
 Bert              </td><td>0.994475138121547    
 0.994475138121547    
 0.994475138121547    
 0.994475138121547                 </td><td>0.4972375690607735   
 0.4972375690607735   
 0.4972375690607735   
 0.4972375690607735                  </td><td>0.5    
 0.5    
 0.5    
 0.5               </td><td>0.4986149584487534    
 0.4986149584487534    
 0.4986149584487534    
 0.4986149584487534           </td></tr>
<tr><td>large_corporations                                                                                      </td><td>Glove    
 Roberta    
 Distilbert    
 Bert              </td><td>0.8784530386740331    
 0.9171270718232044    
 0.8287292817679558    
 0.9281767955801105                 </td><td>0.5960374243258117   
 0.539378612716763   
 0.4953556263269639   
 0.4745762711864407                  </td><td>0.725452196382429    
 0.5352067183462532    
 0.4886950904392765    
 0.4883720930232558               </td><td>0.6229166666666666    
 0.5370843989769821    
 0.4831905682969513    
 0.48137535816618904           </td></tr>
<tr><td>manufacturers_building_professionals                                                                    </td><td>Glove    
 Distilbert    
 Roberta    
 Bert              </td><td>1.0    
 1.0    
 1.0    
 1.0                 </td><td>1.0   
 1.0   
 1.0   
 1.0                  </td><td>1.0    
 1.0    
 1.0    
 1.0               </td><td>1.0    
 1.0    
 1.0    
 1.0           </td></tr>
<tr><td>no tag                                                                                                  </td><td>Glove    
 Bert    
 Roberta    
 Distilbert              </td><td>0.7071823204419889    
 0.6629834254143646    
 0.6574585635359116    
 0.6740331491712708                 </td><td>0.6927777777777777   
 0.6230897471519867   
 0.6122047244094488   
 0.6254355400696865                  </td><td>0.7116427216047709    
 0.6200867443751694    
 0.6042965573326105    
 0.5975874220656004               </td><td>0.6936807484752691    
 0.6213778677000102    
 0.6067423605270537    
 0.5996926191100949           </td></tr>
<tr><td>non_grant_pilot                                                                                         </td><td>Glove    
 Distilbert    
 Roberta    
 Bert              </td><td>0.988950276243094    
 0.988950276243094    
 0.988950276243094    
 0.988950276243094                 </td><td>0.494475138121547   
 0.494475138121547   
 0.494475138121547   
 0.494475138121547                  </td><td>0.5    
 0.5    
 0.5    
 0.5               </td><td>0.49722222222222223    
 0.49722222222222223    
 0.49722222222222223    
 0.49722222222222223           </td></tr>
<tr><td>poultry_fisheries_dairy_horticulture_khadi_homespun_cloth_and_silk_weaving_bamboo_and_commercial_cooking</td><td>Glove    
 Distilbert    
 Roberta    
 Bert              </td><td>1.0    
 1.0    
 1.0    
 1.0                 </td><td>1.0   
 1.0   
 1.0   
 1.0                  </td><td>1.0    
 1.0    
 1.0    
 1.0               </td><td>1.0    
 1.0    
 1.0    
 1.0           </td></tr>
<tr><td>retailers_manufacturers                                                                                 </td><td>Glove    
 Roberta    
 Bert    
 Distilbert              </td><td>1.0    
 1.0    
 1.0    
 0.994475138121547                 </td><td>1.0   
 1.0   
 1.0   
 0.5                  </td><td>1.0    
 1.0    
 1.0    
 0.4972375690607735               </td><td>1.0    
 1.0    
 1.0    
 0.4986149584487534           </td></tr>
<tr><td>small_and_medium_sized_enterprises                                                                      </td><td>Bert    
 Glove    
 Distilbert    
 Roberta              </td><td>0.7458563535911602    
 0.8342541436464088    
 0.7624309392265194    
 0.7900552486187845                 </td><td>0.5769955464200068   
 0.6015779092702169   
 0.5376588021778584   
 0.5316770186335403                  </td><td>0.6115384615384616    
 0.5511166253101737    
 0.5411910669975186    
 0.5253101736972705               </td><td>0.5835334133653461    
 0.5589668615984404    
 0.5390583358010068    
 0.5268299394606495           </td></tr>
<tr><td>sugarcane_mills                                                                                         </td><td>Glove    
 Distilbert    
 Roberta    
 Bert              </td><td>1.0    
 1.0    
 1.0    
 0.988950276243094                 </td><td>1.0   
 1.0   
 1.0   
 0.5                  </td><td>1.0    
 1.0    
 1.0    
 0.494475138121547               </td><td>1.0    
 1.0    
 1.0    
 0.49722222222222223           </td></tr>
</tbody>
</table>