In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import BertTokenizerFast, BertModel, get_linear_schedule_with_warmup
import numpy as np
from sklearn.model_selection import train_test_split
import random
from typing import List, Dict, Optional, Tuple
import copy
import itertools
import json
from collections import Counter
import math
import pandas as pd
import json

# Configuration
BERT_MODEL_NAME = 'bert-base-uncased'
TRAIT_NAMES = ['openness', 'conscientiousness', 'extraversion', 'agreeableness', 'neuroticism', 'humility']
NUM_TRAITS = len(TRAIT_NAMES)
NUM_LEVELS = 3  # low, medium, high
ORDINAL_OUTPUTS_PER_TRAIT = NUM_LEVELS - 1

MAX_SEQ_LENGTH = 128
# N_COMMENTS_TO_PROCESS will be set by grid search
# ATTENTION_HIDDEN_DIM will be set by grid search
# NUM_NUMERICAL_FEATURES will be determined from data

# Early Stopping Configuration
EARLY_STOPPING_PATIENCE = 3 # Number of epochs to wait for improvement before stopping
MIN_DELTA = 0.001 # Minimum change in validation loss to be considered an improvement

# --- 1. Dataset Class (Unchanged) ---
class PersonalityDataset(Dataset):
    def __init__(self,
                 data: List[Dict],
                 tokenizer: BertTokenizerFast,
                 max_seq_length: int,
                 trait_names: List[str],
                 num_comments_to_process: int):
        self.data = data
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.trait_names = trait_names
        self.num_comments_to_process = num_comments_to_process

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        user_comments_all = sample['comments']
        numerical_features = sample.get('numerical_features', [])

        if len(user_comments_all) > self.num_comments_to_process:
            comments_to_process_or_pad = random.sample(user_comments_all, self.num_comments_to_process)
        else:
            comments_to_process_or_pad = user_comments_all

        processed_comments_input_ids = []
        processed_comments_attention_mask = []
        active_comment_flags = []

        num_actual_comments = len(comments_to_process_or_pad)

        for i in range(self.num_comments_to_process):
            if i < num_actual_comments:
                comment_text = comments_to_process_or_pad[i]
                active_comment_flags.append(True)
            else:
                comment_text = ""
                active_comment_flags.append(False)

            encoding = self.tokenizer.encode_plus(
                comment_text, add_special_tokens=True, max_length=self.max_seq_length,
                padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
            )
            processed_comments_input_ids.append(encoding['input_ids'].squeeze(0))
            processed_comments_attention_mask.append(encoding['attention_mask'].squeeze(0))

        input_ids_tensor = torch.stack(processed_comments_input_ids)
        attention_mask_tensor = torch.stack(processed_comments_attention_mask)
        comment_active_mask_tensor = torch.tensor(active_comment_flags, dtype=torch.bool)

        integer_labels = []
        for trait_name in self.trait_names:
            label = sample['labels'][trait_name]
            integer_labels.append(label)

        return {
            'input_ids': input_ids_tensor,
            'attention_mask': attention_mask_tensor,
            'comment_active_mask': comment_active_mask_tensor,
            'numerical_features': torch.tensor(numerical_features, dtype=torch.float),
            'labels': torch.tensor(integer_labels, dtype=torch.long)
        }

