In [1]:
import numpy as np 
import pandas as pd 
import sys 
from tqdm import tqdm 
import os 
from random import sample 
from annoy import AnnoyIndex 
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay 
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score 
from sklearn import metrics 

In [2]:
train_file = './Data/train.csv'
train_enriched_file = './Data/train_enriched.csv'
test_file = './Data/test.csv'
test_enriched = './Data/test_enriched.csv'
label_file = './Data/sample_submission.csv'
embeddings_dim = 384

In [3]:
df = pd.read_csv(train_enriched_file, encoding='utf-8')
df_test = pd.read_csv(test_file,encoding='utf-8')
df_test_enr = pd.read_csv(test_enriched,encoding='utf-8')
df_true = pd.read_csv('./Data/submission.csv', encoding='utf-8')

In [4]:
from sentence_transformers import SentenceTransformer 
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def encode(text):
    emb = model.encode(text, convert_to_tensor=True).tolist()
    return emb

In [6]:
def label_tweet(testdf,targetdf):
          
    tweets = testdf['emb'].to_list()
    text = testdf['text'].to_list()
    test_ids = testdf['id'].to_list()
    df_pred = pd.DataFrame()
          
    for tweet,tx,t_id in zip(tweets,text,test_ids):
              
        df_temp = pd.DataFrame()
        search_res = annoy_index.get_nns_by_vector(tweet,n=1,search_k=-1,include_distances=False)
        d = {t_id: [t_id, tx, search_res[0]]}
        df_temp = pd.DataFrame.from_dict(d, orient = 'index', columns=['test_id','text', 'target_id'])
        df_pred = pd.concat([df_pred,df_temp])
          
          
    df_pred = df_pred.merge(targetdf[['id','target']], right_on='id', left_on='target_id', how='left')
                                  
    #mergedf = targetdf[['id','target']]
          
    return df_pred

In [7]:
df['tx_key'] = df['text'] + ' ' + df['keyword'].apply(lambda x: x if x != 'no_keyword' else '')
df.head(2)

Unnamed: 0,id,keyword,text,target,word_count,unique_words_count,Tweet_len,special_chars_count,hash_count,@_count,URL_count,sentiment,subjectivity,dis%,text_clean,keyword_clean,newtext,tx_key
0,1,no_keyword,Our Deeds are the Reason of this #earthquake M...,1,13,13,69,1,1,0,0,0.0,0.0,no_keyword,deed reason #earthquak may allah forgiv,nokeyword,deed reason #earthquak may allah forgiv,Our Deeds are the Reason of this #earthquake M...
1,4,no_keyword,Forest fire near La Ronge Sask. Canada,1,7,7,38,1,0,0,0,0.1,0.4,no_keyword,forest fire near la rong sask. canada,nokeyword,forest fire near la rong sask. canada,Forest fire near La Ronge Sask. Canada


In [None]:
df['emb'] = df['text'].apply(lambda t: encode(t))

In [None]:
#df.to_csv('./Data/train_emb.csv', encoding='utf-8')

In [None]:
df_test_enr['tx_key'] = df_test_enr['text'] + ' ' + df_test_enr['keyword'].apply(lambda x: x if x != 'no_keyword' else '')
df_test_enr.head(2)

In [None]:
df_test_enr['emb'] = df_test_enr['text'].apply(lambda t: encode(t))

## build annoy

In [None]:
embeddings = df['emb'].to_list()
tweets_ids = df['id'].to_list()

In [None]:
annoy_index = AnnoyIndex(embeddings_dim, 'angular')

In [None]:
for tweet_id, embedding in zip(tweets_ids,embeddings):
    if len(embedding) != embeddings_dim:
        print('wrong dim lenght')
        continue 
    annoy_index.add_item(tweet_id,embedding)

In [None]:
annoy_index.build(100) #,n_jobs=-1)
annoy_index.save('./Data/annoy_index.ann')

## load built annoy

In [None]:
# annoy_index = AnnoyIndex(embeddings_dim, 'angular')
# annoy_index.load('./Data/annoy_index.ann')

## search for similarities

In [None]:
df_pred = label_tweet(df_test_enr,df)

In [None]:
df_pred.head()

## <font color = 'dark green'> metrics 

In [None]:
y_pred = df_pred['target']
y = df_true['target']

In [None]:
accuracy = accuracy_score(y,y_pred)
f1 = f1_score(y,y_pred,average='macro')
precision = precision_score(y,y_pred,average='macro')
recall = recall_score(y,y_pred,average='macro')

print(f'Accuracy  {accuracy}')
print(f'F1        {f1}')
print(f'Precision {precision}')
print(f'Recall    {recall}')

In [None]:
print(metrics.classification_report(y, y_pred))

In [None]:
cm = confusion_matrix(y, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()

In [None]:
cm = confusion_matrix(y, y_pred, normalize='all')
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()