In [None]:
import sys, os
sys.path.append(os.path.abspath(".."))
import pandas as pd
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from src.data_preprocessing import preprocess_dataframe, get_mlb_labels
from src.model import RoBERTaMultiLabelClassifier
from src.utils import RedditMentalHealthDataset, evaluate_model
import swifter
import time
import random


In [None]:
# Config
BATCH_SIZE = 16
MAX_LENGTH = 128


### This part should take 3-5 minutes, depending on your machine. 

In [None]:
# Load and preprocess data
def preprocess_dataframe_swifter(df):
    from src.data_preprocessing import aggressive_clean_text
    df['clean_text'] = df['text'].swifter.apply(aggressive_clean_text)
    df['labels'] = df['subreddit'].apply(lambda x: [lab.strip() for lab in x.split(',')])
    return df

data = pd.read_csv("../data/cleaned_paper.csv")
data = preprocess_dataframe_swifter(data)
disorders = ["depression", "anxiety", "OCD", "PTSD", "autism",
             "eatingdisorders", "adhd", "bipolar", "schizophrenia"]
mlb, _ = get_mlb_labels(data, disorders)


In [None]:
# Tokenizer and Dataset
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
full_dataset = RedditMentalHealthDataset(data, mlb, tokenizer, MAX_LENGTH)
loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RoBERTaMultiLabelClassifier(num_labels=len(mlb.classes_)).to(device)
model.load_state_dict(torch.load("../models/best_roberta_multilabel.pt", map_location=device))

In [None]:
# Evaluate
print("\nEvaluating model on the full dataset...\n")
start_time = time.time()
evaluate_model(model, loader, device, mlb)
print(f"\nEvaluation completed in {time.time() - start_time:.2f} seconds.")



In [None]:
# Example Predictions
print("\nExample Predictions:\n")
sample_indices = random.sample(range(len(data)), 3)
model.eval()
with torch.no_grad():
    for idx in sample_indices:
        input_data = full_dataset[idx]
        input_ids = input_data["input_ids"].unsqueeze(0).to(device)
        attention_mask = input_data["attention_mask"].unsqueeze(0).to(device)
        logits = model(input_ids, attention_mask)
        probs = torch.sigmoid(logits).cpu().numpy()[0]
        preds = [label for i, label in enumerate(mlb.classes_) if probs[i] > 0.5]
        true_labels = data.iloc[idx]['labels']

        print(f"Post: {data.iloc[idx]['clean_text'][:200]}...")
        print(f"True Labels: {true_labels}")
        print(f"Predicted Labels: {preds}\n")