In [1]:
import os
import prettytable
import numpy as np
import pandas as pd
from tqdm.notebook  import tqdm
from functions      import readSet, dirs
from gensim.models  import TfidfModel
from gensim.corpora import Dictionary
from prettytable    import PrettyTable
from wordEmbedders  import Word2Vec, WESCScore, AverageClassifier
#from gensim.models import Word2Vec
tqdm.pandas()

In [2]:
datasets = dirs('./data')
#datasets = ['AirlineTweets']
#ignoreCache = True
ignoreCache = False
datasetSpecificClusters = True
embedder = Word2Vec
classifier = AverageClassifier
positiveWords = readSet('./wordlists/positiveWords.txt')
negativeWords = readSet('./wordlists/negativeWords.txt')

In [3]:
out = []
for dataset in tqdm(datasets, desc="Datasets"):
    dataFile   = f'./data/{dataset}/Data-Cleaned.csv'
    outputFile = f'./data/{dataset}/{embedder.name}-Prediction.csv'
    modelFile  = f'./models/{dataset}/{embedder.name}.model'
    tfidfFile  = f'./models/{dataset}/TF-IDF.model'
    dictFile   = f'./models/{dataset}/Dictionary.model'
    tfidf = TfidfModel.load(tfidfFile)
    dct   = Dictionary.load(dictFile)

    if not os.path.exists(dataFile):
        raise ValueError(f'Dataset {dataset} has not been cleaned')
    if not os.path.exists(modelFile):
        raise ValueError(f'Dataset {dataset} has no {embedder.name} trained')

    if datasetSpecificClusters:
        positiveWords = readSet(f'./data/{dataset}/positiveWords.txt')
        negativeWords = readSet(f'./data/{dataset}/negativeWords.txt')
    
    if os.path.exists(outputFile) and not ignoreCache:
        print(f'{dataset}: using cached data')
        result = WESCScore.load(outputFile)
    else:
        print(f'{dataset}: predicting')
        df     = pd.read_csv(dataFile)
        model  = embedder.load(modelFile)
        #model  = Word2Vec.load(modelFile).wv
        clas   = classifier(model, positiveWords, negativeWords, tfidf, dct)
        result = clas.predict(df)
        result.save(outputFile)

    out.append((dataset, result))

Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

AirlineTweets2: using cached data
IMDB: using cached data
Sentiment140: using cached data


In [4]:
print("Baseline dataset evaluation")
table = PrettyTable(['Dataset', 'Balanced Accuracy', 'Confusion Matrix'])
for dataset, result in out:
    #acc = (result.truePos + result.trueNeg) / len(result.data)
    table.add_row([dataset, result.balancedAccuracy, result.confusionMatrix])
table.hrules = prettytable.ALL
print(table)

Baseline dataset evaluation
+----------------+--------------------+------------------+
|    Dataset     | Balanced Accuracy  | Confusion Matrix |
+----------------+--------------------+------------------+
| AirlineTweets2 | 0.6258187116462933 |   1165 | 3819    |
|                |                    |   -----+-----    |
|                |                    |    350 | 3563    |
+----------------+--------------------+------------------+
|      IMDB      |      0.62762       |   7206 | 825     |
|                |                    |  ------+------   |
|                |                    |  17794 | 24175   |
+----------------+--------------------+------------------+
|  Sentiment140  |    0.597770625     | 677373 | 520940  |
|                |                    | -------+-------  |
|                |                    | 122627 | 279060  |
+----------------+--------------------+------------------+
