In [30]:
import os
import pandas as pd
import json
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from dynasent_models import DynaSentModel

In [5]:
# `dynasent_model0` should be downloaded from the above Google Drive link and 
# placed in the `models` directory. `dynasent_model1` works the same way.
model = DynaSentModel(os.path.join('models', 'dynasent_model0.bin'))


In [20]:
# evaluate zeroshot examples

with open('zero-shot-examples.json', 'r') as file:
    data = json.load(file)
    
df_zeroshot = pd.DataFrame(data)

In [22]:
predictions = model.predict(df_zeroshot['sentence'])
f1 = f1_score(predictions, df_zeroshot['label'], average='macro')

In [31]:
report = classification_report(df_zeroshot['label'], predictions) # seems to be confusing neutral and negative
print(report)

              precision    recall  f1-score   support

    negative       0.77      1.00      0.87        37
     neutral       1.00      0.67      0.80        36
    positive       0.96      1.00      0.98        27

    accuracy                           0.88       100
   macro avg       0.91      0.89      0.88       100
weighted avg       0.91      0.88      0.88       100



In [32]:
# evaluate fewshot examples

with open('few-shot-examples.json', 'r') as file:
    data = json.load(file)
    
df_fewshot = pd.DataFrame(data)


In [33]:
predictions_fewshot = model.predict(df_fewshot['sentence'])
f1 = f1_score(predictions_fewshot, df_fewshot['label'], average='macro')

In [34]:
report = classification_report(df_fewshot['label'], predictions_fewshot) # seems to be under-predicting neutral
print(report)

              precision    recall  f1-score   support

    negative       0.85      0.95      0.90        37
     neutral       0.91      0.65      0.75        31
    positive       0.84      0.97      0.90        32

    accuracy                           0.86       100
   macro avg       0.87      0.85      0.85       100
weighted avg       0.87      0.86      0.85       100