# --- 2. Model Class (Unchanged) ---
class PersonalityModel(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 num_traits: int,
                 ordinal_outputs_per_trait: int,
                 num_numerical_features: int = 0,
                 n_comments_to_process: int = 3,
                 dropout_rate: float = 0.2,
                 attention_hidden_dim: int = 128):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.n_comments_to_process = n_comments_to_process
        self.ordinal_outputs_per_trait = ordinal_outputs_per_trait
        self.num_numerical_features = num_numerical_features # Store for clarity

        bert_hidden_size = self.bert.config.hidden_size

        self.attention_w = nn.Linear(bert_hidden_size, attention_hidden_dim)
        self.attention_v = nn.Linear(attention_hidden_dim, 1, bias=False)

        self.feature_combiner_input_size = bert_hidden_size + self.num_numerical_features
        self.dropout = nn.Dropout(dropout_rate)

        self.trait_classifiers = nn.ModuleList()
        for _ in range(num_traits):
            self.trait_classifiers.append(
                nn.Linear(self.feature_combiner_input_size, ordinal_outputs_per_trait)
            )

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                comment_active_mask: torch.Tensor,
                numerical_features: Optional[torch.Tensor] = None):

        batch_size = input_ids.shape[0]
        input_ids_flat = input_ids.view(-1, input_ids.shape[-1])
        attention_mask_flat = attention_mask.view(-1, attention_mask.shape[-1])

        outputs = self.bert(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        comment_embeddings_flat = outputs.pooler_output
        comment_embeddings = comment_embeddings_flat.view(batch_size, self.n_comments_to_process, -1)

        u = torch.tanh(self.attention_w(comment_embeddings))
        scores = self.attention_v(u).squeeze(-1)
        if comment_active_mask is not None:
            scores = scores.masked_fill(~comment_active_mask, -1e9)
        attention_weights = F.softmax(scores, dim=1)
        attention_weights_expanded = attention_weights.unsqueeze(-1)
        aggregated_comment_embedding = torch.sum(attention_weights_expanded * comment_embeddings, dim=1)

        if numerical_features is not None and numerical_features.numel() > 0 and numerical_features.shape[1] > 0:
            if numerical_features.shape[1] != self.num_numerical_features:
                 print(f"Warning: numerical_features.shape[1] ({numerical_features.shape[1]}) "
                       f"does not match self.num_numerical_features ({self.num_numerical_features})")
            combined_features = torch.cat((aggregated_comment_embedding, numerical_features), dim=1)
        else:
            combined_features = aggregated_comment_embedding

        combined_features_dropped = self.dropout(combined_features)

        trait_specific_logits = []
        for classifier_head in self.trait_classifiers:
            trait_specific_logits.append(classifier_head(combined_features_dropped))

        all_logits = torch.cat(trait_specific_logits, dim=1)
        return all_logits


# --- 3. CORAL Loss Function (Unchanged) ---
class MultiTaskCORALLoss(nn.Module):
    def __init__(self, num_traits: int, num_levels: int, device: torch.device, trait_importance_weights: Optional[List[float]] = None):
        super().__init__()
        self.num_traits = num_traits
        self.num_levels = num_levels
        self.ordinal_outputs_per_trait = num_levels - 1
        self.device = device

        if trait_importance_weights is not None:
            self.trait_importance_weights = torch.tensor(trait_importance_weights, dtype=torch.float, device=self.device)
            if len(self.trait_importance_weights) != num_traits:
                raise ValueError("Length of trait_importance_weights must match num_traits.")
        else:
            self.trait_importance_weights = None

    def forward(self, all_logits: torch.Tensor, true_labels_int: torch.Tensor) -> torch.Tensor:
        batch_size = all_logits.shape[0]
        total_loss = torch.tensor(0.0, device=self.device)
        logits_per_trait_view = all_logits.view(batch_size, self.num_traits, self.ordinal_outputs_per_trait)

        for i in range(self.num_traits):
            trait_logits = logits_per_trait_view[:, i, :]
            trait_labels_int = true_labels_int[:, i]

            levels_binary_targets = torch.zeros_like(trait_logits, device=self.device)
            for k in range(self.ordinal_outputs_per_trait):
                levels_binary_targets[:, k] = (trait_labels_int > k).float()

            loss_trait = F.binary_cross_entropy_with_logits(
                trait_logits, levels_binary_targets, reduction='mean'
            )

            if self.trait_importance_weights is not None:
                total_loss += loss_trait * self.trait_importance_weights[i]
            else:
                total_loss += loss_trait

        return total_loss / self.num_traits if self.num_traits > 0 else torch.tensor(0.0, device=self.device)

# --- Prediction Conversion (Unchanged) ---
def convert_ordinal_logits_to_predictions(logits: torch.Tensor, num_traits: int, ordinal_outputs_per_trait: int, threshold: float = 0.5) -> torch.Tensor:
    batch_size = logits.shape[0]
    logits_per_trait = logits.view(batch_size, num_traits, ordinal_outputs_per_trait)
    probs = torch.sigmoid(logits_per_trait)
    predictions = (probs > threshold).long().sum(dim=2)
    return predictions

# --- 4. Training Loop (Unchanged) ---
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, verbose=True):
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(data_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        comment_active_mask = batch['comment_active_mask'].to(device)
        numerical_features = batch['numerical_features'].to(device)
        labels_int = batch['labels'].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask, comment_active_mask, numerical_features)
        loss = loss_fn(logits, labels_int)
        total_loss += loss.item()

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if scheduler:
            scheduler.step()

        if verbose and (batch_idx + 1) % 5 == 0:
             print(f"  Batch {batch_idx + 1}/{len(data_loader)}, Train Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(data_loader) if len(data_loader) > 0 else 0
    if verbose:
        print(f"Training Epoch Summary: Avg Loss: {avg_loss:.4f}")
    return avg_loss

# --- 5. Evaluation Loop (Unchanged) ---
def evaluate_epoch(model, data_loader, loss_fn, device, verbose=True):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            comment_active_mask = batch['comment_active_mask'].to(device)
            numerical_features = batch['numerical_features'].to(device)
            labels_int = batch['labels'].to(device)

            logits = model(input_ids, attention_mask, comment_active_mask, numerical_features)
            loss = loss_fn(logits, labels_int)
            total_loss += loss.item()

            if verbose and (batch_idx + 1) % 5 == 0:
                 print(f"  Batch {batch_idx + 1}/{len(data_loader)}, Val Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(data_loader) if len(data_loader) > 0 else float('inf')
    if verbose:
        print(f"Validation Epoch Summary: Avg Loss: {avg_loss:.4f}")
    return avg_loss

# --- Test Set Evaluation Function ---
def evaluate_on_test_set(model: PersonalityModel,
                         data_loader: DataLoader,
                         loss_fn: MultiTaskCORALLoss,
                         device: torch.device,
                         num_traits: int,
                         ordinal_outputs_per_trait: int,
                         trait_names: List[str]) -> Tuple[Optional[float], Optional[Dict[str, float]], Optional[float], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
    """
    Evaluates the model on a test dataset and computes various metrics.
    """
    model.eval()
    total_loss = 0
    all_predictions_list = []
    all_true_labels_list = []

    if not data_loader or len(data_loader) == 0:
        print("Test data_loader is empty. Cannot evaluate on test set.")
        return None, None, None, None

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            comment_active_mask = batch['comment_active_mask'].to(device)
            numerical_features = batch['numerical_features'].to(device)
            labels_int = batch['labels'].to(device)

            logits = model(input_ids, attention_mask, comment_active_mask, numerical_features)
            loss = loss_fn(logits, labels_int)
            total_loss += loss.item()

            # Convert logits to class predictions
            predictions_batch = convert_ordinal_logits_to_predictions(
                logits.cpu(), num_traits, ordinal_outputs_per_trait
            )
            all_predictions_list.append(predictions_batch)
            all_true_labels_list.append(labels_int.cpu())

    avg_loss = total_loss / len(data_loader)

    all_predictions_tensor = torch.cat(all_predictions_list, dim=0)
    all_true_labels_tensor = torch.cat(all_true_labels_list, dim=0)

    # Calculate Trait-wise Accuracy
    trait_accuracies = {}
    for i in range(num_traits):
        correct_predictions = (all_predictions_tensor[:, i] == all_true_labels_tensor[:, i]).sum().item()
        total_samples_for_trait = all_true_labels_tensor.shape[0]
        accuracy = correct_predictions / total_samples_for_trait if total_samples_for_trait > 0 else 0
        trait_accuracies[trait_names[i]] = accuracy

    # Calculate Overall Exact Match Accuracy
    # A sample is an exact match if all its traits are predicted correctly
    correct_sample_matches = (all_predictions_tensor == all_true_labels_tensor).all(dim=1).sum().item()
    total_samples = all_true_labels_tensor.shape[0]
    overall_exact_match_accuracy = correct_sample_matches / total_samples if total_samples > 0 else 0

    return avg_loss, trait_accuracies, overall_exact_match_accuracy, (all_predictions_tensor, all_true_labels_tensor)


# --- Function to run training for a set of hyperparameters (Unchanged) ---
def run_training_for_hyperparams(params: Dict,
                                 trial_train_data: List[Dict],
                                 trial_val_data: List[Dict],
                                 device: torch.device) -> float:
    # Extract params
    learning_rate = params['learning_rate']
    dropout_rate = params['dropout_rate']
    batch_size = params['batch_size']
    attention_hidden_dim = params['attention_hidden_dim']
    n_comments_to_process = params['n_comments_to_process']
    num_epochs_trial = params['num_epochs_trial']

    print(f"\n--- Starting Trial with Params: {params} ---")

    # --- Initialize Tokenizer, Datasets, DataLoaders ---
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

    current_num_numerical_features = 0
    if trial_train_data and len(trial_train_data) > 0 and trial_train_data[0].get('numerical_features'):
        current_num_numerical_features = len(trial_train_data[0]['numerical_features'])

    train_dataset = PersonalityDataset(
        data=trial_train_data, tokenizer=tokenizer, max_seq_length=MAX_SEQ_LENGTH,
        trait_names=TRAIT_NAMES, num_comments_to_process=n_comments_to_process
    )
    val_dataset = PersonalityDataset(
        data=trial_val_data, tokenizer=tokenizer, max_seq_length=MAX_SEQ_LENGTH,
        trait_names=TRAIT_NAMES, num_comments_to_process=n_comments_to_process
    )

    use_pin_memory = True if device.type == 'cuda' else False

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=max(1,math.floor((3/4)*os.cpu_count())), pin_memory=use_pin_memory, persistent_workers=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=max(1,math.floor((3/4)*os.cpu_count())), pin_memory=use_pin_memory, persistent_workers=True)

    # --- Initialize Model, Loss, Optimizer ---
    model = PersonalityModel(
        bert_model_name=BERT_MODEL_NAME,
        num_traits=NUM_TRAITS,
        ordinal_outputs_per_trait=ORDINAL_OUTPUTS_PER_TRAIT,
        num_numerical_features=current_num_numerical_features,
        n_comments_to_process=n_comments_to_process,
        dropout_rate=dropout_rate,
        attention_hidden_dim=attention_hidden_dim
    ).to(device)

    model_trait_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 0.4]
    loss_fn = MultiTaskCORALLoss(
        num_traits=NUM_TRAITS, num_levels=NUM_LEVELS, device=device,
        trait_importance_weights=model_trait_weights
    ).to(device)

    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
    total_steps = len(train_dataloader) * num_epochs_trial
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps) if total_steps > 0 else 0,
        num_training_steps=total_steps
    )

    best_val_loss_trial = float('inf')
    epochs_no_improve_trial = 0
    trial_verbose_epoch = True
    trial_verbose_batch = False

    for epoch in range(num_epochs_trial):
        if trial_verbose_epoch:
            print(f"  Epoch {epoch + 1}/{num_epochs_trial}")

        if not train_dataloader:
            print("  Skipping training epoch: train_dataloader is empty.")
            best_val_loss_trial = float('inf')
            break

        train_loss = train_epoch(
            model, train_dataloader, loss_fn, optimizer, device, scheduler, verbose=trial_verbose_batch
        )

        if val_dataloader and len(val_dataloader) > 0:
            current_val_loss = evaluate_epoch(
                model, val_dataloader, loss_fn, device, verbose=trial_verbose_batch
            )
            if trial_verbose_epoch:
                 print(f"  Params: {params}, Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Val Loss: {current_val_loss:.4f}")

            if current_val_loss < best_val_loss_trial - MIN_DELTA:
                best_val_loss_trial = current_val_loss
                epochs_no_improve_trial = 0
            else:
                epochs_no_improve_trial += 1

            if epochs_no_improve_trial >= EARLY_STOPPING_PATIENCE:
                if trial_verbose_epoch:
                    print(f"  Early stopping triggered for params {params} at epoch {epoch + 1}.")
                break
        else:
            best_val_loss_trial = train_loss
            if trial_verbose_epoch:
                print(f"  Params: {params}, Epoch {epoch + 1}, Train Loss: {train_loss:.4f}. No validation data for this trial.")

    print(f"--- Trial Finished for Params: {params}. Best Val Loss for this trial: {best_val_loss_trial:.4f} ---")
    return best_val_loss_trial

