In [None]:
import torch
from transformers import AutoTokenizer
import sys, os
sys.path.append(os.path.abspath(".."))
from src.model import RoBERTaMultiLabelClassifier
from src.data_preprocessing import aggressive_clean_text
from src.data_preprocessing import get_mlb_labels
import pandas as pd
import swifter

# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

# Load labels and preprocess data
def preprocess_dataframe_swifter(df):
    from src.data_preprocessing import aggressive_clean_text
    df['clean_text'] = df['text'].swifter.apply(aggressive_clean_text)
    df['labels'] = df['subreddit'].apply(lambda x: [lab.strip() for lab in x.split(',')])
    return df

data = pd.read_csv("../data/cleaned_paper.csv")
data = preprocess_dataframe_swifter(data)
disorders = ["depression", "anxiety", "OCD", "PTSD", "autism",
             "eatingdisorders", "adhd", "bipolar", "schizophrenia"]
mlb, _ = get_mlb_labels(data, disorders)

model = RoBERTaMultiLabelClassifier(num_labels=len(mlb.classes_)).to(device)
model.load_state_dict(torch.load("../models/best_roberta_multilabel.pt", map_location=device))
model.eval()

# Inference Function
def predict_labels(text, model, tokenizer, mlb, device, max_length=128):
    clean = aggressive_clean_text(text)
    tokens = tokenizer(clean, max_length=max_length, truncation=True, padding='max_length', return_tensors='pt')
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)

    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probs = torch.sigmoid(logits).cpu().numpy()[0]
        preds = [label for i, label in enumerate(mlb.classes_) if probs[i] > 0.5]

    return clean, preds


### Run the above module first to initialize the demo. 

In [None]:

# Demo Examples
examples = [
    "I can't stop worrying about everything and it's making it hard to sleep.",
    "I've been feeling hopeless and empty for the past few weeks.",
    "I get easily distracted and have trouble finishing tasks."
]

for i, text in enumerate(examples):
    cleaned, prediction = predict_labels(text, model, tokenizer, mlb, device)
    print(f"\nExample {i+1}:")
    print(f"Original Text: {text}")
    print(f"Cleaned Text: {cleaned}")
    print(f"Predicted Labels: {prediction}")