In [5]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

torch.cuda.empty_cache()

# Load the fine-tuned MentalBERT model and tokenizer
model_path = "fine_tuned_mentalbert_cause_classifier"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define label mapping
id2label = {
    0: "Health Issues",
    1: "Relationship Issues",
    2: "Financial Stress",
    3: "Workplace Stress",
    4: "Social Isolation",
    5: "No Issues"
}

# Predict Cause
def predict_cause(statement):
    # Tokenize the input
    inputs = tokenizer(statement, return_tensors="pt", truncation=True, padding=True)
    inputs = {key: value.to(device) for key, value in inputs.items()}  # Move inputs to the same device as the model
    
    # Perform prediction
    with torch.no_grad():
        outputs = model(**inputs)
    predicted_label = torch.argmax(outputs.logits, dim=1).item()
    
    # Return the predicted cause
    cause = id2label[predicted_label]
    return cause

# Prompts
example_statement1 = "Strange. I don't have work today and I have a bit of free time so I can read shrill novels but... it feels weird. I've been nervous about checking the google calendar just in case if I read it wrong. But it's still weird, how about this, like you should be looking for a job to get rid of this feeling."
predicted_cause1 = predict_cause(example_statement1)
print(f"\nFor the statement:\n\"{example_statement1}\"\nThe predicted cause is: {predicted_cause1}")

example_statement2 = "I dont like going out these days. No regrets or grudges/angry at things that have passed, and not worrying too much about the future, that's true serenity."
predicted_cause2 = predict_cause(example_statement2)
print(f"\nFor the statement:\n\"{example_statement2}\"\nThe predicted cause is: {predicted_cause2}")

example_statement3 = "Some days I'm very restless when I want to sleep. I often black out and faint"
predicted_cause3 = predict_cause(example_statement3)
print(f"\nFor the statement:\n\"{example_statement3}\"\nThe predicted cause is: {predicted_cause3}")


For the statement:
"Strange. I don't have work today and I have a bit of free time so I can read shrill novels but... it feels weird. I've been nervous about checking the google calendar just in case if I read it wrong. But it's still weird, how about this, like you should be looking for a job to get rid of this feeling."
The predicted cause is: Workplace Stress

For the statement:
"I dont like going out these days. No regrets or grudges/angry at things that have passed, and not worrying too much about the future, that's true serenity."
The predicted cause is: Social Isolation

For the statement:
"Some days I'm very restless when I want to sleep. I often black out and faint"
The predicted cause is: Health Issues