def test_data_transform(path):
    df = pd.read_csv(path)
    cols = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Humility']
    conversion = {
        'low': 0,
        'medium': 1,
        'high': 2
    }
    df[cols] = df[cols].apply(lambda col: col.map(conversion))

    #swap es to neuro
    conversion_es_neuro = {
        'low': 2,
        'medium': 1,
        'high': 0
    }
    df['Emotional stability'] = df['Emotional stability'].map(conversion_es_neuro)

    data = []
    for idx, row in df.iterrows():
        comments = [row[col] for col in ['Q1','Q2','Q3']]
        labels = {
            'openness': row['Openness'],
            'conscientiousness': row['Conscientiousness'],
            'extraversion': row['Extraversion'],
            'agreeableness': row['Agreeableness'],
            'neuroticism': row['Emotional stability'],
            'humility': row['Humility']
        }
        new_dict ={
            'id': row['id'],
            'comments': comments,
            'labels': labels
        }
        data.append(new_dict)
    return data

class TestPersonalityDataset(Dataset):
    def __init__(self,
                 data: List[Dict],
                 tokenizer, # Should be your actual tokenizer instance e.g., BertTokenizerFast
                 max_seq_length: int,
                 num_comments_to_process: int):
        self.data = data
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.num_comments_to_process = num_comments_to_process

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        sample = self.data[idx]
        # Use .get() with a default empty list for robustness if 'comments' key might be missing
        user_comments_all = sample.get('comments', [])
        # Use .get() for numerical_features as well
        numerical_features_list = sample.get('numerical_features', [])

        # Deterministic selection of comments for testing: take the first N
        if len(user_comments_all) > self.num_comments_to_process:
            comments_to_process_or_pad = user_comments_all[:self.num_comments_to_process]
        else:
            comments_to_process_or_pad = user_comments_all

        processed_comments_input_ids = []
        processed_comments_attention_mask = []
        active_comment_flags = []

        num_actual_comments = len(comments_to_process_or_pad)

        for i in range(self.num_comments_to_process):
            if i < num_actual_comments:
                comment_text = comments_to_process_or_pad[i]
                active_comment_flags.append(True)
            else:
                # Pad with empty strings if there are fewer actual comments
                # The tokenizer should handle empty strings (e.g., by producing [CLS], [SEP] and padding)
                comment_text = ""
                active_comment_flags.append(False)

            encoding = self.tokenizer.encode_plus(
                comment_text,
                add_special_tokens=True,
                max_length=self.max_seq_length,
                padding='max_length', # Crucial to ensure all sequences have the same length
                truncation=True,
                return_attention_mask=True,
                return_tensors='pt'  # Returns PyTorch tensors
            )
            # .squeeze(0) because encode_plus by default adds a batch dimension (e.g., [1, seq_len])
            # and we want to stack them along a new "number of comments" dimension.
            processed_comments_input_ids.append(encoding['input_ids'].squeeze(0))
            processed_comments_attention_mask.append(encoding['attention_mask'].squeeze(0))

        # Stack the list of tensors for each comment into a single tensor.
        # input_ids will have shape: [num_comments_to_process, max_seq_length]
        input_ids_tensor = torch.stack(processed_comments_input_ids)
        # attention_mask will have shape: [num_comments_to_process, max_seq_length]
        attention_mask_tensor = torch.stack(processed_comments_attention_mask)

        # comment_active_mask indicates which of the 'num_comments_to_process' slots
        # correspond to actual comments vs. padding comments.
        # Shape: [num_comments_to_process]
        comment_active_mask_tensor = torch.tensor(active_comment_flags, dtype=torch.bool)
        # If your model expects float (0.0, 1.0) for masks, use dtype=torch.float

        # Convert numerical features to a tensor.
        # Ensure numerical_features_list is a list of numbers.
        numerical_features_tensor = torch.tensor(numerical_features_list, dtype=torch.float)
        print('Test dataset finished')
        return {
            'input_ids': input_ids_tensor,
            'attention_mask': attention_mask_tensor,
            'comment_active_mask': comment_active_mask_tensor,
            'numerical_features': numerical_features_tensor,
        }


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#uncomment to run local
#train_mode = True
train_mode = False

