# plot results

In [6]:
# prediction library imports
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt

# Get true labels and predictions
true_labels = df['yo2']
predicted_labels = df['predicted_yo2']

# A confusion matrix provides a summary of prediction results on a classification problem. It shows the number of correct and incorrect predictions broken down by each class.
# Compute confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Display confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=model.config.id2label.values())
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.show()

# A classification report provides precision, recall, F1-score, and support for each class.
report = classification_report(true_labels, predicted_labels, target_names=model.config.id2label.values())
print("Classification Report:\n", report)


# The ROC curve plots the true positive rate against the false positive rate at various threshold settings. The AUC score summarizes the ROC curve.
# Assuming binary classification and that you have access to the probabilities
# Get model predictions probabilities
def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    inputs = tokenizer(
        texts,
        padding='max_length',
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)
    return probs[:, 1].cpu().numpy()  # Probability of the positive class

# Get probabilities for all data
df['predicted_proba'] = predict_proba(df['full_text'].tolist())

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(true_labels, df['predicted_proba'])
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (area = {roc_auc:0.2f})')
plt.plot([0,1], [0,1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()


# Useful for imbalanced datasets.
precision, recall, thresholds = precision_recall_curve(true_labels, df['predicted_proba'])
average_precision = average_precision_score(true_labels, df['predicted_proba'])

# Plot Precision-Recall curve
plt.figure()
plt.plot(recall, precision, color='b', lw=2, label=f'AP = {average_precision:0.2f}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='lower left')
plt.show()
