In [3]:
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import torch.nn.functional as F

# Load the model and tokenizer
saved_model_path = "./climate_fever_model"  # Path where the model was saved
model = AutoModelForSequenceClassification.from_pretrained(saved_model_path)
tokenizer = AutoTokenizer.from_pretrained(saved_model_path)

# Label mapping
label_map = {
    "SUPPORTS": "SUPPORTS",
    "REFUTES": "REFUTES",
    "NOT_ENOUGH_INFO": "UNDECIDED",
    "DISPUTED": "UNDECIDED"
}
label_mapping = ['SUPPORTS', 'REFUTES', 'UNDECIDED']  # Model's output labels
label_to_index = {label: idx for idx, label in enumerate(label_mapping)}  # Map labels to indices

# Function to predict the label for a claim
def predict_label_with_probs(claim):
    # Tokenize the claim
    features = tokenizer(
        [claim], 
        padding='max_length', 
        truncation=True, 
        return_tensors="pt", 
        max_length=512
    )
    
    # Get model predictions
    model.eval()
    with torch.no_grad():
        scores = model(**features).logits
        probs = F.softmax(scores, dim=-1)  # Convert logits to probabilities
        predicted_label_idx = probs.argmax(dim=1).item()
        predicted_label = label_mapping[predicted_label_idx]
        probabilities = probs[0].tolist()  # Convert tensor to list for easy printing
    
    print(f"Claim: {claim}")
    print(f"Probabilities: {probabilities}")
    print(f"Predicted Label: {predicted_label}")
    
    return predicted_label

# Load the dataset
dataset_path = "dataset\\fr_climate-fever-dataset-r1_period_maj_opus-mt-tc-big-en-fr_v2-unicode.jsonl"
true_labels = []
predicted_labels = []

with open(dataset_path, "r", encoding="utf-8") as file:
    for line in file:
        data = json.loads(line)
        claim = data["claim"]
        claim_label = data["claim_label"]
        
        # Skip DISPUTED claims
        if claim_label == 'DISPUTED':
            continue
        
        # Map the claim_label to the model's label space
        mapped_claim_label = label_map.get(claim_label, "neutral")  # Default to 'neutral' if label not found
        
        # Predict the label for the claim
        predicted_label = predict_label_with_probs(claim)
        
        print(f"predicted_label: {predicted_label}")
        # Store true and predicted labels
        true_labels.append(label_to_index[mapped_claim_label])
        predicted_labels.append(label_to_index[predicted_label])

# Calculate metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='weighted')  # Weighted average for multiclass
recall = recall_score(true_labels, predicted_labels, average='weighted')  # Weighted average for multiclass
f1 = f1_score(true_labels, predicted_labels, average='weighted')  # Weighted average for multiclass

# Print results
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision (weighted): {precision:.4f}")
print(f"Recall (weighted): {recall:.4f}")
print(f"F1 Score (weighted): {f1:.4f}")

predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: UNDECIDED
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: UNDECIDED
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: UNDECIDED
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: REFUTES
predicted_label: UNDECIDED
predicted_label: REFUTES
predicted_label: 

KeyboardInterrupt: 

In [2]:
print(model.config.num_labels)  # Should print 3 if you're predicting 3 classes


3