if train_mode == True:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Data Loading (Train and Validation) (train) ---
    try:
        assert 'val_data_holdout' in globals(),'pass' #CHANGE?
        assert 'full_train_data' in globals(),'pass'
    except:
        try:
            with open(f'{folder}/data/intermediate/val_data.json', 'r') as f:
                test_data = json.load(f)
            with open(f'{folder}/data/intermediate/train_data.json', "r") as f:
                full_train_data = json.load(f)
        except:
            try:
                with open('val_data.json', 'r') as f:
                    test_data = json.load(f)
                with open('train_data.json', "r") as f:
                    full_train_data = json.load(f)
            except:
                exit()

    if not full_train_data:
        print("Error: full_train_data is empty or failed to load properly. Exiting.")
        exit()

    #val_data_holdout
    val_holdout_test_size = 0.1
    remaining_for_grid_search, val_data_holdout = train_test_split(
    full_train_data,
    test_size=val_holdout_test_size,
    random_state=42
    )

    NUM_NUMERICAL_FEATURES = 0
    if remaining_for_grid_search[0].get('numerical_features') is not None:
        NUM_NUMERICAL_FEATURES = len(remaining_for_grid_search[0]['numerical_features'])
        print(f"Number of numerical features detected in full_train_data: {NUM_NUMERICAL_FEATURES}")
    else:
        print("No 'numerical_features' key found in the first sample of full_train_data. Assuming 0 numerical features.")

    print(f"Sample full_train_data point (user_id: {remaining_for_grid_search[0].get('user_id', 'N/A')}, num_comments: {len(remaining_for_grid_search[0].get('comments', []))})")

    # --- Stratified Split for Grid Search ---
    traits_for_stratification = ['extraversion','openness']
    composite_labels = []
    valid_samples_for_strat = [] # To ensure full_train_data and composite_labels align

    for sample in remaining_for_grid_search:
        try:
            label_str = "-".join([str(sample['labels'][trait]) for trait in sorted(traits_for_stratification)])
            composite_labels.append(label_str)
            valid_samples_for_strat.append(sample)
        except KeyError as e:
            print(f"Warning: Missing trait {e} in labels for sample: {sample.get('user_id', 'Unknown User')}. Skipping this sample for stratification.")

    full_train_data_for_split = valid_samples_for_strat # Use only samples with valid labels for stratification
    composite_labels_np = np.array(composite_labels)

    if not full_train_data_for_split:
        print("Error: No valid samples remaining after checking for stratification labels. Exiting.")
        exit()

    print(f"Distribution of composite labels for traits {traits_for_stratification} (used for splitting {len(full_train_data_for_split)} samples):")
    label_counts = Counter(composite_labels_np)
    print(label_counts)

    min_samples_per_stratum = 2
    small_strata = {k: v for k, v in label_counts.items() if v < min_samples_per_stratum}
    if small_strata:
        print(f"WARNING: The following strata have fewer than {min_samples_per_stratum} samples: {small_strata}")
        print("Stratified splitting might fail or be unreliable if test_size is too large for these small strata.")

    grid_search_train_data, grid_search_val_data = [], []
    try:
        if len(composite_labels_np) > 0 and len(full_train_data_for_split) == len(composite_labels_np):
            grid_search_train_data, grid_search_val_data = train_test_split(
                full_train_data_for_split,
                test_size=0.20,
                random_state=42,
                stratify=composite_labels_np
            )
            print(f"\nSuccessfully performed stratified split.")
            print(f"Size of data for grid search training: {len(grid_search_train_data)}")
            print(f"Size of data for grid search validation: {len(grid_search_val_data)}")
        else:
             raise ValueError("Composite labels and data mismatch after filtering, or data is empty. Cannot stratify.")
    except ValueError as e:
        print(f"\nError during stratified split with composite labels: {e}")
        print("Falling back to stratifying by the first trait ('extraversion') or random split.")
        try:
            first_trait_labels = np.array([sample['labels'][traits_for_stratification[0]] for sample in full_train_data_for_split])
            grid_search_train_data, grid_search_val_data = train_test_split(
                full_train_data_for_split, test_size=0.20, random_state=42, stratify=first_trait_labels
            )
            print(f"Fallback (single trait stratification): Training size: {len(grid_search_train_data)}, Val size: {len(grid_search_val_data)}")
        except Exception as fallback_e:
            print(f"Fallback stratification also failed: {fallback_e}. Using random split.")
            grid_search_train_data, grid_search_val_data = train_test_split(
                full_train_data_for_split, test_size=0.20, random_state=42
            )
            print(f"Fallback (random split): Training size: {len(grid_search_train_data)}, Val size: {len(grid_search_val_data)}")

    if not grid_search_train_data or not grid_search_val_data:
        print("Error: Grid search training or validation data is empty after splits. Exiting.")
        exit()

    # --- Hyperparameter Grid Definition ---
    param_grid = {
        'learning_rate': [2e-5, 5e-5],
        'dropout_rate': [0.1, 0.2],
        'batch_size': [16],
        'attention_hidden_dim': [64, 128],
        'n_comments_to_process': [3, 4],
        'num_epochs_trial': [5]
    }

    keys, values = zip(*param_grid.items())
    hyperparameter_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
    print(f"\nStarting Grid Search. Total combinations: {len(hyperparameter_combinations)}")

    best_overall_val_loss = float('inf')
    best_hyperparams = None
    results = []

    for i, params_combo in enumerate(hyperparameter_combinations):
        print(f"\nGRID SEARCH TRIAL {i+1}/{len(hyperparameter_combinations)}")
        current_trial_val_loss = run_training_for_hyperparams(
            params_combo, grid_search_train_data, grid_search_val_data, device
        )
        results.append({'params': params_combo, 'val_loss': current_trial_val_loss})

        if current_trial_val_loss < best_overall_val_loss:
            best_overall_val_loss = current_trial_val_loss
            best_hyperparams = params_combo
            print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            print(f"!!! New best overall validation loss: {best_overall_val_loss:.4f} with params: {best_hyperparams} !!!")
            print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

    print("\n--- Grid Search Finished ---")
    print(f"Best overall validation loss: {best_overall_val_loss:.4f}")
    print(f"Best hyperparameters: {best_hyperparams}")

    print("\nAll Grid Search Results (sorted by validation loss):")
    for res in sorted(results, key=lambda x: x['val_loss']):
        print(f"  Params: {res['params']}, Val Loss: {res['val_loss']:.4f}")

    # --- Train final model with best hyperparameters ---
    if best_hyperparams:
        with open('best_hyperparams.json', 'w') as f:
            json.dump(best_hyperparams, f)
        print("\n--- Training a final model with the best hyperparameters ---")

        FINAL_MODEL_EPOCHS = best_hyperparams.get('num_epochs_trial', 10)
        print(f"Using best_hyperparams: {best_hyperparams} for up to {FINAL_MODEL_EPOCHS} epochs for the final model.")

        final_n_comments = best_hyperparams['n_comments_to_process']
        final_batch_size = best_hyperparams['batch_size']

        tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

        final_train_dataset = PersonalityDataset(
            data=full_train_data, # Train final model on ALL original training data
            tokenizer=tokenizer, max_seq_length=MAX_SEQ_LENGTH,
            trait_names=TRAIT_NAMES, num_comments_to_process=final_n_comments
        )
        final_val_dataset = PersonalityDataset(
            data=val_data_holdout,
            tokenizer=tokenizer, max_seq_length=MAX_SEQ_LENGTH,
            trait_names=TRAIT_NAMES, num_comments_to_process=final_n_comments
        )
        use_pin_memory = True if device.type == 'cuda' else False
        final_train_dataloader = DataLoader(final_train_dataset, batch_size=final_batch_size, shuffle=True, num_workers=max(1,math.floor((3/4)*os.cpu_count())), pin_memory=use_pin_memory, persistent_workers=True)
        final_val_dataloader = DataLoader(final_val_dataset, batch_size=final_batch_size, shuffle=False, num_workers=max(1,math.floor((3/4)*os.cpu_count())), pin_memory=use_pin_memory, persistent_workers=True)

        final_model = PersonalityModel(
            bert_model_name=BERT_MODEL_NAME,
            num_traits=NUM_TRAITS,
            ordinal_outputs_per_trait=ORDINAL_OUTPUTS_PER_TRAIT,
            num_numerical_features=NUM_NUMERICAL_FEATURES,
            n_comments_to_process=best_hyperparams['n_comments_to_process'],
            dropout_rate=best_hyperparams['dropout_rate'],
            attention_hidden_dim=best_hyperparams['attention_hidden_dim']
        ).to(device)

        final_loss_fn = MultiTaskCORALLoss(
            num_traits=NUM_TRAITS, num_levels=NUM_LEVELS, device=device,
            trait_importance_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 0.4]
        ).to(device)

        final_optimizer = AdamW(final_model.parameters(), lr=best_hyperparams['learning_rate'], eps=1e-8)

        final_total_steps = len(final_train_dataloader) * FINAL_MODEL_EPOCHS
        final_scheduler = get_linear_schedule_with_warmup(
            final_optimizer,
            num_warmup_steps=int(0.1 * final_total_steps) if final_total_steps > 0 else 0,
            num_training_steps=final_total_steps
        )

        best_final_model_val_loss = float('inf')
        epochs_no_improve_final = 0
        best_final_model_state_dict = None

        for epoch in range(FINAL_MODEL_EPOCHS):
            print(f"\n--- Final Model Training Epoch {epoch + 1}/{FINAL_MODEL_EPOCHS} ---")
            train_loss = train_epoch(final_model, final_train_dataloader, final_loss_fn, final_optimizer, device, final_scheduler, verbose=True)

            if final_val_dataloader and len(final_val_dataloader) > 0:
                current_val_loss = evaluate_epoch(final_model, final_val_dataloader, final_loss_fn, device, verbose=True)
                print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Val Loss: {current_val_loss:.4f}")

                if current_val_loss < best_final_model_val_loss - MIN_DELTA:
                    best_final_model_val_loss = current_val_loss
                    epochs_no_improve_final = 0
                    best_final_model_state_dict = copy.deepcopy(final_model.state_dict())
                    print(f"Validation loss improved to {best_final_model_val_loss:.4f}. Saving model state.")
                else:
                    epochs_no_improve_final += 1

                if epochs_no_improve_final >= EARLY_STOPPING_PATIENCE:
                    print(f"\nEarly stopping triggered for final model after {epoch + 1} epochs.")
                    break
            else:
                print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}. No validation data for early stopping final model.")
                best_final_model_state_dict = copy.deepcopy(final_model.state_dict())


        if best_final_model_state_dict:
            print("\nLoading best weights for the final model.")
            final_model.load_state_dict(best_final_model_state_dict)
            torch.save(best_final_model_state_dict, "best_final_model.pth")
            print("Saved best final model state_dict to best_final_model.pth")

        print("\nFinal model training finished or early stopped.")



        # --- NEW: Evaluate Final Model on Test Set ---
        print("\n--- Evaluating Final Model on Test Set ---")

        if test_data:
            # Check consistency of numerical features in test_data
            # This is a simple check on the first item. More robust checks might be needed.
            if test_data[0].get('numerical_features') is not None:
                num_feat_test = len(test_data[0]['numerical_features'])
                if num_feat_test != NUM_NUMERICAL_FEATURES:
                    print(f"WARNING: Test data has {num_feat_test} numerical features, but model was trained with {NUM_NUMERICAL_FEATURES}.")
                    print("This might lead to errors or unexpected behavior during evaluation.")
            elif NUM_NUMERICAL_FEATURES > 0:
                 print(f"WARNING: Model was trained with {NUM_NUMERICAL_FEATURES} numerical features, but test data sample seems to have none.")


            test_dataset = PersonalityDataset(
                data=test_data,
                tokenizer=tokenizer,
                max_seq_length=MAX_SEQ_LENGTH,
                trait_names=TRAIT_NAMES,
                num_comments_to_process=best_hyperparams['n_comments_to_process'] # Use best n_comments
            )
            test_dataloader = DataLoader(
                test_dataset,
                batch_size=best_hyperparams['batch_size'], # Use best batch_size
                shuffle=False, # No need to shuffle for testing
                num_workers=max(1,math.floor((3/4)*os.cpu_count())),
                pin_memory=use_pin_memory,
                persistent_workers=True if use_pin_memory else False
            )

            if test_dataloader and len(test_dataloader) > 0:
                test_loss, test_trait_accuracies, test_overall_accuracy, (test_preds, test_true) = evaluate_on_test_set(
                    model=final_model,
                    data_loader=test_dataloader,
                    loss_fn=final_loss_fn, # Re-use loss_fn from final model training
                    device=device,
                    num_traits=NUM_TRAITS,
                    ordinal_outputs_per_trait=ORDINAL_OUTPUTS_PER_TRAIT,
                    trait_names=TRAIT_NAMES
                )

                if test_loss is not None:
                    print(f"\nTest Set Evaluation Results:")
                    print(f"  Average Test Loss: {test_loss:.4f}")
                    print(f"  Overall Exact Match Accuracy: {test_overall_accuracy:.4f}")
                    print(f"  Trait-wise Accuracies:")
                    for trait, acc in test_trait_accuracies.items():
                        print(f"    {trait}: {acc:.4f}")

                    # You can also save or further analyze test_preds and test_true if needed
                    # print(f"Test Predictions (first 5): \n{test_preds[:5]}")
                    # print(f"Test True Labels (first 5): \n{test_true[:5]}")
                else:
                    print("Test evaluation could not be completed (e.g., dataloader was empty after all).")
            else:
                print("Test DataLoader is empty. Skipping test evaluation.")
        else:
            print("Test data is empty or not loaded. Skipping test set evaluation.")
        # --- END NEW TEST SET EVALUATION ---

    else:
        print("\nNo best hyperparameters found. Final model not trained. Test set evaluation skipped.")

    print("\n--- Script Finished ---")



