In [11]:
import torch
import torch.nn as nn

class SentimentClassifier(nn.Module):
    def __init__(self, input_dim=384, hidden_dim1=256, hidden_dim2=128, hidden_dim3=64,  output_dim=1):
        super(SentimentClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.BatchNorm1d(hidden_dim1),
            nn.ReLU(),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.BatchNorm1d(hidden_dim2),
            nn.ReLU(),
            nn.Linear(hidden_dim2, hidden_dim3),
            nn.BatchNorm1d(hidden_dim3),
            nn.ReLU(),
            nn.Linear(hidden_dim3, output_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

### Load Model Weights

In [12]:
state_dict = torch.load("sentiment_classifier.bin")

In [13]:
import pandas as pd 
test_df = pd.read_feather("test_df.bin")

In [15]:
model = SentimentClassifier()
model.load_state_dict(state_dict)

<All keys matched successfully>

## Evaluation on test data

In [18]:
from torchmetrics import Accuracy
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, roc_auc_score

model.eval()
accuracy_metric_test = Accuracy(task='binary')

y_true_test = []
y_pred_test = []
y_pred_prob_test = []

with torch.no_grad():
    for _, test_row in tqdm(test_df.iterrows(), desc='Evaluating on Test Data', total=len(test_df)):
        embeddings = torch.Tensor(np.array(test_row.embeddings)).unsqueeze(0)
        label = torch.Tensor([test_row.label]).reshape(-1, 1)
        prediction = model(embeddings).reshape(-1)
        
        y_true_test.append(test_row.label)
        y_pred_test.append(int(prediction >= 0.5))
        y_pred_prob_test.append(prediction.item())
        
        accuracy_metric_test.update(prediction, label.reshape(-1))

accuracy_value_test = accuracy_metric_test.compute().item()
f1_value_test = f1_score(y_true_test, y_pred_test)
roc_auc_value_test = roc_auc_score(y_true_test, y_pred_prob_test)

print(f'Test Accuracy: {accuracy_value_test}')
print(f'Test F1 Score: {f1_value_test}')
print(f'Test ROC AUC Score: {roc_auc_value_test}')

Evaluating on Test Data:   0%|          | 0/818 [00:00<?, ?it/s]

Test Accuracy: 0.8973104953765869
Test F1 Score: 0.9334389857369256
Test ROC AUC Score: 0.9530156277307137