elif train_mode == False: #----------------------------------------------------------------------------------------------------------------------------------------- EVAL MODE
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_pin_memory = True if device.type == 'cuda' else False
    print(f"Using device: {device}")
    #test data
    def real_test_data_transform(path):
        df = pd.read_csv(path)
        output = []
        for idx, row in df.iterrows():
                id = row['id']
                comments = [row['Q1'], row['Q2'], row['Q3']]
                new_entry = {
                    'id': id,
                    'comments': comments
                }
                output.append(new_entry)
        return output

    test_data = real_test_data_transform(r'..\..\data\test_data_a.csv')
    print('Data loaded')
    NUM_NUMERICAL_FEATURES = 0
    #LOAD BEST HYPERPARAMETERS
    best_hyperparams = {"learning_rate": 5e-05, "dropout_rate": 0.1, "batch_size": 16, "attention_hidden_dim": 64, "n_comments_to_process": 4, "num_epochs_trial": 5}

    #LOAD model
    final_model = PersonalityModel(
        bert_model_name=BERT_MODEL_NAME,
        num_traits=NUM_TRAITS,
        ordinal_outputs_per_trait=ORDINAL_OUTPUTS_PER_TRAIT,
        num_numerical_features=NUM_NUMERICAL_FEATURES,
        n_comments_to_process=3,
        dropout_rate=best_hyperparams['dropout_rate'],
        attention_hidden_dim=best_hyperparams['attention_hidden_dim']
    ).to(device)

    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

    #load state dict
    print("\nLoading best weights for the final model.")
    final_model.load_state_dict(torch.load(r"..\..\best_final_model.pth", map_location=device))

    # --- Evaluate Final Model on Test Set ---
    print("\n--- Evaluating Final Model on Test Set ---")
    if test_data:
        # Check consistency of numerical features in test_data
        if test_data[0].get('numerical_features') is not None:
            num_feat_test = len(test_data[0]['numerical_features'])
            if num_feat_test != NUM_NUMERICAL_FEATURES:
                print(f"WARNING: Test data has {num_feat_test} numerical features, but model was trained with {NUM_NUMERICAL_FEATURES}.")
                print("This might lead to errors or unexpected behavior during evaluation.")
        elif NUM_NUMERICAL_FEATURES > 0:
                print(f"WARNING: Model was trained with {NUM_NUMERICAL_FEATURES} numerical features, but test data sample seems to have none.")


        test_dataset = TestPersonalityDataset(
            data=test_data,
            tokenizer=tokenizer,
            max_seq_length=MAX_SEQ_LENGTH,
            num_comments_to_process=3
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=16,
            shuffle=False, # No need to shuffle for testing
            num_workers=0,#max(1,math.floor((3/4)*os.cpu_count())),
            pin_memory=False,#use_pin_memory,
            persistent_workers=False#True if use_pin_memory else False
        )
        #START TRY 3
        # --- Prediction Loop ---
        final_model.eval()
        all_predictions_list = []
        all_original_indices = [] # To map predictions back to original data if needed

        print("\nStarting predictions one by one:")
        with torch.no_grad():
            for i, batch in enumerate(test_dataloader):
                print(f"\nProcessing sample {i + 1}/{len(test_dataloader)//16}...")
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                comment_active_mask = batch['comment_active_mask'].to(device)
                numerical_features = batch['numerical_features'].to(device)

                # --- Model Forward Pass ---
                logits_batch = final_model(input_ids, attention_mask, comment_active_mask, numerical_features)
                # logits shape will be [16, NUM_TRAITS * ORDINAL_OUTPUTS_PER_TRAIT] because batch_size=1

                # --- Convert Logits to Predictions ---
                # predictions_batch will be for this single sample.
                # Expected shape: [1, NUM_TRAITS]
                predictions_batch_tensor = convert_ordinal_logits_to_predictions(
                    logits_batch.cpu(), NUM_TRAITS, ORDINAL_OUTPUTS_PER_TRAIT
                )
                # predictions_for_current_sample_tensor shape is [16, NUM_TRAITS]

                # If you want to work with the prediction for this specific sample immediately:
                # .squeeze(0) removes the batch dimension of 1.
                # .tolist() converts the tensor to a Python list.
                all_predictions_list.extend(predictions_batch_tensor.tolist())

    if all_predictions_list:
        if len(all_predictions_list) == len(test_data):
            test_df = pd.read_csv(r'..\..\data\test_data_a.csv')
            preds_df = test_df.copy()
            clean_traits = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Emotional stability', 'Humility']
            preds = pd.DataFrame(all_predictions_list, columns=TRAIT_NAMES)
            for c_trait, trait in zip(clean_traits, TRAIT_NAMES):
                if trait == 'neuroticism':
                    preds_df[c_trait] = preds[trait].apply(lambda x: 'low' if x == 2 else ('medium' if x == 1 else 'high'))
                else:
                    preds_df[c_trait] = preds[trait].apply(lambda x: 'low' if x == 0 else ('medium' if x == 1 else 'high'))
            with open('test_predictions.csv', 'w') as f:
                preds_df.to_csv(f, index=False)
            print("\nTest predictions saved to 'test_predictions.csv'.")
            

    else:
        print("Test data is empty or not loaded. Skipping test set evaluation.")
else:
  print('ERROR: Neither train or test mode.')

Using device: cuda
Data loaded

Loading best weights for the final model.

--- Evaluating Final Model on Test Set ---

Starting predictions one by one:
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished

Processing sample 1/3...
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished

Processing sample 2/3...
Test dataset finished
Test dataset finished
Test dataset finished
Test dataset finished
Test