# arrow doesnt work

In [None]:
import json # Keep for Optuna best_params.json, but not for dataset loading
import torch
from torch.utils.data import IterableDataset
# from transformers.tokenization_utils_base import BatchEncoding # No longer directly used in dataset for loading
import logging
import random
import numpy as np
import torch.nn.functional as F
from transformers import BertModel, BertConfig, get_linear_schedule_with_warmup
from typing import Optional, Tuple, Dict, Union, List
from torch import nn
import optuna
from torch.utils.data import DataLoader
import gc
# from transformers.tokenization_utils_base import BatchEncoding # For type checking, if needed elsewhere
import torch.optim as optim
import os
import shutil

import pyarrow.feather as feather # For reading arrow files
import pyarrow as pa # For pyarrow types, if needed for schema checks

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Removed JSON-specific constants and helper functions for dataset loading

class ArrowIterableDataset(IterableDataset):
    def __init__(self, file_path: str, trait_names: List[str], n_comments_to_process: int,
                 other_numerical_feature_names: List[str], num_q_features_per_comment: int,
                 tokenizer_max_length: int, # For padding/truncation reference in transform
                 is_test_set: bool = False, transform_fn: Optional[callable] = None):
        super().__init__()
        self.file_path = file_path
        self.trait_names_ordered = trait_names
        self.n_comments_to_process = n_comments_to_process
        self.other_numerical_feature_names = other_numerical_feature_names
        self.num_q_features_per_comment = num_q_features_per_comment
        self.tokenizer_max_length = tokenizer_max_length
        self.is_test_set = is_test_set
        self.transform_fn = self._default_transform if transform_fn is None else transform_fn

        self.table: Optional[pa.Table] = None
        self.num_samples: int = 0

        try:
            # For persistent workers, this happens once per worker process.
            # Arrow tables are memory-mapped by default if the OS supports it and
            # the file fits, which is efficient for multi-processing.
            logger.info(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: Loading Arrow table from {self.file_path}...")
            self.table = feather.read_table(self.file_path) #, memory_map=True) # memory_map is often default
            self.num_samples = len(self.table)
            logger.info(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: Successfully loaded Arrow table with {self.num_samples} samples from {self.file_path}.")
            if self.num_samples > 0:
                logger.debug(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: Arrow table schema: {self.table.schema}")
                # Optional: Check for expected columns (useful for debugging)
                expected_cols = ["input_ids", "attention_mask", "q_scores"] + self.other_numerical_feature_names
                if not self.is_test_set:
                    expected_cols += [self._get_label_col_name(t) for t in self.trait_names_ordered]
                
                missing_cols = [col for col in expected_cols if col not in self.table.column_names]
                if missing_cols:
                    logger.warning(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: Missing expected columns in Arrow table '{self.file_path}': {missing_cols}")

        except FileNotFoundError:
            logger.error(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: Arrow file not found: {self.file_path}. Dataset will be empty.")
        except Exception as e:
            logger.error(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: Error loading Arrow table from {self.file_path}: {e}. Dataset will be empty.", exc_info=True)
        
        if self.num_samples == 0:
             logger.warning(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: Initialized ArrowIterableDataset for {self.file_path} with 0 samples.")


    def _get_label_col_name(self, trait_name: str) -> str:
        """Helper to get consistent label column names."""
        return f"label_{trait_name.lower().replace(' ', '_').replace('-', '_')}"

    def _process_row_data(self, row_dict: Dict, row_idx: int) -> Optional[Tuple]:
        """Applies the transform function to a row dictionary."""
        try:
            return self.transform_fn(row_dict, row_idx)
        except Exception as e:
            # Log more details, including worker ID if available
            worker_id_str = f"Worker {torch.utils.data.get_worker_info().id}" if torch.utils.data.get_worker_info() else "Main"
            logger.error(f"{worker_id_str}: Error in transform_fn for row {row_idx} from {self.file_path}: {e}", exc_info=True)
            # To debug problematic data:
            # logger.error(f"Problematic data: {row_dict}")
            return None

    def __len__(self) -> int:
        return self.num_samples

    def _default_transform(self, sample_dict: Dict, idx: int) -> Optional[Tuple]:
        """
        Transforms a dictionary (representing a row from Arrow) into tensors.
        'sample_dict' contains column names as keys and their values for the current row.
        'idx' is the original row index in the full dataset.
        """
        # Directly access columns from the sample_dict
        all_input_ids_list = sample_dict.get('input_ids') # Expected: list of lists of ints
        all_attention_mask_list = sample_dict.get('attention_mask') # Expected: list of lists of ints

        # Basic validation for tokenized data
        if not isinstance(all_input_ids_list, list) or \
           not isinstance(all_attention_mask_list, list) or \
           not all_input_ids_list or \
           (all_input_ids_list and not isinstance(all_input_ids_list[0], list)):
            logger.warning(f"Sample {idx} from {self.file_path} has malformed 'input_ids' or 'attention_mask'. Expected list of lists. Skipping. "
                           f"Input IDs type: {type(all_input_ids_list)}, Attention Mask type: {type(all_attention_mask_list)}")
            return None

        try:
            # Assuming lists are already appropriately sized from your conversion script
            all_input_ids = torch.tensor(all_input_ids_list, dtype=torch.long)
            all_attention_mask = torch.tensor(all_attention_mask_list, dtype=torch.long)
        except Exception as e:
            logger.error(f"Error converting input_ids/attention_mask to tensor for sample {idx}. Data shapes: "
                         f"input_ids: {len(all_input_ids_list) if isinstance(all_input_ids_list, list) else 'N/A'}, "
                         f"attention_mask: {len(all_attention_mask_list) if isinstance(all_attention_mask_list, list) else 'N/A'}. Error: {e}")
            return None

        num_actual_comments = all_input_ids.shape[0]
        current_seq_len = all_input_ids.shape[1] if num_actual_comments > 0 and all_input_ids.ndim == 2 else self.tokenizer_max_length

        # --- Comment selection and padding logic (similar to your original _default_transform) ---
        final_input_ids = torch.zeros((self.n_comments_to_process, current_seq_len), dtype=torch.long)
        final_attention_mask = torch.zeros((self.n_comments_to_process, current_seq_len), dtype=torch.long)
        comment_active_flags = torch.zeros(self.n_comments_to_process, dtype=torch.bool)

        indices_to_select = list(range(num_actual_comments))
        if num_actual_comments > self.n_comments_to_process:
            indices_to_select = random.sample(indices_to_select, self.n_comments_to_process)
            comments_to_fill = self.n_comments_to_process
        else:
            comments_to_fill = num_actual_comments

        for i in range(comments_to_fill):
            original_idx_in_sample_comments = indices_to_select[i] # Index within the comments of the current sample
            if num_actual_comments > 0: # Ensure there are comments to select from
                 final_input_ids[i] = all_input_ids[original_idx_in_sample_comments]
                 final_attention_mask[i] = all_attention_mask[original_idx_in_sample_comments]
            comment_active_flags[i] = True
        
        # --- Q-scores processing ---
        raw_q_scores_list = sample_dict.get('q_scores', []) # Expected: list of lists of floats
        final_q_scores = torch.zeros((self.n_comments_to_process, self.num_q_features_per_comment), dtype=torch.float)

        selected_raw_q_scores_for_tensor = []
        for i in range(comments_to_fill):
            original_comment_idx = indices_to_select[i] # Index within the comments of the current sample
            if isinstance(raw_q_scores_list, list) and original_comment_idx < len(raw_q_scores_list):
                qs_for_comment = raw_q_scores_list[original_comment_idx] # This should be a list of floats
                if not isinstance(qs_for_comment, list): qs_for_comment = [] # Handle if a specific comment's q_scores is malformed
                
                qs_for_comment_truncated = qs_for_comment[:self.num_q_features_per_comment]
                padded_qs = qs_for_comment_truncated + [0.0] * (self.num_q_features_per_comment - len(qs_for_comment_truncated))
                selected_raw_q_scores_for_tensor.append(padded_qs)
            else:
                selected_raw_q_scores_for_tensor.append([0.0] * self.num_q_features_per_comment)

        if comments_to_fill > 0 and selected_raw_q_scores_for_tensor:
            try:
                final_q_scores[:comments_to_fill] = torch.tensor(selected_raw_q_scores_for_tensor, dtype=torch.float)
            except Exception as e:
                logger.error(f"Error converting selected_raw_q_scores to tensor for sample {idx}: {e}. Data: {selected_raw_q_scores_for_tensor}")
        
        # --- Other numerical features ---
        other_numerical_features_list = []
        for fname in self.other_numerical_feature_names:
            val = sample_dict.get(fname) # Directly get from dict
            if val is None: val = 0.0 # Handle missing values if your Arrow file might have them
            try:
                other_numerical_features_list.append(float(val))
            except (ValueError, TypeError):
                logger.warning(f"Sample {idx}: Could not convert numerical feature '{fname}' value '{val}' to float. Using 0.0.")
                other_numerical_features_list.append(0.0)
        other_numerical_features_tensor = torch.tensor(other_numerical_features_list, dtype=torch.float)

        # --- Labels (for training/validation) ---
        if not self.is_test_set:
            regression_labels = []
            for trait_key in self.trait_names_ordered:
                label_col_name = self._get_label_col_name(trait_key)
                label_val = sample_dict.get(label_col_name)
                if label_val is None:
                    logger.warning(f"Sample {idx}: Missing label for trait '{trait_key}' (column '{label_col_name}'). Using 0.0. "
                                   "Ensure your Arrow conversion script includes all label columns.")
                    label_val = 0.0
                try:
                    label_float = float(label_val)
                    if not (0.0 <= label_float <= 1.0): label_float = np.clip(label_float, 0.0, 1.0)
                    regression_labels.append(label_float)
                except (ValueError, TypeError):
                    logger.warning(f"Sample {idx}: Could not convert label for trait '{trait_key}' value '{label_val}' to float. Using 0.0.")
                    regression_labels.append(0.0)
            labels_tensor = torch.tensor(regression_labels, dtype=torch.float)
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor, labels_tensor)
        else:
            # For test set, no labels
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor)

    def __iter__(self):
        if self.table is None or self.num_samples == 0:
            # logger.warning(f"Worker {os.getpid() if torch.utils.data.get_worker_info() else 'main'}: __iter__ called but table is None or num_samples is 0. Yielding nothing.")
            return iter([]) # Return an empty iterator

        worker_info = torch.utils.data.get_worker_info()
        
        # Determine the range of indices for this worker
        if worker_info is None: # Single-process
            iter_start = 0
            iter_end = self.num_samples
        else: # Multi-process
            # Basic sharding: each worker gets a slice of the indices.
            # More sophisticated sharding might be needed for uneven workloads, but this is standard.
            per_worker = int(np.ceil(self.num_samples / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.num_samples)

        # logger.debug(f"Worker {worker_info.id if worker_info else 'main'} processing indices from {iter_start} to {iter_end-1}")

        # Iterate over the assigned range of indices.
        # Reading row by row can be slow. Reading by batches is better.
        # `self.table.to_pydict()` loads the whole table; we want to avoid that if very large.
        # Instead, iterate through record batches for the worker's assigned range.
        
        current_global_idx = 0
        processed_count_this_worker = 0

        # We can iterate through the table's record batches
        # and then within each batch, check if the global index falls into the worker's range.
        # This is more efficient than slicing the table per worker if the table is large.
        for record_batch in self.table.to_batches(max_chunksize=1024): # Adjust chunksize as needed
            batch_pydict = record_batch.to_pydict() # Convert current RecordBatch
            num_rows_in_this_record_batch = record_batch.num_rows

            for i in range(num_rows_in_this_record_batch):
                global_idx = current_global_idx + i
                if global_idx >= iter_end: # This worker is done with its range
                    # logger.debug(f"Worker {worker_info.id if worker_info else 'main'} finished its range at global_idx {global_idx-1}. Processed {processed_count_this_worker} items.")
                    return # Stop iteration for this worker

                if global_idx >= iter_start: # This global_idx is for this worker
                    sample_dict_for_row = {col_name: batch_pydict[col_name][i] for col_name in record_batch.column_names}
                    processed_item = self._process_row_data(sample_dict_for_row, global_idx)
                    if processed_item:
                        yield processed_item
                        processed_count_this_worker += 1
            
            current_global_idx += num_rows_in_this_record_batch
            if current_global_idx >= self.num_samples: # Processed all samples in the table
                # logger.debug(f"Worker {worker_info.id if worker_info else 'main'} reached end of table. Processed {processed_count_this_worker} items.")
                return
        
        # logger.debug(f"Worker {worker_info.id if worker_info else 'main'} completed iteration. Processed {processed_count_this_worker} items.")


# --- PersonalityModelV3 (Your model code remains the same) ---
class PersonalityModelV3(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 num_traits: int,
                 n_comments_to_process: int = 3,
                 dropout_rate: float = 0.2,
                 attention_hidden_dim: int = 128,
                 num_bert_layers_to_pool: int = 4,
                 num_q_features_per_comment: int = 3,
                 num_other_numerical_features: int = 0,
                 numerical_embedding_dim: int = 64,
                 num_additional_dense_layers: int = 0,
                 additional_dense_hidden_dim: int = 256,
                 additional_layers_dropout_rate: float = 0.3
                ):
        super().__init__()
        self.bert_config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)
        self.bert = BertModel.from_pretrained(bert_model_name, config=self.bert_config)
        self.n_comments_to_process = n_comments_to_process
        self.num_bert_layers_to_pool = num_bert_layers_to_pool
        bert_hidden_size = self.bert.config.hidden_size
        self.num_q_features_per_comment = num_q_features_per_comment

        comment_feature_dim = bert_hidden_size + self.num_q_features_per_comment
        self.attention_w = nn.Linear(comment_feature_dim, attention_hidden_dim)
        self.attention_v = nn.Linear(attention_hidden_dim, 1, bias=False)
        
        self.final_dropout_layer = nn.Dropout(dropout_rate) 

        self.num_other_numerical_features = num_other_numerical_features
        self.uses_other_numerical_features = self.num_other_numerical_features > 0
        self.other_numerical_processor_output_dim = 0
        
        aggregated_comment_feature_dim = comment_feature_dim 
        combined_input_dim_for_block = aggregated_comment_feature_dim

        if self.uses_other_numerical_features:
            self.other_numerical_processor_output_dim = numerical_embedding_dim
            self.other_numerical_processor = nn.Sequential(
                nn.Linear(self.num_other_numerical_features, self.other_numerical_processor_output_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )
            combined_input_dim_for_block += self.other_numerical_processor_output_dim
            logger.info(f"Model will use {self.num_other_numerical_features} other numerical features, processed to dim {self.other_numerical_processor_output_dim}.")
        else:
            logger.info("Model will NOT use other numerical features.")

        self.num_additional_dense_layers = num_additional_dense_layers
        self.additional_dense_block = nn.Sequential()
        current_dim_for_dense_block = combined_input_dim_for_block

        if self.num_additional_dense_layers > 0:
            logger.info(f"Model using {self.num_additional_dense_layers} additional dense layers with hidden_dim {additional_dense_hidden_dim} and dropout {additional_layers_dropout_rate}")
            for i in range(self.num_additional_dense_layers):
                self.additional_dense_block.add_module(f"add_dense_{i}_linear", nn.Linear(current_dim_for_dense_block, additional_dense_hidden_dim))
                self.additional_dense_block.add_module(f"add_dense_{i}_relu", nn.ReLU())
                self.additional_dense_block.add_module(f"add_dense_{i}_dropout", nn.Dropout(additional_layers_dropout_rate))
                current_dim_for_dense_block = additional_dense_hidden_dim
            input_dim_for_regressors = current_dim_for_dense_block
        else:
            logger.info("Model not using additional dense layers. Will use final_dropout_layer if dropout_rate > 0.")
            input_dim_for_regressors = combined_input_dim_for_block

        self.trait_regressors = nn.ModuleList()
        for _ in range(num_traits):
            self.trait_regressors.append(
                nn.Linear(input_dim_for_regressors, 1)
            )

    def _pool_bert_layers(self, all_hidden_states: Tuple[torch.Tensor, ...], attention_mask: torch.Tensor) -> torch.Tensor:
        layers_to_pool = all_hidden_states[-self.num_bert_layers_to_pool:]
        pooled_outputs = []
        expanded_attention_mask = attention_mask.unsqueeze(-1).expand_as(layers_to_pool[0])
        
        for layer_hidden_states in layers_to_pool:
            sum_embeddings = torch.sum(layer_hidden_states * expanded_attention_mask, dim=1)
            sum_mask = expanded_attention_mask.sum(dim=1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled_outputs.append(sum_embeddings / sum_mask)
            
        stacked_pooled_outputs = torch.stack(pooled_outputs, dim=0)
        mean_pooled_layers_embedding = torch.mean(stacked_pooled_outputs, dim=0)
        return mean_pooled_layers_embedding

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                q_scores: torch.Tensor,
                comment_active_mask: torch.Tensor,
                other_numerical_features: Optional[torch.Tensor] = None
               ):
        batch_size = input_ids.shape[0]
        
        input_ids_flat = input_ids.view(-1, input_ids.shape[-1])
        attention_mask_flat = attention_mask.view(-1, attention_mask.shape[-1])
        
        bert_outputs = self.bert(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        comment_bert_embeddings_flat = self._pool_bert_layers(bert_outputs.hidden_states, attention_mask_flat)
        comment_bert_embeddings = comment_bert_embeddings_flat.view(batch_size, self.n_comments_to_process, -1)
        
        comment_features_with_q = torch.cat((comment_bert_embeddings, q_scores), dim=2)
        
        u = torch.tanh(self.attention_w(comment_features_with_q))
        scores = self.attention_v(u).squeeze(-1)
        
        if comment_active_mask is not None:
            scores = scores.masked_fill(~comment_active_mask, -float('inf')) # Use -float('inf') for softmax
            
        attention_weights = F.softmax(scores, dim=1)
        attention_weights_expanded = attention_weights.unsqueeze(-1)
        
        aggregated_comment_features = torch.sum(attention_weights_expanded * comment_features_with_q, dim=1)

        final_features_for_processing = aggregated_comment_features
        if self.uses_other_numerical_features:
            if other_numerical_features is None or other_numerical_features.shape[1] != self.num_other_numerical_features:
                raise ValueError(
                    f"Other numerical features expected but not provided correctly. "
                    f"Expected {self.num_other_numerical_features}, got shape {other_numerical_features.shape if other_numerical_features is not None else 'None'}"
                )
            processed_other_numerical_features = self.other_numerical_processor(other_numerical_features)
            final_features_for_processing = torch.cat((aggregated_comment_features, processed_other_numerical_features), dim=1)
        
        if self.num_additional_dense_layers > 0:
            features_for_trait_heads = self.additional_dense_block(final_features_for_processing)
        else: # Apply dropout even if no additional dense layers
            features_for_trait_heads = self.final_dropout_layer(final_features_for_processing)
        
        trait_regression_outputs = []
        for regressor_head in self.trait_regressors:
            trait_regression_outputs.append(regressor_head(features_for_trait_heads))
        
        all_trait_outputs_raw = torch.cat(trait_regression_outputs, dim=1)
        all_trait_outputs_sigmoid = torch.sigmoid(all_trait_outputs_raw)
        
        return all_trait_outputs_sigmoid

    def predict_scores(self, outputs: torch.Tensor) -> torch.Tensor:
        return outputs


# --- Optuna Objective Function (MODIFIED) ---
def objective(trial: optuna.trial.Trial,
              train_file_path: str, # Will be .arrow
              val_file_path: str,   # Will be .arrow
              global_config: Dict,
              device: torch.device,
              num_epochs_per_trial: int,
              num_dataloader_workers: int, # Added for flexibility
              overall_best_weights_filepath: str
             ):
    logger.info(f"Starting Optuna Trial {trial.number} with Arrow files and {num_dataloader_workers} workers.")

    # --- Hyperparameter suggestions (remain the same) ---
    num_traits = len(global_config['TRAIT_NAMES'])
    other_numerical_feature_names_trial = global_config.get('OTHER_NUMERICAL_FEATURE_NAMES', [])
    num_other_numerical_features_trial = len(other_numerical_feature_names_trial)
    num_q_features_per_comment_trial = global_config.get('NUM_Q_FEATURES_PER_COMMENT', 3)

    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    attention_hidden_dim = trial.suggest_categorical("attention_hidden_dim", [128, 256, 512])
    lr_bert = trial.suggest_float("lr_bert", 5e-6, 1e-4, log=True)
    lr_head = trial.suggest_float("lr_head", 1e-4, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    num_bert_layers_to_pool = trial.suggest_int("num_bert_layers_to_pool", 1, 4)
    n_comments_trial = trial.suggest_int("n_comments_to_process", 3, global_config.get('MAX_COMMENTS_TO_PROCESS_PHYSICAL', 3))
    num_unfrozen_bert_layers = trial.suggest_int("num_unfrozen_bert_layers", 0, 6)
    patience_early_stopping = trial.suggest_int("patience_early_stopping", 3, 5)
    scheduler_type = trial.suggest_categorical("scheduler_type", ["none", "linear_warmup"])
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.05, 0.2) if scheduler_type != "none" else 0.0
    batch_size_trial = trial.suggest_categorical("batch_size", [8, 16, 32]) # Keep batch size reasonable with more workers

    other_numerical_embedding_dim_trial = 0
    if num_other_numerical_features_trial > 0:
        other_numerical_embedding_dim_trial = trial.suggest_categorical("other_numerical_embedding_dim", [32, 64, 128])

    num_additional_dense_layers_trial = trial.suggest_int("num_additional_dense_layers", 0, 3)
    additional_dense_hidden_dim_trial = 0
    additional_layers_dropout_rate_trial = 0.0
    if num_additional_dense_layers_trial > 0:
        additional_dense_hidden_dim_trial = trial.suggest_categorical("additional_dense_hidden_dim", [128, 256, 512])
        additional_layers_dropout_rate_trial = trial.suggest_float("additional_layers_dropout_rate", 0.1, 0.5)

    logger.info(f"Trial {trial.number} - Suggested Parameters: {trial.params}")
    
    # --- Dataset and DataLoader setup ---
    try:
        logger.info(f"Trial {trial.number}: Initializing ArrowIterableDataset for training...")
        train_dataset_trial = ArrowIterableDataset(
            file_path=train_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            tokenizer_max_length=global_config['TOKENIZER_MAX_LENGTH'],
            is_test_set=False
        )
        if train_dataset_trial.num_samples == 0:
            logger.error(f"Trial {trial.number} - Training dataset '{train_file_path}' is empty or failed to load. Skipping trial.")
            return float('inf') # Return high loss
        
        logger.info(f"Trial {trial.number}: Initializing ArrowIterableDataset for validation...")
        val_dataset_trial = ArrowIterableDataset(
            file_path=val_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            tokenizer_max_length=global_config['TOKENIZER_MAX_LENGTH'],
            is_test_set=False
        )
        if val_dataset_trial.num_samples == 0:
            logger.warning(f"Trial {trial.number} - Validation dataset '{val_file_path}' is empty or failed to load. Validation may not be effective.")
            # We might still proceed if training data is fine, but val loss will be inf.

        logger.info(f"Trial {trial.number}: Creating DataLoaders with num_workers={num_dataloader_workers}, persistent_workers=True")
        train_loader_trial = DataLoader(
            train_dataset_trial, 
            batch_size=batch_size_trial, 
            num_workers=num_dataloader_workers, 
            pin_memory=True if device.type == 'cuda' else False, 
            persistent_workers=True if num_dataloader_workers > 0 else False, # persistent_workers only if num_workers > 0
            worker_init_fn=None # Optional: can be used for worker-specific setup if needed
        )
        val_loader_trial = DataLoader(
            val_dataset_trial, 
            batch_size=batch_size_trial, 
            num_workers=num_dataloader_workers, # Can also use fewer workers for validation
            pin_memory=True if device.type == 'cuda' else False, 
            persistent_workers=True if num_dataloader_workers > 0 else False
        )

    except Exception as e:
        logger.error(f"Trial {trial.number} - Error creating dataset/dataloader: {e}", exc_info=True)
        return float('inf')

    # --- Model Initialization ---
    model = PersonalityModelV3(
        bert_model_name=global_config['BERT_MODEL_NAME'],
        num_traits=num_traits,
        n_comments_to_process=n_comments_trial,
        dropout_rate=dropout_rate,
        attention_hidden_dim=attention_hidden_dim,
        num_bert_layers_to_pool=num_bert_layers_to_pool,
        num_q_features_per_comment=num_q_features_per_comment_trial,
        num_other_numerical_features=num_other_numerical_features_trial,
        numerical_embedding_dim=other_numerical_embedding_dim_trial,
        num_additional_dense_layers=num_additional_dense_layers_trial,
        additional_dense_hidden_dim=additional_dense_hidden_dim_trial,
        additional_layers_dropout_rate=additional_layers_dropout_rate_trial
    ).to(device)

    # --- Optimizer and Scheduler Setup ---
    for name, param in model.bert.named_parameters(): param.requires_grad = False
    if num_unfrozen_bert_layers > 0:
        if hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters(): param.requires_grad = True
        actual_layers_to_unfreeze = min(num_unfrozen_bert_layers, model.bert.config.num_hidden_layers)
        for i in range(model.bert.config.num_hidden_layers - actual_layers_to_unfreeze, model.bert.config.num_hidden_layers):
            if i >= 0 and i < len(model.bert.encoder.layer) : # Check bounds
                for param in model.bert.encoder.layer[i].parameters(): param.requires_grad = True
        if hasattr(model.bert, 'pooler') and model.bert.pooler is not None:
            for param in model.bert.pooler.parameters(): param.requires_grad = True
    
    logger.debug(f"Trial {trial.number} - BERT params requiring grad: "
                 f"{sum(p.numel() for p in model.bert.parameters() if p.requires_grad)}")

    optimizer_grouped_parameters = []
    bert_params_to_tune = [p for p in model.bert.parameters() if p.requires_grad]
    if bert_params_to_tune and lr_bert > 0:
         optimizer_grouped_parameters.append({"params": bert_params_to_tune, "lr": lr_bert, "weight_decay": 0.01}) # Common to use different wd for BERT

    head_params = [] # Collect all non-BERT tunable parameters
    head_params.extend(list(model.attention_w.parameters()))
    head_params.extend(list(model.attention_v.parameters()))
    if model.uses_other_numerical_features and hasattr(model, 'other_numerical_processor'):
        head_params.extend(list(model.other_numerical_processor.parameters()))
    if model.num_additional_dense_layers > 0 and hasattr(model, 'additional_dense_block'):
        head_params.extend(list(model.additional_dense_block.parameters()))
    if hasattr(model, 'final_dropout_layer'): # Although dropout doesn't have params, good to be explicit if it were a learnable layer
        pass 
    for regressor_head in model.trait_regressors:
        head_params.extend(list(regressor_head.parameters()))
    
    if head_params:
        optimizer_grouped_parameters.append({"params": head_params, "lr": lr_head, "weight_decay": weight_decay})
        
    if not any(pg.get('params') for pg in optimizer_grouped_parameters): # Ensure there are actually parameters to optimize
        logger.warning(f"Trial {trial.number} - No parameters to optimize. Skipping training.")
        if model: del model
        if train_loader_trial: del train_loader_trial
        if val_loader_trial: del val_loader_trial
        torch.cuda.empty_cache(); gc.collect()
        return float('inf')

    optimizer = optim.AdamW(optimizer_grouped_parameters)
    
    scheduler = None
    if scheduler_type == "linear_warmup":
        # Use len(train_dataset_trial) which is self.num_samples from Arrow table
        if train_dataset_trial.num_samples > 0:
            num_batches_per_epoch = (train_dataset_trial.num_samples + batch_size_trial - 1) // batch_size_trial
            num_training_steps = num_batches_per_epoch * num_epochs_per_trial
            num_warmup_steps = int(num_training_steps * warmup_ratio)
            if num_warmup_steps > 0 and num_training_steps > 0 and num_warmup_steps < num_training_steps :
                scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
                logger.info(f"Trial {trial.number}: Linear warmup scheduler created. Warmup steps: {num_warmup_steps}, Total steps: {num_training_steps}")
            else:
                logger.warning(f"Trial {trial.number}: Calculated num_warmup_steps ({num_warmup_steps}) or num_training_steps ({num_training_steps}) is invalid. Scheduler not created.")
        else:
            logger.warning(f"Trial {trial.number}: Training dataset has 0 samples. Cannot create linear_warmup scheduler.")

    # --- Training Loop ---
    loss_fn = nn.L1Loss().to(device) # Using L1Loss as in your original code
    best_val_loss_this_trial = float('inf')
    patience_counter = 0
                
    for epoch in range(num_epochs_per_trial):
        model.train()
        total_train_loss = 0
        train_batches_processed = 0
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}: Starting training...")
        for batch_idx, batch_data in enumerate(train_loader_trial):
            if not batch_data or len(batch_data) < 6: # Check if batch is empty or malformed
                logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: Received empty or malformed batch. Skipping.")
                continue
            
            try:
                input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_data]
            except Exception as e:
                logger.error(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: Error moving batch to device or unpacking: {e}. Batch data: {batch_data}")
                continue

            optimizer.zero_grad()
            predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
            current_batch_loss = loss_fn(predicted_scores, labels_reg)
            
            if torch.isnan(current_batch_loss) or torch.isinf(current_batch_loss):
                logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: NaN or Inf loss detected ({current_batch_loss.item()}). Skipping batch gradient update.")
                torch.cuda.empty_cache() # Try to clear memory if it's an OOM leading to NaN
                continue # Skip optimizer step and backward for this batch
                
            current_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler: scheduler.step()
            
            total_train_loss += current_batch_loss.item()
            train_batches_processed += 1
            if (batch_idx + 1) % 100 == 0: # Log progress every 100 batches
                 logger.debug(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader_trial) if hasattr(train_loader_trial, '__len__') and len(train_loader_trial) > 0 else 'Unknown'}: Current Avg Train Loss: {total_train_loss/train_batches_processed:.4f}")

        avg_train_loss = total_train_loss / train_batches_processed if train_batches_processed > 0 else float('inf')
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}/{num_epochs_per_trial} completed. Avg Train Loss: {avg_train_loss:.4f}")

        # --- Validation Phase ---
        model.eval()
        current_epoch_val_loss = 0
        val_batches_processed = 0
        all_val_preds_epoch = []
        all_val_labels_epoch = []
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}: Starting validation...")
        with torch.no_grad():
            for val_batch_idx, batch_data_val in enumerate(val_loader_trial):
                if not batch_data_val or len(batch_data_val) < 6:
                    logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Val Batch {val_batch_idx}: Received empty or malformed validation batch. Skipping.")
                    continue
                try:
                    input_ids_v, attention_m_v, q_s_v, comment_active_m_v, other_num_feats_v, labels_reg_v = [b.to(device) for b in batch_data_val]
                except Exception as e:
                    logger.error(f"Trial {trial.number}, Epoch {epoch+1}, Val Batch {val_batch_idx}: Error moving val batch to device or unpacking: {e}")
                    continue

                if input_ids_v.numel() == 0: continue # Should not happen if dataset is not empty
                predicted_scores_v = model(input_ids_v, attention_m_v, q_s_v, comment_active_m_v, other_num_feats_v)
                if predicted_scores_v.numel() == 0: continue

                batch_val_loss = loss_fn(predicted_scores_v, labels_reg_v)
                if torch.isnan(batch_val_loss) or torch.isinf(batch_val_loss):
                    logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Val Batch {val_batch_idx}: NaN or Inf validation loss ({batch_val_loss.item()}).")
                    current_epoch_val_loss += float('inf') # Penalize heavily
                else:
                    current_epoch_val_loss += batch_val_loss.item()
                
                all_val_preds_epoch.append(predicted_scores_v.cpu())
                all_val_labels_epoch.append(labels_reg_v.cpu())
                val_batches_processed += 1
        
        if val_dataset_trial.num_samples == 0 or val_batches_processed == 0: # Handle empty validation set
            avg_val_loss_epoch = float('inf')
            val_mae = float('inf')
            logger.warning(f"Trial {trial.number}, Epoch {epoch+1}: Validation dataset is empty or no validation batches processed. Setting val_loss to infinity.")
        else:
            avg_val_loss_epoch = current_epoch_val_loss / val_batches_processed
            val_mae = -1.0 # Default if calculation fails
            if all_val_labels_epoch and all_val_preds_epoch:
                try:
                    all_val_labels_cat = torch.cat(all_val_labels_epoch, dim=0)
                    all_val_preds_cat = torch.cat(all_val_preds_epoch, dim=0)
                    if all_val_labels_cat.numel() > 0 and all_val_preds_cat.numel() > 0:
                        val_mae = F.l1_loss(all_val_preds_cat, all_val_labels_cat).item()
                except Exception as e:
                    logger.error(f"Trial {trial.number}, Epoch {epoch+1}: Error calculating validation MAE: {e}")
                    val_mae = float('inf') # Indicate error in MAE calculation

        logger.info(f"Trial {trial.number}, Epoch {epoch+1} Val Loss (Target Metric, e.g. L1): {avg_val_loss_epoch:.4f}, Val MAE (if L1 used): {val_mae:.4f}")


        # --- Early Stopping & Optuna Pruning/Reporting ---
        if avg_val_loss_epoch < best_val_loss_this_trial:
            best_val_loss_this_trial = avg_val_loss_epoch
            patience_counter = 0
            logger.debug(f"Trial {trial.number}, Epoch {epoch+1}: New best val_loss for this trial: {best_val_loss_this_trial:.4f}")
        else:
            patience_counter += 1
        
        if hasattr(trial, 'study') and trial.study is not None:
            current_overall_best_loss = trial.study.user_attrs.get("overall_best_val_loss", float('inf'))
            if avg_val_loss_epoch < current_overall_best_loss:
                logger.info(f"Trial {trial.number}, Epoch {epoch+1}: New OVERALL best val_loss: {avg_val_loss_epoch:.4f} (Prev: {current_overall_best_loss:.4f}). Saving model.")
                trial.study.set_user_attr("overall_best_val_loss", avg_val_loss_epoch)
                trial.study.set_user_attr("overall_best_trial_number", trial.number)
                trial.study.set_user_attr("overall_best_epoch", epoch + 1)
                model_state_dict_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
                torch.save(model_state_dict_cpu, overall_best_weights_filepath)
                logger.info(f"Trial {trial.number}: Saved new OVERALL best model weights to {overall_best_weights_filepath}")
        else:
            logger.warning(f"Trial {trial.number}: Cannot access study.user_attrs for overall best model tracking.")

        trial.report(avg_val_loss_epoch, epoch)
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned by Optuna at epoch {epoch+1}.")
            del model, train_loader_trial, val_loader_trial, optimizer, scheduler
            torch.cuda.empty_cache(); gc.collect()
            return best_val_loss_this_trial
        
        if patience_counter >= patience_early_stopping:
            logger.info(f"Trial {trial.number} - Early stopping at epoch {epoch+1} (Patience: {patience_early_stopping}).")
            break
        
    logger.info(f"Trial {trial.number} finished. Best Val Loss for this trial: {best_val_loss_this_trial:.4f}")
    del model, train_loader_trial, val_loader_trial, optimizer, scheduler
    torch.cuda.empty_cache(); gc.collect()
    return best_val_loss_this_trial

In [None]:
# Cell 2 (Modified)

# Assuming ArrowIterableDataset, PersonalityModelV3, objective are defined/imported from Cell 1

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

# --- UPDATE FILE PATHS TO .arrow ---
TRAIN_DATA_FILE = "train_data.arrow" # Changed from .jsonl
VAL_DATA_FILE = "val_data.arrow"     # Changed from .jsonl
TEST_DATA_FILE = "test_data.arrow"       # Changed from .jsonl

# --- DATALOADER WORKERS ---
NUM_DATALOADER_WORKERS = 7 # As requested
logger.info(f"Using {NUM_DATALOADER_WORKERS} DataLoader workers.")

_trait_names_ordered_config = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Emotional stability', 'Humility']
_other_numerical_features_config = [
    'mean_words_per_comment', 'mean_sents_per_comment',
    'median_words_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
]

GLOBAL_CONFIG = {
    'BERT_MODEL_NAME': "bert-base-uncased",
    'TRAIT_NAMES_ORDERED': _trait_names_ordered_config,
    'TRAIT_NAMES': _trait_names_ordered_config, # Keep both for consistency if used differently
    'MAX_COMMENTS_TO_PROCESS_PHYSICAL': 6, # Max physical comments in data
    'NUM_Q_FEATURES_PER_COMMENT': 3,
    'OTHER_NUMERICAL_FEATURE_NAMES': _other_numerical_features_config,
    'TOKENIZER_MAX_LENGTH': 256 # Passed to ArrowIterableDataset
}

NUM_EPOCHS_PER_TRIAL_OPTUNA = 15
N_OPTUNA_TRIALS = 20 # Or however many you intend

# REMOVE count_lines_in_file and pre-calculation of NUM_TRAIN_SAMPLES/NUM_VAL_SAMPLES
# These will be derived from the ArrowIterableDataset instances if needed (e.g., for scheduler).
# The ArrowIterableDataset.__init__ logs the number of samples.

# --- Optuna Study Setup ---
logger.info(f"Starting Optuna study: {N_OPTUNA_TRIALS} trials, up to {NUM_EPOCHS_PER_TRIAL_OPTUNA} epochs/trial.")

study_name = "personality_regression_v9_arrow_multiworker" # Updated study name
storage_name = f"sqlite:///{study_name}.db"
BEST_PARAMS_FILENAME = f"{study_name}_best_params.json"
BEST_WEIGHTS_FILENAME = f"{study_name}_best_weights.pth"

study = optuna.create_study(study_name=study_name,
                            direction="minimize",
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=3, n_min_trials=5, interval_steps=1),
                            storage=storage_name,
                            load_if_exists=True)

if "overall_best_val_loss" not in study.user_attrs:
    study.set_user_attr("overall_best_val_loss", float('inf'))
    logger.info(f"Initialized 'overall_best_val_loss' in study user_attrs to infinity.")
else:
    logger.info(f"Resuming study. Current 'overall_best_val_loss' in study user_attrs: {study.user_attrs['overall_best_val_loss']:.4f}")

if study.trials: logger.info(f"Resuming existing study {study.study_name} with {len(study.trials)} previous trials.")

try:
    study.optimize(
        lambda trial: objective(
            trial, TRAIN_DATA_FILE, VAL_DATA_FILE,
            GLOBAL_CONFIG, DEVICE,
            num_epochs_per_trial=NUM_EPOCHS_PER_TRIAL_OPTUNA,
            num_dataloader_workers=NUM_DATALOADER_WORKERS, # Pass this
            overall_best_weights_filepath=BEST_WEIGHTS_FILENAME
        ),
        n_trials=N_OPTUNA_TRIALS,
        gc_after_trial=True, # Good practice with large models/data
        # timeout=SOME_TIMEOUT_IN_SECONDS, # Optional: set a timeout for the whole study
        # n_jobs=1 # Optuna's n_jobs is for parallel trials, not dataloader workers. Keep as 1 unless running trials in parallel.
    )
except KeyboardInterrupt:
    logger.warning("Optuna study interrupted by user.")
except Exception as e:
    logger.exception("An error occurred during the Optuna study.")
finally:
    logger.info("\n--- Optuna Study Finished (or Interrupted) ---")
    logger.info(f"Number of trials in study: {len(study.trials)}")

    # --- Save Best Trial Info ---
    best_trial_overall_from_study_obj = None
    if not study.trials:
        logger.warning("No trials were completed in the study.")
    else:
        try:
            completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE and t.value is not None and t.value != float('inf')]
            if completed_trials:
                # Get the best trial based on the value reported by the objective function
                # This might differ from the one that saved overall_best_weights if pruning happened or if an error occurred
                # after saving weights but before reporting the value.
                # The overall_best_weights_filepath should point to the truly best model seen across all epochs of all trials.
                
                # Log Optuna's best trial (based on returned objective values)
                best_trial_optuna_reported = study.best_trial 
                if best_trial_optuna_reported and best_trial_optuna_reported.value is not None:
                     logger.info(f"Optuna's Best Trial (based on reported values to Optuna):")
                     logger.info(f"  Number: {best_trial_optuna_reported.number}")
                     logger.info(f"  Value (Validation Loss): {best_trial_optuna_reported.value:.4f}")
                     logger.info("  Params (from this trial): ")
                     for key, value in best_trial_optuna_reported.params.items():
                         logger.info(f"    {key}: {value}")
                     # Save these params, as they led to Optuna's best *reported* value.
                     with open(BEST_PARAMS_FILENAME, 'w') as f:
                         json.dump(best_trial_optuna_reported.params, f, indent=4)
                     logger.info(f"Hyperparameters from Optuna's best reported trial ({best_trial_optuna_reported.number}) saved to {BEST_PARAMS_FILENAME}")
                else:
                    logger.warning("Optuna study has trials, but study.best_trial is None or has no value. Cannot save its parameters.")


                # Log information about the model whose weights were saved
                overall_best_val_loss_attr = study.user_attrs.get("overall_best_val_loss", float('inf'))
                overall_best_trial_attr = study.user_attrs.get("overall_best_trial_number", "N/A")
                overall_best_epoch_attr = study.user_attrs.get("overall_best_epoch", "N/A")

                logger.info(f"Overall best model weights (saved during training) are expected in: {BEST_WEIGHTS_FILENAME}")
                if os.path.exists(BEST_WEIGHTS_FILENAME) and overall_best_val_loss_attr != float('inf'):
                    logger.info(f"  This model achieved a validation loss of: {overall_best_val_loss_attr:.4f}")
                    logger.info(f"  It was saved from Trial: {overall_best_trial_attr}, Epoch: {overall_best_epoch_attr}")
                    # You might want to ALSO save the params of the trial that produced overall_best_weights_filepath
                    # if overall_best_trial_attr is different from best_trial_optuna_reported.number
                    if overall_best_trial_attr != "N/A" and (not best_trial_optuna_reported or overall_best_trial_attr != best_trial_optuna_reported.number):
                        try:
                            params_of_best_saved_model = study.trials[overall_best_trial_attr].params
                            best_saved_model_params_filename = f"{study_name}_params_for_best_weights.json"
                            with open(best_saved_model_params_filename, 'w') as f:
                                json.dump(params_of_best_saved_model, f, indent=4)
                            logger.info(f"Hyperparameters for the model in '{BEST_WEIGHTS_FILENAME}' (Trial {overall_best_trial_attr}) saved to {best_saved_model_params_filename}")
                        except Exception as e_params:
                            logger.error(f"Could not retrieve or save params for trial {overall_best_trial_attr}: {e_params}")

                else:
                    logger.warning(f"  Expected overall best weights file {BEST_WEIGHTS_FILENAME} was NOT found, or no model improved initial loss.")
            else:
                logger.warning("No trials completed successfully to determine the best trial.")

            study_df = study.trials_dataframe(attrs=('number', 'value', 'params', 'state', 'user_attrs', 'datetime_start', 'datetime_complete', 'duration'))
            study_df.to_csv(f"{study_name}_results.csv", index=False)
            logger.info(f"Optuna study results saved to {study_name}_results.csv")

        except Exception as e:
            logger.error(f"Could not process or save Optuna study results: {e}", exc_info=True)


# --- Example: Predicting on Test Data using saved best model and params ---
# Ensure BEST_PARAMS_FILENAME corresponds to the params for BEST_WEIGHTS_FILENAME
# Using the potentially separate params file for the best saved weights
params_file_for_testing = f"{study_name}_params_for_best_weights.json"
if not os.path.exists(params_file_for_testing):
    params_file_for_testing = BEST_PARAMS_FILENAME # Fallback to Optuna's best reported

if os.path.exists(TEST_DATA_FILE) and os.path.exists(params_file_for_testing) and os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.info(f"\n--- Predicting on Test Data using overall best saved model and params from '{params_file_for_testing}' ---")
    try:
        with open(params_file_for_testing, 'r') as f:
            loaded_best_params = json.load(f)
        logger.info(f"Loaded best hyperparameters from {params_file_for_testing}")

        # Initialize test model (same as before)
        test_model = PersonalityModelV3(
            bert_model_name=GLOBAL_CONFIG['BERT_MODEL_NAME'],
            num_traits=len(GLOBAL_CONFIG['TRAIT_NAMES']),
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            dropout_rate=loaded_best_params.get("dropout_rate", 0.2),
            attention_hidden_dim=loaded_best_params.get("attention_hidden_dim", 128),
            num_bert_layers_to_pool=loaded_best_params.get("num_bert_layers_to_pool", 2),
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            num_other_numerical_features=len(GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES']),
            numerical_embedding_dim=loaded_best_params.get("other_numerical_embedding_dim", 0) if GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'] else 0,
            num_additional_dense_layers=loaded_best_params.get("num_additional_dense_layers", 0),
            additional_dense_hidden_dim=loaded_best_params.get("additional_dense_hidden_dim", 256),
            additional_layers_dropout_rate=loaded_best_params.get("additional_layers_dropout_rate", 0.3)
        ).to(DEVICE)
        logger.info("Test model initialized with loaded best hyperparameters.")

        if torch.cuda.is_available():
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME)
        else:
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME, map_location=torch.device('cpu'))
        
        test_model.load_state_dict(loaded_state_dict)
        logger.info(f"Successfully loaded model weights from {BEST_WEIGHTS_FILENAME}")
        test_model.eval()

        # Use ArrowIterableDataset for test data
        logger.info(f"Initializing ArrowIterableDataset for test data from {TEST_DATA_FILE}...")
        test_dataset = ArrowIterableDataset(
            file_path=TEST_DATA_FILE,
            trait_names=GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            other_numerical_feature_names=GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'],
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            tokenizer_max_length=GLOBAL_CONFIG['TOKENIZER_MAX_LENGTH'],
            is_test_set=True
        )

        if test_dataset.num_samples == 0:
             logger.warning(f"Test file {TEST_DATA_FILE} is empty or failed to load. No test predictions will be made.")
        else:
            logger.info(f"Test dataset loaded with {test_dataset.num_samples} samples.")
            test_batch_size = loaded_best_params.get("batch_size", 16) # Use batch size from best params
            # For test loader, num_workers=0 is often fine and simpler, but can be >0.
            # persistent_workers is also less critical for a single pass of testing.
            test_loader = DataLoader(test_dataset,
                                     batch_size=test_batch_size,
                                     shuffle=False, # No shuffle for testing
                                     num_workers=0, # Or min(NUM_DATALOADER_WORKERS, 4) for faster test eval if needed
                                     persistent_workers=False)

            all_test_predictions = []
            with torch.no_grad():
                for batch_idx, batch_tuple_test in enumerate(test_loader):
                    if not batch_tuple_test or len(batch_tuple_test) < 5:
                        logger.warning(f"Test Batch {batch_idx}: Received empty or malformed batch. Skipping.")
                        continue
                    try:
                        input_ids_t, attention_m_t, q_s_t, comment_active_m_t, other_num_feats_t = [b.to(DEVICE) for b in batch_tuple_test]
                    except Exception as e:
                        logger.error(f"Test Batch {batch_idx}: Error moving batch to device or unpacking: {e}")
                        continue

                    predicted_scores_t = test_model(input_ids_t, attention_m_t, q_s_t, comment_active_m_t, other_num_feats_t)
                    all_test_predictions.append(predicted_scores_t.cpu().numpy())

            if all_test_predictions:
                final_test_predictions = np.concatenate(all_test_predictions, axis=0)
                logger.info(f"Shape of final test predictions: {final_test_predictions.shape}")
                # Log first few predictions
                for i in range(min(5, len(final_test_predictions))):
                    pred_dict = {trait: round(score.item(), 4) for trait, score in zip(GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'], final_test_predictions[i])}
                    logger.info(f"Test Sample Index {i} Predictions: {pred_dict}")
                # np.save(f"{study_name}_test_predictions.npy", final_test_predictions) # Optionally save predictions
                # logger.info(f"Test predictions saved to {study_name}_test_predictions.npy")
            else:
                logger.warning("No predictions generated for the test set (all_test_predictions list is empty).")
    
    except FileNotFoundError as e:
        logger.warning(f"Required file for test prediction not found: {e}. Skipping test prediction.")
    except Exception as e:
        logger.error(f"An error occurred during test prediction: {e}", exc_info=True)
elif not os.path.exists(TEST_DATA_FILE):
    logger.info(f"Test data file '{TEST_DATA_FILE}' not found. Skipping test prediction example.")
elif not os.path.exists(params_file_for_testing) or not os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.warning(f"Best parameters file ('{params_file_for_testing}') or weights file ('{BEST_WEIGHTS_FILENAME}') not found. Skipping test prediction.")

# BEFORE GOING FROM JSONL TO ARROW

In [None]:
import json
import torch
from torch.utils.data import IterableDataset
from transformers.tokenization_utils_base import BatchEncoding # For your decode_from_json
import logging
import random
import numpy as np
import torch.nn.functional as F
from transformers import BertModel, BertConfig, get_linear_schedule_with_warmup
from typing import Optional, Tuple, Dict, Union
from torch import nn
import optuna
from torch.utils.data import DataLoader
import gc
from transformers.tokenization_utils_base import BatchEncoding # For type checking and instantiation
import torch.optim as optim
import os
import shutil # Keep for now, might be useful for other file ops if needed later

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants for JSON (ensure these match what you used when saving) ---
_TENSOR_MARKER = "__tensor__"
_TENSOR_DTYPE_MARKER = "__tensor_dtype__"
_BATCH_ENCODING_MARKER = "__batch_encoding__"
_BATCH_ENCODING_DATA_MARKER = "data"

def _convert_str_to_dtype(dtype_str: str) -> torch.dtype:
    if not dtype_str.startswith("torch."):
        try:
            return torch.__getattribute__(dtype_str)
        except AttributeError:
            return torch.dtype(dtype_str)
    dtype_name = dtype_str.split('.')[1]
    return torch.__getattribute__(dtype_name)

def _json_object_hook_for_dataset(dct: dict) -> any:
    if _TENSOR_MARKER in dct:
        dtype_str = dct.get(_TENSOR_DTYPE_MARKER, 'float32')
        dtype = _convert_str_to_dtype(dtype_str)
        return torch.tensor(dct[_BATCH_ENCODING_DATA_MARKER], dtype=dtype)
    elif _BATCH_ENCODING_MARKER in dct:
        reconstructed_data_for_be = {}
        batch_encoding_payload = dct.get(_BATCH_ENCODING_DATA_MARKER, {})
        for k, v_data in batch_encoding_payload.items():
            if isinstance(v_data, list) and k in ["input_ids", "token_type_ids", "attention_mask"]:
                try:
                    tensor_dtype = torch.long if k in ["input_ids", "token_type_ids"] else torch.long
                    reconstructed_data_for_be[k] = torch.tensor(v_data, dtype=tensor_dtype)
                except Exception as e:
                    logger.error(f"Error converting field '{k}' in BatchEncoding to tensor: {e}. Keeping as list.")
                    reconstructed_data_for_be[k] = v_data
            else:
                reconstructed_data_for_be[k] = v_data
        return BatchEncoding(reconstructed_data_for_be)
    return dct

class JsonlIterableDataset(IterableDataset):
    def __init__(self, file_path, trait_names, n_comments_to_process,
                 other_numerical_feature_names, num_q_features_per_comment,
                 is_test_set=False, transform_fn=None, num_samples = None):
        super().__init__()
        self.file_path = file_path
        self.trait_names_ordered = trait_names
        self.n_comments_to_process = n_comments_to_process
        self.other_numerical_feature_names = other_numerical_feature_names
        self.num_q_features_per_comment = num_q_features_per_comment
        self.is_test_set = is_test_set
        self.transform_fn = self._default_transform if transform_fn is None else transform_fn
        if num_samples is None:
            logger.info(f'Counting samples in {file_path} for __len__ was not provided...')
            self.num_samples = self._count_samples_in_file()
            logger.info(f"Counted {self.num_samples} samples in {self.file_path}.")
        else:
            self.num_samples = num_samples
        if self.num_samples == 0:
            logger.warning(f"Initialized JsonlIterableDataset for {self.file_path} with 0 samples. DataLoader will be empty.")

    def _count_samples_in_file(self):
            count = 0
            try:
                with open(self.file_path, 'r', encoding='utf-8') as f:
                    for _ in f:
                        count += 1
            except FileNotFoundError:
                logger.error(f"File not found during initial sample count: {self.file_path}. Returning 0 samples.")
                return 0
            except Exception as e:
                logger.error(f"Error during initial sample count for {self.file_path}: {e}. Returning 0 samples.")
                return 0
            return count

    def _process_line(self, line):
        try:
            sample = json.loads(line, object_hook=_json_object_hook_for_dataset)
            return self.transform_fn(sample, idx=None)
        except json.JSONDecodeError:
            return None
        except Exception:
            return None

    def __len__(self):
        return self.num_samples

    def _default_transform(self, sample, idx):
        tokenized_info = sample.get('features', {}).get('comments_tokenized', {})
        all_input_ids = tokenized_info['input_ids']
        all_attention_mask = tokenized_info['attention_mask']

        num_actual_comments = all_input_ids.shape[0]
        final_input_ids = torch.zeros((self.n_comments_to_process, all_input_ids.shape[1]), dtype=torch.long)
        final_attention_mask = torch.zeros((self.n_comments_to_process, all_attention_mask.shape[1]), dtype=torch.long)
        comment_active_flags = torch.zeros(self.n_comments_to_process, dtype=torch.bool)

        indices_to_select = list(range(num_actual_comments))
        if num_actual_comments > self.n_comments_to_process:
            indices_to_select = random.sample(indices_to_select, self.n_comments_to_process)
            comments_to_fill = self.n_comments_to_process
        else:
            comments_to_fill = num_actual_comments

        for i in range(comments_to_fill):
            original_idx = indices_to_select[i]
            final_input_ids[i] = all_input_ids[original_idx]
            final_attention_mask[i] = all_attention_mask[original_idx]
            comment_active_flags[i] = True

        raw_q_scores = sample['features'].get('q_scores', [])
        final_q_scores = torch.zeros((self.n_comments_to_process, self.num_q_features_per_comment), dtype=torch.float)

        selected_raw_q_scores = []
        for i in range(comments_to_fill):
            original_comment_idx = indices_to_select[i]
            if original_comment_idx < len(raw_q_scores):
                qs_for_comment = raw_q_scores[original_comment_idx][:self.num_q_features_per_comment]
                padded_qs = qs_for_comment + [0.0] * (self.num_q_features_per_comment - len(qs_for_comment))
                selected_raw_q_scores.append(padded_qs[:self.num_q_features_per_comment])
            else:
                selected_raw_q_scores.append([0.0] * self.num_q_features_per_comment)

        if comments_to_fill > 0 and selected_raw_q_scores:
            try:
                final_q_scores[:comments_to_fill] = torch.tensor(selected_raw_q_scores, dtype=torch.float)
            except Exception as e:
                logger.error(f"Error converting selected_raw_q_scores to tensor: {e}. Data: {selected_raw_q_scores}")

        other_numerical_features_list = []
        for fname in self.other_numerical_feature_names:
            val = sample['features'].get(fname, 0.0)
            try:
                other_numerical_features_list.append(float(val))
            except (ValueError, TypeError):
                other_numerical_features_list.append(0.0)
        other_numerical_features_tensor = torch.tensor(other_numerical_features_list, dtype=torch.float)

        if not self.is_test_set:
            labels_dict = sample['labels']
            regression_labels = []
            for trait_key in self.trait_names_ordered:
                label_val = labels_dict.get(trait_key.title(), labels_dict.get(trait_key, 0.0))
                try:
                    label_float = float(label_val)
                    if not (0.0 <= label_float <= 1.0): label_float = np.clip(label_float, 0.0, 1.0)
                    regression_labels.append(label_float)
                except (ValueError, TypeError): regression_labels.append(0.0)
            labels_tensor = torch.tensor(regression_labels, dtype=torch.float)
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor, labels_tensor)
        else:
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        try:
            file_iter = open(self.file_path, 'r', encoding='utf-8')
        except FileNotFoundError:
            logger.error(f"File not found in __iter__: {self.file_path}. Yielding nothing.")
            return

        if worker_info is None:
            for line in file_iter:
                processed_item = self._process_line(line)
                if processed_item:
                    yield processed_item
        else:
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            for i, line in enumerate(file_iter):
                if i % num_workers == worker_id:
                    processed_item = self._process_line(line)
                    if processed_item:
                        yield processed_item
        file_iter.close()


class PersonalityModelV3(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 num_traits: int,
                 n_comments_to_process: int = 3,
                 dropout_rate: float = 0.2,
                 attention_hidden_dim: int = 128,
                 num_bert_layers_to_pool: int = 4,
                 num_q_features_per_comment: int = 3,
                 num_other_numerical_features: int = 0,
                 numerical_embedding_dim: int = 64,
                 num_additional_dense_layers: int = 0,
                 additional_dense_hidden_dim: int = 256,
                 additional_layers_dropout_rate: float = 0.3
                ):
        super().__init__()
        self.bert_config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)
        self.bert = BertModel.from_pretrained(bert_model_name, config=self.bert_config)
        self.n_comments_to_process = n_comments_to_process
        self.num_bert_layers_to_pool = num_bert_layers_to_pool
        bert_hidden_size = self.bert.config.hidden_size
        self.num_q_features_per_comment = num_q_features_per_comment

        comment_feature_dim = bert_hidden_size + self.num_q_features_per_comment
        self.attention_w = nn.Linear(comment_feature_dim, attention_hidden_dim)
        self.attention_v = nn.Linear(attention_hidden_dim, 1, bias=False)
        
        self.final_dropout_layer = nn.Dropout(dropout_rate) 

        self.num_other_numerical_features = num_other_numerical_features
        self.uses_other_numerical_features = self.num_other_numerical_features > 0
        self.other_numerical_processor_output_dim = 0
        
        aggregated_comment_feature_dim = comment_feature_dim 
        combined_input_dim_for_block = aggregated_comment_feature_dim

        if self.uses_other_numerical_features:
            self.other_numerical_processor_output_dim = numerical_embedding_dim
            self.other_numerical_processor = nn.Sequential(
                nn.Linear(self.num_other_numerical_features, self.other_numerical_processor_output_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )
            combined_input_dim_for_block += self.other_numerical_processor_output_dim
            logger.info(f"Model will use {self.num_other_numerical_features} other numerical features, processed to dim {self.other_numerical_processor_output_dim}.")
        else:
            logger.info("Model will NOT use other numerical features.")

        self.num_additional_dense_layers = num_additional_dense_layers
        self.additional_dense_block = nn.Sequential()
        current_dim_for_dense_block = combined_input_dim_for_block

        if self.num_additional_dense_layers > 0:
            logger.info(f"Model using {self.num_additional_dense_layers} additional dense layers with hidden_dim {additional_dense_hidden_dim} and dropout {additional_layers_dropout_rate}")
            for i in range(self.num_additional_dense_layers):
                self.additional_dense_block.add_module(f"add_dense_{i}_linear", nn.Linear(current_dim_for_dense_block, additional_dense_hidden_dim))
                self.additional_dense_block.add_module(f"add_dense_{i}_relu", nn.ReLU())
                self.additional_dense_block.add_module(f"add_dense_{i}_dropout", nn.Dropout(additional_layers_dropout_rate))
                current_dim_for_dense_block = additional_dense_hidden_dim
            input_dim_for_regressors = current_dim_for_dense_block
        else:
            logger.info("Model not using additional dense layers. Will use final_dropout_layer if dropout_rate > 0.")
            input_dim_for_regressors = combined_input_dim_for_block

        self.trait_regressors = nn.ModuleList()
        for _ in range(num_traits):
            self.trait_regressors.append(
                nn.Linear(input_dim_for_regressors, 1)
            )

    def _pool_bert_layers(self, all_hidden_states: Tuple[torch.Tensor, ...], attention_mask: torch.Tensor) -> torch.Tensor:
        layers_to_pool = all_hidden_states[-self.num_bert_layers_to_pool:]
        pooled_outputs = []
        expanded_attention_mask = attention_mask.unsqueeze(-1).expand_as(layers_to_pool[0])
        
        for layer_hidden_states in layers_to_pool:
            sum_embeddings = torch.sum(layer_hidden_states * expanded_attention_mask, dim=1)
            sum_mask = expanded_attention_mask.sum(dim=1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled_outputs.append(sum_embeddings / sum_mask)
            
        stacked_pooled_outputs = torch.stack(pooled_outputs, dim=0)
        mean_pooled_layers_embedding = torch.mean(stacked_pooled_outputs, dim=0)
        return mean_pooled_layers_embedding

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                q_scores: torch.Tensor,
                comment_active_mask: torch.Tensor,
                other_numerical_features: Optional[torch.Tensor] = None
               ):
        batch_size = input_ids.shape[0]
        
        input_ids_flat = input_ids.view(-1, input_ids.shape[-1])
        attention_mask_flat = attention_mask.view(-1, attention_mask.shape[-1])
        
        bert_outputs = self.bert(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        comment_bert_embeddings_flat = self._pool_bert_layers(bert_outputs.hidden_states, attention_mask_flat)
        comment_bert_embeddings = comment_bert_embeddings_flat.view(batch_size, self.n_comments_to_process, -1)
        
        comment_features_with_q = torch.cat((comment_bert_embeddings, q_scores), dim=2)
        
        u = torch.tanh(self.attention_w(comment_features_with_q))
        scores = self.attention_v(u).squeeze(-1)
        
        if comment_active_mask is not None:
            scores = scores.masked_fill(~comment_active_mask, -1e9)
            
        attention_weights = F.softmax(scores, dim=1)
        attention_weights_expanded = attention_weights.unsqueeze(-1)
        
        aggregated_comment_features = torch.sum(attention_weights_expanded * comment_features_with_q, dim=1)

        final_features_for_processing = aggregated_comment_features
        if self.uses_other_numerical_features:
            if other_numerical_features is None or other_numerical_features.shape[1] != self.num_other_numerical_features:
                raise ValueError(
                    f"Other numerical features expected but not provided correctly. "
                    f"Expected {self.num_other_numerical_features}, got shape {other_numerical_features.shape if other_numerical_features is not None else 'None'}"
                )
            processed_other_numerical_features = self.other_numerical_processor(other_numerical_features)
            final_features_for_processing = torch.cat((aggregated_comment_features, processed_other_numerical_features), dim=1)
        
        if self.num_additional_dense_layers > 0:
            features_for_trait_heads = self.additional_dense_block(final_features_for_processing)
        else:
            features_for_trait_heads = self.final_dropout_layer(final_features_for_processing)
        
        trait_regression_outputs = []
        for regressor_head in self.trait_regressors:
            trait_regression_outputs.append(regressor_head(features_for_trait_heads))
        
        all_trait_outputs_raw = torch.cat(trait_regression_outputs, dim=1)
        all_trait_outputs_sigmoid = torch.sigmoid(all_trait_outputs_raw)
        
        return all_trait_outputs_sigmoid

    def predict_scores(self, outputs: torch.Tensor) -> torch.Tensor:
        return outputs

# --- Optuna Objective Function (MODIFIED for overall best model saving) ---
def objective(trial: optuna.trial.Trial,
              train_file_path: str,
              val_file_path: str,
              global_config: Dict,
              device: torch.device,
              num_epochs_per_trial: int, # Removed default
              ### NEW: Pass the path for saving the overall best model weights ###
              overall_best_weights_filepath: str 
             ):
    logger.info(f"Starting Optuna Trial {trial.number}")

    num_traits = len(global_config['TRAIT_NAMES'])
    other_numerical_feature_names_trial = global_config.get('OTHER_NUMERICAL_FEATURE_NAMES', [])
    num_other_numerical_features_trial = len(other_numerical_feature_names_trial)
    num_q_features_per_comment_trial = global_config.get('NUM_Q_FEATURES_PER_COMMENT', 3)

    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    attention_hidden_dim = trial.suggest_categorical("attention_hidden_dim", [128, 256, 512])
    lr_bert = trial.suggest_float("lr_bert", 5e-6, 1e-4, log=True)
    lr_head = trial.suggest_float("lr_head", 1e-4, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    num_bert_layers_to_pool = trial.suggest_int("num_bert_layers_to_pool", 1, 4)
    n_comments_trial = trial.suggest_int("n_comments_to_process", 3, global_config.get('MAX_COMMENTS_TO_PROCESS_PHYSICAL', 3))
    num_unfrozen_bert_layers = trial.suggest_int("num_unfrozen_bert_layers", 0, 6)
    patience_early_stopping = trial.suggest_int("patience_early_stopping", 3, 5)
    scheduler_type = trial.suggest_categorical("scheduler_type", ["none", "linear_warmup"])
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.05, 0.2) if scheduler_type != "none" else 0.0
    batch_size_trial = trial.suggest_categorical("batch_size", [8, 16, 32])

    other_numerical_embedding_dim_trial = 0
    if num_other_numerical_features_trial > 0:
        other_numerical_embedding_dim_trial = trial.suggest_categorical("other_numerical_embedding_dim", [32, 64, 128])

    num_additional_dense_layers_trial = trial.suggest_int("num_additional_dense_layers", 0, 3)
    additional_dense_hidden_dim_trial = 0
    additional_layers_dropout_rate_trial = 0.0
    if num_additional_dense_layers_trial > 0:
        additional_dense_hidden_dim_trial = trial.suggest_categorical("additional_dense_hidden_dim", [128, 256, 512])
        additional_layers_dropout_rate_trial = trial.suggest_float("additional_layers_dropout_rate", 0.1, 0.5)

    logger.info(f"Trial {trial.number} - Suggested Parameters: {trial.params}")
    try:
        train_dataset_trial = JsonlIterableDataset(
            file_path=train_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_TRAIN_SAMPLES')
        )
        val_dataset_trial = JsonlIterableDataset(
            file_path=val_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_VAL_SAMPLES')
        )
        train_loader_trial = DataLoader(train_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
        val_loader_trial = DataLoader(val_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
    except Exception as e:
        logger.error(f"Trial {trial.number} - Error creating dataset/dataloader: {e}", exc_info=True)
        return float('inf')

    model = PersonalityModelV3(
        bert_model_name=global_config['BERT_MODEL_NAME'],
        num_traits=num_traits,
        n_comments_to_process=n_comments_trial,
        dropout_rate=dropout_rate,
        attention_hidden_dim=attention_hidden_dim,
        num_bert_layers_to_pool=num_bert_layers_to_pool,
        num_q_features_per_comment=num_q_features_per_comment_trial,
        num_other_numerical_features=num_other_numerical_features_trial,
        numerical_embedding_dim=other_numerical_embedding_dim_trial,
        num_additional_dense_layers=num_additional_dense_layers_trial,
        additional_dense_hidden_dim=additional_dense_hidden_dim_trial,
        additional_layers_dropout_rate=additional_layers_dropout_rate_trial
    ).to(device)

    for name, param in model.bert.named_parameters(): param.requires_grad = False
    if num_unfrozen_bert_layers > 0:
        if hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters(): param.requires_grad = True
        actual_layers_to_unfreeze = min(num_unfrozen_bert_layers, model.bert.config.num_hidden_layers)
        for i in range(model.bert.config.num_hidden_layers - actual_layers_to_unfreeze, model.bert.config.num_hidden_layers):
            if i >= 0 and i < model.bert.config.num_hidden_layers :
                for param in model.bert.encoder.layer[i].parameters(): param.requires_grad = True
        if hasattr(model.bert, 'pooler') and model.bert.pooler is not None:
            for param in model.bert.pooler.parameters(): param.requires_grad = True
    
    logger.debug(f"Trial {trial.number} - BERT params requiring grad: "
                 f"{sum(p.numel() for p in model.bert.parameters() if p.requires_grad)}")

    optimizer_grouped_parameters = []
    bert_params_to_tune = [p for p in model.bert.parameters() if p.requires_grad]
    if bert_params_to_tune and lr_bert > 0:
         optimizer_grouped_parameters.append({"params": bert_params_to_tune, "lr": lr_bert, "weight_decay": 0.01})

    head_params = []
    head_params.extend(list(model.attention_w.parameters()))
    head_params.extend(list(model.attention_v.parameters()))
    if model.uses_other_numerical_features:
        head_params.extend(list(model.other_numerical_processor.parameters()))
    if model.num_additional_dense_layers > 0:
        head_params.extend(list(model.additional_dense_block.parameters()))
    for regressor_head in model.trait_regressors:
        head_params.extend(list(regressor_head.parameters()))
    
    if head_params:
        optimizer_grouped_parameters.append({"params": head_params, "lr": lr_head, "weight_decay": weight_decay})
        
    if not any(pg['params'] for pg in optimizer_grouped_parameters if pg.get('params')):
        logger.warning(f"Trial {trial.number} - No parameters to optimize. Skipping training.")
        return float('inf')

    optimizer = optim.AdamW(optimizer_grouped_parameters)
    
    scheduler = None
    if scheduler_type == "linear_warmup":
        if global_config.get('NUM_TRAIN_SAMPLES', 0) > 0:
            num_batches_per_epoch = (global_config['NUM_TRAIN_SAMPLES'] + batch_size_trial - 1) // batch_size_trial
            num_training_steps = num_batches_per_epoch * num_epochs_per_trial
            num_warmup_steps = int(num_training_steps * warmup_ratio)
            if num_warmup_steps > 0 and num_training_steps > 0:
                scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
            else:
                logger.warning(f"Trial {trial.number}: Calculated num_warmup_steps or num_training_steps is zero. Scheduler not created. Warmup: {num_warmup_steps}, Training: {num_training_steps}")
        else:
            logger.warning(f"Trial {trial.number}: NUM_TRAIN_SAMPLES not available or zero in global_config. Cannot create linear_warmup scheduler.")

    loss_fn = nn.L1Loss().to(device)
    best_val_loss_this_trial = float('inf') # For early stopping within this trial
    patience_counter = 0
                
    for epoch in range(num_epochs_per_trial):
        model.train()
        total_train_loss = 0
        train_batches_processed = 0
        for batch_idx, batch_tuple in enumerate(train_loader_trial):
            input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
            optimizer.zero_grad()
            predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
            current_batch_loss = loss_fn(predicted_scores, labels_reg)
            if torch.isnan(current_batch_loss) or torch.isinf(current_batch_loss):
                logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: NaN or Inf loss detected. Skipping batch.")
                torch.cuda.empty_cache()
                continue
            current_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler: scheduler.step()
            total_train_loss += current_batch_loss.item()
            train_batches_processed += 1
            
        avg_train_loss = total_train_loss / train_batches_processed if train_batches_processed > 0 else float('inf')
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}/{num_epochs_per_trial} completed. Avg Train Loss: {avg_train_loss:.4f}")

        model.eval()
        current_epoch_val_loss = 0
        val_batches_processed = 0
        all_val_preds_epoch = []
        all_val_labels_epoch = []
        with torch.no_grad():
            for batch_tuple in val_loader_trial:
                input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
                if input_ids.numel() == 0: continue
                predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                if predicted_scores.numel() == 0: continue
                batch_val_loss = loss_fn(predicted_scores, labels_reg)
                current_epoch_val_loss += batch_val_loss.item()
                all_val_preds_epoch.append(predicted_scores.cpu())
                all_val_labels_epoch.append(labels_reg.cpu())
                val_batches_processed += 1

        avg_val_loss_epoch = current_epoch_val_loss / val_batches_processed if val_batches_processed > 0 else float('inf')
        
        val_mae = -1.0
        if all_val_labels_epoch and all_val_preds_epoch:
            all_val_labels_cat = torch.cat(all_val_labels_epoch, dim=0)
            all_val_preds_cat = torch.cat(all_val_preds_epoch, dim=0)
            if all_val_labels_cat.numel() > 0 and all_val_preds_cat.numel() > 0:
                val_mae = F.l1_loss(all_val_preds_cat, all_val_labels_cat).item()

        logger.info(f"Trial {trial.number}, Epoch {epoch+1} Val Loss (MSE): {avg_val_loss_epoch:.4f}, Val MAE: {val_mae:.4f}")

        # Check for improvement for early stopping within this trial
        if avg_val_loss_epoch < best_val_loss_this_trial:
            best_val_loss_this_trial = avg_val_loss_epoch
            patience_counter = 0
            logger.debug(f"Trial {trial.number}, Epoch {epoch+1}: New best val_loss for this trial: {best_val_loss_this_trial:.4f}")
        else:
            patience_counter += 1
        
        ### MODIFIED: Check against overall best and save if better ###
        # Ensure study user_attrs are available (should be, unless running trial standalone)
        if hasattr(trial, 'study') and trial.study is not None:
            current_overall_best_loss = trial.study.user_attrs.get("overall_best_val_loss", float('inf'))
            if avg_val_loss_epoch < current_overall_best_loss:
                logger.info(f"Trial {trial.number}, Epoch {epoch+1}: New OVERALL best val_loss: {avg_val_loss_epoch:.4f} (Previous overall best: {current_overall_best_loss:.4f}). Saving model.")
                trial.study.set_user_attr("overall_best_val_loss", avg_val_loss_epoch)
                trial.study.set_user_attr("overall_best_trial_number", trial.number)
                trial.study.set_user_attr("overall_best_epoch", epoch + 1)
                # Save model state dict (on CPU to be safe)
                model_state_dict_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
                torch.save(model_state_dict_cpu, overall_best_weights_filepath)
                logger.info(f"Trial {trial.number}: Saved new OVERALL best model weights to {overall_best_weights_filepath}")
        else:
            logger.warning(f"Trial {trial.number}: Cannot access study.user_attrs to check/update overall best model.")


        trial.report(avg_val_loss_epoch, epoch)
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned by Optuna at epoch {epoch+1}.")
            del model, train_loader_trial, val_loader_trial, optimizer, scheduler
            torch.cuda.empty_cache(); gc.collect()
            return best_val_loss_this_trial # Return this trial's best loss for Optuna's pruning logic
        
        if patience_counter >= patience_early_stopping:
            logger.info(f"Trial {trial.number} - Early stopping at epoch {epoch+1} (Patience: {patience_early_stopping}).")
            break
        
    logger.info(f"Trial {trial.number} finished. Best Val Loss (MSE) for this trial: {best_val_loss_this_trial:.4f}")
    del model, train_loader_trial, val_loader_trial, optimizer, scheduler
    torch.cuda.empty_cache(); gc.collect()
    return best_val_loss_this_trial # Return the best validation loss achieved in *this specific trial*

In [None]:
# Assuming PersonalityDatasetV3, PersonalityModelV3, objective are defined/imported

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

TRAIN_DATA_FILE = "train_data.jsonl" 
VAL_DATA_FILE = "val_data.jsonl"
TEST_DATA_FILE = "test_data.jsonl"

_trait_names_ordered_config = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Emotional stability', 'Humility']
_other_numerical_features_config = [
    'mean_words_per_comment', 'mean_sents_per_comment',
    'median_words_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
]

GLOBAL_CONFIG = {
    'BERT_MODEL_NAME': "bert-base-uncased",
    'TRAIT_NAMES_ORDERED': _trait_names_ordered_config,
    'TRAIT_NAMES': _trait_names_ordered_config,
    'MAX_COMMENTS_TO_PROCESS_PHYSICAL': 6,
    'NUM_Q_FEATURES_PER_COMMENT': 3,
    'OTHER_NUMERICAL_FEATURE_NAMES': _other_numerical_features_config,
    'TOKENIZER_MAX_LENGTH': 256
}

NUM_EPOCHS_PER_TRIAL_OPTUNA = 15
N_OPTUNA_TRIALS = 20

def count_lines_in_file(filepath):
    try:
        count = 0
        with open(filepath, 'r', encoding='utf-8') as f:
            for _ in f:
                count += 1
        return count
    except FileNotFoundError:
        logger.error(f"File not found for line counting: {filepath}. Returning 0.")
        return 0
    except Exception as e:
        logger.error(f"Error counting lines in {filepath}: {e}. Returning 0.")
        return 0

NUM_TRAIN_SAMPLES = count_lines_in_file(TRAIN_DATA_FILE)
if NUM_TRAIN_SAMPLES == 0:
    logger.error(f"Training file {TRAIN_DATA_FILE} is empty or not found. Exiting.")
    exit()
GLOBAL_CONFIG['NUM_TRAIN_SAMPLES'] = NUM_TRAIN_SAMPLES
logger.info(f"Number of training samples: {NUM_TRAIN_SAMPLES}")

NUM_VAL_SAMPLES = count_lines_in_file(VAL_DATA_FILE)
if NUM_VAL_SAMPLES == 0:
    logger.warning(f"Validation file {VAL_DATA_FILE} is empty or not found. Validation might not work as expected.")
GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = NUM_VAL_SAMPLES
logger.info(f"Number of validation samples: {NUM_VAL_SAMPLES}")

logger.info(f"Starting Optuna study: {N_OPTUNA_TRIALS} trials, up to {NUM_EPOCHS_PER_TRIAL_OPTUNA} epochs/trial.")

study_name = "personality_regression_v8_overall_best"
storage_name = f"sqlite:///{study_name}.db"
BEST_PARAMS_FILENAME = f"{study_name}_best_params.json"
BEST_WEIGHTS_FILENAME = f"{study_name}_best_weights.pth" # This is the single file for the overall best model

study = optuna.create_study(study_name=study_name,
                            direction="minimize",
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=3, n_min_trials=5, interval_steps=1),
                            storage=storage_name,
                            load_if_exists=True)

# Initialize overall_best_val_loss in study.user_attrs if it doesn't exist
if "overall_best_val_loss" not in study.user_attrs:
    study.set_user_attr("overall_best_val_loss", float('inf'))
    logger.info(f"Initialized 'overall_best_val_loss' in study user_attrs to infinity.")
else:
    logger.info(f"Resuming study. Current 'overall_best_val_loss' in study user_attrs: {study.user_attrs['overall_best_val_loss']:.4f}")


if study.trials: logger.info(f"Resuming existing study {study.study_name} with {len(study.trials)} previous trials.")

try:
    study.optimize(
        lambda trial: objective(
            trial, TRAIN_DATA_FILE, VAL_DATA_FILE,
            GLOBAL_CONFIG, DEVICE, 
            num_epochs_per_trial=NUM_EPOCHS_PER_TRIAL_OPTUNA,
            overall_best_weights_filepath=BEST_WEIGHTS_FILENAME # Pass the path here
        ),
        n_trials=N_OPTUNA_TRIALS,
        gc_after_trial=True,
    )
except Exception as e:
    logger.exception("An error occurred during the Optuna study.")

logger.info("\n--- Optuna Study Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")

best_trial_overall_from_study_obj = None 

if not study.trials:
    logger.warning("No trials were completed in the study.")
else:
    try:
        completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE and t.value is not None]
        if completed_trials:
            best_trial_overall_from_study_obj = study.best_trial # Optuna's record of the best trial

            if best_trial_overall_from_study_obj:
                logger.info(f"Optuna's Best Trial (based on reported values):")
                logger.info(f"  Number: {best_trial_overall_from_study_obj.number}")
                logger.info(f"  Value (Validation Loss - MSE): {best_trial_overall_from_study_obj.value:.4f}") # This is the value *returned* by the objective for that trial
                logger.info("  Best Params (from this trial): ")
                for key, value in best_trial_overall_from_study_obj.params.items():
                    logger.info(f"    {key}: {value}")

                # Save the hyperparameters of Optuna's identified best trial
                with open(BEST_PARAMS_FILENAME, 'w') as f:
                    json.dump(best_trial_overall_from_study_obj.params, f, indent=4)
                logger.info(f"Best hyperparameters (from trial {best_trial_overall_from_study_obj.number}) saved to {BEST_PARAMS_FILENAME}")

                # The BEST_WEIGHTS_FILENAME should already contain the weights of the overall best model saved during the study.
                # We can log information about which trial/epoch produced it, if stored.
                overall_best_val_loss_attr = study.user_attrs.get("overall_best_val_loss", float('inf'))
                overall_best_trial_attr = study.user_attrs.get("overall_best_trial_number", "N/A")
                overall_best_epoch_attr = study.user_attrs.get("overall_best_epoch", "N/A")

                logger.info(f"Overall best model weights are expected in: {BEST_WEIGHTS_FILENAME}")
                if os.path.exists(BEST_WEIGHTS_FILENAME):
                    logger.info(f"  This model achieved a validation loss of: {overall_best_val_loss_attr:.4f} (recorded in study.user_attrs)")
                    logger.info(f"  It was saved from Trial: {overall_best_trial_attr}, Epoch: {overall_best_epoch_attr}")
                else:
                    logger.warning(f"  Expected overall best weights file {BEST_WEIGHTS_FILENAME} was NOT found. "
                                   "This might happen if no trial improved upon the initial 'inf' loss, "
                                   "or if there was an issue during saving.")
            else: # best_trial_overall_from_study_obj is None
                logger.warning("Study has completed trials, but study.best_trial is None. Cannot save parameters.")
        else: # No completed trials
            logger.warning("No trials completed successfully to determine the best trial. Cannot save parameters or confirm weights.")

        study_df = study.trials_dataframe(attrs=('number', 'value', 'params', 'state', 'user_attrs'))
        study_df.to_csv(f"{study_name}_results.csv", index=False)
        logger.info(f"Optuna study results saved to {study_name}_results.csv")

    except Exception as e:
        logger.error(f"Could not process or save Optuna study results: {e}", exc_info=True)


# --- Example: Predicting on Test Data using saved best model and params ---
# This part remains largely the same, as it expects BEST_PARAMS_FILENAME and BEST_WEIGHTS_FILENAME
if os.path.exists(TEST_DATA_FILE) and os.path.exists(BEST_PARAMS_FILENAME) and os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.info(f"\n--- Predicting on Test Data using overall best saved model and params ---")
    try:
        with open(BEST_PARAMS_FILENAME, 'r') as f:
            loaded_best_params = json.load(f)
        logger.info(f"Loaded best hyperparameters from {BEST_PARAMS_FILENAME}")

        test_model = PersonalityModelV3(
            bert_model_name=GLOBAL_CONFIG['BERT_MODEL_NAME'],
            num_traits=len(GLOBAL_CONFIG['TRAIT_NAMES']),
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            dropout_rate=loaded_best_params.get("dropout_rate", 0.2),
            attention_hidden_dim=loaded_best_params.get("attention_hidden_dim", 128),
            num_bert_layers_to_pool=loaded_best_params.get("num_bert_layers_to_pool", 2),
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            num_other_numerical_features=len(GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES']),
            numerical_embedding_dim=loaded_best_params.get("other_numerical_embedding_dim", 0) if GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'] else 0,
            num_additional_dense_layers=loaded_best_params.get("num_additional_dense_layers", 0),
            additional_dense_hidden_dim=loaded_best_params.get("additional_dense_hidden_dim", 256),
            additional_layers_dropout_rate=loaded_best_params.get("additional_layers_dropout_rate", 0.3)
        ).to(DEVICE)
        logger.info("Test model initialized with loaded best hyperparameters.")

        if torch.cuda.is_available():
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME)
        else:
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME, map_location=torch.device('cpu'))
        
        test_model.load_state_dict(loaded_state_dict)
        logger.info(f"Successfully loaded model weights from {BEST_WEIGHTS_FILENAME}")
        test_model.eval()

        NUM_TEST_SAMPLES = count_lines_in_file(TEST_DATA_FILE)
        if NUM_TEST_SAMPLES == 0:
             logger.warning(f"Test file {TEST_DATA_FILE} is empty or not found. No test predictions will be made.")
        else:
            test_dataset = JsonlIterableDataset(
                file_path=TEST_DATA_FILE,
                trait_names=GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'],
                n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
                other_numerical_feature_names=GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'],
                num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
                is_test_set=True,
                num_samples=NUM_TEST_SAMPLES
            )
            test_batch_size = loaded_best_params.get("batch_size", 16)
            test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)

            all_test_predictions = []
            with torch.no_grad():
                for batch_tuple in test_loader:
                    input_ids, attention_m, q_s, comment_active_m, other_num_feats = [b.to(DEVICE) for b in batch_tuple]
                    predicted_scores = test_model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                    all_test_predictions.append(predicted_scores.cpu().numpy())

            if all_test_predictions:
                final_test_predictions = np.concatenate(all_test_predictions, axis=0)
                logger.info(f"Shape of final test predictions: {final_test_predictions.shape}")
                for i in range(min(5, len(final_test_predictions))):
                    pred_dict = {trait: round(score.item(), 4) for trait, score in zip(GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'], final_test_predictions[i])}
                    logger.info(f"Test Sample Index {i} Predictions: {pred_dict}")
            else:
                logger.warning("No predictions generated for the test set (all_test_predictions list is empty).")
    
    except FileNotFoundError as e:
        logger.warning(f"Required file for test prediction not found: {e}. Skipping test prediction.")
    except Exception as e:
        logger.error(f"An error occurred during test prediction: {e}", exc_info=True)
elif not os.path.exists(TEST_DATA_FILE):
    logger.info(f"Test data file '{TEST_DATA_FILE}' not found. Skipping test prediction example.")
elif not os.path.exists(BEST_PARAMS_FILENAME) or not os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.warning(f"Best parameters file ({BEST_PARAMS_FILENAME}) or weights file ({BEST_WEIGHTS_FILENAME}) not found. Skipping test prediction.")

# unuse

In [None]:
# data processing to arrow (unused)
import string
import re
import numpy as np
from collections import Counter, defaultdict
import textstat
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
import pandas as pd
# import json # No longer needed for saving the main data if using HF Datasets
import seaborn as sns
import matplotlib.pyplot as plt
import nltk

from transformers import BertTokenizerFast, pipeline, BatchEncoding
import torch
import accelerate
from typing import List, Dict, Union
from sklearn.preprocessing import MinMaxScaler
import pickle

# NEW: Import Hugging Face Datasets
from datasets import Dataset #, Features, Value, Sequence # Optional for explicit schema

# --- [Your TextFeatureExtractor, show_corr, clean_train, get_q_score, append_q_score_train, to_input, pretokenize classes/functions remain largely the same] ---
# Make sure they are defined before df_preprocess or imported.
# I'll assume they are present as in your original code.

traits = ['Openness','Conscientiousness','Extraversion','Agreeableness','Emotional stability','Humility']

class TextFeatureExtractor:
    """
    A class to extract various linguistic and sentiment features from text data,
    designed to work with a Pandas DataFrame where each row contains a list of
    comment strings for a user/entry.
    """

    def __init__(self,
                 specific_punctuation_to_track: list = None,
                 readability_agg_method: str = "concat",
                 ttr_agg_method: str = "concat"):
        """
        Initializes the TextFeatureExtractor.

        Args:
            specific_punctuation_to_track (list, optional):
                A list of specific punctuation marks to count.
                Defaults to ['!', '?', '.', ','].
            readability_agg_method (str, optional):
                Method to aggregate readability scores ("concat" or "mean").
                Defaults to "concat".
            ttr_agg_method (str, optional):
                Method to aggregate Type-Token Ratio ("concat" or "mean").
                Defaults to "concat".
        """
        self.vader_analyzer = SentimentIntensityAnalyzer()

        if specific_punctuation_to_track is None:
            self.specific_punctuation_to_track = ['!', '?', '.', ',',':',';']
        else:
            self.specific_punctuation_to_track = specific_punctuation_to_track

        self.readability_agg_method = readability_agg_method
        self.ttr_agg_method = ttr_agg_method

        try:
            nltk.data.find('tokenizers/punkt')
        except nltk.downloader.DownloadError:
            print("NLTK 'punkt' tokenizer not found. Downloading...", flush=True)
            nltk.download('punkt', quiet=True)
        except LookupError: # Sometimes it's a LookupError if path is configured but resource missing
             print("NLTK 'punkt' tokenizer not found (LookupError). Downloading...", flush=True)
             nltk.download('punkt', quiet=True)


    # --- I. Basic structural feature helpers (operating on lists from a single DataFrame row) ---

    def _sentence_split(self, comment_list: list) -> list:
        """Splits each comment in a list of comments into sentences."""
        all_sentences_for_user = []
        if not isinstance(comment_list, list): return []
        for comment_text in comment_list:
            if isinstance(comment_text, str) and comment_text.strip():
                sentences = nltk.sent_tokenize(comment_text)
                all_sentences_for_user.append(sentences)
            else:
                all_sentences_for_user.append([]) # Handle empty or non-string comments
        return all_sentences_for_user # Returns list of lists of sentences e.g. [[s1,s2],[s3,s4,s5]]

    def _get_word_counts_per_comment(self, comment_list: list) -> list:
        """Calculates word count for each comment string in a list."""
        if not isinstance(comment_list, list): return []
        return [len(str(comment).split()) if isinstance(comment, str) else 0 for comment in comment_list]

    def _get_sentence_counts_per_comment(self, list_of_sentence_lists: list) -> list:
        """Counts sentences in each original comment (given pre-split sentences)."""
        if not isinstance(list_of_sentence_lists, list): return []
        return [len(sentences_in_one_comment) if isinstance(sentences_in_one_comment, list) else 0 for sentences_in_one_comment in list_of_sentence_lists]

    def _get_sentence_word_counts_per_comment(self, list_of_sentence_lists: list) -> list:
        """Calculates word counts for each sentence within each original comment."""
        result_for_user = []
        if not isinstance(list_of_sentence_lists, list): return []
        for sentences_in_one_comment in list_of_sentence_lists:
            if isinstance(sentences_in_one_comment, list):
                sent_lens = [len(str(sent).split()) if isinstance(sent, str) else 0 for sent in sentences_in_one_comment]
                result_for_user.append(sent_lens)
            else:
                result_for_user.append([])
        return result_for_user

    def _aggregate_numeric_list_of_lists(self, list_of_lists_of_numbers: list, agg_func) -> float:
        """Flattens a list of lists of numbers and applies an aggregation function."""
        if not isinstance(list_of_lists_of_numbers, list): return np.nan
        flat_list = []
        for sublist in list_of_lists_of_numbers:
            if isinstance(sublist, list):
                flat_list.extend(num for num in sublist if isinstance(num, (int, float)) and not np.isnan(num))
        return agg_func(flat_list) if flat_list else np.nan

    def _aggregate_numeric_list(self, list_of_numbers: list, agg_func) -> float:
        """Applies an aggregation function to a list of numbers."""
        if not isinstance(list_of_numbers, list): return np.nan
        valid_numbers = [num for num in list_of_numbers if isinstance(num, (int, float)) and not np.isnan(num)]
        return agg_func(valid_numbers) if valid_numbers else np.nan

    # --- II. Single-text processing helper methods (private) ---

    def _get_punctuation_counts_single(self, text: str) -> dict:
        if not isinstance(text, str): return {}
        counts = Counter(char for char in text if char in self.specific_punctuation_to_track)
        return {punc: counts.get(punc, 0) for punc in self.specific_punctuation_to_track}

    def _get_double_whitespace_count_single(self, text: str) -> int:
        """Counts occurrences of two or more consecutive whitespace characters."""
        if not isinstance(text, str) or not text.strip():
            return 0
        # Find all non-overlapping matches of 2 or more whitespace characters
        matches = re.findall(r"\s{2,}", text)
        return len(matches)

    def _get_readability_scores_single(self, text: str) -> dict:
        if not isinstance(text, str) or not text.strip():
            return {'flesch_reading_ease': np.nan, 'gunning_fog': np.nan}
        try:
            return {
                'flesch_reading_ease': textstat.flesch_reading_ease(text),
                'gunning_fog': textstat.gunning_fog(text)
            }
        except Exception:
            return {'flesch_reading_ease': np.nan, 'gunning_fog': np.nan}

    def _get_mean_word_length_single(self, text: str) -> float:
        if not isinstance(text, str) or not text.strip(): return np.nan
        words = re.findall(r'\b\w+\b', text.lower())
        if not words: return np.nan
        return sum(len(word) for word in words) / len(words)

    def _get_type_token_ratio_single(self, text: str) -> float:
        if not isinstance(text, str) or not text.strip(): return np.nan
        words = re.findall(r'\b\w+\b', text.lower())
        if not words: return np.nan
        return len(set(words)) / len(words) if len(words) > 0 else np.nan

    def _get_vader_sentiment_scores_single(self, text: str) -> dict:
        if not isinstance(text, str):
            return {'sentiment_neg': np.nan, 'sentiment_neu': np.nan,
                    'sentiment_pos': np.nan, 'sentiment_compound': np.nan}
        scores = self.vader_analyzer.polarity_scores(text)
        return {
            'sentiment_neg': scores['neg'], 'sentiment_neu': scores['neu'],
            'sentiment_pos': scores['pos'], 'sentiment_compound': scores['compound']
        }

    # --- III. Methods for processing a LIST of comments from one user/row ---

    def _get_aggregated_punctuation_counts_from_list(self, comment_list: list) -> dict:
        if not isinstance(comment_list, list) or not comment_list:
            return {punc: 0 for punc in self.specific_punctuation_to_track}

        total_counts = Counter()
        for comment_text in comment_list:
            if isinstance(comment_text, str):
                single_comment_punc_counts = self._get_punctuation_counts_single(comment_text)
                total_counts.update(single_comment_punc_counts)
        return dict(total_counts)

    def _get_aggregated_double_whitespace_from_list(self, comment_list: list) -> int:
        """Aggregates double whitespace counts from a list of comments."""
        if not isinstance(comment_list, list) or not comment_list:
            return 0
        
        total_double_whitespace = 0
        for comment_text in comment_list:
            if isinstance(comment_text, str):
                total_double_whitespace += self._get_double_whitespace_count_single(comment_text)
        return total_double_whitespace

    def _get_readability_scores_from_list(self, comment_list: list) -> dict:
        default_scores = {'flesch_reading_ease_agg': np.nan, 'gunning_fog_agg': np.nan}
        if not isinstance(comment_list, list) or not comment_list: return default_scores

        valid_comments = [c for c in comment_list if isinstance(c, str) and c.strip()]
        if not valid_comments: return default_scores

        if self.readability_agg_method == "concat":
            full_text = " ".join(valid_comments)
            scores = self._get_readability_scores_single(full_text)
            return {f"{k}_agg": v for k, v in scores.items()}
        elif self.readability_agg_method == "mean":
            flesch_s, gunning_s = [], []
            for ct in valid_comments:
                s = self._get_readability_scores_single(ct)
                if not np.isnan(s['flesch_reading_ease']): flesch_s.append(s['flesch_reading_ease'])
                if not np.isnan(s['gunning_fog']): gunning_s.append(s['gunning_fog'])
            return {
                'flesch_reading_ease_agg': np.nanmean(flesch_s) if flesch_s else np.nan,
                'gunning_fog_agg': np.nanmean(gunning_s) if gunning_s else np.nan
            }
        raise ValueError("Invalid readability_agg_method.")

    def _get_mean_word_length_from_list(self, comment_list: list) -> float:
        if not isinstance(comment_list, list) or not comment_list: return np.nan
        lengths = [self._get_mean_word_length_single(c) for c in comment_list if isinstance(c, str)]
        valid_lengths = [l for l in lengths if not np.isnan(l)]
        return np.nanmean(valid_lengths) if valid_lengths else np.nan

    def _get_ttr_from_list(self, comment_list: list) -> float:
        if not isinstance(comment_list, list) or not comment_list: return np.nan
        valid_comments = [c for c in comment_list if isinstance(c, str) and c.strip()]
        if not valid_comments: return np.nan

        if self.ttr_agg_method == "concat":
            return self._get_type_token_ratio_single(" ".join(valid_comments))
        elif self.ttr_agg_method == "mean":
            ttrs = [self._get_type_token_ratio_single(c) for c in valid_comments]
            valid_ttrs = [ttr for ttr in ttrs if not np.isnan(ttr)]
            return np.nanmean(valid_ttrs) if valid_ttrs else np.nan
        raise ValueError("Invalid ttr_agg_method.")

    def _get_aggregated_sentiment_from_list(self, comment_list: list) -> dict:
        default_scores = {'mean_sentiment_neg': np.nan, 'mean_sentiment_neu': np.nan,
                          'mean_sentiment_pos': np.nan, 'mean_sentiment_compound': np.nan,
                          'std_sentiment_compound': np.nan}
        if not isinstance(comment_list, list) or not comment_list: return default_scores

        scores_acc = {'neg': [], 'neu': [], 'pos': [], 'compound': []}
        for comment_text in comment_list:
            if isinstance(comment_text, str):
                single_s = self._get_vader_sentiment_scores_single(comment_text)
                for key_base in scores_acc.keys():
                    val = single_s[f'sentiment_{key_base}']
                    if not np.isnan(val): scores_acc[key_base].append(val)

        results = {}
        for key_base, val_list in scores_acc.items():
            results[f'mean_sentiment_{key_base}'] = np.nanmean(val_list) if val_list else np.nan

        comp_list = scores_acc['compound']
        results['std_sentiment_compound'] = np.nanstd(comp_list) if comp_list and len(comp_list) > 1 else 0.0 if comp_list else np.nan
        return results

    # --- IV. Main Public Method ---
    def extract_features(self, df: pd.DataFrame, comment_column: str = 'comments', output_prefix: str = "") -> pd.DataFrame:
        if comment_column not in df.columns:
            raise ValueError(f"Column '{comment_column}' not found in DataFrame.")
        if not df[comment_column].apply(lambda x: isinstance(x, (list, tuple, np.ndarray))).all():
            print(f"Warning: Not all entries in '{comment_column}' are lists/tuples/np.ndarray. Ensure data format is correct.", flush=True)

        df[f'{output_prefix}comment_word_counts'] = df[comment_column].apply(self._get_word_counts_per_comment)
        df[f'{output_prefix}mean_words_per_comment'] = df[f'{output_prefix}comment_word_counts'].apply(lambda x: self._aggregate_numeric_list(x, np.mean))
        df[f'{output_prefix}median_words_per_comment'] = df[f'{output_prefix}comment_word_counts'].apply(lambda x: self._aggregate_numeric_list(x, np.median))
        df[f'{output_prefix}total_words'] = df[f'{output_prefix}comment_word_counts'].apply(lambda x: self._aggregate_numeric_list(x, np.sum))
        df_sent_col = df[comment_column].apply(self._sentence_split)
        df_sent_counts_per_comment_col = df_sent_col.apply(self._get_sentence_counts_per_comment)
        df[f'{output_prefix}mean_sents_per_comment'] = df_sent_counts_per_comment_col.apply(lambda x: self._aggregate_numeric_list(x, np.mean))
        df[f'{output_prefix}median_sents_per_comment'] = df_sent_counts_per_comment_col.apply(lambda x: self._aggregate_numeric_list(x, np.median))
        df[f'{output_prefix}total_sents'] = df_sent_counts_per_comment_col.apply(lambda x: self._aggregate_numeric_list(x, np.sum))
        df_sent_word_counts_col = df_sent_col.apply(self._get_sentence_word_counts_per_comment)
        df[f'{output_prefix}mean_words_per_sentence'] = df_sent_word_counts_col.apply(lambda x: self._aggregate_numeric_list_of_lists(x, np.mean))
        df[f'{output_prefix}median_words_per_sentence'] = df_sent_word_counts_col.apply(lambda x: self._aggregate_numeric_list_of_lists(x, np.median))
        df[f'{output_prefix}sents_per_comment_skew'] = df[f'{output_prefix}mean_sents_per_comment'] - df[f'{output_prefix}median_sents_per_comment']
        df[f'{output_prefix}words_per_sentence_skew'] = df[f'{output_prefix}mean_words_per_sentence'] - df[f'{output_prefix}median_words_per_sentence']
        df[f'{output_prefix}total_double_whitespace'] = df[comment_column].apply(self._get_aggregated_double_whitespace_from_list)
        punc_data_col = df[comment_column].apply(self._get_aggregated_punctuation_counts_from_list)
        col_name_punc = ['em','qm','period','comma','colon','semicolon']
        actual_punc_to_track = self.specific_punctuation_to_track[:len(col_name_punc)]
        for punc_char, punc_name in zip(actual_punc_to_track, col_name_punc):
            df[f'{output_prefix}punc_{punc_name}_total'] = punc_data_col.apply(lambda d: d.get(punc_char, 0))
        readability_df = df[comment_column].apply(self._get_readability_scores_from_list).apply(pd.Series)
        readability_df.columns = [f'{output_prefix}{col}' for col in readability_df.columns]
        df = pd.concat([df, readability_df], axis=1)
        df[f'{output_prefix}mean_word_len_overall'] = df[comment_column].apply(self._get_mean_word_length_from_list)
        df[f'{output_prefix}ttr_overall'] = df[comment_column].apply(self._get_ttr_from_list)
        sentiment_df = df[comment_column].apply(self._get_aggregated_sentiment_from_list).apply(pd.Series)
        sentiment_df.columns = [f'{output_prefix}{col}' for col in sentiment_df.columns]
        df = pd.concat([df, sentiment_df], axis=1)
        return df

def show_corr(df,cols_drop=None,size=(15,7),save=False,save_name='UNKNOWN'):
    feature_name_map = {
        'mean_words_per_comment': "Avg. Words/Comment",'median_words_per_comment': "Median Words/Comment",
        'total_words': "Total Words",'mean_sents_per_comment': "Avg. Sents/Comment",
        'median_sents_per_comment': "Median Sents/Comment",'total_sents': "Total Sentences",
        'mean_words_per_sentence': "Avg. Words/Sentence",'median_words_per_sentence': "Median Words/Sentence",
        'sents_per_comment_skew': "Sentence Count Skew",'words_per_sentence_skew': "Sentence Length Skew",
        'total_double_whitespace': "Total Double Whitespace",'punc_em_total': "(!) Count",
        'punc_qm_total': "(?) Count",'punc_period_total': "(.) Count",'punc_comma_total': "(,) Count",
        'punc_colon_total': "(:) Count",'punc_semicolon_total': "(;) Count",
        'flesch_reading_ease_agg': "Flesch Reading Ease",'gunning_fog_agg': "Gunning Fog Index",
        'mean_word_len_overall': "Avg. Word Length",'ttr_overall': "Type-Token Ratio (TTR)",
        'mean_sentiment_neg': "Avg. Negative Sentiment",'mean_sentiment_neu': "Avg. Neutral Sentiment",
        'mean_sentiment_pos': "Avg. Positive Sentiment",'mean_sentiment_compound': "Avg. Compound Sentiment",
        'std_sentiment_compound': "Std. Compound Sentiment",'Openness': 'Openness',
        'Conscientiousness': 'Conscientiousness','Extraversion': 'Extraversion',
        'Agreeableness': 'Agreeableness','Emotional stability': 'Emotional Stability','Humility': 'Humility'
    }
    df_num = df.select_dtypes(include=['float64','int64','float32','int32'])
    df_num_corr = df_num.copy()
    if cols_drop is not None:
        valid_cols_to_drop = [col for col in cols_drop if col in df_num_corr.columns]
        if len(valid_cols_to_drop) < len(cols_drop):
            print(f"Warning: Some columns in cols_drop were not found: {set(cols_drop) - set(valid_cols_to_drop)}")
        if valid_cols_to_drop: df_num_corr.drop(valid_cols_to_drop, axis=1, inplace=True)
        else: print("Warning: No valid columns found to drop.")
    if df_num_corr.empty: print("No numerical data left for correlation matrix."); return
    corr = df_num_corr.corr()
    if feature_name_map:
        current_names = corr.columns.tolist()
        new_names = [feature_name_map.get(name, name) for name in current_names]
        corr.columns = new_names; corr.index = new_names
    plt.figure(figsize=size)
    sns.heatmap(corr, annot=True, cmap='seismic', fmt='.2f', vmin=-1, vmax=1, annot_kws={"size": 8})
    plt.xticks(rotation=45, ha='right', fontsize=10); plt.yticks(rotation=0, fontsize=10)
    plt.title(f"Feature Correlation Matrix ({save_name.replace('_corr','')})", fontsize=16)
    plt.tight_layout()
    if save:
        s_name = save_name if save_name.endswith('.png') else f'{save_name}.png'
        try: plt.savefig(s_name, dpi=300, bbox_inches='tight'); print(f"Corr plot saved to {s_name}")
        except Exception as e: print(f"Error saving plot: {e}")
    plt.show()

def clean_train(df):
    url_pattern_embedded = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
    www_pattern_embedded = re.compile(r'(?:^|[^a-zA-Z0-9])(www\.[a-zA-Z0-9][a-zA-Z0-9.-]+[a-zA-Z0-9]\.[a-zA-Z]{2,6}(?:/[^\s]*)?)')
    new_comments_column = []
    for idx, row in df.iterrows():
        current_comment_list = row.get('comments')
        if not isinstance(current_comment_list, list):
            new_comments_column.append(current_comment_list)
            # print('Non-list found in comments, keeping as is.') # Reduced verbosity
            continue
        cleaned_list_for_this_row = []
        for single_comment_string in current_comment_list:
            if isinstance(single_comment_string, str):
                temp_comment = re.sub(url_pattern_embedded, '', single_comment_string)
                temp_comment = re.sub(www_pattern_embedded, '', temp_comment)
                cleaned_list_for_this_row.append(temp_comment)
            else:
                cleaned_list_for_this_row.append(single_comment_string)
        new_comments_column.append(cleaned_list_for_this_row)
    df['comments'] = new_comments_column
    return df

def get_q_score(input_list_of_dicts): # Renamed for clarity
    if torch.cuda.is_available(): device = 'cuda'
    else: device = 'cpu'
    classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli', device = device)
    labels = [
        'an answer to: "Please describe a situation where you were presented with a problem outside of your comfort zone and where you were able to come up with a creative solution."', #q1
        'an answer to: "Tell us about a time when you have failed or made a mistake. What happened? What did you learn from this experience?"', #q2
        'an answer to: "Describe a situation in which you got a group of people to work together as a team. Did you encounter any issues? What was the end result?"' #q3
    ]
    for i, user_data in enumerate(input_list_of_dicts): # Iterate over the list
        print(f'Q-Scoring user {i+1}/{len(input_list_of_dicts)}')
        user_comments = user_data.get('comments') # user_data is a dict
        if user_comments and isinstance(user_comments, list) and len(user_comments) > 0:
            try:
                results = classifier(user_comments, labels, multi_label=True) # Pass list of comments
                user_results_per_comment = []
                for result in results: # result is for one comment
                    result_scores = {label: round(score,4) for label, score in zip(result['labels'], result['scores'])}
                    q_scores_for_comment = [
                        result_scores[labels[0]], result_scores[labels[1]], result_scores[labels[2]]
                    ]
                    user_results_per_comment.append(q_scores_for_comment)
                user_data['features']['q_scores'] = user_results_per_comment
            except Exception as e: print(f'Error in Q-scoring for user {i}: {e}')
        else:
             user_data['features']['q_scores'] = [] # Ensure key exists even if no comments
             # print(f'No comments or invalid format for Q-scoring user {i}')
    return input_list_of_dicts # Modified in-place

def append_q_score_train(input_list_of_dicts, q_score_path): # Renamed for clarity
    # This function expects q_score_path to be a JSON file, not JSONL
    # If it's JSONL, it needs to be read line by line. Assuming JSON for now.
    import json # Local import
    try:
        with open(q_score_path,'r',encoding='utf-8') as f:
            data_from_file = json.load(f) # Assumes a list of users
    except Exception as e:
        print(f"Error loading Q-scores from {q_score_path}: {e}")
        # Potentially return input_list_of_dicts unmodified or raise error
        return input_list_of_dicts


    if len(input_list_of_dicts) != len(data_from_file):
        print(f"Warning: Mismatch in user count between input data ({len(input_list_of_dicts)}) and Q-score file ({len(data_from_file)}). Q-scores may not be appended correctly.")
        # Decide on behavior: skip, error, or attempt partial merge
        # For now, proceeding with a warning if lengths differ but will likely fail or misalign.

    for i, (input_user, data_user_from_file) in enumerate(zip(input_list_of_dicts, data_from_file)):
        if 'comment_classifications' not in data_user_from_file:
            print(f"Warning: 'comment_classifications' not found for user {i} in Q-score file.")
            input_user['features']['q_scores'] = [] # Default if missing
            continue

        q_scores_for_user = []
        for comment_q_scores_dict in data_user_from_file['comment_classifications']:
            # Assuming the dict values are already in the correct Q1, Q2, Q3 order
            temp_q = [round(v, 4) for v in comment_q_scores_dict.values()]
            q_scores_for_user.append(temp_q)
        input_user['features']['q_scores'] = q_scores_for_user
    return input_list_of_dicts


def to_input(df):
    input_list = []
    for index, row in df.iterrows():
        input_user = {}
        comments = row['comments']
        if not isinstance(comments, list): # Ensure comments is always a list
            comments = [] if pd.isna(comments) else [str(comments)]

        if traits[0] in df.columns: # train and val
            labels = {trait: row[trait] for trait in traits}
            input_user = {'comments': comments, 'labels': labels}
        else: # test
            input_user = {'comments': comments}
        
        input_user['features'] = {} 
        # Select only numeric types that are actual features, not IDs or labels
        df_num_cols = df.select_dtypes(include=[np.number]).columns
        for col in df_num_cols:
            if col in traits or col == 'id' or col.lower() == 'id': # Check for 'id' case-insensitively
                continue
            # Ensure feature value is a standard Python float, handle potential NaN
            feature_val = row[col]
            input_user['features'][col] = round(float(feature_val), 4) if pd.notna(feature_val) else 0.0 # Or np.nan
        input_list.append(input_user)
    return input_list


def pretokenize(
    comments: List[str],
    tokenizer: BertTokenizerFast, # Pass tokenizer instance
    max_length: int = 256,
    padding_strategy: Union[str, bool] = "max_length",
    truncation_strategy: bool = True,
    return_tensors_type: str = 'pt',
) -> BatchEncoding:
    if not comments: # Handle empty list of comments
        empty_tensor_long = torch.empty((0, max_length), dtype=torch.long) if return_tensors_type == 'pt' else []
        empty_tensor_int = torch.empty((0, max_length), dtype=torch.int8) if return_tensors_type == 'pt' else [] # For attention_mask
        
        # For 'pt', ensure shape is (0, max_length) for BatchEncoding to accept it.
        # If padding is False or 'do_not_pad', max_length might not be fixed,
        # in which case (0,0) might be more appropriate or handle as per tokenizer's spec for empty.
        # However, with padding="max_length", the second dim should be max_length.

        return BatchEncoding({
            "input_ids": empty_tensor_long,
            "token_type_ids": empty_tensor_long,
            "attention_mask": empty_tensor_int # or long, but int8 is common for masks
        }, tensor_type=return_tensors_type if return_tensors_type else None)


    tokenized_output = tokenizer(
        comments,
        add_special_tokens=True,
        max_length=max_length,
        padding=padding_strategy,
        truncation=truncation_strategy,
        return_attention_mask=True,
        return_token_type_ids=True,
        return_tensors=return_tensors_type
    )
    return tokenized_output


# --- df_preprocess: Main function with modifications ---
def df_preprocess(df_input_path_or_df): # Renamed for clarity
    extractor = TextFeatureExtractor()
    numerical_feature_cols = [
    'mean_words_per_comment', 'median_words_per_comment', 'mean_sents_per_comment',
    'median_sents_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
    ]
    # Initialize tokenizer once
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
    
    dataset_save_path = None # To store the path of the saved HF Dataset
    df_feat = None # Will hold the dataframe with features

    if isinstance(df_input_path_or_df, str): # Path to CSV (VAL or TEST)
        df = pd.read_csv(df_input_path_or_df)

        if traits[0] in df.columns: # VAL data
            print('Processing VAL data from CSV...')
            dataset_save_path = 'val_hf_dataset'
            df['comments'] = df[['Q1','Q2','Q3']].astype(str).values.tolist()
            for trait in traits:
                df[trait] = df[trait]/100
            
            print('Extracting features for VAL...')
            df_feat = extractor.extract_features(df.copy(), comment_column='comments') # Use copy
            
            # show_corr(df_feat, save=True, cols_drop=['id'], save_name='val_corr') # id might not exist or be named differently
            show_corr(df_feat, save=True, cols_drop=[col for col in ['id', 'Id', 'ID'] if col in df_feat.columns], save_name='val_corr')


            # Load scaler (should be fit on train data)
            scaler_path = 'scaler.pkl'
            try:
                with open(scaler_path, 'rb') as f:
                    scaler = pickle.load(f)
                numerical_data = df_feat[numerical_feature_cols].fillna(0).values
                numerical_scaled = scaler.transform(numerical_data)
                df_feat[numerical_feature_cols] = pd.DataFrame(numerical_scaled, index=df_feat.index, columns=numerical_feature_cols)
            except FileNotFoundError:
                print(f"Scaler file {scaler_path} not found. Skipping scaling for VAL data. Ensure scaler is trained and saved.")
            except Exception as e:
                print(f"Error loading or applying scaler for VAL data: {e}. Skipping scaling.")


            # df_feat.to_csv('df_val_feat.csv', index=False) # Save intermediate features

            print('Converting VAL data to input structure...')
            user_data_list = to_input(df_feat) # This is your temp_input_structure_for_users
            
            print('Getting Q-scores for VAL...')
            user_data_list = get_q_score(user_data_list) # Modifies in-place

        else: # TEST data
            print('Processing TEST data from CSV...')
            dataset_save_path = 'test_hf_dataset'
            df['comments'] = df[['Q1','Q2','Q3']].astype(str).values.tolist()

            print('Extracting features for TEST...')
            df_feat = extractor.extract_features(df.copy(), comment_column='comments')

            scaler_path = 'scaler.pkl'
            try:
                with open(scaler_path, 'rb') as f:
                    scaler = pickle.load(f)
                numerical_data = df_feat[numerical_feature_cols].fillna(0).values
                numerical_scaled = scaler.transform(numerical_data)
                df_feat[numerical_feature_cols] = pd.DataFrame(numerical_scaled, index=df_feat.index, columns=numerical_feature_cols)
            except FileNotFoundError:
                print(f"Scaler file {scaler_path} not found. Skipping scaling for TEST data.")
            except Exception as e:
                print(f"Error loading or applying scaler for TEST data: {e}. Skipping scaling.")

            # df_feat.to_csv('df_test_feat.csv', index=False)

            print('Converting TEST data to input structure...')
            user_data_list = to_input(df_feat)
            
            print('Getting Q-scores for TEST...')
            user_data_list = get_q_score(user_data_list)
    
    else: # TRAIN data (df_input_path_or_df is a DataFrame)
        print('Processing TRAIN data (DataFrame input)...')
        dataset_save_path = 'train_hf_dataset'
        df = df_input_path_or_df # df is already a DataFrame

        # print('Cleaning TRAIN data...') # Assuming clean_train is called before this if needed
        # df = clean_train(df.copy()) # Example if you call it here

        print('Extracting features for TRAIN...')
        df_feat = extractor.extract_features(df.copy(), comment_column='comments') # Use copy
        
        scaler = MinMaxScaler(feature_range=(-1, 1))
        numerical_data = df_feat[numerical_feature_cols].fillna(0).values
        scaler.fit(numerical_data) # Fit scaler on TRAIN
        numerical_scaled = scaler.transform(numerical_data)
        df_feat[numerical_feature_cols] = pd.DataFrame(numerical_scaled, index=df_feat.index, columns=numerical_feature_cols)
        
        scaler_path = 'scaler.pkl'
        try:
            with open(scaler_path,'wb') as f:
                pickle.dump(scaler,f)
            print(f"Scaler saved to {scaler_path}")
        except Exception as e:
            print(f"Error saving scaler: {e}")

        # df_feat.to_csv('df_train_feat.csv', index=False)
        # show_corr(df_feat, save=True, cols_drop=[col for col in ['id', 'Id', 'ID'] if col in df_feat.columns], save_name='train_corr')
        
        print('Converting TRAIN data to input structure...')
        user_data_list = to_input(df_feat)

        q_score_file_path = r'q_scored.json' # Path to pre-computed Q-scores for training
        print(f'Appending pre-computed Q-scores for TRAIN from {q_score_file_path}...')
        user_data_list = append_q_score_train(user_data_list, q_score_file_path)


    # --- Common processing for all splits (TRAIN, VAL, TEST) ---
    # Tokenization and preparation for Hugging Face Dataset

    print('Tokenizing comments and preparing data for Hugging Face Dataset...')
    processed_for_hf_dataset = []
    MAX_COMMENTS_PER_TOKENIZE_CHUNK = 1024 # As in your original code

    for i, user_item in enumerate(user_data_list):
        # print(f"Tokenizing for user {i+1}/{len(user_data_list)}") # Can be verbose

        # Prepare item structure for HF Dataset
        item_for_dataset = {}
        
        # Labels (if present)
        if 'labels' in user_item:
            item_for_dataset['labels'] = user_item['labels']
        
        # Numerical features and Q-scores
        # Consolidate all non-tokenized features into 'numerical_features' for simpler schema
        # or list them explicitly. Here, 'q_scores' is kept separate as it's a list of lists.
        current_numerical_features = {}
        if 'features' in user_item:
            for f_key, f_val in user_item['features'].items():
                if f_key == 'q_scores':
                    item_for_dataset['q_scores'] = f_val if isinstance(f_val, list) else []
                elif f_key != 'comments_tokenized': # Exclude tokenized if it somehow got in early
                    current_numerical_features[f_key] = f_val
        item_for_dataset['numerical_features'] = current_numerical_features
        if 'q_scores' not in item_for_dataset: # Ensure q_scores key exists
            item_for_dataset['q_scores'] = []


        # Tokenize comments for this user
        user_comments_list = user_item.get('comments', [])
        # Ensure comments are strings, handle potential non-strings robustly
        string_comments = [str(c) for c in user_comments_list if isinstance(c, (str, int, float, np.str_)) and pd.notna(c)]
        
        tokenized_data_for_this_user = None
        if string_comments:
            all_tokenized_chunks_for_user = {"input_ids": [], "token_type_ids": [], "attention_mask": []}
            for chunk_start_idx in range(0, len(string_comments), MAX_COMMENTS_PER_TOKENIZE_CHUNK):
                comment_chunk = string_comments[chunk_start_idx : chunk_start_idx + MAX_COMMENTS_PER_TOKENIZE_CHUNK]
                tokenized_output_chunk = pretokenize(comment_chunk, tokenizer=tokenizer, return_tensors_type='pt')
                
                # Append tensor data directly
                all_tokenized_chunks_for_user["input_ids"].append(tokenized_output_chunk["input_ids"])
                all_tokenized_chunks_for_user["token_type_ids"].append(tokenized_output_chunk["token_type_ids"])
                all_tokenized_chunks_for_user["attention_mask"].append(tokenized_output_chunk["attention_mask"])
            
            if all_tokenized_chunks_for_user["input_ids"] and len(all_tokenized_chunks_for_user["input_ids"][0]) > 0 : # Check if any tensors were added
                final_input_ids = torch.cat(all_tokenized_chunks_for_user["input_ids"], dim=0)
                final_token_type_ids = torch.cat(all_tokenized_chunks_for_user["token_type_ids"], dim=0)
                final_attention_mask = torch.cat(all_tokenized_chunks_for_user["attention_mask"], dim=0)

                tokenized_data_for_this_user = BatchEncoding({
                    "input_ids": final_input_ids,
                    "token_type_ids": final_token_type_ids,
                    "attention_mask": final_attention_mask
                }, tensor_type='pt')
        
        # Add tokenized data to the item for dataset (as lists)
        if tokenized_data_for_this_user and tokenized_data_for_this_user['input_ids'].nelement() > 0:
            item_for_dataset['input_ids'] = tokenized_data_for_this_user['input_ids'].tolist()
            item_for_dataset['token_type_ids'] = tokenized_data_for_this_user['token_type_ids'].tolist()
            item_for_dataset['attention_mask'] = tokenized_data_for_this_user['attention_mask'].tolist()
        else: # Handle cases with no comments or empty tokenization for schema consistency
            item_for_dataset['input_ids'] = []
            item_for_dataset['token_type_ids'] = []
            item_for_dataset['attention_mask'] = []
            
        processed_for_hf_dataset.append(item_for_dataset)

    # Create Hugging Face Dataset
    if not processed_for_hf_dataset:
        print("No data processed. Skipping dataset creation.")
        return None

    # Let Hugging Face `datasets` infer the schema.
    # This usually works well if the structure of dicts in `processed_for_hf_dataset` is consistent.
    try:
        hf_dataset = Dataset.from_list(processed_for_hf_dataset)
    except Exception as e:
        print(f"Error creating Hugging Face Dataset: {e}")
        print("Ensure all items in 'processed_for_hf_dataset' have a consistent structure (same keys).")
        # You might want to inspect `processed_for_hf_dataset[0]` vs `processed_for_hf_dataset[i]` if errors occur.
        return None


    # Save the dataset
    if dataset_save_path:
        hf_dataset.save_to_disk(dataset_save_path)
        print(f"Finished processing. Hugging Face Dataset saved to {dataset_save_path}")
    else:
        print("Error: dataset_save_path was not set. Dataset not saved.")

    return dataset_save_path # Return the path where dataset was saved


# LOAD ------------------------------------------------------------------------------------------------------------------------------------
from datasets import load_from_disk # Already imported Dataset above

def load_hf_dataset_from_disk(dataset_path, set_torch_format=True, torch_columns=None):
    """
    Loads a Hugging Face Dataset from disk.
    Optionally sets the format to PyTorch tensors for specified columns.

    Args:
        dataset_path (str): Path to the saved dataset directory.
        set_torch_format (bool): Whether to set PyTorch tensor format.
        torch_columns (list, optional): List of column names to convert to PyTorch tensors.
                                        Defaults to ['input_ids', 'token_type_ids', 'attention_mask'].

    Returns:
        datasets.Dataset: The loaded dataset.
    """
    if torch_columns is None:
        torch_columns = ['input_ids', 'token_type_ids', 'attention_mask']

    try:
        loaded_dataset = load_from_disk(dataset_path)
        print(f"Dataset loaded successfully from {dataset_path}")
    except Exception as e:
        print(f"Error loading dataset from {dataset_path}: {e}")
        return None
    
    if set_torch_format:
        # Filter for columns that actually exist in the loaded dataset
        columns_to_format = [col for col in torch_columns if col in loaded_dataset.column_names]
        
        if columns_to_format:
            try:
                loaded_dataset.set_format(type='torch', columns=columns_to_format)
                print(f"Dataset format set to PyTorch for columns: {columns_to_format}")
            except Exception as e:
                print(f"Error setting PyTorch format for dataset: {e}")
        elif torch_columns: # If user specified columns but none were found
             print(f"Warning: Specified torch_columns ({torch_columns}) not found in the dataset. No format set.")
            
    return loaded_dataset

# before swithcing filetype

In [None]:
# preprocess
import string
import re
import numpy as np
from collections import Counter, defaultdict
import textstat
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
import pandas as pd
import json
import seaborn as sns
import matplotlib.pyplot as plt
import nltk

from transformers import BertTokenizerFast, pipeline
import torch
import accelerate
from typing import List, Dict, Union
import pandas as pd
import numpy as np
import nltk
import re # Make sure re is imported
from collections import Counter
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
import textstat # Assuming textstat is available
from sklearn.preprocessing import MinMaxScaler
import pickle


traits = ['Openness','Conscientiousness','Extraversion','Agreeableness','Emotional stability','Humility']

# separate functions

class TextFeatureExtractor:
    """
    A class to extract various linguistic and sentiment features from text data,
    designed to work with a Pandas DataFrame where each row contains a list of
    comment strings for a user/entry.
    """

    def __init__(self,
                 specific_punctuation_to_track: list = None,
                 readability_agg_method: str = "concat",
                 ttr_agg_method: str = "concat"):
        """
        Initializes the TextFeatureExtractor.

        Args:
            specific_punctuation_to_track (list, optional):
                A list of specific punctuation marks to count.
                Defaults to ['!', '?', '.', ','].
            readability_agg_method (str, optional):
                Method to aggregate readability scores ("concat" or "mean").
                Defaults to "concat".
            ttr_agg_method (str, optional):
                Method to aggregate Type-Token Ratio ("concat" or "mean").
                Defaults to "concat".
        """
        self.vader_analyzer = SentimentIntensityAnalyzer()

        if specific_punctuation_to_track is None:
            self.specific_punctuation_to_track = ['!', '?', '.', ',',':',';']
        else:
            self.specific_punctuation_to_track = specific_punctuation_to_track

        self.readability_agg_method = readability_agg_method
        self.ttr_agg_method = ttr_agg_method

        try:
            nltk.data.find('tokenizers/punkt')
        except nltk.downloader.DownloadError:
            print("NLTK 'punkt' tokenizer not found. Downloading...", flush=True)
            nltk.download('punkt', quiet=True)
        except LookupError: # Sometimes it's a LookupError if path is configured but resource missing
             print("NLTK 'punkt' tokenizer not found (LookupError). Downloading...", flush=True)
             nltk.download('punkt', quiet=True)


    # --- I. Basic structural feature helpers (operating on lists from a single DataFrame row) ---

    def _sentence_split(self, comment_list: list) -> list:
        """Splits each comment in a list of comments into sentences."""
        all_sentences_for_user = []
        if not isinstance(comment_list, list): return []
        for comment_text in comment_list:
            if isinstance(comment_text, str) and comment_text.strip():
                sentences = nltk.sent_tokenize(comment_text)
                all_sentences_for_user.append(sentences)
            else:
                all_sentences_for_user.append([]) # Handle empty or non-string comments
        return all_sentences_for_user # Returns list of lists of sentences e.g. [[s1,s2],[s3,s4,s5]]

    def _get_word_counts_per_comment(self, comment_list: list) -> list:
        """Calculates word count for each comment string in a list."""
        if not isinstance(comment_list, list): return []
        return [len(str(comment).split()) if isinstance(comment, str) else 0 for comment in comment_list]

    def _get_sentence_counts_per_comment(self, list_of_sentence_lists: list) -> list:
        """Counts sentences in each original comment (given pre-split sentences)."""
        if not isinstance(list_of_sentence_lists, list): return []
        return [len(sentences_in_one_comment) if isinstance(sentences_in_one_comment, list) else 0 for sentences_in_one_comment in list_of_sentence_lists]

    def _get_sentence_word_counts_per_comment(self, list_of_sentence_lists: list) -> list:
        """Calculates word counts for each sentence within each original comment."""
        result_for_user = []
        if not isinstance(list_of_sentence_lists, list): return []
        for sentences_in_one_comment in list_of_sentence_lists:
            if isinstance(sentences_in_one_comment, list):
                sent_lens = [len(str(sent).split()) if isinstance(sent, str) else 0 for sent in sentences_in_one_comment]
                result_for_user.append(sent_lens)
            else:
                result_for_user.append([])
        return result_for_user

    def _aggregate_numeric_list_of_lists(self, list_of_lists_of_numbers: list, agg_func) -> float:
        """Flattens a list of lists of numbers and applies an aggregation function."""
        if not isinstance(list_of_lists_of_numbers, list): return np.nan
        flat_list = []
        for sublist in list_of_lists_of_numbers:
            if isinstance(sublist, list):
                flat_list.extend(num for num in sublist if isinstance(num, (int, float)) and not np.isnan(num))
        return agg_func(flat_list) if flat_list else np.nan

    def _aggregate_numeric_list(self, list_of_numbers: list, agg_func) -> float:
        """Applies an aggregation function to a list of numbers."""
        if not isinstance(list_of_numbers, list): return np.nan
        valid_numbers = [num for num in list_of_numbers if isinstance(num, (int, float)) and not np.isnan(num)]
        return agg_func(valid_numbers) if valid_numbers else np.nan

    # --- II. Single-text processing helper methods (private) ---

    def _get_punctuation_counts_single(self, text: str) -> dict:
        if not isinstance(text, str): return {}
        counts = Counter(char for char in text if char in self.specific_punctuation_to_track)
        return {punc: counts.get(punc, 0) for punc in self.specific_punctuation_to_track}

    def _get_double_whitespace_count_single(self, text: str) -> int:
        """Counts occurrences of two or more consecutive whitespace characters."""
        if not isinstance(text, str) or not text.strip():
            return 0
        # Find all non-overlapping matches of 2 or more whitespace characters
        matches = re.findall(r"\s{2,}", text)
        return len(matches)

    def _get_readability_scores_single(self, text: str) -> dict:
        if not isinstance(text, str) or not text.strip():
            return {'flesch_reading_ease': np.nan, 'gunning_fog': np.nan}
        try:
            return {
                'flesch_reading_ease': textstat.flesch_reading_ease(text),
                'gunning_fog': textstat.gunning_fog(text)
            }
        except Exception:
            return {'flesch_reading_ease': np.nan, 'gunning_fog': np.nan}

    def _get_mean_word_length_single(self, text: str) -> float:
        if not isinstance(text, str) or not text.strip(): return np.nan
        words = re.findall(r'\b\w+\b', text.lower())
        if not words: return np.nan
        return sum(len(word) for word in words) / len(words)

    def _get_type_token_ratio_single(self, text: str) -> float:
        if not isinstance(text, str) or not text.strip(): return np.nan
        words = re.findall(r'\b\w+\b', text.lower())
        if not words: return np.nan
        return len(set(words)) / len(words) if len(words) > 0 else np.nan

    def _get_vader_sentiment_scores_single(self, text: str) -> dict:
        if not isinstance(text, str):
            return {'sentiment_neg': np.nan, 'sentiment_neu': np.nan,
                    'sentiment_pos': np.nan, 'sentiment_compound': np.nan}
        scores = self.vader_analyzer.polarity_scores(text)
        return {
            'sentiment_neg': scores['neg'], 'sentiment_neu': scores['neu'],
            'sentiment_pos': scores['pos'], 'sentiment_compound': scores['compound']
        }

    # --- III. Methods for processing a LIST of comments from one user/row ---

    def _get_aggregated_punctuation_counts_from_list(self, comment_list: list) -> dict:
        if not isinstance(comment_list, list) or not comment_list:
            return {punc: 0 for punc in self.specific_punctuation_to_track}

        total_counts = Counter()
        for comment_text in comment_list:
            if isinstance(comment_text, str):
                single_comment_punc_counts = self._get_punctuation_counts_single(comment_text)
                total_counts.update(single_comment_punc_counts)
        return dict(total_counts)

    def _get_aggregated_double_whitespace_from_list(self, comment_list: list) -> int:
        """Aggregates double whitespace counts from a list of comments."""
        if not isinstance(comment_list, list) or not comment_list:
            return 0
        
        total_double_whitespace = 0
        for comment_text in comment_list:
            if isinstance(comment_text, str):
                total_double_whitespace += self._get_double_whitespace_count_single(comment_text)
        return total_double_whitespace

    def _get_readability_scores_from_list(self, comment_list: list) -> dict:
        default_scores = {'flesch_reading_ease_agg': np.nan, 'gunning_fog_agg': np.nan}
        if not isinstance(comment_list, list) or not comment_list: return default_scores

        valid_comments = [c for c in comment_list if isinstance(c, str) and c.strip()]
        if not valid_comments: return default_scores

        if self.readability_agg_method == "concat":
            full_text = " ".join(valid_comments)
            scores = self._get_readability_scores_single(full_text)
            return {f"{k}_agg": v for k, v in scores.items()}
        elif self.readability_agg_method == "mean":
            flesch_s, gunning_s = [], []
            for ct in valid_comments:
                s = self._get_readability_scores_single(ct)
                if not np.isnan(s['flesch_reading_ease']): flesch_s.append(s['flesch_reading_ease'])
                if not np.isnan(s['gunning_fog']): gunning_s.append(s['gunning_fog'])
            return {
                'flesch_reading_ease_agg': np.nanmean(flesch_s) if flesch_s else np.nan,
                'gunning_fog_agg': np.nanmean(gunning_s) if gunning_s else np.nan
            }
        raise ValueError("Invalid readability_agg_method.")

    def _get_mean_word_length_from_list(self, comment_list: list) -> float:
        if not isinstance(comment_list, list) or not comment_list: return np.nan
        lengths = [self._get_mean_word_length_single(c) for c in comment_list if isinstance(c, str)]
        valid_lengths = [l for l in lengths if not np.isnan(l)]
        return np.nanmean(valid_lengths) if valid_lengths else np.nan

    def _get_ttr_from_list(self, comment_list: list) -> float:
        if not isinstance(comment_list, list) or not comment_list: return np.nan
        valid_comments = [c for c in comment_list if isinstance(c, str) and c.strip()]
        if not valid_comments: return np.nan

        if self.ttr_agg_method == "concat":
            return self._get_type_token_ratio_single(" ".join(valid_comments))
        elif self.ttr_agg_method == "mean":
            ttrs = [self._get_type_token_ratio_single(c) for c in valid_comments]
            valid_ttrs = [ttr for ttr in ttrs if not np.isnan(ttr)]
            return np.nanmean(valid_ttrs) if valid_ttrs else np.nan
        raise ValueError("Invalid ttr_agg_method.")

    def _get_aggregated_sentiment_from_list(self, comment_list: list) -> dict:
        default_scores = {'mean_sentiment_neg': np.nan, 'mean_sentiment_neu': np.nan,
                          'mean_sentiment_pos': np.nan, 'mean_sentiment_compound': np.nan,
                          'std_sentiment_compound': np.nan}
        if not isinstance(comment_list, list) or not comment_list: return default_scores

        scores_acc = {'neg': [], 'neu': [], 'pos': [], 'compound': []}
        for comment_text in comment_list:
            if isinstance(comment_text, str):
                single_s = self._get_vader_sentiment_scores_single(comment_text)
                for key_base in scores_acc.keys():
                    val = single_s[f'sentiment_{key_base}']
                    if not np.isnan(val): scores_acc[key_base].append(val)

        results = {}
        for key_base, val_list in scores_acc.items():
            results[f'mean_sentiment_{key_base}'] = np.nanmean(val_list) if val_list else np.nan

        comp_list = scores_acc['compound']
        results['std_sentiment_compound'] = np.nanstd(comp_list) if comp_list and len(comp_list) > 1 else 0.0 if comp_list else np.nan
        return results

    # --- IV. Main Public Method ---
    def extract_features(self, df: pd.DataFrame, comment_column: str = 'comments', output_prefix: str = "") -> pd.DataFrame:
        """
        Extracts all defined text features and adds them to the DataFrame.

        Args:
            df (pd.DataFrame): The input DataFrame.
            comment_column (str): The name of the column in df that contains
                                  lists of comment strings for each row/user.
            output_prefix (str, optional): A prefix to add to all new feature
                                           column names. Defaults to "".

        Returns:
            pd.DataFrame: The DataFrame with added feature columns.
        """
        if comment_column not in df.columns:
            raise ValueError(f"Column '{comment_column}' not found in DataFrame.")

        if not df[comment_column].apply(lambda x: isinstance(x, (list, tuple, np.ndarray))).all():
            print(f"Warning: Not all entries in '{comment_column}' are lists/tuples/np.ndarray. Ensure data format is correct.", flush=True)


        # --- 1. Basic structural features ---
        df[f'{output_prefix}comment_word_counts'] = df[comment_column].apply(self._get_word_counts_per_comment)
        df[f'{output_prefix}mean_words_per_comment'] = df[f'{output_prefix}comment_word_counts'].apply(lambda x: self._aggregate_numeric_list(x, np.mean))
        df[f'{output_prefix}median_words_per_comment'] = df[f'{output_prefix}comment_word_counts'].apply(lambda x: self._aggregate_numeric_list(x, np.median))
        df[f'{output_prefix}total_words'] = df[f'{output_prefix}comment_word_counts'].apply(lambda x: self._aggregate_numeric_list(x, np.sum))

        df_sent_col = df[comment_column].apply(self._sentence_split)
        df_sent_counts_per_comment_col = df_sent_col.apply(self._get_sentence_counts_per_comment)
        df[f'{output_prefix}mean_sents_per_comment'] = df_sent_counts_per_comment_col.apply(lambda x: self._aggregate_numeric_list(x, np.mean))
        df[f'{output_prefix}median_sents_per_comment'] = df_sent_counts_per_comment_col.apply(lambda x: self._aggregate_numeric_list(x, np.median))
        df[f'{output_prefix}total_sents'] = df_sent_counts_per_comment_col.apply(lambda x: self._aggregate_numeric_list(x, np.sum))

        df_sent_word_counts_col = df_sent_col.apply(self._get_sentence_word_counts_per_comment)
        df[f'{output_prefix}mean_words_per_sentence'] = df_sent_word_counts_col.apply(lambda x: self._aggregate_numeric_list_of_lists(x, np.mean))
        df[f'{output_prefix}median_words_per_sentence'] = df_sent_word_counts_col.apply(lambda x: self._aggregate_numeric_list_of_lists(x, np.median))

        df[f'{output_prefix}sents_per_comment_skew'] = df[f'{output_prefix}mean_sents_per_comment'] - df[f'{output_prefix}median_sents_per_comment']
        df[f'{output_prefix}words_per_sentence_skew'] = df[f'{output_prefix}mean_words_per_sentence'] - df[f'{output_prefix}median_words_per_sentence']

        # --- 1b. Double Whitespace Count ---  NEW SECTION
        df[f'{output_prefix}total_double_whitespace'] = df[comment_column].apply(self._get_aggregated_double_whitespace_from_list)


        # --- 2. Punctuation Features ---
        punc_data_col = df[comment_column].apply(self._get_aggregated_punctuation_counts_from_list)
        # Note: Your hardcoded col_name_punc has 'd_ws'. If you want 'total_double_whitespace' to be
        # named 'punc_d_ws_total' and be part of this loop, you'd need a different approach.
        # For now, 'd_ws' in col_name_punc will likely result in a column of zeros if it's not
        # in self.specific_punctuation_to_track with a corresponding character.
        col_name_punc = ['em','qm','period','comma','colon','semicolon'] # Adjusted: removed 'd_ws' as it's handled separately
        actual_punc_to_track = self.specific_punctuation_to_track[:len(col_name_punc)] # Ensure lists align

        for punc_char, punc_name in zip(actual_punc_to_track, col_name_punc):
            df[f'{output_prefix}punc_{punc_name}_total'] = punc_data_col.apply(lambda d: d.get(punc_char, 0))


        # --- 3. Readability Features ---
        readability_df = df[comment_column].apply(self._get_readability_scores_from_list).apply(pd.Series)
        readability_df.columns = [f'{output_prefix}{col}' for col in readability_df.columns]
        df = pd.concat([df, readability_df], axis=1)

        # --- 4. Mean Word Length (lexical diversity proxy) ---
        df[f'{output_prefix}mean_word_len_overall'] = df[comment_column].apply(self._get_mean_word_length_from_list)

        # --- 5. Type-Token Ratio (lexical richness) ---
        df[f'{output_prefix}ttr_overall'] = df[comment_column].apply(self._get_ttr_from_list)

        # --- 6. Sentiment Features ---
        sentiment_df = df[comment_column].apply(self._get_aggregated_sentiment_from_list).apply(pd.Series)
        sentiment_df.columns = [f'{output_prefix}{col}' for col in sentiment_df.columns]
        df = pd.concat([df, sentiment_df], axis=1)

        return df
    
def show_corr(df,cols_drop=None,size=(15,7),save=False,save_name='UNKNOWN'):
    feature_name_map = {
        # Basic Structural & Length Features
        'mean_words_per_comment': "Avg. Words/Comment",
        'median_words_per_comment': "Median Words/Comment",
        'total_words': "Total Words",
        'mean_sents_per_comment': "Avg. Sents/Comment",
        'median_sents_per_comment': "Median Sents/Comment",
        'total_sents': "Total Sentences",
        'mean_words_per_sentence': "Avg. Words/Sentence",
        'median_words_per_sentence': "Median Words/Sentence",
        'sents_per_comment_skew': "Sentence Count Skew",
        'words_per_sentence_skew': "Sentence Length Skew",
        'total_double_whitespace': "Total Double Whitespace",

        # Punctuation Usage Features
        'punc_em_total': "(!) Count",
        'punc_qm_total': "(?) Count",
        'punc_period_total': "(.) Count",
        'punc_comma_total': "(,) Count",
        'punc_colon_total': "(:) Count",
        'punc_semicolon_total': "(;) Count",

        # Readability Features
        'flesch_reading_ease_agg': "Flesch Reading Ease",
        'gunning_fog_agg': "Gunning Fog Index",

        # Lexical Features
        'mean_word_len_overall': "Avg. Word Length",
        'ttr_overall': "Type-Token Ratio (TTR)",

        # Sentiment Features
        'mean_sentiment_neg': "Avg. Negative Sentiment",
        'mean_sentiment_neu': "Avg. Neutral Sentiment",
        'mean_sentiment_pos': "Avg. Positive Sentiment",
        'mean_sentiment_compound': "Avg. Compound Sentiment",
        'std_sentiment_compound': "Std. Compound Sentiment",
        
        # If you have your trait names also in the numerical df for correlation
        'Openness': 'Openness',
        'Conscientiousness': 'Conscientiousness',
        'Extraversion': 'Extraversion',
        'Agreeableness': 'Agreeableness',
        'Emotional stability': 'Emotional Stability', # Ensure key matches DataFrame
        'Humility': 'Humility'
    }
    df_num = df.select_dtypes(include=['float64','int64','float32','int32'])
    df_num_corr = df_num.copy()
    if cols_drop is not None:
        # Ensure cols_drop contains valid columns present in df_num_corr
        valid_cols_to_drop = [col for col in cols_drop if col in df_num_corr.columns]
        if len(valid_cols_to_drop) < len(cols_drop):
            print(f"Warning: Some columns in cols_drop were not found in the numerical DataFrame: {set(cols_drop) - set(valid_cols_to_drop)}")
        if valid_cols_to_drop:
            df_num_corr.drop(valid_cols_to_drop, axis=1, inplace=True)
        else:
            print("Warning: No valid columns found to drop from cols_drop list.")
    if df_num_corr.empty:
        print("No numerical data left to compute correlation matrix after dropping columns.")
        return
    corr = df_num_corr.corr()
    # Rename index and columns for the heatmap if name_map is provided
    if feature_name_map:
        # Get current column/index names from the correlation matrix
        current_names = corr.columns.tolist()
        # Create new names, using mapped name if available, else keep original
        new_names = [feature_name_map.get(name, name) for name in current_names]
        # Apply new names
        corr.columns = new_names
        corr.index = new_names

    plt.figure(figsize=size)
    sns.heatmap(corr, annot=True, cmap='seismic', fmt='.2f', vmin=-1, vmax=1, annot_kws={"size": 8}) # Smaller annotation size
    plt.xticks(rotation=45, ha='right', fontsize=10) # Rotate x-axis labels for readability
    plt.yticks(rotation=0, fontsize=10)          # Ensure y-axis labels are readable
    plt.title(f"Feature Correlation Matrix ({save_name.replace('_corr','')})", fontsize=16)
    plt.tight_layout() # Adjust layout to prevent labels from overlapping

    if save:
        # Ensure save_name doesn't already include .png if you're appending it
        s_name = save_name if save_name.endswith('.png') else f'{save_name}.png'
        try:
            plt.savefig(s_name, dpi=300, bbox_inches='tight') # Add bbox_inches for better saving
            print(f"Correlation plot saved to {s_name}")
        except Exception as e:
            print(f"Error saving plot: {e}")
    plt.show()

    return # No need to return anything explicitly unless you want the corr matrix

import re
import pandas as pd

def clean_train(df):
    # Option 1: Your original pattern (for comments that are *entirely* URLs)
    # url_pattern_entire_string = re.compile(r"^(?:(?:https?:\/\/)?(?:www\.)?)?([a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)+)(?::[0-9]{1,5})?(?:\/[^\s]*)?(?:\?[^\s#]*)?(?:#[^\s]*)?$")

    # Option 2: A more general pattern to find URLs embedded within text
    # This one is simpler and commonly used for http(s) links.
    url_pattern_embedded = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
    # You might want another for "www.domain.com" type links if not covered
    www_pattern_embedded = re.compile(r'(?:^|[^a-zA-Z0-9])(www\.[a-zA-Z0-9][a-zA-Z0-9.-]+[a-zA-Z0-9]\.[a-zA-Z]{2,6}(?:/[^\s]*)?)')


    # We will build a new list for the 'comments' column
    new_comments_column = []

    for idx, row in df.iterrows():
        current_comment_list = row.get('comments') # Use .get() for safety

        if not isinstance(current_comment_list, list):
            # If the entry for 'comments' is not a list (e.g., None, NaN, float, str),
            # append it as is. This is crucial for the warning you saw.
            new_comments_column.append(current_comment_list)
            print('non list still found')
            continue # Move to the next row

        cleaned_list_for_this_row = []
        for single_comment_string in current_comment_list:
            if isinstance(single_comment_string, str):
                # Apply cleaning using the embedded pattern (Option 2)
                temp_comment = re.sub(url_pattern_embedded, '', single_comment_string)
                temp_comment = re.sub(www_pattern_embedded, '', temp_comment)

                # If you intended Option 1 (entire string is URL):
                # if re.fullmatch(url_pattern_entire_string, single_comment_string):
                #    temp_comment = '[URL COMMENT REMOVED]'
                # else:
                #    temp_comment = single_comment_string
                cleaned_list_for_this_row.append(temp_comment)
            else:
                # If an item within the list is not a string (e.g. None), keep it.
                cleaned_list_for_this_row.append(single_comment_string)
        
        new_comments_column.append(cleaned_list_for_this_row)

    # Assign the newly built list of lists back to the DataFrame column
    df['comments'] = new_comments_column
    return df

# Example of how to use it:
# Assuming df_train is your DataFrame
# df_train = clean_train_corrected(df_train.copy()) # Use .copy() if you want to avoid modifying the original df_train in place


def get_q_score(input):
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    classifier = pipeline(
        'zero-shot-classification',
        model='facebook/bart-large-mnli',
        device = device
    )

    labels = [
        'an answer to: "Please describe a situation where you were presented with a problem outside of your comfort zone and where you were able to come up with a creative solution."', #q1
        'an answer to: "Tell us about a time when you have failed or made a mistake. What happened? What did you learn from this experience?"', #q2
        'an answer to: "Describe a situation in which you got a group of people to work together as a team. Did you encounter any issues? What was the end result?"' #q3
    ]



    for i, user in enumerate(input):
        print(f'{i+1}/{len(input)}')
        user_comments = user.get('comments')
        if user_comments and isinstance(user_comments, list) and len(user_comments) > 0:
            try:
                results = classifier(user_comments, labels, multi_label=True)

                user_results = []
                for result in results:
                    try:
                        result = {label: round(score,4) for label, score in zip(result['labels'], result['scores'])}
                        result_dict = {
                            'Q1_score': result['an answer to: "Please describe a situation where you were presented with a problem outside of your comfort zone and where you were able to come up with a creative solution."'],
                            'Q2_score': result['an answer to: "Tell us about a time when you have failed or made a mistake. What happened? What did you learn from this experience?"'],
                            'Q3_score': result['an answer to: "Describe a situation in which you got a group of people to work together as a team. Did you encounter any issues? What was the end result?"']
                        }
                        
                        user_results.append([v for v in result_dict.values()])
                    except Exception as e:
                        print(f'{e}')
                user['features']['q_scores'] = user_results
            except Exception as e:
                print(f'{e}')
        else: print('OOPY')
    
    return input

def append_q_score_train(input, q_score_path):
    with open(q_score_path,'r',encoding='utf-8') as f:
        data = json.load(f)
    
    for input_user, data_user in zip(input,data):
        
        q_scores = []
        for comment_q_scores in data_user['comment_classifications']:
            temp_q = [v for v in comment_q_scores.values()]
            q_scores.append(temp_q)
        input_user['features']['q_scores'] = q_scores
    return input


def to_input(df):
    input = []
    for index, row in df.iterrows():
        input_user = {}
        comments = row['comments']

        if traits[0] in df.columns:             # train and val
            # labels
            labels = {}
            for trait in traits:
                labels[trait] = row[trait]
            input_user = {
                'comments': comments,
                'labels': labels
            }
        else:                                   # test
            input_user = {
                'comments': comments
            }
        # num features
        input_user['features'] = {} 
        df_num = df.select_dtypes(include=['float64','int64'])
        for col in df_num.columns:
            if col in traits or col == 'id':
                continue
            else: input_user['features'][col] = round(row[col],4)
        input.append(input_user)
    return input



def pretokenize(
    comments: List[str],
    model_name: str = "bert-base-uncased",
    max_length: int = 256,
    padding_strategy: Union[str, bool] = "max_length", # or True/'longest'
    truncation_strategy: bool = True,
    return_tensors_type: str = 'pt', # 'pt' for PyTorch, 'tf' for TensorFlow, None for Python lists
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
) -> Dict[str, List[List[int]]]:
    """
    Tokenizes a list of comments separately using a BERT tokenizer.

    Args:
        comments (List[str]): A list of strings, where each string is a comment.
        model_name (str): The pretrained BERT model name (e.g., "bert-base-uncased").
        max_length (int): The maximum sequence length. Comments longer than this
                          will be truncated. BERT's typical max is 512.
        padding_strategy (Union[str, bool]):
            - True or 'longest': Pad to the longest sequence in the batch.
            - 'max_length': Pad to `max_length`.
            - 'do_not_pad' or False: Do not pad.
            Recommended: 'max_length' or True if you plan to batch.
        truncation_strategy (bool): Whether to truncate sequences longer than `max_length`.
                                    Should generally be True for BERT.
        return_tensors_type (str, optional): If set, will return tensors of a
            specific framework ('pt' for PyTorch, 'tf' for TensorFlow).
            If None, returns lists of integers. Defaults to None.

    Returns:
        transformers.tokenization_utils_base.BatchEncoding:
        A dictionary-like object containing:
            - 'input_ids': List of lists of token IDs.
            - 'token_type_ids': List of lists of token type IDs (segment IDs).
            - 'attention_mask': List of lists of attention masks.
        Each inner list corresponds to one of the input comments.
    """


    # Tokenize the list of comments
    # The tokenizer can handle a list of texts directly.
    # It will tokenize each text and apply padding/truncation as specified.
    # `add_special_tokens=True` is the default and adds [CLS] and [SEP]
    tokenized_output = tokenizer(
        comments,
        add_special_tokens=True,        # Adds [CLS] and [SEP]
        max_length=max_length,
        padding=padding_strategy,
        truncation=truncation_strategy,
        return_attention_mask=True,     # Explicitly ask for attention mask
        return_token_type_ids=True,   # Explicitly ask for token type ids
        return_tensors=return_tensors_type
    )

    return tokenized_output

import json
import torch # Assuming BatchEncoding contains PyTorch tensors
from transformers.tokenization_utils_base import BatchEncoding # For type checking and instantiation

# --- Constants for marking types in JSON ---
_TENSOR_MARKER = "__tensor__"
_TENSOR_DTYPE_MARKER = "__tensor_dtype__"
_BATCH_ENCODING_MARKER = "__batch_encoding__"
_BATCH_ENCODING_DATA_MARKER = "data"

def _convert_dtype_to_str(dtype):
    """Converts a torch.dtype to a string representation."""
    # This mapping might need to be expanded for other dtypes
    # Alternatively, just use str(dtype) e.g. "torch.int64"
    # and torch.__getattribute__(str_dtype.split('.')[1]) for loading
    return str(dtype)

def _convert_str_to_dtype(dtype_str):
    """Converts a string representation back to a torch.dtype."""
    if not dtype_str.startswith("torch."):
        # Fallback for simple dtype names if str(dtype) was used without "torch." prefix
        # Or if it's a very basic type like 'float32' that torch.dtype can parse
        try:
            return torch.__getattribute__(dtype_str)
        except AttributeError:
            return torch.dtype(dtype_str) # Try direct parsing

    # For strings like "torch.int64"
    dtype_name = dtype_str.split('.')[1]
    return torch.__getattribute__(dtype_name)


def _recursive_encode(obj):
    """
    Recursively traverses the object and converts tensors and BatchEncoding
    to JSON-serializable representations.
    """
    if isinstance(obj, torch.Tensor):
        return {
            _TENSOR_MARKER: True,
            _BATCH_ENCODING_DATA_MARKER: obj.tolist(),
            _TENSOR_DTYPE_MARKER: _convert_dtype_to_str(obj.dtype)
        }
    elif isinstance(obj, BatchEncoding):
        # BatchEncoding is dict-like. We need to encode its items.
        # We also mark that this dictionary was originally a BatchEncoding.
        return {
            _BATCH_ENCODING_MARKER: True,
            _BATCH_ENCODING_DATA_MARKER: {k: _recursive_encode(v) for k, v in obj.items()}
        }
    elif isinstance(obj, dict):
        return {k: _recursive_encode(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [_recursive_encode(item) for item in obj]
    # Add other custom types here if needed
    return obj # For primitive types

def encode_for_json(data_to_encode):
    """
    Encodes data containing PyTorch tensors and BatchEncoding objects
    into a JSON-serializable format.
    """
    return _recursive_encode(data_to_encode)

def _json_object_hook(dct):
    """
    Object hook for json.load or json.loads to reconstruct
    tensors and BatchEncoding objects.
    """
    if _TENSOR_MARKER in dct:
        dtype_str = dct.get(_TENSOR_DTYPE_MARKER, 'float32') # Default to float32 if not found
        dtype = _convert_str_to_dtype(dtype_str)
        return torch.tensor(dct[_BATCH_ENCODING_DATA_MARKER], dtype=dtype)
    elif _BATCH_ENCODING_MARKER in dct:
        # The items within dct['data'] should have already been processed
        # by this hook if they were tensors or other BatchEncodings.
        return BatchEncoding(dct[_BATCH_ENCODING_DATA_MARKER])
    return dct # Return dict as is if not a special type

def decode_from_json(json_data):
    """
    Decodes JSON data (string or file pointer) back into Python objects,
    reconstructing PyTorch tensors and BatchEncoding objects.
    
    :param json_data: JSON string or a file-like object (e.g., opened file).
    """
    if isinstance(json_data, str):
        return json.loads(json_data, object_hook=_json_object_hook)
    else: # Assuming file-like object
        return json.load(json_data, object_hook=_json_object_hook)



# COMBINED FUNCTION
def df_preprocess(df_path):
    extractor = TextFeatureExtractor()
    numerical_feature_cols = [
    'mean_words_per_comment', 'median_words_per_comment', 'mean_sents_per_comment',
    'median_sents_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
    ]
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
    
    if isinstance(df_path,str):
        df = pd.read_csv(df_path)

        if traits[0] in df.columns: #VAL
            print('processing val')
            df['comments'] = df[['Q1','Q2','Q3']].astype(str).values.tolist()
            for trait in traits:
                df[trait] = df[trait]/100
            print('extracting features')
            df_feat = extractor.extract_features(df.copy())
            
            show_corr(df_feat,save=True,cols_drop='id',save_name='val_corr')
            with open('scaler.pkl', 'rb') as f:
                scaler = pickle.load(f)
            numerical_data = df_feat[numerical_feature_cols].fillna(0).values
            numerical_scaled = scaler.transform(numerical_data)
            df_feat[numerical_feature_cols] = pd.DataFrame(numerical_scaled, index=df_feat.index, columns=numerical_feature_cols)

            df_feat.to_csv('df_val_feat.csv')

            print('turning into dict')
            input_dict = to_input(df_feat)
            print('getting q_scores')
            temp_input_structure_for_users = get_q_score(input_dict)
            print('Tokenizing and writing user by user to JSONL file...')
            output_file_path = 'val_data.jsonl'
            with open(output_file_path, 'w', encoding='utf-8') as outfile:
                for i, user_data_item in enumerate(temp_input_structure_for_users):
                    print(f"Processing user {i+1}/{len(temp_input_structure_for_users)}")

                    # Tokenize comments for this user (apply chunking here if a user has many comments)
                    user_comments_list = user_data_item.get('comments', [])
                    string_comments = [str(c) for c in user_comments_list if isinstance(c, (str, int, float))] # Ensure strings
                    
                    # Implement your chunked pretokenize logic here for string_comments
                    # For simplicity, let's assume pretokenize handles a potentially large list
                    # by chunking or that a single user's tokenization is manageable
                    if string_comments:
                        # --- Incorporate the chunking logic for pretokenize for this user ---
                        MAX_COMMENTS_PER_TOKENIZE_CHUNK = 1024 
                        all_tokenized_chunks_for_user = {"input_ids": [], "token_type_ids": [], "attention_mask": []}
                        for chunk_start_idx in range(0, len(string_comments), MAX_COMMENTS_PER_TOKENIZE_CHUNK):
                            comment_chunk = string_comments[chunk_start_idx : chunk_start_idx + MAX_COMMENTS_PER_TOKENIZE_CHUNK]
                            tokenized_output_chunk = pretokenize(comment_chunk, return_tensors_type='pt',tokenizer=tokenizer)
                            all_tokenized_chunks_for_user["input_ids"].append(tokenized_output_chunk["input_ids"])
                            all_tokenized_chunks_for_user["token_type_ids"].append(tokenized_output_chunk["token_type_ids"])
                            all_tokenized_chunks_for_user["attention_mask"].append(tokenized_output_chunk["attention_mask"])
                        
                        # Concatenate chunks for this user
                        final_input_ids = torch.cat(all_tokenized_chunks_for_user["input_ids"], dim=0) if all_tokenized_chunks_for_user["input_ids"] else torch.empty(0,0, dtype=torch.long)
                        # ... similar for token_type_ids, attention_mask
                        final_token_type_ids = torch.cat(all_tokenized_chunks_for_user["token_type_ids"], dim=0) if all_tokenized_chunks_for_user["token_type_ids"] else torch.empty(0,0, dtype=torch.long)
                        final_attention_mask = torch.cat(all_tokenized_chunks_for_user["attention_mask"], dim=0) if all_tokenized_chunks_for_user["attention_mask"] else torch.empty(0,0, dtype=torch.long)

                        tokenized_data_for_this_user = BatchEncoding({
                            "input_ids": final_input_ids,
                            "token_type_ids": final_token_type_ids,
                            "attention_mask": final_attention_mask
                        })
                        user_data_item['features']['comments_tokenized'] = tokenized_data_for_this_user
                    else:
                        user_data_item['features']['comments_tokenized'] = pretokenize([]) # empty tokenization

                    # Encode just this user's data for JSON
                    del user_data_item['comments']
                    encoded_user_data = encode_for_json(user_data_item)
                    
                    # Write this user's JSON object, followed by a newline
                    json.dump(encoded_user_data, outfile)
                    outfile.write('\n')
                    
                    # Optionally, clear parts of user_data_item if it helps release memory sooner,
                    # though Python's GC should handle it as `user_data_item` gets redefined in the next iteration.
                    # del user_data_item['features']['comments_tokenized'] # Tensors are now in encoded_user_data as lists

            print(f"Finished processing. Data streamed to {output_file_path}")
            return None # Or path to the file, as data_dict is not returned

        



        else: #TEST
            print('processing test')
            df['comments'] = df[['Q1','Q2','Q3']].astype(str).values.tolist()
            print('extracting features')
            df_feat = extractor.extract_features(df.copy())
            
            with open('scaler.pkl', 'rb') as f:
                scaler = pickle.load(f)
            numerical_data = df_feat[numerical_feature_cols].fillna(0).values
            numerical_scaled = scaler.transform(numerical_data)
            df_feat[numerical_feature_cols] = pd.DataFrame(numerical_scaled, index=df_feat.index, columns=numerical_feature_cols)

            print('turning into dict')
            input_dict = to_input(df_feat)
            print('getting q_scores')
            temp_input_structure_for_users = get_q_score(input_dict)
            print('Tokenizing and writing user by user to JSONL file...')
            output_file_path = 'test_data.jsonl'
            with open(output_file_path, 'w', encoding='utf-8') as outfile:
                for i, user_data_item in enumerate(temp_input_structure_for_users):
                    print(f"Processing user {i+1}/{len(temp_input_structure_for_users)}")

                    # Tokenize comments for this user (apply chunking here if a user has many comments)
                    user_comments_list = user_data_item.get('comments', [])
                    string_comments = [str(c) for c in user_comments_list if isinstance(c, (str, int, float))] # Ensure strings
                    
                    # Implement your chunked pretokenize logic here for string_comments
                    # For simplicity, let's assume pretokenize handles a potentially large list
                    # by chunking or that a single user's tokenization is manageable
                    if string_comments:
                        # --- Incorporate the chunking logic for pretokenize for this user ---
                        MAX_COMMENTS_PER_TOKENIZE_CHUNK = 1024 
                        all_tokenized_chunks_for_user = {"input_ids": [], "token_type_ids": [], "attention_mask": []}
                        for chunk_start_idx in range(0, len(string_comments), MAX_COMMENTS_PER_TOKENIZE_CHUNK):
                            comment_chunk = string_comments[chunk_start_idx : chunk_start_idx + MAX_COMMENTS_PER_TOKENIZE_CHUNK]
                            tokenized_output_chunk = pretokenize(comment_chunk, return_tensors_type='pt',tokenizer = tokenizer)
                            all_tokenized_chunks_for_user["input_ids"].append(tokenized_output_chunk["input_ids"])
                            all_tokenized_chunks_for_user["token_type_ids"].append(tokenized_output_chunk["token_type_ids"])
                            all_tokenized_chunks_for_user["attention_mask"].append(tokenized_output_chunk["attention_mask"])
                        
                        # Concatenate chunks for this user
                        final_input_ids = torch.cat(all_tokenized_chunks_for_user["input_ids"], dim=0) if all_tokenized_chunks_for_user["input_ids"] else torch.empty(0,0, dtype=torch.long)
                        # ... similar for token_type_ids, attention_mask
                        final_token_type_ids = torch.cat(all_tokenized_chunks_for_user["token_type_ids"], dim=0) if all_tokenized_chunks_for_user["token_type_ids"] else torch.empty(0,0, dtype=torch.long)
                        final_attention_mask = torch.cat(all_tokenized_chunks_for_user["attention_mask"], dim=0) if all_tokenized_chunks_for_user["attention_mask"] else torch.empty(0,0, dtype=torch.long)

                        tokenized_data_for_this_user = BatchEncoding({
                            "input_ids": final_input_ids,
                            "token_type_ids": final_token_type_ids,
                            "attention_mask": final_attention_mask
                        })
                        user_data_item['features']['comments_tokenized'] = tokenized_data_for_this_user
                    else:
                        user_data_item['features']['comments_tokenized'] = pretokenize([]) # empty tokenization

                    # Encode just this user's data for JSON
                    del user_data_item['comments']
                    encoded_user_data = encode_for_json(user_data_item)
                    
                    # Write this user's JSON object, followed by a newline
                    json.dump(encoded_user_data, outfile)
                    outfile.write('\n')
                    
                    # Optionally, clear parts of user_data_item if it helps release memory sooner,
                    # though Python's GC should handle it as `user_data_item` gets redefined in the next iteration.
                    # del user_data_item['features']['comments_tokenized'] # Tensors are now in encoded_user_data as lists

            print(f"Finished processing. Data streamed to {output_file_path}")
            return None # Or path to the file, as data_dict is not returned






    else: #TRAIN
        scaler = MinMaxScaler(feature_range=(-1, 1))
        print('processing train')
        #print('cleaning')
        #df_clean = clean_train(df_path)
        print('extracting features')
        df_feat = extractor.extract_features(df_path.copy())
        
        # scale features
        numerical_data = df_feat[numerical_feature_cols].fillna(0).values
        scaler.fit(numerical_data)
        numerical_scaled = scaler.transform(numerical_data)
        df_feat[numerical_feature_cols] = pd.DataFrame(numerical_scaled, index=df_feat.index, columns=numerical_feature_cols)
        with open('scaler.pkl','wb') as f:
            pickle.dump(scaler,f)
        df_feat.to_csv('df_train_feat.csv')
        #show_corr(df_feat, save=True,save_name='train_corr')
        print('turning intoo dict')
        temp_input_structure_for_users = to_input(df_feat) 

        print('appending q_scores')
        # append_q_score_train modifies temp_input_structure_for_users in place
        temp_input_structure_for_users = append_q_score_train(temp_input_structure_for_users, r'q_scored.json') 

        print('Tokenizing and writing user by user to JSONL file...')
        output_file_path = 'train_data.jsonl'
        with open(output_file_path, 'w', encoding='utf-8') as outfile:
            for i, user_data_item in enumerate(temp_input_structure_for_users):
                print(f"Processing user {i+1}/{len(temp_input_structure_for_users)}")

                # Tokenize comments for this user (apply chunking here if a user has many comments)
                user_comments_list = user_data_item.get('comments', [])
                string_comments = [str(c) for c in user_comments_list if isinstance(c, (str, int, float))] # Ensure strings
                
                # Implement your chunked pretokenize logic here for string_comments
                # For simplicity, let's assume pretokenize handles a potentially large list
                # by chunking or that a single user's tokenization is manageable
                if string_comments:
                    # --- Incorporate the chunking logic for pretokenize for this user ---
                    MAX_COMMENTS_PER_TOKENIZE_CHUNK = 1024 
                    all_tokenized_chunks_for_user = {"input_ids": [], "token_type_ids": [], "attention_mask": []}
                    for chunk_start_idx in range(0, len(string_comments), MAX_COMMENTS_PER_TOKENIZE_CHUNK):
                        comment_chunk = string_comments[chunk_start_idx : chunk_start_idx + MAX_COMMENTS_PER_TOKENIZE_CHUNK]
                        tokenized_output_chunk = pretokenize(comment_chunk, return_tensors_type='pt', tokenizer=tokenizer)
                        all_tokenized_chunks_for_user["input_ids"].append(tokenized_output_chunk["input_ids"])
                        all_tokenized_chunks_for_user["token_type_ids"].append(tokenized_output_chunk["token_type_ids"])
                        all_tokenized_chunks_for_user["attention_mask"].append(tokenized_output_chunk["attention_mask"])
                    
                    # Concatenate chunks for this user
                    final_input_ids = torch.cat(all_tokenized_chunks_for_user["input_ids"], dim=0) if all_tokenized_chunks_for_user["input_ids"] else torch.empty(0,0, dtype=torch.long)
                    # ... similar for token_type_ids, attention_mask
                    final_token_type_ids = torch.cat(all_tokenized_chunks_for_user["token_type_ids"], dim=0) if all_tokenized_chunks_for_user["token_type_ids"] else torch.empty(0,0, dtype=torch.long)
                    final_attention_mask = torch.cat(all_tokenized_chunks_for_user["attention_mask"], dim=0) if all_tokenized_chunks_for_user["attention_mask"] else torch.empty(0,0, dtype=torch.long)

                    tokenized_data_for_this_user = BatchEncoding({
                        "input_ids": final_input_ids,
                        "token_type_ids": final_token_type_ids,
                        "attention_mask": final_attention_mask
                    })
                    user_data_item['features']['comments_tokenized'] = tokenized_data_for_this_user
                else:
                    user_data_item['features']['comments_tokenized'] = pretokenize([]) # empty tokenization

                # Encode just this user's data for JSON
                del user_data_item['comments']
                encoded_user_data = encode_for_json(user_data_item)
                
                # Write this user's JSON object, followed by a newline
                json.dump(encoded_user_data, outfile)
                outfile.write('\n')
                
                # Optionally, clear parts of user_data_item if it helps release memory sooner,
                # though Python's GC should handle it as `user_data_item` gets redefined in the next iteration.
                # del user_data_item['features']['comments_tokenized'] # Tensors are now in encoded_user_data as lists

        print(f"Finished processing. Data streamed to {output_file_path}")
        return None # Or path to the file, as data_dict is not returned






In [None]:
import json
import torch
from torch.utils.data import IterableDataset
from transformers.tokenization_utils_base import BatchEncoding # For your decode_from_json
import logging
import random
import numpy as np
import torch.nn.functional as F
from transformers import BertModel, BertConfig, get_linear_schedule_with_warmup
from typing import Optional, Tuple, Dict, Union
from torch import nn
import optuna
from torch.utils.data import DataLoader
import gc
from transformers.tokenization_utils_base import BatchEncoding # For type checking and instantiation
import torch.optim as optim
import os
import shutil # Keep for now, might be useful for other file ops if needed later

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants for JSON (ensure these match what you used when saving) ---
_TENSOR_MARKER = "__tensor__"
_TENSOR_DTYPE_MARKER = "__tensor_dtype__"
_BATCH_ENCODING_MARKER = "__batch_encoding__"
_BATCH_ENCODING_DATA_MARKER = "data"

def _convert_str_to_dtype(dtype_str: str) -> torch.dtype:
    if not dtype_str.startswith("torch."):
        try:
            return torch.__getattribute__(dtype_str)
        except AttributeError:
            return torch.dtype(dtype_str)
    dtype_name = dtype_str.split('.')[1]
    return torch.__getattribute__(dtype_name)

def _json_object_hook_for_dataset(dct: dict) -> any:
    if _TENSOR_MARKER in dct:
        dtype_str = dct.get(_TENSOR_DTYPE_MARKER, 'float32')
        dtype = _convert_str_to_dtype(dtype_str)
        return torch.tensor(dct[_BATCH_ENCODING_DATA_MARKER], dtype=dtype)
    elif _BATCH_ENCODING_MARKER in dct:
        reconstructed_data_for_be = {}
        batch_encoding_payload = dct.get(_BATCH_ENCODING_DATA_MARKER, {})
        for k, v_data in batch_encoding_payload.items():
            if isinstance(v_data, list) and k in ["input_ids", "token_type_ids", "attention_mask"]:
                try:
                    tensor_dtype = torch.long if k in ["input_ids", "token_type_ids"] else torch.long
                    reconstructed_data_for_be[k] = torch.tensor(v_data, dtype=tensor_dtype)
                except Exception as e:
                    logger.error(f"Error converting field '{k}' in BatchEncoding to tensor: {e}. Keeping as list.")
                    reconstructed_data_for_be[k] = v_data
            else:
                reconstructed_data_for_be[k] = v_data
        return BatchEncoding(reconstructed_data_for_be)
    return dct

class JsonlIterableDataset(IterableDataset):
    def __init__(self, file_path, trait_names, n_comments_to_process,
                 other_numerical_feature_names, num_q_features_per_comment,
                 is_test_set=False, transform_fn=None, num_samples = None):
        super().__init__()
        self.file_path = file_path
        self.trait_names_ordered = trait_names
        self.n_comments_to_process = n_comments_to_process
        self.other_numerical_feature_names = other_numerical_feature_names
        self.num_q_features_per_comment = num_q_features_per_comment
        self.is_test_set = is_test_set
        self.transform_fn = self._default_transform if transform_fn is None else transform_fn
        if num_samples is None:
            logger.info(f'Counting samples in {file_path} for __len__ was not provided...')
            self.num_samples = self._count_samples_in_file()
            logger.info(f"Counted {self.num_samples} samples in {self.file_path}.")
        else:
            self.num_samples = num_samples
        if self.num_samples == 0:
            logger.warning(f"Initialized JsonlIterableDataset for {self.file_path} with 0 samples. DataLoader will be empty.")

    def _count_samples_in_file(self):
            count = 0
            try:
                with open(self.file_path, 'r', encoding='utf-8') as f:
                    for _ in f:
                        count += 1
            except FileNotFoundError:
                logger.error(f"File not found during initial sample count: {self.file_path}. Returning 0 samples.")
                return 0
            except Exception as e:
                logger.error(f"Error during initial sample count for {self.file_path}: {e}. Returning 0 samples.")
                return 0
            return count

    def _process_line(self, line):
        try:
            sample = json.loads(line, object_hook=_json_object_hook_for_dataset)
            return self.transform_fn(sample, idx=None)
        except json.JSONDecodeError:
            return None
        except Exception:
            return None

    def __len__(self):
        return self.num_samples

    def _default_transform(self, sample, idx):
        tokenized_info = sample.get('features', {}).get('comments_tokenized', {})
        all_input_ids = tokenized_info['input_ids']
        all_attention_mask = tokenized_info['attention_mask']

        num_actual_comments = all_input_ids.shape[0]
        final_input_ids = torch.zeros((self.n_comments_to_process, all_input_ids.shape[1]), dtype=torch.long)
        final_attention_mask = torch.zeros((self.n_comments_to_process, all_attention_mask.shape[1]), dtype=torch.long)
        comment_active_flags = torch.zeros(self.n_comments_to_process, dtype=torch.bool)

        indices_to_select = list(range(num_actual_comments))
        if num_actual_comments > self.n_comments_to_process:
            indices_to_select = random.sample(indices_to_select, self.n_comments_to_process)
            comments_to_fill = self.n_comments_to_process
        else:
            comments_to_fill = num_actual_comments

        for i in range(comments_to_fill):
            original_idx = indices_to_select[i]
            final_input_ids[i] = all_input_ids[original_idx]
            final_attention_mask[i] = all_attention_mask[original_idx]
            comment_active_flags[i] = True

        raw_q_scores = sample['features'].get('q_scores', [])
        final_q_scores = torch.zeros((self.n_comments_to_process, self.num_q_features_per_comment), dtype=torch.float)

        selected_raw_q_scores = []
        for i in range(comments_to_fill):
            original_comment_idx = indices_to_select[i]
            if original_comment_idx < len(raw_q_scores):
                qs_for_comment = raw_q_scores[original_comment_idx][:self.num_q_features_per_comment]
                padded_qs = qs_for_comment + [0.0] * (self.num_q_features_per_comment - len(qs_for_comment))
                selected_raw_q_scores.append(padded_qs[:self.num_q_features_per_comment])
            else:
                selected_raw_q_scores.append([0.0] * self.num_q_features_per_comment)

        if comments_to_fill > 0 and selected_raw_q_scores:
            try:
                final_q_scores[:comments_to_fill] = torch.tensor(selected_raw_q_scores, dtype=torch.float)
            except Exception as e:
                logger.error(f"Error converting selected_raw_q_scores to tensor: {e}. Data: {selected_raw_q_scores}")

        other_numerical_features_list = []
        for fname in self.other_numerical_feature_names:
            val = sample['features'].get(fname, 0.0)
            try:
                other_numerical_features_list.append(float(val))
            except (ValueError, TypeError):
                other_numerical_features_list.append(0.0)
        other_numerical_features_tensor = torch.tensor(other_numerical_features_list, dtype=torch.float)

        if not self.is_test_set:
            labels_dict = sample['labels']
            regression_labels = []
            for trait_key in self.trait_names_ordered:
                label_val = labels_dict.get(trait_key.title(), labels_dict.get(trait_key, 0.0))
                try:
                    label_float = float(label_val)
                    if not (0.0 <= label_float <= 1.0): label_float = np.clip(label_float, 0.0, 1.0)
                    regression_labels.append(label_float)
                except (ValueError, TypeError): regression_labels.append(0.0)
            labels_tensor = torch.tensor(regression_labels, dtype=torch.float)
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor, labels_tensor)
        else:
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        try:
            file_iter = open(self.file_path, 'r', encoding='utf-8')
        except FileNotFoundError:
            logger.error(f"File not found in __iter__: {self.file_path}. Yielding nothing.")
            return

        if worker_info is None:
            for line in file_iter:
                processed_item = self._process_line(line)
                if processed_item:
                    yield processed_item
        else:
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            for i, line in enumerate(file_iter):
                if i % num_workers == worker_id:
                    processed_item = self._process_line(line)
                    if processed_item:
                        yield processed_item
        file_iter.close()


class PersonalityModelV3(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 num_traits: int,
                 n_comments_to_process: int = 3,
                 dropout_rate: float = 0.2,
                 attention_hidden_dim: int = 128,
                 num_bert_layers_to_pool: int = 4,
                 num_q_features_per_comment: int = 3,
                 num_other_numerical_features: int = 0,
                 numerical_embedding_dim: int = 64,
                 num_additional_dense_layers: int = 0,
                 additional_dense_hidden_dim: int = 256,
                 additional_layers_dropout_rate: float = 0.3
                ):
        super().__init__()
        self.bert_config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)
        self.bert = BertModel.from_pretrained(bert_model_name, config=self.bert_config)
        self.n_comments_to_process = n_comments_to_process
        self.num_bert_layers_to_pool = num_bert_layers_to_pool
        bert_hidden_size = self.bert.config.hidden_size
        self.num_q_features_per_comment = num_q_features_per_comment

        comment_feature_dim = bert_hidden_size + self.num_q_features_per_comment
        self.attention_w = nn.Linear(comment_feature_dim, attention_hidden_dim)
        self.attention_v = nn.Linear(attention_hidden_dim, 1, bias=False)
        
        self.final_dropout_layer = nn.Dropout(dropout_rate) 

        self.num_other_numerical_features = num_other_numerical_features
        self.uses_other_numerical_features = self.num_other_numerical_features > 0
        self.other_numerical_processor_output_dim = 0
        
        aggregated_comment_feature_dim = comment_feature_dim 
        combined_input_dim_for_block = aggregated_comment_feature_dim

        if self.uses_other_numerical_features:
            self.other_numerical_processor_output_dim = numerical_embedding_dim
            self.other_numerical_processor = nn.Sequential(
                nn.Linear(self.num_other_numerical_features, self.other_numerical_processor_output_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )
            combined_input_dim_for_block += self.other_numerical_processor_output_dim
            logger.info(f"Model will use {self.num_other_numerical_features} other numerical features, processed to dim {self.other_numerical_processor_output_dim}.")
        else:
            logger.info("Model will NOT use other numerical features.")

        self.num_additional_dense_layers = num_additional_dense_layers
        self.additional_dense_block = nn.Sequential()
        current_dim_for_dense_block = combined_input_dim_for_block

        if self.num_additional_dense_layers > 0:
            logger.info(f"Model using {self.num_additional_dense_layers} additional dense layers with hidden_dim {additional_dense_hidden_dim} and dropout {additional_layers_dropout_rate}")
            for i in range(self.num_additional_dense_layers):
                self.additional_dense_block.add_module(f"add_dense_{i}_linear", nn.Linear(current_dim_for_dense_block, additional_dense_hidden_dim))
                self.additional_dense_block.add_module(f"add_dense_{i}_relu", nn.ReLU())
                self.additional_dense_block.add_module(f"add_dense_{i}_dropout", nn.Dropout(additional_layers_dropout_rate))
                current_dim_for_dense_block = additional_dense_hidden_dim
            input_dim_for_regressors = current_dim_for_dense_block
        else:
            logger.info("Model not using additional dense layers. Will use final_dropout_layer if dropout_rate > 0.")
            input_dim_for_regressors = combined_input_dim_for_block

        self.trait_regressors = nn.ModuleList()
        for _ in range(num_traits):
            self.trait_regressors.append(
                nn.Linear(input_dim_for_regressors, 1)
            )

    def _pool_bert_layers(self, all_hidden_states: Tuple[torch.Tensor, ...], attention_mask: torch.Tensor) -> torch.Tensor:
        layers_to_pool = all_hidden_states[-self.num_bert_layers_to_pool:]
        pooled_outputs = []
        expanded_attention_mask = attention_mask.unsqueeze(-1).expand_as(layers_to_pool[0])
        
        for layer_hidden_states in layers_to_pool:
            sum_embeddings = torch.sum(layer_hidden_states * expanded_attention_mask, dim=1)
            sum_mask = expanded_attention_mask.sum(dim=1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled_outputs.append(sum_embeddings / sum_mask)
            
        stacked_pooled_outputs = torch.stack(pooled_outputs, dim=0)
        mean_pooled_layers_embedding = torch.mean(stacked_pooled_outputs, dim=0)
        return mean_pooled_layers_embedding

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                q_scores: torch.Tensor,
                comment_active_mask: torch.Tensor,
                other_numerical_features: Optional[torch.Tensor] = None
               ):
        batch_size = input_ids.shape[0]
        
        input_ids_flat = input_ids.view(-1, input_ids.shape[-1])
        attention_mask_flat = attention_mask.view(-1, attention_mask.shape[-1])
        
        bert_outputs = self.bert(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        comment_bert_embeddings_flat = self._pool_bert_layers(bert_outputs.hidden_states, attention_mask_flat)
        comment_bert_embeddings = comment_bert_embeddings_flat.view(batch_size, self.n_comments_to_process, -1)
        
        comment_features_with_q = torch.cat((comment_bert_embeddings, q_scores), dim=2)
        
        u = torch.tanh(self.attention_w(comment_features_with_q))
        scores = self.attention_v(u).squeeze(-1)
        
        if comment_active_mask is not None:
            scores = scores.masked_fill(~comment_active_mask, -1e9)
            
        attention_weights = F.softmax(scores, dim=1)
        attention_weights_expanded = attention_weights.unsqueeze(-1)
        
        aggregated_comment_features = torch.sum(attention_weights_expanded * comment_features_with_q, dim=1)

        final_features_for_processing = aggregated_comment_features
        if self.uses_other_numerical_features:
            if other_numerical_features is None or other_numerical_features.shape[1] != self.num_other_numerical_features:
                raise ValueError(
                    f"Other numerical features expected but not provided correctly. "
                    f"Expected {self.num_other_numerical_features}, got shape {other_numerical_features.shape if other_numerical_features is not None else 'None'}"
                )
            processed_other_numerical_features = self.other_numerical_processor(other_numerical_features)
            final_features_for_processing = torch.cat((aggregated_comment_features, processed_other_numerical_features), dim=1)
        
        if self.num_additional_dense_layers > 0:
            features_for_trait_heads = self.additional_dense_block(final_features_for_processing)
        else:
            features_for_trait_heads = self.final_dropout_layer(final_features_for_processing)
        
        trait_regression_outputs = []
        for regressor_head in self.trait_regressors:
            trait_regression_outputs.append(regressor_head(features_for_trait_heads))
        
        all_trait_outputs_raw = torch.cat(trait_regression_outputs, dim=1)
        all_trait_outputs_sigmoid = torch.sigmoid(all_trait_outputs_raw)
        
        return all_trait_outputs_sigmoid

    def predict_scores(self, outputs: torch.Tensor) -> torch.Tensor:
        return outputs

# --- Optuna Objective Function (MODIFIED for overall best model saving) ---
def objective(trial: optuna.trial.Trial,
              train_file_path: str,
              val_file_path: str,
              global_config: Dict,
              device: torch.device,
              num_epochs_per_trial: int, # Removed default
              ### NEW: Pass the path for saving the overall best model weights ###
              overall_best_weights_filepath: str 
             ):
    logger.info(f"Starting Optuna Trial {trial.number}")

    num_traits = len(global_config['TRAIT_NAMES'])
    other_numerical_feature_names_trial = global_config.get('OTHER_NUMERICAL_FEATURE_NAMES', [])
    num_other_numerical_features_trial = len(other_numerical_feature_names_trial)
    num_q_features_per_comment_trial = global_config.get('NUM_Q_FEATURES_PER_COMMENT', 3)

    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    attention_hidden_dim = trial.suggest_categorical("attention_hidden_dim", [128, 256, 512])
    lr_bert = trial.suggest_float("lr_bert", 5e-6, 1e-4, log=True)
    lr_head = trial.suggest_float("lr_head", 1e-4, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    num_bert_layers_to_pool = trial.suggest_int("num_bert_layers_to_pool", 1, 4)
    n_comments_trial = trial.suggest_int("n_comments_to_process", 3, global_config.get('MAX_COMMENTS_TO_PROCESS_PHYSICAL', 3))
    num_unfrozen_bert_layers = trial.suggest_int("num_unfrozen_bert_layers", 0, 6)
    patience_early_stopping = trial.suggest_int("patience_early_stopping", 3, 5)
    scheduler_type = trial.suggest_categorical("scheduler_type", ["none", "linear_warmup"])
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.05, 0.2) if scheduler_type != "none" else 0.0
    batch_size_trial = trial.suggest_categorical("batch_size", [8, 16, 32])

    other_numerical_embedding_dim_trial = 0
    if num_other_numerical_features_trial > 0:
        other_numerical_embedding_dim_trial = trial.suggest_categorical("other_numerical_embedding_dim", [32, 64, 128])

    num_additional_dense_layers_trial = trial.suggest_int("num_additional_dense_layers", 0, 3)
    additional_dense_hidden_dim_trial = 0
    additional_layers_dropout_rate_trial = 0.0
    if num_additional_dense_layers_trial > 0:
        additional_dense_hidden_dim_trial = trial.suggest_categorical("additional_dense_hidden_dim", [128, 256, 512])
        additional_layers_dropout_rate_trial = trial.suggest_float("additional_layers_dropout_rate", 0.1, 0.5)

    logger.info(f"Trial {trial.number} - Suggested Parameters: {trial.params}")
    try:
        train_dataset_trial = JsonlIterableDataset(
            file_path=train_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_TRAIN_SAMPLES')
        )
        val_dataset_trial = JsonlIterableDataset(
            file_path=val_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_VAL_SAMPLES')
        )
        train_loader_trial = DataLoader(train_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
        val_loader_trial = DataLoader(val_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
    except Exception as e:
        logger.error(f"Trial {trial.number} - Error creating dataset/dataloader: {e}", exc_info=True)
        return float('inf')

    model = PersonalityModelV3(
        bert_model_name=global_config['BERT_MODEL_NAME'],
        num_traits=num_traits,
        n_comments_to_process=n_comments_trial,
        dropout_rate=dropout_rate,
        attention_hidden_dim=attention_hidden_dim,
        num_bert_layers_to_pool=num_bert_layers_to_pool,
        num_q_features_per_comment=num_q_features_per_comment_trial,
        num_other_numerical_features=num_other_numerical_features_trial,
        numerical_embedding_dim=other_numerical_embedding_dim_trial,
        num_additional_dense_layers=num_additional_dense_layers_trial,
        additional_dense_hidden_dim=additional_dense_hidden_dim_trial,
        additional_layers_dropout_rate=additional_layers_dropout_rate_trial
    ).to(device)

    for name, param in model.bert.named_parameters(): param.requires_grad = False
    if num_unfrozen_bert_layers > 0:
        if hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters(): param.requires_grad = True
        actual_layers_to_unfreeze = min(num_unfrozen_bert_layers, model.bert.config.num_hidden_layers)
        for i in range(model.bert.config.num_hidden_layers - actual_layers_to_unfreeze, model.bert.config.num_hidden_layers):
            if i >= 0 and i < model.bert.config.num_hidden_layers :
                for param in model.bert.encoder.layer[i].parameters(): param.requires_grad = True
        if hasattr(model.bert, 'pooler') and model.bert.pooler is not None:
            for param in model.bert.pooler.parameters(): param.requires_grad = True
    
    logger.debug(f"Trial {trial.number} - BERT params requiring grad: "
                 f"{sum(p.numel() for p in model.bert.parameters() if p.requires_grad)}")

    optimizer_grouped_parameters = []
    bert_params_to_tune = [p for p in model.bert.parameters() if p.requires_grad]
    if bert_params_to_tune and lr_bert > 0:
         optimizer_grouped_parameters.append({"params": bert_params_to_tune, "lr": lr_bert, "weight_decay": 0.01})

    head_params = []
    head_params.extend(list(model.attention_w.parameters()))
    head_params.extend(list(model.attention_v.parameters()))
    if model.uses_other_numerical_features:
        head_params.extend(list(model.other_numerical_processor.parameters()))
    if model.num_additional_dense_layers > 0:
        head_params.extend(list(model.additional_dense_block.parameters()))
    for regressor_head in model.trait_regressors:
        head_params.extend(list(regressor_head.parameters()))
    
    if head_params:
        optimizer_grouped_parameters.append({"params": head_params, "lr": lr_head, "weight_decay": weight_decay})
        
    if not any(pg['params'] for pg in optimizer_grouped_parameters if pg.get('params')):
        logger.warning(f"Trial {trial.number} - No parameters to optimize. Skipping training.")
        return float('inf')

    optimizer = optim.AdamW(optimizer_grouped_parameters)
    
    scheduler = None
    if scheduler_type == "linear_warmup":
        if global_config.get('NUM_TRAIN_SAMPLES', 0) > 0:
            num_batches_per_epoch = (global_config['NUM_TRAIN_SAMPLES'] + batch_size_trial - 1) // batch_size_trial
            num_training_steps = num_batches_per_epoch * num_epochs_per_trial
            num_warmup_steps = int(num_training_steps * warmup_ratio)
            if num_warmup_steps > 0 and num_training_steps > 0:
                scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
            else:
                logger.warning(f"Trial {trial.number}: Calculated num_warmup_steps or num_training_steps is zero. Scheduler not created. Warmup: {num_warmup_steps}, Training: {num_training_steps}")
        else:
            logger.warning(f"Trial {trial.number}: NUM_TRAIN_SAMPLES not available or zero in global_config. Cannot create linear_warmup scheduler.")

    loss_fn = nn.L1Loss().to(device)
    best_val_loss_this_trial = float('inf') # For early stopping within this trial
    patience_counter = 0
                
    for epoch in range(num_epochs_per_trial):
        model.train()
        total_train_loss = 0
        train_batches_processed = 0
        for batch_idx, batch_tuple in enumerate(train_loader_trial):
            input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
            optimizer.zero_grad()
            predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
            current_batch_loss = loss_fn(predicted_scores, labels_reg)
            if torch.isnan(current_batch_loss) or torch.isinf(current_batch_loss):
                logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: NaN or Inf loss detected. Skipping batch.")
                torch.cuda.empty_cache()
                continue
            current_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler: scheduler.step()
            total_train_loss += current_batch_loss.item()
            train_batches_processed += 1
            
        avg_train_loss = total_train_loss / train_batches_processed if train_batches_processed > 0 else float('inf')
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}/{num_epochs_per_trial} completed. Avg Train Loss: {avg_train_loss:.4f}")

        model.eval()
        current_epoch_val_loss = 0
        val_batches_processed = 0
        all_val_preds_epoch = []
        all_val_labels_epoch = []
        with torch.no_grad():
            for batch_tuple in val_loader_trial:
                input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
                if input_ids.numel() == 0: continue
                predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                if predicted_scores.numel() == 0: continue
                batch_val_loss = loss_fn(predicted_scores, labels_reg)
                current_epoch_val_loss += batch_val_loss.item()
                all_val_preds_epoch.append(predicted_scores.cpu())
                all_val_labels_epoch.append(labels_reg.cpu())
                val_batches_processed += 1

        avg_val_loss_epoch = current_epoch_val_loss / val_batches_processed if val_batches_processed > 0 else float('inf')
        
        val_mae = -1.0
        if all_val_labels_epoch and all_val_preds_epoch:
            all_val_labels_cat = torch.cat(all_val_labels_epoch, dim=0)
            all_val_preds_cat = torch.cat(all_val_preds_epoch, dim=0)
            if all_val_labels_cat.numel() > 0 and all_val_preds_cat.numel() > 0:
                val_mae = F.l1_loss(all_val_preds_cat, all_val_labels_cat).item()

        logger.info(f"Trial {trial.number}, Epoch {epoch+1} Val Loss (MSE): {avg_val_loss_epoch:.4f}, Val MAE: {val_mae:.4f}")

        # Check for improvement for early stopping within this trial
        if avg_val_loss_epoch < best_val_loss_this_trial:
            best_val_loss_this_trial = avg_val_loss_epoch
            patience_counter = 0
            logger.debug(f"Trial {trial.number}, Epoch {epoch+1}: New best val_loss for this trial: {best_val_loss_this_trial:.4f}")
        else:
            patience_counter += 1
        
        ### MODIFIED: Check against overall best and save if better ###
        # Ensure study user_attrs are available (should be, unless running trial standalone)
        if hasattr(trial, 'study') and trial.study is not None:
            current_overall_best_loss = trial.study.user_attrs.get("overall_best_val_loss", float('inf'))
            if avg_val_loss_epoch < current_overall_best_loss:
                logger.info(f"Trial {trial.number}, Epoch {epoch+1}: New OVERALL best val_loss: {avg_val_loss_epoch:.4f} (Previous overall best: {current_overall_best_loss:.4f}). Saving model.")
                trial.study.set_user_attr("overall_best_val_loss", avg_val_loss_epoch)
                trial.study.set_user_attr("overall_best_trial_number", trial.number)
                trial.study.set_user_attr("overall_best_epoch", epoch + 1)
                # Save model state dict (on CPU to be safe)
                model_state_dict_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
                torch.save(model_state_dict_cpu, overall_best_weights_filepath)
                logger.info(f"Trial {trial.number}: Saved new OVERALL best model weights to {overall_best_weights_filepath}")
        else:
            logger.warning(f"Trial {trial.number}: Cannot access study.user_attrs to check/update overall best model.")


        trial.report(avg_val_loss_epoch, epoch)
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned by Optuna at epoch {epoch+1}.")
            del model, train_loader_trial, val_loader_trial, optimizer, scheduler
            torch.cuda.empty_cache(); gc.collect()
            return best_val_loss_this_trial # Return this trial's best loss for Optuna's pruning logic
        
        if patience_counter >= patience_early_stopping:
            logger.info(f"Trial {trial.number} - Early stopping at epoch {epoch+1} (Patience: {patience_early_stopping}).")
            break
        
    logger.info(f"Trial {trial.number} finished. Best Val Loss (MSE) for this trial: {best_val_loss_this_trial:.4f}")
    del model, train_loader_trial, val_loader_trial, optimizer, scheduler
    torch.cuda.empty_cache(); gc.collect()
    return best_val_loss_this_trial # Return the best validation loss achieved in *this specific trial*

In [None]:
# Assuming PersonalityDatasetV3, PersonalityModelV3, objective are defined/imported

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

TRAIN_DATA_FILE = "train_data.jsonl" 
VAL_DATA_FILE = "val_data.jsonl"
TEST_DATA_FILE = "test_data.jsonl"

_trait_names_ordered_config = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Emotional stability', 'Humility']
_other_numerical_features_config = [
    'mean_words_per_comment', 'mean_sents_per_comment',
    'median_words_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
]

GLOBAL_CONFIG = {
    'BERT_MODEL_NAME': "bert-base-uncased",
    'TRAIT_NAMES_ORDERED': _trait_names_ordered_config,
    'TRAIT_NAMES': _trait_names_ordered_config,
    'MAX_COMMENTS_TO_PROCESS_PHYSICAL': 6,
    'NUM_Q_FEATURES_PER_COMMENT': 3,
    'OTHER_NUMERICAL_FEATURE_NAMES': _other_numerical_features_config,
    'TOKENIZER_MAX_LENGTH': 256
}

NUM_EPOCHS_PER_TRIAL_OPTUNA = 15
N_OPTUNA_TRIALS = 20

def count_lines_in_file(filepath):
    try:
        count = 0
        with open(filepath, 'r', encoding='utf-8') as f:
            for _ in f:
                count += 1
        return count
    except FileNotFoundError:
        logger.error(f"File not found for line counting: {filepath}. Returning 0.")
        return 0
    except Exception as e:
        logger.error(f"Error counting lines in {filepath}: {e}. Returning 0.")
        return 0

NUM_TRAIN_SAMPLES = count_lines_in_file(TRAIN_DATA_FILE)
if NUM_TRAIN_SAMPLES == 0:
    logger.error(f"Training file {TRAIN_DATA_FILE} is empty or not found. Exiting.")
    exit()
GLOBAL_CONFIG['NUM_TRAIN_SAMPLES'] = NUM_TRAIN_SAMPLES
logger.info(f"Number of training samples: {NUM_TRAIN_SAMPLES}")

NUM_VAL_SAMPLES = count_lines_in_file(VAL_DATA_FILE)
if NUM_VAL_SAMPLES == 0:
    logger.warning(f"Validation file {VAL_DATA_FILE} is empty or not found. Validation might not work as expected.")
GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = NUM_VAL_SAMPLES
logger.info(f"Number of validation samples: {NUM_VAL_SAMPLES}")

logger.info(f"Starting Optuna study: {N_OPTUNA_TRIALS} trials, up to {NUM_EPOCHS_PER_TRIAL_OPTUNA} epochs/trial.")

study_name = "personality_regression_v8_overall_best"
storage_name = f"sqlite:///{study_name}.db"
BEST_PARAMS_FILENAME = f"{study_name}_best_params.json"
BEST_WEIGHTS_FILENAME = f"{study_name}_best_weights.pth" # This is the single file for the overall best model

study = optuna.create_study(study_name=study_name,
                            direction="minimize",
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=3, n_min_trials=5, interval_steps=1),
                            storage=storage_name,
                            load_if_exists=True)

# Initialize overall_best_val_loss in study.user_attrs if it doesn't exist
if "overall_best_val_loss" not in study.user_attrs:
    study.set_user_attr("overall_best_val_loss", float('inf'))
    logger.info(f"Initialized 'overall_best_val_loss' in study user_attrs to infinity.")
else:
    logger.info(f"Resuming study. Current 'overall_best_val_loss' in study user_attrs: {study.user_attrs['overall_best_val_loss']:.4f}")


if study.trials: logger.info(f"Resuming existing study {study.study_name} with {len(study.trials)} previous trials.")

try:
    study.optimize(
        lambda trial: objective(
            trial, TRAIN_DATA_FILE, VAL_DATA_FILE,
            GLOBAL_CONFIG, DEVICE, 
            num_epochs_per_trial=NUM_EPOCHS_PER_TRIAL_OPTUNA,
            overall_best_weights_filepath=BEST_WEIGHTS_FILENAME # Pass the path here
        ),
        n_trials=N_OPTUNA_TRIALS,
        gc_after_trial=True,
    )
except Exception as e:
    logger.exception("An error occurred during the Optuna study.")

logger.info("\n--- Optuna Study Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")

best_trial_overall_from_study_obj = None 

if not study.trials:
    logger.warning("No trials were completed in the study.")
else:
    try:
        completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE and t.value is not None]
        if completed_trials:
            best_trial_overall_from_study_obj = study.best_trial # Optuna's record of the best trial

            if best_trial_overall_from_study_obj:
                logger.info(f"Optuna's Best Trial (based on reported values):")
                logger.info(f"  Number: {best_trial_overall_from_study_obj.number}")
                logger.info(f"  Value (Validation Loss - MSE): {best_trial_overall_from_study_obj.value:.4f}") # This is the value *returned* by the objective for that trial
                logger.info("  Best Params (from this trial): ")
                for key, value in best_trial_overall_from_study_obj.params.items():
                    logger.info(f"    {key}: {value}")

                # Save the hyperparameters of Optuna's identified best trial
                with open(BEST_PARAMS_FILENAME, 'w') as f:
                    json.dump(best_trial_overall_from_study_obj.params, f, indent=4)
                logger.info(f"Best hyperparameters (from trial {best_trial_overall_from_study_obj.number}) saved to {BEST_PARAMS_FILENAME}")

                # The BEST_WEIGHTS_FILENAME should already contain the weights of the overall best model saved during the study.
                # We can log information about which trial/epoch produced it, if stored.
                overall_best_val_loss_attr = study.user_attrs.get("overall_best_val_loss", float('inf'))
                overall_best_trial_attr = study.user_attrs.get("overall_best_trial_number", "N/A")
                overall_best_epoch_attr = study.user_attrs.get("overall_best_epoch", "N/A")

                logger.info(f"Overall best model weights are expected in: {BEST_WEIGHTS_FILENAME}")
                if os.path.exists(BEST_WEIGHTS_FILENAME):
                    logger.info(f"  This model achieved a validation loss of: {overall_best_val_loss_attr:.4f} (recorded in study.user_attrs)")
                    logger.info(f"  It was saved from Trial: {overall_best_trial_attr}, Epoch: {overall_best_epoch_attr}")
                else:
                    logger.warning(f"  Expected overall best weights file {BEST_WEIGHTS_FILENAME} was NOT found. "
                                   "This might happen if no trial improved upon the initial 'inf' loss, "
                                   "or if there was an issue during saving.")
            else: # best_trial_overall_from_study_obj is None
                logger.warning("Study has completed trials, but study.best_trial is None. Cannot save parameters.")
        else: # No completed trials
            logger.warning("No trials completed successfully to determine the best trial. Cannot save parameters or confirm weights.")

        study_df = study.trials_dataframe(attrs=('number', 'value', 'params', 'state', 'user_attrs'))
        study_df.to_csv(f"{study_name}_results.csv", index=False)
        logger.info(f"Optuna study results saved to {study_name}_results.csv")

    except Exception as e:
        logger.error(f"Could not process or save Optuna study results: {e}", exc_info=True)


# --- Example: Predicting on Test Data using saved best model and params ---
# This part remains largely the same, as it expects BEST_PARAMS_FILENAME and BEST_WEIGHTS_FILENAME
if os.path.exists(TEST_DATA_FILE) and os.path.exists(BEST_PARAMS_FILENAME) and os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.info(f"\n--- Predicting on Test Data using overall best saved model and params ---")
    try:
        with open(BEST_PARAMS_FILENAME, 'r') as f:
            loaded_best_params = json.load(f)
        logger.info(f"Loaded best hyperparameters from {BEST_PARAMS_FILENAME}")

        test_model = PersonalityModelV3(
            bert_model_name=GLOBAL_CONFIG['BERT_MODEL_NAME'],
            num_traits=len(GLOBAL_CONFIG['TRAIT_NAMES']),
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            dropout_rate=loaded_best_params.get("dropout_rate", 0.2),
            attention_hidden_dim=loaded_best_params.get("attention_hidden_dim", 128),
            num_bert_layers_to_pool=loaded_best_params.get("num_bert_layers_to_pool", 2),
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            num_other_numerical_features=len(GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES']),
            numerical_embedding_dim=loaded_best_params.get("other_numerical_embedding_dim", 0) if GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'] else 0,
            num_additional_dense_layers=loaded_best_params.get("num_additional_dense_layers", 0),
            additional_dense_hidden_dim=loaded_best_params.get("additional_dense_hidden_dim", 256),
            additional_layers_dropout_rate=loaded_best_params.get("additional_layers_dropout_rate", 0.3)
        ).to(DEVICE)
        logger.info("Test model initialized with loaded best hyperparameters.")

        if torch.cuda.is_available():
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME)
        else:
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME, map_location=torch.device('cpu'))
        
        test_model.load_state_dict(loaded_state_dict)
        logger.info(f"Successfully loaded model weights from {BEST_WEIGHTS_FILENAME}")
        test_model.eval()

        NUM_TEST_SAMPLES = count_lines_in_file(TEST_DATA_FILE)
        if NUM_TEST_SAMPLES == 0:
             logger.warning(f"Test file {TEST_DATA_FILE} is empty or not found. No test predictions will be made.")
        else:
            test_dataset = JsonlIterableDataset(
                file_path=TEST_DATA_FILE,
                trait_names=GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'],
                n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
                other_numerical_feature_names=GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'],
                num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
                is_test_set=True,
                num_samples=NUM_TEST_SAMPLES
            )
            test_batch_size = loaded_best_params.get("batch_size", 16)
            test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)

            all_test_predictions = []
            with torch.no_grad():
                for batch_tuple in test_loader:
                    input_ids, attention_m, q_s, comment_active_m, other_num_feats = [b.to(DEVICE) for b in batch_tuple]
                    predicted_scores = test_model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                    all_test_predictions.append(predicted_scores.cpu().numpy())

            if all_test_predictions:
                final_test_predictions = np.concatenate(all_test_predictions, axis=0)
                logger.info(f"Shape of final test predictions: {final_test_predictions.shape}")
                for i in range(min(5, len(final_test_predictions))):
                    pred_dict = {trait: round(score.item(), 4) for trait, score in zip(GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'], final_test_predictions[i])}
                    logger.info(f"Test Sample Index {i} Predictions: {pred_dict}")
            else:
                logger.warning("No predictions generated for the test set (all_test_predictions list is empty).")
    
    except FileNotFoundError as e:
        logger.warning(f"Required file for test prediction not found: {e}. Skipping test prediction.")
    except Exception as e:
        logger.error(f"An error occurred during test prediction: {e}", exc_info=True)
elif not os.path.exists(TEST_DATA_FILE):
    logger.info(f"Test data file '{TEST_DATA_FILE}' not found. Skipping test prediction example.")
elif not os.path.exists(BEST_PARAMS_FILENAME) or not os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.warning(f"Best parameters file ({BEST_PARAMS_FILENAME}) or weights file ({BEST_WEIGHTS_FILENAME}) not found. Skipping test prediction.")

# save

In [None]:
import json
import torch
from torch.utils.data import IterableDataset
from transformers.tokenization_utils_base import BatchEncoding # For your decode_from_json
import logging
import random
import numpy as np
import torch.nn.functional as F
from transformers import BertModel, BertConfig, get_linear_schedule_with_warmup
from typing import Optional, Tuple, Dict, Union
from torch import nn
import optuna
from torch.utils.data import DataLoader
import gc
from transformers.tokenization_utils_base import BatchEncoding # For type checking and instantiation
import torch.optim as optim
import os
import shutil

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants for JSON (ensure these match what you used when saving) ---
_TENSOR_MARKER = "__tensor__"
_TENSOR_DTYPE_MARKER = "__tensor_dtype__"
_BATCH_ENCODING_MARKER = "__batch_encoding__"
_BATCH_ENCODING_DATA_MARKER = "data" # Make sure this matches what was saved

def _convert_str_to_dtype(dtype_str: str) -> torch.dtype:
    """Converts a string representation back to a torch.dtype."""
    if not dtype_str.startswith("torch."):
        try:
            return torch.__getattribute__(dtype_str) # e.g. "float32"
        except AttributeError:
            return torch.dtype(dtype_str) # Try direct parsing
    dtype_name = dtype_str.split('.')[1] # e.g., "torch.int64" -> "int64"
    return torch.__getattribute__(dtype_name)

def _json_object_hook_for_dataset(dct: dict) -> any:
    """
    Object hook for json.loads to reconstruct tensors and BatchEncoding objects.
    """
    if _TENSOR_MARKER in dct:
        dtype_str = dct.get(_TENSOR_DTYPE_MARKER, 'float32') # Default dtype
        dtype = _convert_str_to_dtype(dtype_str)
        return torch.tensor(dct[_BATCH_ENCODING_DATA_MARKER], dtype=dtype)
    elif _BATCH_ENCODING_MARKER in dct:
        reconstructed_data_for_be = {}
        batch_encoding_payload = dct.get(_BATCH_ENCODING_DATA_MARKER, {})
        for k, v_data in batch_encoding_payload.items():
            if isinstance(v_data, list) and k in ["input_ids", "token_type_ids", "attention_mask"]:
                try:
                    tensor_dtype = torch.long if k in ["input_ids", "token_type_ids"] else torch.long
                    reconstructed_data_for_be[k] = torch.tensor(v_data, dtype=tensor_dtype)
                except Exception as e:
                    logger.error(f"Error converting field '{k}' in BatchEncoding to tensor: {e}. Keeping as list.")
                    reconstructed_data_for_be[k] = v_data
            else:
                reconstructed_data_for_be[k] = v_data
        return BatchEncoding(reconstructed_data_for_be)
    return dct

class JsonlIterableDataset(IterableDataset):
    def __init__(self, file_path, trait_names, n_comments_to_process,
                 other_numerical_feature_names, num_q_features_per_comment,
                 is_test_set=False, transform_fn=None, num_samples = None):
        super().__init__()
        self.file_path = file_path
        self.trait_names_ordered = trait_names
        self.n_comments_to_process = n_comments_to_process
        self.other_numerical_feature_names = other_numerical_feature_names
        self.num_q_features_per_comment = num_q_features_per_comment
        self.is_test_set = is_test_set
        self.transform_fn = self._default_transform if transform_fn is None else transform_fn
        if num_samples is None:
            logger.info(f'Counting samples in {file_path} for __len__ was not provided...')
            self.num_samples = self._count_samples_in_file()
            logger.info(f"Counted {self.num_samples} samples in {self.file_path}.")
        else:
            self.num_samples = num_samples
        if self.num_samples == 0:
            logger.warning(f"Initialized JsonlIterableDataset for {self.file_path} with 0 samples. DataLoader will be empty.")
    
    def _count_samples_in_file(self):
            count = 0
            try:
                with open(self.file_path, 'r', encoding='utf-8') as f:
                    for _ in f:
                        count += 1
            except FileNotFoundError:
                logger.error(f"File not found during initial sample count: {self.file_path}. Returning 0 samples.")
                return 0
            except Exception as e:
                logger.error(f"Error during initial sample count for {self.file_path}: {e}. Returning 0 samples.")
                return 0
            return count
    
    def _process_line(self, line):
        try:
            sample = json.loads(line, object_hook=_json_object_hook_for_dataset)
            return self.transform_fn(sample, idx=None)
        except json.JSONDecodeError: # Removed e, line args as they weren't used
            return None
        except Exception: # Removed e_hook arg
            return None
        
    def __len__(self):
        return self.num_samples
    
    def _default_transform(self, sample, idx):
        tokenized_info = sample.get('features', {}).get('comments_tokenized', {})
        all_input_ids = tokenized_info['input_ids']
        all_attention_mask = tokenized_info['attention_mask']
        
        num_actual_comments = all_input_ids.shape[0]
        final_input_ids = torch.zeros((self.n_comments_to_process, all_input_ids.shape[1]), dtype=torch.long)
        final_attention_mask = torch.zeros((self.n_comments_to_process, all_attention_mask.shape[1]), dtype=torch.long)
        comment_active_flags = torch.zeros(self.n_comments_to_process, dtype=torch.bool)

        indices_to_select = list(range(num_actual_comments))
        if num_actual_comments > self.n_comments_to_process:
            indices_to_select = random.sample(indices_to_select, self.n_comments_to_process)
            comments_to_fill = self.n_comments_to_process
        else:
            comments_to_fill = num_actual_comments
        
        for i in range(comments_to_fill):
            original_idx = indices_to_select[i]
            final_input_ids[i] = all_input_ids[original_idx]
            final_attention_mask[i] = all_attention_mask[original_idx]
            comment_active_flags[i] = True

        raw_q_scores = sample['features'].get('q_scores', [])
        final_q_scores = torch.zeros((self.n_comments_to_process, self.num_q_features_per_comment), dtype=torch.float)
        
        selected_raw_q_scores = []
        for i in range(comments_to_fill):
            original_comment_idx = indices_to_select[i]
            if original_comment_idx < len(raw_q_scores):
                qs_for_comment = raw_q_scores[original_comment_idx][:self.num_q_features_per_comment]
                padded_qs = qs_for_comment + [0.0] * (self.num_q_features_per_comment - len(qs_for_comment))
                selected_raw_q_scores.append(padded_qs[:self.num_q_features_per_comment])
            else:
                selected_raw_q_scores.append([0.0] * self.num_q_features_per_comment)

        if comments_to_fill > 0 and selected_raw_q_scores: # ensure not empty before tensor conversion
            try:
                final_q_scores[:comments_to_fill] = torch.tensor(selected_raw_q_scores, dtype=torch.float)
            except Exception as e:
                logger.error(f"Error converting selected_raw_q_scores to tensor: {e}. Data: {selected_raw_q_scores}")
        
        other_numerical_features_list = []
        for fname in self.other_numerical_feature_names:
            val = sample['features'].get(fname, 0.0)
            try:
                other_numerical_features_list.append(float(val))
            except (ValueError, TypeError):
                other_numerical_features_list.append(0.0)
        other_numerical_features_tensor = torch.tensor(other_numerical_features_list, dtype=torch.float)

        if not self.is_test_set:
            labels_dict = sample['labels']
            regression_labels = []
            for trait_key in self.trait_names_ordered:
                label_val = labels_dict.get(trait_key.title(), labels_dict.get(trait_key, 0.0))
                try:
                    label_float = float(label_val)
                    if not (0.0 <= label_float <= 1.0): label_float = np.clip(label_float, 0.0, 1.0)
                    regression_labels.append(label_float)
                except (ValueError, TypeError): regression_labels.append(0.0)
            labels_tensor = torch.tensor(regression_labels, dtype=torch.float)
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor, labels_tensor)
        else:
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        try:
            file_iter = open(self.file_path, 'r', encoding='utf-8')
        except FileNotFoundError:
            logger.error(f"File not found in __iter__: {self.file_path}. Yielding nothing.")
            return # Stop iteration

        if worker_info is None:
            for line in file_iter:
                processed_item = self._process_line(line)
                if processed_item:
                    yield processed_item
        else:
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            for i, line in enumerate(file_iter):
                if i % num_workers == worker_id:
                    processed_item = self._process_line(line)
                    if processed_item:
                        yield processed_item
        file_iter.close()

# --- PersonalityModelV3 (MODIFIED for additional dense layers) ---
class PersonalityModelV3(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 num_traits: int,
                 n_comments_to_process: int = 3,
                 dropout_rate: float = 0.2, # Dropout for the final layer if no additional dense layers
                 attention_hidden_dim: int = 128,
                 num_bert_layers_to_pool: int = 4,
                 num_q_features_per_comment: int = 3,
                 num_other_numerical_features: int = 0,
                 numerical_embedding_dim: int = 64,
                 num_additional_dense_layers: int = 0,
                 additional_dense_hidden_dim: int = 256,
                 additional_layers_dropout_rate: float = 0.3
                ):
        super().__init__()
        self.bert_config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)
        self.bert = BertModel.from_pretrained(bert_model_name, config=self.bert_config)
        self.n_comments_to_process = n_comments_to_process
        self.num_bert_layers_to_pool = num_bert_layers_to_pool
        bert_hidden_size = self.bert.config.hidden_size
        self.num_q_features_per_comment = num_q_features_per_comment

        comment_feature_dim = bert_hidden_size + self.num_q_features_per_comment
        self.attention_w = nn.Linear(comment_feature_dim, attention_hidden_dim)
        self.attention_v = nn.Linear(attention_hidden_dim, 1, bias=False)
        
        # This dropout is used IF num_additional_dense_layers == 0
        self.final_dropout_layer = nn.Dropout(dropout_rate) 

        self.num_other_numerical_features = num_other_numerical_features
        self.uses_other_numerical_features = self.num_other_numerical_features > 0
        self.other_numerical_processor_output_dim = 0
        
        aggregated_comment_feature_dim = comment_feature_dim 
        combined_input_dim_for_block = aggregated_comment_feature_dim # Input to dense block or final dropout

        if self.uses_other_numerical_features:
            self.other_numerical_processor_output_dim = numerical_embedding_dim
            self.other_numerical_processor = nn.Sequential(
                nn.Linear(self.num_other_numerical_features, self.other_numerical_processor_output_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate) # Using main dropout_rate here, or could be another specific one
            )
            combined_input_dim_for_block += self.other_numerical_processor_output_dim
            logger.info(f"Model will use {self.num_other_numerical_features} other numerical features, processed to dim {self.other_numerical_processor_output_dim}.")
        else:
            logger.info("Model will NOT use other numerical features.")

        ### NEW: Additional Dense Layers Block ###
        self.num_additional_dense_layers = num_additional_dense_layers
        self.additional_dense_block = nn.Sequential()
        current_dim_for_dense_block = combined_input_dim_for_block

        if self.num_additional_dense_layers > 0:
            logger.info(f"Model using {self.num_additional_dense_layers} additional dense layers with hidden_dim {additional_dense_hidden_dim} and dropout {additional_layers_dropout_rate}")
            for i in range(self.num_additional_dense_layers):
                self.additional_dense_block.add_module(f"add_dense_{i}_linear", nn.Linear(current_dim_for_dense_block, additional_dense_hidden_dim))
                self.additional_dense_block.add_module(f"add_dense_{i}_relu", nn.ReLU())
                self.additional_dense_block.add_module(f"add_dense_{i}_dropout", nn.Dropout(additional_layers_dropout_rate))
                current_dim_for_dense_block = additional_dense_hidden_dim
            input_dim_for_regressors = current_dim_for_dense_block # Output of last additional layer
        else:
            logger.info("Model not using additional dense layers. Will use final_dropout_layer if dropout_rate > 0.")
            input_dim_for_regressors = combined_input_dim_for_block # Input directly to final_dropout_layer then heads

        self.trait_regressors = nn.ModuleList()
        for _ in range(num_traits):
            self.trait_regressors.append(
                nn.Linear(input_dim_for_regressors, 1)
            )

    def _pool_bert_layers(self, all_hidden_states: Tuple[torch.Tensor, ...], attention_mask: torch.Tensor) -> torch.Tensor:
        layers_to_pool = all_hidden_states[-self.num_bert_layers_to_pool:]
        pooled_outputs = []
        expanded_attention_mask = attention_mask.unsqueeze(-1).expand_as(layers_to_pool[0])
        
        for layer_hidden_states in layers_to_pool:
            sum_embeddings = torch.sum(layer_hidden_states * expanded_attention_mask, dim=1)
            sum_mask = expanded_attention_mask.sum(dim=1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled_outputs.append(sum_embeddings / sum_mask)
            
        stacked_pooled_outputs = torch.stack(pooled_outputs, dim=0)
        mean_pooled_layers_embedding = torch.mean(stacked_pooled_outputs, dim=0)
        return mean_pooled_layers_embedding

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                q_scores: torch.Tensor,
                comment_active_mask: torch.Tensor,
                other_numerical_features: Optional[torch.Tensor] = None
               ):
        batch_size = input_ids.shape[0]
        
        input_ids_flat = input_ids.view(-1, input_ids.shape[-1])
        attention_mask_flat = attention_mask.view(-1, attention_mask.shape[-1])
        
        bert_outputs = self.bert(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        comment_bert_embeddings_flat = self._pool_bert_layers(bert_outputs.hidden_states, attention_mask_flat)
        comment_bert_embeddings = comment_bert_embeddings_flat.view(batch_size, self.n_comments_to_process, -1)
        
        comment_features_with_q = torch.cat((comment_bert_embeddings, q_scores), dim=2)
        
        u = torch.tanh(self.attention_w(comment_features_with_q))
        scores = self.attention_v(u).squeeze(-1)
        
        if comment_active_mask is not None:
            scores = scores.masked_fill(~comment_active_mask, -1e9)
            
        attention_weights = F.softmax(scores, dim=1)
        attention_weights_expanded = attention_weights.unsqueeze(-1)
        
        aggregated_comment_features = torch.sum(attention_weights_expanded * comment_features_with_q, dim=1)

        final_features_for_processing = aggregated_comment_features # Input to dense block or final dropout
        if self.uses_other_numerical_features:
            if other_numerical_features is None or other_numerical_features.shape[1] != self.num_other_numerical_features:
                raise ValueError(
                    f"Other numerical features expected but not provided correctly. "
                    f"Expected {self.num_other_numerical_features}, got shape {other_numerical_features.shape if other_numerical_features is not None else 'None'}"
                )
            processed_other_numerical_features = self.other_numerical_processor(other_numerical_features)
            final_features_for_processing = torch.cat((aggregated_comment_features, processed_other_numerical_features), dim=1)
        
        ### MODIFIED: Feature processing before heads ###
        if self.num_additional_dense_layers > 0:
            # Pass through the additional dense block (which includes its own activations and dropouts)
            features_for_trait_heads = self.additional_dense_block(final_features_for_processing)
        else:
            # Apply the single final_dropout_layer if no additional dense block
            features_for_trait_heads = self.final_dropout_layer(final_features_for_processing)
        
        trait_regression_outputs = []
        for regressor_head in self.trait_regressors:
            trait_regression_outputs.append(regressor_head(features_for_trait_heads))
        
        all_trait_outputs_raw = torch.cat(trait_regression_outputs, dim=1)
        all_trait_outputs_sigmoid = torch.sigmoid(all_trait_outputs_raw)
        
        return all_trait_outputs_sigmoid

    def predict_scores(self, outputs: torch.Tensor) -> torch.Tensor:
        return outputs

# --- Optuna Objective Function (MODIFIED for new HPs) ---
def objective(trial: optuna.trial.Trial,
              train_file_path: str,
              val_file_path: str,
              global_config: Dict,
              device: torch.device,
              num_epochs_per_trial: int = 10):
    logger.info(f"Starting Optuna Trial {trial.number}")

    num_traits = len(global_config['TRAIT_NAMES'])
    other_numerical_feature_names_trial = global_config.get('OTHER_NUMERICAL_FEATURE_NAMES', [])
    num_other_numerical_features_trial = len(other_numerical_feature_names_trial)
    num_q_features_per_comment_trial = global_config.get('NUM_Q_FEATURES_PER_COMMENT', 3)

    # --- Suggest Hyperparameters ---
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5) # For final_dropout_layer if no dense block
    attention_hidden_dim = trial.suggest_categorical("attention_hidden_dim", [128, 256, 512])
    lr_bert = trial.suggest_float("lr_bert", 5e-6, 1e-4, log=True)
    lr_head = trial.suggest_float("lr_head", 1e-4, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    num_bert_layers_to_pool = trial.suggest_int("num_bert_layers_to_pool", 1, 4)
    n_comments_trial = trial.suggest_int("n_comments_to_process", 1, global_config.get('MAX_COMMENTS_TO_PROCESS_PHYSICAL', 3))
    num_unfrozen_bert_layers = trial.suggest_int("num_unfrozen_bert_layers", 0, 6)
    patience_early_stopping = trial.suggest_int("patience_early_stopping", 3, 5)
    scheduler_type = trial.suggest_categorical("scheduler_type", ["none", "linear_warmup"])
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.05, 0.2) if scheduler_type != "none" else 0.0
    batch_size_trial = trial.suggest_categorical("batch_size", [8, 16, 32]) # Added 32

    other_numerical_embedding_dim_trial = 0
    if num_other_numerical_features_trial > 0:
        other_numerical_embedding_dim_trial = trial.suggest_categorical("other_numerical_embedding_dim", [32, 64, 128]) # Added 128

    ### NEW Hyperparameters for additional dense layers ###
    num_additional_dense_layers_trial = trial.suggest_int("num_additional_dense_layers", 0, 3) # 0, 1, 2, or 3 layers
    
    additional_dense_hidden_dim_trial = 0
    additional_layers_dropout_rate_trial = 0.0 # Default if no additional layers
    if num_additional_dense_layers_trial > 0:
        additional_dense_hidden_dim_trial = trial.suggest_categorical("additional_dense_hidden_dim", [128, 256, 512])
        additional_layers_dropout_rate_trial = trial.suggest_float("additional_layers_dropout_rate", 0.1, 0.5)

    logger.info(f"Trial {trial.number} - Suggested Parameters: {trial.params}")
    try:
        train_dataset_trial = JsonlIterableDataset(
            file_path=train_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_TRAIN_SAMPLES')
        )
        val_dataset_trial = JsonlIterableDataset(
            file_path=val_file_path,
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_VAL_SAMPLES')
        )
        train_loader_trial = DataLoader(train_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
        val_loader_trial = DataLoader(val_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
    except Exception as e:
        logger.error(f"Trial {trial.number} - Error creating dataset/dataloader: {e}", exc_info=True)
        return float('inf')

    model = PersonalityModelV3(
        bert_model_name=global_config['BERT_MODEL_NAME'],
        num_traits=num_traits,
        n_comments_to_process=n_comments_trial,
        dropout_rate=dropout_rate, # For final_dropout_layer
        attention_hidden_dim=attention_hidden_dim,
        num_bert_layers_to_pool=num_bert_layers_to_pool,
        num_q_features_per_comment=num_q_features_per_comment_trial,
        num_other_numerical_features=num_other_numerical_features_trial,
        numerical_embedding_dim=other_numerical_embedding_dim_trial,
        ### NEW arguments for model ###
        num_additional_dense_layers=num_additional_dense_layers_trial,
        additional_dense_hidden_dim=additional_dense_hidden_dim_trial,
        additional_layers_dropout_rate=additional_layers_dropout_rate_trial
    ).to(device)

    # BERT Layer Freezing
    for name, param in model.bert.named_parameters(): param.requires_grad = False
    if num_unfrozen_bert_layers > 0:
        if hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters(): param.requires_grad = True
        
        actual_layers_to_unfreeze = min(num_unfrozen_bert_layers, model.bert.config.num_hidden_layers)
        for i in range(model.bert.config.num_hidden_layers - actual_layers_to_unfreeze, model.bert.config.num_hidden_layers):
            if i >= 0 and i < model.bert.config.num_hidden_layers : # check index bounds
                for param in model.bert.encoder.layer[i].parameters(): param.requires_grad = True
        
        if hasattr(model.bert, 'pooler') and model.bert.pooler is not None:
            for param in model.bert.pooler.parameters(): param.requires_grad = True
    
    logger.debug(f"Trial {trial.number} - BERT params requiring grad: "
                 f"{sum(p.numel() for p in model.bert.parameters() if p.requires_grad)}")

    optimizer_grouped_parameters = []
    bert_params_to_tune = [p for p in model.bert.parameters() if p.requires_grad]
    if bert_params_to_tune and lr_bert > 0:
         optimizer_grouped_parameters.append({"params": bert_params_to_tune, "lr": lr_bert, "weight_decay": 0.01})

    head_params = []
    head_params.extend(list(model.attention_w.parameters()))
    head_params.extend(list(model.attention_v.parameters()))
    if model.uses_other_numerical_features:
        head_params.extend(list(model.other_numerical_processor.parameters()))
    if model.num_additional_dense_layers > 0:
        head_params.extend(list(model.additional_dense_block.parameters()))
    else: # if no additional dense layers, the final_dropout_layer is part of the "head"
        # final_dropout_layer has no learnable parameters if it's just nn.Dropout
        pass 
    for regressor_head in model.trait_regressors:
        head_params.extend(list(regressor_head.parameters()))
    
    if head_params: # only add if there are head parameters
        optimizer_grouped_parameters.append({"params": head_params, "lr": lr_head, "weight_decay": weight_decay})
        
    if not any(pg['params'] for pg in optimizer_grouped_parameters if pg.get('params')):
        logger.warning(f"Trial {trial.number} - No parameters to optimize. Skipping training.")
        return float('inf')

    optimizer = optim.AdamW(optimizer_grouped_parameters)
    
    scheduler = None
    if scheduler_type == "linear_warmup":
        if global_config.get('NUM_TRAIN_SAMPLES', 0) > 0:
            num_batches_per_epoch = (global_config['NUM_TRAIN_SAMPLES'] + batch_size_trial - 1) // batch_size_trial
            num_training_steps = num_batches_per_epoch * num_epochs_per_trial
            num_warmup_steps = int(num_training_steps * warmup_ratio)
            if num_warmup_steps > 0 and num_training_steps > 0:
                scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
            else:
                logger.warning(f"Trial {trial.number}: Calculated num_warmup_steps or num_training_steps is zero. Scheduler not created. Warmup: {num_warmup_steps}, Training: {num_training_steps}")
        else:
            logger.warning(f"Trial {trial.number}: NUM_TRAIN_SAMPLES not available or zero in global_config. Cannot create linear_warmup scheduler.")

    loss_fn = nn.MSELoss().to(device)
    best_trial_val_loss = float('inf')
    patience_counter = 0
    
    # Directory for saving this trial's best model (will be cleaned up if not overall best)
    temp_model_dir = "optuna_trial_models" 
    os.makedirs(temp_model_dir, exist_ok=True)
            
    for epoch in range(num_epochs_per_trial):
        model.train()
        total_train_loss = 0
        train_batches_processed = 0
        for batch_idx, batch_tuple in enumerate(train_loader_trial):
            input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
            
            optimizer.zero_grad()
            predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
            current_batch_loss = loss_fn(predicted_scores, labels_reg)
            
            if torch.isnan(current_batch_loss) or torch.isinf(current_batch_loss):
                logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: NaN or Inf loss detected. Skipping batch.")
                torch.cuda.empty_cache()
                continue

            current_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler: scheduler.step()
            total_train_loss += current_batch_loss.item()
            train_batches_processed += 1
            
        avg_train_loss = total_train_loss / train_batches_processed if train_batches_processed > 0 else float('inf')
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}/{num_epochs_per_trial} completed. Avg Train Loss: {avg_train_loss:.4f}")

        model.eval()
        current_epoch_val_loss = 0
        val_batches_processed = 0
        all_val_preds_epoch = []
        all_val_labels_epoch = []
        with torch.no_grad():
            for batch_tuple in val_loader_trial:
                input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
                if input_ids.numel() == 0: continue
                
                predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                if predicted_scores.numel() == 0: continue
                
                batch_val_loss = loss_fn(predicted_scores, labels_reg)
                current_epoch_val_loss += batch_val_loss.item()
                all_val_preds_epoch.append(predicted_scores.cpu())
                all_val_labels_epoch.append(labels_reg.cpu())
                val_batches_processed += 1

        avg_val_loss_epoch = current_epoch_val_loss / val_batches_processed if val_batches_processed > 0 else float('inf')
        
        val_mae = -1.0
        if all_val_labels_epoch and all_val_preds_epoch: # ensure both lists are non-empty
            all_val_labels_cat = torch.cat(all_val_labels_epoch, dim=0)
            all_val_preds_cat = torch.cat(all_val_preds_epoch, dim=0)
            if all_val_labels_cat.numel() > 0 and all_val_preds_cat.numel() > 0: # ensure tensors are not empty
                val_mae = F.l1_loss(all_val_preds_cat, all_val_labels_cat).item()

        logger.info(f"Trial {trial.number}, Epoch {epoch+1} Val Loss (MSE): {avg_val_loss_epoch:.4f}, Val MAE: {val_mae:.4f}")

        if avg_val_loss_epoch < best_trial_val_loss:
            best_trial_val_loss = avg_val_loss_epoch
            patience_counter = 0
            logger.debug(f"Trial {trial.number}, Epoch {epoch+1}: New best val_loss for this trial: {best_trial_val_loss:.4f}")
            
            # Define a unique path for this trial's best model for THIS EPOCH
            # This file will be overwritten if a later epoch in the same trial is better
            temp_model_path_for_this_trial = os.path.join(temp_model_dir, f"trial_{trial.number}_best_model.pth")
            
            current_best_state_dict_for_trial = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save(current_best_state_dict_for_trial, temp_model_path_for_this_trial)
            logger.info(f"Trial {trial.number}: Saved new best model FOR THIS TRIAL to {temp_model_path_for_this_trial}")
            
            # Store the path to this saved model in user_attrs for this trial
            trial.set_user_attr("best_model_path_this_trial", temp_model_path_for_this_trial)
        
        else:
            patience_counter += 1
        
        trial.report(avg_val_loss_epoch, epoch)
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned by Optuna at epoch {epoch+1}.")
            del model, train_loader_trial, val_loader_trial, optimizer, scheduler
            torch.cuda.empty_cache(); gc.collect()
            # Even if pruned, return the best loss *achieved so far by this trial*
            return best_trial_val_loss 
        
        if patience_counter >= patience_early_stopping:
            logger.info(f"Trial {trial.number} - Early stopping at epoch {epoch+1} (Patience: {patience_early_stopping}).")
            break
        
    logger.info(f"Trial {trial.number} finished. Best Val Loss (MSE) for this trial: {best_trial_val_loss:.4f}")
    del model, train_loader_trial, val_loader_trial, optimizer, scheduler
    torch.cuda.empty_cache(); gc.collect()
    return best_trial_val_loss

In [None]:
# Assuming PersonalityDatasetV3, PersonalityModelV3, decode_from_json are defined/imported
# from your_module import PersonalityDatasetV3, PersonalityModelV3, decode_from_json

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

# --- Data File Paths ---
TRAIN_DATA_FILE = "train_data.jsonl" 
VAL_DATA_FILE = "val_data.jsonl"
TEST_DATA_FILE = "test_data.jsonl"

# --- Global Configuration ---
_trait_names_ordered_config = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Emotional stability', 'Humility']
_other_numerical_features_config = [
    'mean_words_per_comment', 'mean_sents_per_comment',
    'median_words_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
]

GLOBAL_CONFIG = {
    'BERT_MODEL_NAME': "bert-base-uncased",
    'TRAIT_NAMES_ORDERED': _trait_names_ordered_config,
    'TRAIT_NAMES': _trait_names_ordered_config, # Redundant but kept for consistency if used elsewhere
    'MAX_COMMENTS_TO_PROCESS_PHYSICAL': 3, # Max physical comments data might have / you allow
    'NUM_Q_FEATURES_PER_COMMENT': 3,
    'OTHER_NUMERICAL_FEATURE_NAMES': _other_numerical_features_config,
    'TOKENIZER_MAX_LENGTH': 256 # This is not directly used in the provided model code, but good to have
}

NUM_EPOCHS_PER_TRIAL_OPTUNA = 15 # Adjust as needed
N_OPTUNA_TRIALS = 20             # Adjust as needed

def count_lines_in_file(filepath):
    try:
        count = 0
        with open(filepath, 'r', encoding='utf-8') as f:
            for _ in f:
                count += 1
        return count
    except FileNotFoundError:
        logger.error(f"File not found for line counting: {filepath}. Returning 0.")
        return 0
    except Exception as e:
        logger.error(f"Error counting lines in {filepath}: {e}. Returning 0.")
        return 0

# Pre-count samples for scheduler and dataset __len__
NUM_TRAIN_SAMPLES = count_lines_in_file(TRAIN_DATA_FILE)
if NUM_TRAIN_SAMPLES == 0:
    logger.error(f"Training file {TRAIN_DATA_FILE} is empty or not found. Exiting.")
    exit()
GLOBAL_CONFIG['NUM_TRAIN_SAMPLES'] = NUM_TRAIN_SAMPLES
logger.info(f"Number of training samples: {NUM_TRAIN_SAMPLES}")

NUM_VAL_SAMPLES = count_lines_in_file(VAL_DATA_FILE)
if NUM_VAL_SAMPLES == 0:
    logger.warning(f"Validation file {VAL_DATA_FILE} is empty or not found. Validation might not work as expected.")
GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = NUM_VAL_SAMPLES
logger.info(f"Number of validation samples: {NUM_VAL_SAMPLES}")


# START STUDY
logger.info(f"Starting Optuna study: {N_OPTUNA_TRIALS} trials, up to {NUM_EPOCHS_PER_TRIAL_OPTUNA} epochs/trial.")

study_name = "personality_regression_v5_more_layers" # Updated name
storage_name = f"sqlite:///{study_name}.db"
BEST_PARAMS_FILENAME = f"{study_name}_best_params.json"
BEST_WEIGHTS_FILENAME = f"{study_name}_best_weights.pth"
TEMP_MODEL_DIR = "optuna_trial_models" # Directory where trial-specific models are saved

study = optuna.create_study(study_name=study_name,
                            direction="minimize",
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=3, n_min_trials=5, interval_steps=1),
                            storage=storage_name,
                            load_if_exists=True)
if study.trials: logger.info(f"Resuming existing study {study.study_name} with {len(study.trials)} previous trials.")

try:
    study.optimize(
        lambda trial: objective(
            trial, TRAIN_DATA_FILE, VAL_DATA_FILE,
            GLOBAL_CONFIG, DEVICE, num_epochs_per_trial=NUM_EPOCHS_PER_TRIAL_OPTUNA
        ),
        n_trials=N_OPTUNA_TRIALS,
        gc_after_trial=True,
    )
except Exception as e:
    logger.exception("An error occurred during the Optuna study.")

logger.info("\n--- Optuna Study Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")

best_trial_overall = None 

if not study.trials:
    logger.warning("No trials were completed in the study.")
else:
    try:
        completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE and t.value is not None]
        if completed_trials:
            best_trial_overall = study.best_trial

            if best_trial_overall:
                logger.info(f"Overall Best Trial Number: {best_trial_overall.number}")
                logger.info(f"  Value (Validation Loss - MSE): {best_trial_overall.value:.4f}")
                logger.info("  Best Params: ")
                for key, value in best_trial_overall.params.items():
                    logger.info(f"    {key}: {value}")

                with open(BEST_PARAMS_FILENAME, 'w') as f:
                    json.dump(best_trial_overall.params, f, indent=4)
                logger.info(f"Best hyperparameters saved to {BEST_PARAMS_FILENAME}")

                if "best_model_path_this_trial" in best_trial_overall.user_attrs:
                    path_to_best_model_from_trial = best_trial_overall.user_attrs["best_model_path_this_trial"]
                    
                    if path_to_best_model_from_trial and os.path.exists(path_to_best_model_from_trial):
                        try:
                            shutil.copyfile(path_to_best_model_from_trial, BEST_WEIGHTS_FILENAME)
                            logger.info(f"Best model weights from trial {best_trial_overall.number} (path: {path_to_best_model_from_trial}) copied to {BEST_WEIGHTS_FILENAME}")
                            
                            # Optional: Clean up the temporary model directory *after* successful copy
                            # Be cautious with this if multiple studies might use the same temp dir.
                            # For simplicity, not adding automatic cleanup here, but you can add:
                            # if os.path.exists(TEMP_MODEL_DIR):
                            #     logger.info(f"Cleaning up temporary model directory: {TEMP_MODEL_DIR}")
                            #     shutil.rmtree(TEMP_MODEL_DIR)

                        except Exception as e:
                            logger.error(f"Error copying best model weights from {path_to_best_model_from_trial} to {BEST_WEIGHTS_FILENAME}: {e}")
                    elif not path_to_best_model_from_trial:
                         logger.warning(f"Overall best trial {best_trial_overall.number} has 'best_model_path_this_trial' but its value is None. Weights not saved.")
                    else:
                        logger.warning(f"Model file '{path_to_best_model_from_trial}' from best trial {best_trial_overall.number} not found. Weights not saved.")
                else:
                    logger.warning(f"Key 'best_model_path_this_trial' not found in user_attrs of best trial {best_trial_overall.number}. Weights not saved.")
            else:
                logger.warning("Study has completed trials, but study.best_trial is None.")
        else:
            logger.warning("No trials completed successfully to determine the best trial.")

        study_df = study.trials_dataframe(attrs=('number', 'value', 'params', 'state', 'user_attrs')) # include user_attrs
        study_df.to_csv(f"{study_name}_results.csv", index=False)
        logger.info(f"Optuna study results saved to {study_name}_results.csv")

    except Exception as e:
        logger.error(f"Could not process or save Optuna study results: {e}", exc_info=True)


# --- Example: Predicting on Test Data using saved best model and params ---
if os.path.exists(TEST_DATA_FILE) and os.path.exists(BEST_PARAMS_FILENAME) and os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.info(f"\n--- Predicting on Test Data using overall best saved model ---")
    try:
        with open(BEST_PARAMS_FILENAME, 'r') as f:
            loaded_best_params = json.load(f)
        logger.info(f"Loaded best hyperparameters from {BEST_PARAMS_FILENAME}")

        test_model = PersonalityModelV3(
            bert_model_name=GLOBAL_CONFIG['BERT_MODEL_NAME'],
            num_traits=len(GLOBAL_CONFIG['TRAIT_NAMES']),
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            dropout_rate=loaded_best_params.get("dropout_rate", 0.2),
            attention_hidden_dim=loaded_best_params.get("attention_hidden_dim", 128),
            num_bert_layers_to_pool=loaded_best_params.get("num_bert_layers_to_pool", 2),
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            num_other_numerical_features=len(GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES']),
            numerical_embedding_dim=loaded_best_params.get("other_numerical_embedding_dim", 0) if GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'] else 0,
            ### MODIFIED: Pass new HPs with defaults for test model ###
            num_additional_dense_layers=loaded_best_params.get("num_additional_dense_layers", 0),
            additional_dense_hidden_dim=loaded_best_params.get("additional_dense_hidden_dim", 256), # Default if not in params
            additional_layers_dropout_rate=loaded_best_params.get("additional_layers_dropout_rate", 0.3) # Default if not in params
        ).to(DEVICE)
        logger.info("Test model initialized with loaded best hyperparameters.")

        if torch.cuda.is_available():
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME)
        else:
            loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME, map_location=torch.device('cpu'))
        
        test_model.load_state_dict(loaded_state_dict)
        logger.info(f"Successfully loaded model weights from {BEST_WEIGHTS_FILENAME}")
        test_model.eval()

        # Count test samples if not already done
        NUM_TEST_SAMPLES = count_lines_in_file(TEST_DATA_FILE)
        if NUM_TEST_SAMPLES == 0:
             logger.warning(f"Test file {TEST_DATA_FILE} is empty or not found. No test predictions will be made.")
        else:
            test_dataset = JsonlIterableDataset(
                file_path=TEST_DATA_FILE, ### MODIFIED: param name from data to file_path ###
                trait_names=GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'],
                n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
                other_numerical_feature_names=GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'],
                num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
                is_test_set=True,
                num_samples=NUM_TEST_SAMPLES # Provide pre-counted samples
            )
            # Use batch_size from loaded_best_params, or a default like 8 or 16
            test_batch_size = loaded_best_params.get("batch_size", 16)
            test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)

            all_test_predictions = []
            with torch.no_grad():
                for batch_tuple in test_loader:
                    input_ids, attention_m, q_s, comment_active_m, other_num_feats = [b.to(DEVICE) for b in batch_tuple]
                    predicted_scores = test_model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                    all_test_predictions.append(predicted_scores.cpu().numpy())

            if all_test_predictions:
                final_test_predictions = np.concatenate(all_test_predictions, axis=0)
                logger.info(f"Shape of final test predictions: {final_test_predictions.shape}")
                # Log first few predictions
                for i in range(min(5, len(final_test_predictions))):
                    pred_dict = {trait: round(score.item(), 4) for trait, score in zip(GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'], final_test_predictions[i])}
                    logger.info(f"Test Sample Index {i} Predictions: {pred_dict}")
                # np.save(f"{study_name}_test_predictions.npy", final_test_predictions) # Optional: save predictions
                # logger.info(f"Test predictions saved to {study_name}_test_predictions.npy")
            else:
                logger.warning("No predictions generated for the test set (all_test_predictions list is empty).")
    
    except FileNotFoundError as e:
        logger.warning(f"Required file for test prediction not found: {e}. Skipping test prediction.")
    except Exception as e:
        logger.error(f"An error occurred during test prediction: {e}", exc_info=True)
elif not os.path.exists(TEST_DATA_FILE):
    logger.info(f"Test data file '{TEST_DATA_FILE}' not found. Skipping test prediction example.")
elif not os.path.exists(BEST_PARAMS_FILENAME) or not os.path.exists(BEST_WEIGHTS_FILENAME):
    logger.warning(f"Best parameters file ({BEST_PARAMS_FILENAME}) or weights file ({BEST_WEIGHTS_FILENAME}) not found. Skipping test prediction.")

# temp


In [None]:
import json
import torch
from torch.utils.data import IterableDataset
from transformers.tokenization_utils_base import BatchEncoding # For your decode_from_json
import logging
import random
import numpy as np
import torch.nn.functional as F
from transformers import BertModel, BertConfig, get_linear_schedule_with_warmup
from typing import Optional, Tuple, Dict, Union
from torch import nn
import optuna
from torch.utils.data import DataLoader
import gc
from transformers.tokenization_utils_base import BatchEncoding # For type checking and instantiation
import torch.optim as optim
import os
import shutil

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants for JSON (ensure these match what you used when saving) ---
_TENSOR_MARKER = "__tensor__"
_TENSOR_DTYPE_MARKER = "__tensor_dtype__"
_BATCH_ENCODING_MARKER = "__batch_encoding__"
_BATCH_ENCODING_DATA_MARKER = "data" # Make sure this matches what was saved

def _convert_str_to_dtype(dtype_str: str) -> torch.dtype:
    """Converts a string representation back to a torch.dtype."""
    if not dtype_str.startswith("torch."):
        try:
            return torch.__getattribute__(dtype_str) # e.g. "float32"
        except AttributeError:
            return torch.dtype(dtype_str) # Try direct parsing
    dtype_name = dtype_str.split('.')[1] # e.g., "torch.int64" -> "int64"
    return torch.__getattribute__(dtype_name)

def _json_object_hook_for_dataset(dct: dict) -> any:
    """
    Object hook for json.loads to reconstruct tensors and BatchEncoding objects.
    """
    if _TENSOR_MARKER in dct:
        dtype_str = dct.get(_TENSOR_DTYPE_MARKER, 'float32') # Default dtype
        dtype = _convert_str_to_dtype(dtype_str)
        # Data from tensor.tolist() is a list of lists (or list for 1D)
        return torch.tensor(dct[_BATCH_ENCODING_DATA_MARKER], dtype=dtype)
    elif _BATCH_ENCODING_MARKER in dct:
        # The 'data' part of BatchEncoding should be a dictionary.
        # Its values (like input_ids) should have been converted to tensors
        # by this hook if they were marked as tensors.
        reconstructed_data_for_be = {}
        batch_encoding_payload = dct.get(_BATCH_ENCODING_DATA_MARKER, {})
        for k, v_data in batch_encoding_payload.items():
            # If v_data is a list (e.g., input_ids was list of lists from tolist())
            # and wasn't explicitly marked as a __tensor__ itself, convert it now.
            # This typically happens if the BatchEncoding's internal tensors were directly converted to lists.
            if isinstance(v_data, list) and k in ["input_ids", "token_type_ids", "attention_mask"]:
                try:
                    # Determine dtype (input_ids, token_type_ids are usually long)
                    tensor_dtype = torch.long if k in ["input_ids", "token_type_ids"] else torch.long # attention_mask can be long or bool
                    reconstructed_data_for_be[k] = torch.tensor(v_data, dtype=tensor_dtype)
                except Exception as e:
                    logger.error(f"Error converting field '{k}' in BatchEncoding to tensor: {e}. Keeping as list.")
                    reconstructed_data_for_be[k] = v_data # Fallback
            else:
                reconstructed_data_for_be[k] = v_data # Already a tensor or primitive
        return BatchEncoding(reconstructed_data_for_be)
    return dct

class JsonlIterableDataset(IterableDataset):
    def __init__(self, file_path, trait_names, n_comments_to_process,
                 other_numerical_feature_names, num_q_features_per_comment,
                 is_test_set=False, transform_fn=None, num_samples = None):
        super().__init__()
        self.file_path = file_path
        self.trait_names_ordered = trait_names
        self.n_comments_to_process = n_comments_to_process
        self.other_numerical_feature_names = other_numerical_feature_names
        self.num_q_features_per_comment = num_q_features_per_comment
        self.is_test_set = is_test_set
        # transform_fn is what PersonalityDatasetV3.__getitem__ does
        self.transform_fn = self._default_transform if transform_fn is None else transform_fn
        if num_samples is None:
            logger.info(f'Counting samples in {file_path} for __len__ was not provided...')
            self.num_samples = self._count_samples_in_file()
            logger.info(f"Counted {self.num_samples} samples in {self.file_path}.")
        else:
            self.num_samples = num_samples
        if self.num_samples == 0:
            logger.warning(f"Initialized JsonlIterableDataset for {self.file_path} with 0 samples. DataLoader will be empty.")
    


    
    def _count_samples_in_file(self):
            count = 0
            try:
                with open(self.file_path, 'r', encoding='utf-8') as f:
                    for _ in f:
                        count += 1
            except FileNotFoundError:
                logger.error(f"File not found during initial sample count: {self.file_path}. Returning 0 samples.")
                return 0
            except Exception as e:
                logger.error(f"Error during initial sample count for {self.file_path}: {e}. Returning 0 samples.")
                return 0
            return count
    



    def _process_line(self, line):
        try:
            # Apply the hook to each JSON object (line)
            sample = json.loads(line, object_hook=_json_object_hook_for_dataset)
            return self.transform_fn(sample, idx=None) # idx is not really used if sample has all info
        except json.JSONDecodeError as e:
            # logger.error(f"Error decoding JSON in {self.file_path}: {e} on line: {line[:100]}")
            return None
        except Exception as e_hook:
            # logger.error(f"Error in object_hook or transform_fn in {self.file_path}: {e_hook}")
            return None
        

    def __len__(self):
        return self.num_samples
    


    def _default_transform(self, sample, idx): # Replicates PersonalityDatasetV3.__getitem__ logic
        # --- Start of PersonalityDatasetV3.__getitem__ logic ---
        tokenized_info = sample.get('features', {}).get('comments_tokenized', {})
        all_input_ids = tokenized_info['input_ids']
        all_attention_mask = tokenized_info['attention_mask']
        
        
        num_actual_comments = all_input_ids.shape[0]
        # more robust seq_len 

        final_input_ids = torch.zeros((self.n_comments_to_process, all_input_ids.shape[1]), dtype=torch.long)
        final_attention_mask = torch.zeros((self.n_comments_to_process, all_attention_mask.shape[1]), dtype=torch.long)
        comment_active_flags = torch.zeros(self.n_comments_to_process, dtype=torch.bool)

        indices_to_select = list(range(num_actual_comments))
        if num_actual_comments > self.n_comments_to_process:
            indices_to_select = random.sample(indices_to_select, self.n_comments_to_process)
            comments_to_fill = self.n_comments_to_process
        else:
            comments_to_fill = num_actual_comments
        
        for i in range(comments_to_fill):
            original_idx = indices_to_select[i]
            final_input_ids[i] = all_input_ids[original_idx]
            final_attention_mask[i] = all_attention_mask[original_idx]
            comment_active_flags[i] = True

        raw_q_scores = sample['features'].get('q_scores', [])
        final_q_scores = torch.zeros((self.n_comments_to_process, self.num_q_features_per_comment), dtype=torch.float)
        num_actual_q_score_sets = len(raw_q_scores)
        q_scores_to_fill = min(num_actual_q_score_sets, self.n_comments_to_process)




        selected_raw_q_scores = []
        for i in range(comments_to_fill): # Iterate up to comments_to_fill
            original_comment_idx = indices_to_select[i]
            if original_comment_idx < len(raw_q_scores):
                qs_for_comment = raw_q_scores[original_comment_idx][:self.num_q_features_per_comment]
                # Pad if necessary
                padded_qs = qs_for_comment + [0.0] * (self.num_q_features_per_comment - len(qs_for_comment))
                selected_raw_q_scores.append(padded_qs[:self.num_q_features_per_comment])
            else:
                selected_raw_q_scores.append([0.0] * self.num_q_features_per_comment)

        if comments_to_fill > 0:
            try:
                final_q_scores[:comments_to_fill] = torch.tensor(selected_raw_q_scores, dtype=torch.float)
            except Exception as e: # Catch error if selected_raw_q_scores is ragged or non-numeric
                logger.error(f"Error converting selected_raw_q_scores to tensor: {e}. Data: {selected_raw_q_scores}")
                # final_q_scores will remain zeros for this batch
        # else: final_q_scores remains zeros.






        other_numerical_features_list = []
        for fname in self.other_numerical_feature_names:
            val = sample['features'].get(fname, 0.0)
            try:
                other_numerical_features_list.append(float(val))
            except (ValueError, TypeError):
                other_numerical_features_list.append(0.0)
        other_numerical_features_tensor = torch.tensor(other_numerical_features_list, dtype=torch.float)

        if not self.is_test_set:
            labels_dict = sample['labels']
            regression_labels = []
            for trait_key in self.trait_names_ordered:
                label_val = labels_dict.get(trait_key.title(), labels_dict.get(trait_key, 0.0))
                try:
                    label_float = float(label_val)
                    if not (0.0 <= label_float <= 1.0): label_float = np.clip(label_float, 0.0, 1.0)
                    regression_labels.append(label_float)
                except (ValueError, TypeError): regression_labels.append(0.0)
            labels_tensor = torch.tensor(regression_labels, dtype=torch.float)
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor, labels_tensor)
        else:
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor)
        # --- End of PersonalityDatasetV3.__getitem__ logic ---

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        file_iter = open(self.file_path, 'r', encoding='utf-8')

        if worker_info is None:  # single-process data loading
            for line in file_iter:
                processed_item = self._process_line(line)
                if processed_item:
                    yield processed_item
        else:  # multi-process data loading
            # Each worker processes a different part of the file (approximate)
            # This is a simplified way; for exact splitting, one might pre-calculate line offsets.
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            for i, line in enumerate(file_iter):
                if i % num_workers == worker_id:
                    processed_item = self._process_line(line)
                    if processed_item:
                        yield processed_item
        file_iter.close()


# --- Regression Loss Function (NEW) ---
# We'll use nn.MSELoss directly in the training loop.

# --- PersonalityModelV3 (Regression and q_scores integration) ---
class PersonalityModelV3(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 num_traits: int,
                 n_comments_to_process: int = 3,
                 dropout_rate: float = 0.2,
                 attention_hidden_dim: int = 128,
                 num_bert_layers_to_pool: int = 4,
                 num_q_features_per_comment: int = 3, # For Q1, Q2, Q3 scores per comment
                 num_other_numerical_features: int = 0, # From sample['features'] excluding q_scores
                 numerical_embedding_dim: int = 64
                ):
        super().__init__()
        self.bert_config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)
        self.bert = BertModel.from_pretrained(bert_model_name, config=self.bert_config)
        self.n_comments_to_process = n_comments_to_process
        self.num_bert_layers_to_pool = num_bert_layers_to_pool
        bert_hidden_size = self.bert.config.hidden_size
        self.num_q_features_per_comment = num_q_features_per_comment

        # Comment processing part (BERT embedding + q_scores)
        comment_feature_dim = bert_hidden_size + self.num_q_features_per_comment
        self.attention_w = nn.Linear(comment_feature_dim, attention_hidden_dim)
        self.attention_v = nn.Linear(attention_hidden_dim, 1, bias=False)
        
        self.dropout = nn.Dropout(dropout_rate)

        # Other numerical features processing part (from sample['features'])
        self.num_other_numerical_features = num_other_numerical_features
        self.uses_other_numerical_features = self.num_other_numerical_features > 0
        self.other_numerical_processor_output_dim = 0

        # Dimension of aggregated comment features (output of attention over comment_feature_dim)
        aggregated_comment_feature_dim = comment_feature_dim 
        combined_input_dim_for_heads = aggregated_comment_feature_dim

        if self.uses_other_numerical_features:
            self.other_numerical_processor_output_dim = numerical_embedding_dim
            self.other_numerical_processor = nn.Sequential(
                nn.Linear(self.num_other_numerical_features, self.other_numerical_processor_output_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )
            combined_input_dim_for_heads += self.other_numerical_processor_output_dim
            logger.info(f"Model will use {self.num_other_numerical_features} other numerical features, processed to dim {self.other_numerical_processor_output_dim}.")
        else:
            logger.info("Model will NOT use other numerical features.")

        # Trait regression heads
        self.trait_regressors = nn.ModuleList()
        for _ in range(num_traits):
            self.trait_regressors.append(
                nn.Linear(combined_input_dim_for_heads, 1) # Output one value per trait
            )

    def _pool_bert_layers(self, all_hidden_states: Tuple[torch.Tensor, ...], attention_mask: torch.Tensor) -> torch.Tensor:
        # Assuming all_hidden_states contains embeddings for all layers
        # The last 'num_bert_layers_to_pool' layers are averaged.
        # Or, more commonly, take the [CLS] token embedding from the last few layers or just the last layer.
        # Your current pooling averages token embeddings for selected layers. Let's keep it for now.
        
        layers_to_pool = all_hidden_states[-self.num_bert_layers_to_pool:]
        pooled_outputs = []
        expanded_attention_mask = attention_mask.unsqueeze(-1).expand_as(layers_to_pool[0]) # (batch*n_comments, seq_len, hidden_size)
        
        for layer_hidden_states in layers_to_pool:
            # Masked average pooling
            sum_embeddings = torch.sum(layer_hidden_states * expanded_attention_mask, dim=1) # (batch*n_comments, hidden_size)
            sum_mask = expanded_attention_mask.sum(dim=1) # (batch*n_comments, hidden_size)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled_outputs.append(sum_embeddings / sum_mask) # Element-wise division
            
        stacked_pooled_outputs = torch.stack(pooled_outputs, dim=0) # (num_pool_layers, batch*n_comments, hidden_size)
        mean_pooled_layers_embedding = torch.mean(stacked_pooled_outputs, dim=0) # (batch*n_comments, hidden_size)
        return mean_pooled_layers_embedding


    def forward(self,
                input_ids: torch.Tensor,      # (batch_size, n_comments, seq_len)
                attention_mask: torch.Tensor, # (batch_size, n_comments, seq_len)
                q_scores: torch.Tensor,       # (batch_size, n_comments, num_q_features)
                comment_active_mask: torch.Tensor, # (batch_size, n_comments)
                other_numerical_features: Optional[torch.Tensor] = None # (batch_size, num_other_num_features)
               ):
        batch_size = input_ids.shape[0]
        
        # Flatten for BERT: (batch_size * n_comments, seq_len)
        input_ids_flat = input_ids.view(-1, input_ids.shape[-1])
        attention_mask_flat = attention_mask.view(-1, attention_mask.shape[-1])
        
        bert_outputs = self.bert(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        # bert_last_hidden_state = bert_outputs.last_hidden_state # (batch*n_comments, seq_len, bert_hidden_size)
        # Pooled BERT embeddings for each comment
        # comment_bert_embeddings_flat = bert_last_hidden_state[:, 0, :] # Using [CLS] token
        comment_bert_embeddings_flat = self._pool_bert_layers(bert_outputs.hidden_states, attention_mask_flat)


        # Reshape back to (batch_size, n_comments, bert_hidden_size)
        comment_bert_embeddings = comment_bert_embeddings_flat.view(batch_size, self.n_comments_to_process, -1)
        
        # Concatenate q_scores with BERT embeddings for each comment
        # q_scores is (batch_size, n_comments, num_q_features)
        comment_features_with_q = torch.cat((comment_bert_embeddings, q_scores), dim=2)
        
        # Attention over combined comment features
        # comment_features_with_q shape: (batch_size, n_comments, bert_hidden_size + num_q_features)
        u = torch.tanh(self.attention_w(comment_features_with_q)) # (batch_size, n_comments, attention_hidden_dim)
        scores = self.attention_v(u).squeeze(-1) # (batch_size, n_comments)
        
        if comment_active_mask is not None:
            scores = scores.masked_fill(~comment_active_mask, -1e9) # Apply mask before softmax
            
        attention_weights = F.softmax(scores, dim=1) # (batch_size, n_comments)
        attention_weights_expanded = attention_weights.unsqueeze(-1) # (batch_size, n_comments, 1)
        
        # Weighted sum of comment_features_with_q
        aggregated_comment_features = torch.sum(attention_weights_expanded * comment_features_with_q, dim=1)
        # aggregated_comment_features shape: (batch_size, bert_hidden_size + num_q_features)

        final_features_for_heads = aggregated_comment_features
        if self.uses_other_numerical_features:
            if other_numerical_features is None or other_numerical_features.shape[1] != self.num_other_numerical_features:
                raise ValueError(
                    f"Other numerical features expected but not provided correctly. "
                    f"Expected {self.num_other_numerical_features}, got shape {other_numerical_features.shape if other_numerical_features is not None else 'None'}"
                )
            processed_other_numerical_features = self.other_numerical_processor(other_numerical_features)
            final_features_for_heads = torch.cat((aggregated_comment_features, processed_other_numerical_features), dim=1)
        
        combined_features_dropped = self.dropout(final_features_for_heads)
        
        trait_regression_outputs = []
        for regressor_head in self.trait_regressors:
            trait_regression_outputs.append(regressor_head(combined_features_dropped))
        
        # Concatenate outputs for all traits: (batch_size, num_traits)
        all_trait_outputs_raw = torch.cat(trait_regression_outputs, dim=1)
        
        # Apply sigmoid to constrain output to [0, 1] for regression
        all_trait_outputs_sigmoid = torch.sigmoid(all_trait_outputs_raw)
        
        return all_trait_outputs_sigmoid

    def predict_scores(self, outputs: torch.Tensor) -> torch.Tensor:
        # The forward pass already returns the sigmoid-activated scores
        return outputs



# --- Optuna Objective Function (MODIFIED for Regression) ---
def objective(trial: optuna.trial.Trial,
              # REMOVE: train_data_list: List[Dict],
              # REMOVE: val_data_list: List[Dict],
              # ADD file paths if you want to pass them, or use global constants
              train_file_path: str,
              val_file_path: str,
              global_config: Dict,
              device: torch.device,
              num_epochs_per_trial: int = 10):
    logger.info(f"Starting Optuna Trial {trial.number}")

    num_traits = len(global_config['TRAIT_NAMES'])
    other_numerical_feature_names_trial = global_config.get('OTHER_NUMERICAL_FEATURE_NAMES', [])
    num_other_numerical_features_trial = len(other_numerical_feature_names_trial)
    num_q_features_per_comment_trial = global_config.get('NUM_Q_FEATURES_PER_COMMENT', 3)

    # --- Suggest Hyperparameters ---
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.4) # Adjusted range
    attention_hidden_dim = trial.suggest_categorical("attention_hidden_dim", [128, 256, 512]) # Larger options
    lr_bert = trial.suggest_float("lr_bert", 5e-6, 1e-4, log=True) # Adjusted range
    lr_head = trial.suggest_float("lr_head", 1e-4, 1e-2, log=True) # Adjusted range
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True) # Adjusted range
    num_bert_layers_to_pool = trial.suggest_int("num_bert_layers_to_pool", 1, 4)
    n_comments_trial = trial.suggest_int("n_comments_to_process", 1, global_config.get('MAX_COMMENTS_TO_PROCESS_PHYSICAL', 3)) # Max based on data
    num_unfrozen_bert_layers = trial.suggest_int("num_unfrozen_bert_layers", 0, 6) # Fewer unfrozen layers often better
    patience_early_stopping = trial.suggest_int("patience_early_stopping", 3, 5)
    scheduler_type = trial.suggest_categorical("scheduler_type", ["none", "linear_warmup"])
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.05, 0.2) if scheduler_type != "none" else 0.0
    batch_size_trial = trial.suggest_categorical("batch_size", [8, 16]) # Kept smaller due to BERT

    other_numerical_embedding_dim_trial = 0
    if num_other_numerical_features_trial > 0:
        other_numerical_embedding_dim_trial = trial.suggest_categorical("other_numerical_embedding_dim", [32, 64])

    logger.info(f"Trial {trial.number} - Suggested Parameters: {trial.params}")
    try:
        logger.info(f"Trial {trial.number} - Loading data from: {train_file_path}, {val_file_path}")
        train_dataset_trial = JsonlIterableDataset( # Use JsonlIterableDataset
            file_path=train_file_path, # Pass the file path
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_TRAIN_SAMPLES')
        )
        val_dataset_trial = JsonlIterableDataset( # Use JsonlIterableDataset
            file_path=val_file_path,   # Pass the file path
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_VAL_SAMPLES')
        )
        # For IterableDataset, shuffle is not a parameter.
        # num_workers > 0 can be tricky with IterableDatasets if not designed carefully. Start with 0.
        train_loader_trial = DataLoader(train_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
        val_loader_trial = DataLoader(val_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
    except Exception as e:
        logger.error(f"Trial {trial.number} - Error creating dataset/dataloader: {e}", exc_info=True)
        return float('inf')

    model = PersonalityModelV3(
        bert_model_name=global_config['BERT_MODEL_NAME'],
        num_traits=num_traits,
        n_comments_to_process=n_comments_trial,
        dropout_rate=dropout_rate,
        attention_hidden_dim=attention_hidden_dim,
        num_bert_layers_to_pool=num_bert_layers_to_pool,
        num_q_features_per_comment=num_q_features_per_comment_trial,
        num_other_numerical_features=num_other_numerical_features_trial,
        numerical_embedding_dim=other_numerical_embedding_dim_trial
    ).to(device)

    # BERT Layer Freezing
    for name, param in model.bert.named_parameters(): param.requires_grad = False # Freeze all initially
    if num_unfrozen_bert_layers > 0:
        if hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters(): param.requires_grad = True
        
        actual_layers_to_unfreeze = min(num_unfrozen_bert_layers, model.bert.config.num_hidden_layers)
        for i in range(model.bert.config.num_hidden_layers - actual_layers_to_unfreeze, model.bert.config.num_hidden_layers):
            if i >= 0:
                for param in model.bert.encoder.layer[i].parameters(): param.requires_grad = True
        
        if hasattr(model.bert, 'pooler') and model.bert.pooler is not None: # Though pooler is often not used for seq classification
            for param in model.bert.pooler.parameters(): param.requires_grad = True
    
    logger.debug(f"Trial {trial.number} - BERT params requiring grad: "
                 f"{sum(p.numel() for p in model.bert.parameters() if p.requires_grad)}")

    # Optimizer Setup
    optimizer_grouped_parameters = []
    bert_params_to_tune = [p for p in model.bert.parameters() if p.requires_grad]
    if bert_params_to_tune and lr_bert > 0:
         optimizer_grouped_parameters.append({"params": bert_params_to_tune, "lr": lr_bert, "weight_decay": 0.01}) # Different WD for BERT

    head_params = list(model.attention_w.parameters()) + list(model.attention_v.parameters())
    for regressor_head in model.trait_regressors:
        head_params.extend(list(regressor_head.parameters()))
    if model.uses_other_numerical_features:
        head_params.extend(list(model.other_numerical_processor.parameters()))
    
    optimizer_grouped_parameters.append({"params": head_params, "lr": lr_head, "weight_decay": weight_decay}) # Main WD for head
        
    if not any(pg['params'] for pg in optimizer_grouped_parameters if pg['params']): # Check if any group has params
        logger.warning(f"Trial {trial.number} - No parameters to optimize. Skipping training.")
        return float('inf') # Return high loss for minimization

    optimizer = optim.AdamW(optimizer_grouped_parameters) # WD applied per group
    
    # set schedule
    scheduler = None
    if scheduler_type == "linear_warmup":
        # Calculate num_training_steps using the pre-counted samples
        if global_config.get('NUM_TRAIN_SAMPLES', 0) > 0: # Check if count is available
            num_batches_per_epoch = (global_config['NUM_TRAIN_SAMPLES'] + batch_size_trial - 1) // batch_size_trial # Ceiling division
            num_training_steps = num_batches_per_epoch * num_epochs_per_trial
            num_warmup_steps = int(num_training_steps * warmup_ratio)
            if num_warmup_steps > 0 and num_training_steps > 0:
                scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
            else:
                logger.warning(f"Trial {trial.number}: Calculated num_warmup_steps or num_training_steps is zero. Scheduler not created. Warmup: {num_warmup_steps}, Training: {num_training_steps}")
        else:
            logger.warning(f"Trial {trial.number}: NUM_TRAIN_SAMPLES not available or zero in global_config. Cannot create linear_warmup scheduler.")


    # Regression loss
    loss_fn = nn.MSELoss().to(device) # Or nn.L1Loss()
    best_trial_val_loss = float('inf')
    patience_counter = 0
    for epoch in range(num_epochs_per_trial):
        model.train()
        total_train_loss = 0
        train_batches_processed = 0

        # testing shit
        #with torch.profiler.profile(
        #schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        #on_trace_ready=torch.profiler.tensorboard_trace_handler('./log_dir/profiler'), # Save to TensorBoard
        #record_shapes=True,
        #with_stack=True,
        #profile_memory=True
        #) as prof: # testing shit
        for batch_idx, batch_tuple in enumerate(train_loader_trial):
            input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
            
            optimizer.zero_grad()
            predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
            
            current_batch_loss = loss_fn(predicted_scores, labels_reg)
            
            if torch.isnan(current_batch_loss) or torch.isinf(current_batch_loss):
                logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: NaN or Inf loss detected. Skipping batch.")
                torch.cuda.empty_cache()
                continue

            current_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler: scheduler.step()
            total_train_loss += current_batch_loss.item()
            train_batches_processed += 1
                
                #testing shit
                #prof.step() # Signal profiler that a step is done
                #if batch_idx >= 5: # Profile a few initial steps
                #    break
                # testing shit, fix indent
            
        avg_train_loss = total_train_loss / train_batches_processed if train_batches_processed > 0 else float('inf')
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}/{num_epochs_per_trial} completed. Avg Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        current_epoch_val_loss = 0
        val_batches_processed = 0
        all_val_preds_epoch = []
        all_val_labels_epoch = []
        with torch.no_grad():
            for batch_tuple in val_loader_trial:
                input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
                if input_ids.numel() == 0: continue
                
                predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                if predicted_scores.numel() == 0: continue
                
                batch_val_loss = loss_fn(predicted_scores, labels_reg)
                current_epoch_val_loss += batch_val_loss.item()
                all_val_preds_epoch.append(predicted_scores.cpu())
                all_val_labels_epoch.append(labels_reg.cpu())
                val_batches_processed += 1

        avg_val_loss_epoch = current_epoch_val_loss / val_batches_processed if val_batches_processed > 0 else float('inf')
        
        # Calculate MAE for logging (optional, but good for interpretability)
        val_mae = -1.0
        if all_val_labels_epoch:
            all_val_labels_cat = torch.cat(all_val_labels_epoch, dim=0)
            all_val_preds_cat = torch.cat(all_val_preds_epoch, dim=0)
            if all_val_labels_cat.numel() > 0:
                val_mae = F.l1_loss(all_val_preds_cat, all_val_labels_cat).item() # MAE

        logger.info(f"Trial {trial.number}, Epoch {epoch+1} Val Loss (MSE): {avg_val_loss_epoch:.4f}, Val MAE: {val_mae:.4f}")

        if avg_val_loss_epoch < best_trial_val_loss:
            best_trial_val_loss = avg_val_loss_epoch
            patience_counter = 0
            logger.debug(f"Trial {trial.number}, Epoch {epoch+1}: New best val_loss: {best_trial_val_loss:.4f}")
            
            temp_model_dir = "optuna_trial_models"
            os.makedirs(temp_model_dir, exist_ok=True) # Ensure the directory exists
            
            # Define a unique path for this trial's best model
            temp_model_path = os.path.join(temp_model_dir, f"trial_{trial.number}_epoch_{epoch+1}_best_model.pth")
            
            # Save the model state dict to this path
            # It's good practice to save the state_dict on CPU
            current_best_state_dict_for_trial = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save(current_best_state_dict_for_trial, temp_model_path)
            logger.info(f"Trial {trial.number}: Saved new best model for this trial to {temp_model_path}")
            
            # Store the *path* to this saved model in user_attrs
            trial.set_user_attr("best_model_path_this_trial", temp_model_path)
        
        else:
            patience_counter += 1
        
        trial.report(avg_val_loss_epoch, epoch) # Report validation loss to Optuna
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned by Optuna at epoch {epoch+1}.")
            del model, train_loader_trial, val_loader_trial, optimizer, scheduler
            torch.cuda.empty_cache(); gc.collect()
            return best_trial_val_loss # Return the best loss achieved so far for this pruned trial
        
        if patience_counter >= patience_early_stopping:
            logger.info(f"Trial {trial.number} - Early stopping at epoch {epoch+1} (Patience: {patience_early_stopping}).")
            break
        
    logger.info(f"Trial {trial.number} finished. Best Val Loss (MSE) for this trial: {best_trial_val_loss:.4f}")
    del model, train_loader_trial, val_loader_trial, optimizer, scheduler
    torch.cuda.empty_cache(); gc.collect()
    return best_trial_val_loss

# In your objective function:
# train_dataset_trial = JsonlIterableDataset(
#     file_path="train_data_streamed.jsonl", # Path to your train JSONL
#     trait_names=global_config['TRAIT_NAMES_ORDERED'],
#     n_comments_to_process=n_comments_trial,
#     other_numerical_feature_names=other_numerical_feature_names_trial,
#     num_q_features_per_comment=num_q_features_per_comment_trial,
#     is_test_set=False
# )
# val_dataset_trial = JsonlIterableDataset(...) # For validation
# train_loader_trial = DataLoader(train_dataset_trial, batch_size=batch_size_trial, num_workers=0) # shuffle=True not for IterableDataset
# val_loader_trial = DataLoader(val_dataset_trial, batch_size=batch_size_trial, num_workers=0)

In [None]:


# Assuming PersonalityDatasetV3, PersonalityModelV3, decode_from_json are defined/imported
# from your_module import PersonalityDatasetV3, PersonalityModelV3, decode_from_json
# Ensure transformers.get_linear_schedule_with_warmup is available if used in objective.




DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

# --- Data Loading ---
# Ensure decode_from_json is defined and works as expected
# def decode_from_json(data): return data # Placeholder if not available for this snippet
try:
    TRAIN_DATA_FILE = "train_data.jsonl" # Adjust if your filename is different
    VAL_DATA_FILE = "val_data.jsonl"   # Adjust
    TEST_DATA_FILE = "test_data.jsonl" # Adjust
except FileNotFoundError as e:
    logger.error(f"Data file not found: {e}. Exiting.")
    exit()
except Exception as e:
    logger.error(f"Error loading or decoding data: {e}. Exiting.")
    exit()


_trait_names_ordered_config = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Emotional stability', 'Humility']
_other_numerical_features_config = [
    'mean_words_per_comment', 'mean_sents_per_comment',
    'median_words_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
]

# --- Global Configuration ---
GLOBAL_CONFIG = {
    'BERT_MODEL_NAME': "bert-base-uncased",
    'TRAIT_NAMES_ORDERED': _trait_names_ordered_config,
    'TRAIT_NAMES': _trait_names_ordered_config,
    'MAX_COMMENTS_TO_PROCESS_PHYSICAL': 3,
    'NUM_Q_FEATURES_PER_COMMENT': 3,
    'OTHER_NUMERICAL_FEATURE_NAMES': _other_numerical_features_config,
    'TOKENIZER_MAX_LENGTH': 256
}

NUM_EPOCHS_PER_TRIAL_OPTUNA = 15 # Or your desired value
N_OPTUNA_TRIALS = 20             # Or your desired value


def count_lines_in_file(filepath):
    count = 0
    with open(filepath, 'r', encoding='utf-8') as f:
        for _ in f:
            count += 1
    return count

try:
    NUM_TRAIN_SAMPLES = count_lines_in_file(TRAIN_DATA_FILE)
    logger.info(f"Number of training samples in {TRAIN_DATA_FILE}: {NUM_TRAIN_SAMPLES}")
    if NUM_TRAIN_SAMPLES == 0:
        logger.error(f"Training file {TRAIN_DATA_FILE} is empty or not found. Exiting.")
        exit()
    GLOBAL_CONFIG['NUM_TRAIN_SAMPLES'] = NUM_TRAIN_SAMPLES
    
except FileNotFoundError:
    logger.error(f"Training file {TRAIN_DATA_FILE} not found for line counting. Exiting.")
    exit()

try:
    NUM_VAL_SAMPLES = count_lines_in_file(VAL_DATA_FILE)
    GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = NUM_VAL_SAMPLES
    logger.info(f"Number of validation samples in {VAL_DATA_FILE}: {NUM_VAL_SAMPLES}")
except FileNotFoundError:
    logger.error(f"Validation data file '{VAL_DATA_FILE}' not found for line counting. Validation length will be 0.")
    GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = 0 # Set a default or handle error appropriately
except Exception as e:
    logger.error(f"Error counting validation samples: {e}")
    GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = 0



# START STUDY
logger.info(f"Starting Optuna study: {N_OPTUNA_TRIALS} trials, up to {NUM_EPOCHS_PER_TRIAL_OPTUNA} epochs/trial.")

study_name = "personality_regression_v4" # Updated name for clarity
storage_name = f"sqlite:///{study_name}.db"
BEST_PARAMS_FILENAME = f"{study_name}_best_params.json"
BEST_WEIGHTS_FILENAME = f"{study_name}_best_weights.pth"

study = optuna.create_study(study_name=study_name,
                            direction="minimize",
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=3, n_min_trials=5, interval_steps=1), # Adjusted pruner
                            storage=storage_name,
                            load_if_exists=True)
if study.trials: logger.info(f"Resuming existing study {study.study_name} with {len(study.trials)} previous trials.")

try:
    study.optimize(
        lambda trial: objective( # Assuming objective is defined above or imported
            trial, TRAIN_DATA_FILE, VAL_DATA_FILE,
            GLOBAL_CONFIG, DEVICE, num_epochs_per_trial=NUM_EPOCHS_PER_TRIAL_OPTUNA
        ),
        n_trials=N_OPTUNA_TRIALS,
        gc_after_trial=True, # Good for memory management with large models
        # n_jobs=1 # If using CUDA, often best to keep n_jobs=1 for Optuna unless objective is very CPU bound before GPU
    )
except Exception as e:
    logger.exception("An error occurred during the Optuna study.")

logger.info("\n--- Optuna Study Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")

best_trial_overall = None 

if not study.trials:
    logger.warning("No trials were completed in the study.")
else:
    try:
        # Filter for successfully completed trials that have a value
        completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE and t.value is not None]
        if completed_trials:
            best_trial_overall = study.best_trial # Optuna finds the best trial based on reported values

            if best_trial_overall:
                logger.info(f"Overall Best Trial Number: {best_trial_overall.number}")
                logger.info(f"  Value (Validation Loss - MSE): {best_trial_overall.value:.4f}")
                logger.info("  Best Params: ")
                for key, value in best_trial_overall.params.items():
                    logger.info(f"    {key}: {value}")

                # ---- SAVING BEST HYPERPARAMETERS (from overall best trial) ----
                with open(BEST_PARAMS_FILENAME, 'w') as f:
                    json.dump(best_trial_overall.params, f, indent=4)
                logger.info(f"Best hyperparameters saved to {BEST_PARAMS_FILENAME}")

                # ---- SAVING BEST MODEL WEIGHTS (from overall best trial's saved path) ----
                # --- MODIFIED SECTION ---
                if "best_model_path_this_trial" in best_trial_overall.user_attrs:
                    path_to_best_model_from_trial = best_trial_overall.user_attrs["best_model_path_this_trial"]
                    
                    if path_to_best_model_from_trial and os.path.exists(path_to_best_model_from_trial):
                        try:
                            shutil.copyfile(path_to_best_model_from_trial, BEST_WEIGHTS_FILENAME)
                            logger.info(f"Best model weights from trial {best_trial_overall.number} (path: {path_to_best_model_from_trial}) copied to {BEST_WEIGHTS_FILENAME}")
                        except Exception as e:
                            logger.error(f"Error copying best model weights from {path_to_best_model_from_trial} to {BEST_WEIGHTS_FILENAME}: {e}")
                    elif not path_to_best_model_from_trial:
                         logger.warning(
                            f"Overall best trial {best_trial_overall.number} has 'best_model_path_this_trial' "
                            "but its value is None (empty path). Weights not saved."
                        )
                    else: # path_to_best_model_from_trial is a non-empty string, but file doesn't exist
                        logger.warning(
                            f"Model file '{path_to_best_model_from_trial}' from best trial {best_trial_overall.number} "
                            "not found on disk. Weights not saved."
                        )
                else:
                    logger.warning(
                        f"Key 'best_model_path_this_trial' not found in user_attrs of the overall best trial ({best_trial_overall.number}). "
                        "Ensure your Optuna objective function correctly saves the model path to this attribute."
                    )
                # --- END OF MODIFIED SECTION ---
            else:
                logger.warning("Study has completed trials, but study.best_trial is None. Cannot save parameters or weights.")
        else:
            logger.warning("No trials completed successfully to determine the best trial. Cannot save parameters or weights.")

        study_df = study.trials_dataframe()
        # Add user attributes to dataframe if they exist and are simple types
        if completed_trials and best_trial_overall and "best_model_state_dict" in best_trial_overall.user_attrs :
            # Avoid adding the large state_dict to the CSV. Maybe add a flag or path.
            study_df['has_best_model_state'] = study_df['user_attrs_best_model_state_dict'].notna()
        study_df.to_csv(f"{study_name}_results.csv", index=False)
        logger.info(f"Optuna study results saved to {study_name}_results.csv")

    except Exception as e:
        logger.error(f"Could not process or save Optuna study results, parameters, or weights: {e}", exc_info=True)


# --- Example: Predicting on Test Data using saved best model and params ---
if os.path.exists(TEST_DATA_FILE) and best_trial_overall and best_trial_overall.params:
    logger.info(f"\n--- Predicting on Test Data using saved model from Trial {best_trial_overall.number} ---")
    try:
        # 1. Load best hyperparameters
        with open(BEST_PARAMS_FILENAME, 'r') as f:
            loaded_best_params = json.load(f)
        logger.info(f"Loaded best hyperparameters from {BEST_PARAMS_FILENAME}")

        # 2. Initialize model with best HPs
        # Ensure your model class (PersonalityModelV3) is defined or imported
        test_model = PersonalityModelV3(
            bert_model_name=GLOBAL_CONFIG['BERT_MODEL_NAME'],
            num_traits=len(GLOBAL_CONFIG['TRAIT_NAMES']),
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            dropout_rate=loaded_best_params.get("dropout_rate", 0.2), # Default if not in params
            attention_hidden_dim=loaded_best_params.get("attention_hidden_dim", 128),
            num_bert_layers_to_pool=loaded_best_params.get("num_bert_layers_to_pool", 2),
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            num_other_numerical_features=len(GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES']),
            numerical_embedding_dim=loaded_best_params.get("other_numerical_embedding_dim", 0) if GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'] else 0
        ).to(DEVICE)
        logger.info("Test model initialized with loaded best hyperparameters.")

        # 3. Load saved weights
        try:
            if torch.cuda.is_available():
                loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME)
            else:
                loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME, map_location=torch.device('cpu'))
            
            test_model.load_state_dict(loaded_state_dict)
            logger.info(f"Successfully loaded model weights from {BEST_WEIGHTS_FILENAME}")
        except FileNotFoundError:
            logger.error(f"Model weights file {BEST_WEIGHTS_FILENAME} not found. Cannot perform test prediction with loaded weights.")
            # Optionally, proceed with the uninitialized (but configured) model or exit
            raise # Re-raise if essential
        except Exception as e:
            logger.error(f"Error loading model weights: {e}. Predictions will be from a re-initialized model (likely untrained).")
            # Depending on severity, you might want to raise e here


        test_model.eval() # Set to evaluation mode

        # 4. Create Test DataLoader
        test_dataset = JsonlIterableDataset(
            data=TEST_DATA_FILE,
            trait_names=GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            other_numerical_feature_names=GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'],
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            is_test_set=True
        )
        test_loader = DataLoader(test_dataset, batch_size=loaded_best_params.get("batch_size", 8), shuffle=False)

        all_test_predictions = []
        with torch.no_grad():
            for batch_tuple in test_loader:
                # Unpack based on what PersonalityDatasetV3 yields for is_test_set=True
                # Assuming it yields (input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                # Adjust if your dataset yields differently for test_set
                input_ids, attention_m, q_s, comment_active_m, other_num_feats = [b.to(DEVICE) for b in batch_tuple]
                predicted_scores = test_model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                all_test_predictions.append(predicted_scores.cpu().numpy())

        if all_test_predictions:
            final_test_predictions = np.concatenate(all_test_predictions, axis=0)
            logger.info(f"Shape of final test predictions: {final_test_predictions.shape}")
            for i in range(min(5, len(final_test_predictions))): # Print first 5
                # Assuming test_data items have an 'id' field for logging
                sample_id = test_data[i].get('id', f'Unknown_ID_{i}')
                pred_dict = {trait: round(score.item(), 4) for trait, score in zip(GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'], final_test_predictions[i])}
                logger.info(f"Test Sample {sample_id} Predictions: {pred_dict}")
            # np.save(f"{study_name}_test_predictions.npy", final_test_predictions)
            # logger.info(f"Test predictions saved to {study_name}_test_predictions.npy")
        else:
            logger.warning("No predictions generated for the test set.")

    except FileNotFoundError:
        logger.warning(f"Best parameters file {BEST_PARAMS_FILENAME} not found. Skipping test prediction with loaded model.")
    except Exception as e:
        logger.error(f"An error occurred during test prediction: {e}", exc_info=True)
elif not test_data:
    logger.info("No test data provided. Skipping test prediction example.")
elif not best_trial_overall or not best_trial_overall.params:
    logger.warning("No successful best trial found or best trial has no params. Skipping test prediction example.")

# OLDER


In [None]:
import json
import torch
from torch.utils.data import IterableDataset
from transformers.tokenization_utils_base import BatchEncoding # For your decode_from_json
import logging
import random
import numpy as np
import torch.nn.functional as F
from transformers import BertModel, BertConfig, get_linear_schedule_with_warmup
from typing import Optional, Tuple, Dict, Union
from torch import nn
import optuna
from torch.utils.data import DataLoader
import gc
from transformers.tokenization_utils_base import BatchEncoding # For type checking and instantiation
import torch.optim as optim
import os

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants for JSON (ensure these match what you used when saving) ---
_TENSOR_MARKER = "__tensor__"
_TENSOR_DTYPE_MARKER = "__tensor_dtype__"
_BATCH_ENCODING_MARKER = "__batch_encoding__"
_BATCH_ENCODING_DATA_MARKER = "data" # Make sure this matches what was saved

def _convert_str_to_dtype(dtype_str: str) -> torch.dtype:
    """Converts a string representation back to a torch.dtype."""
    if not dtype_str.startswith("torch."):
        try:
            return torch.__getattribute__(dtype_str) # e.g. "float32"
        except AttributeError:
            return torch.dtype(dtype_str) # Try direct parsing
    dtype_name = dtype_str.split('.')[1] # e.g., "torch.int64" -> "int64"
    return torch.__getattribute__(dtype_name)

def _json_object_hook_for_dataset(dct: dict) -> any:
    """
    Object hook for json.loads to reconstruct tensors and BatchEncoding objects.
    """
    if _TENSOR_MARKER in dct:
        dtype_str = dct.get(_TENSOR_DTYPE_MARKER, 'float32') # Default dtype
        dtype = _convert_str_to_dtype(dtype_str)
        # Data from tensor.tolist() is a list of lists (or list for 1D)
        return torch.tensor(dct[_BATCH_ENCODING_DATA_MARKER], dtype=dtype)
    elif _BATCH_ENCODING_MARKER in dct:
        # The 'data' part of BatchEncoding should be a dictionary.
        # Its values (like input_ids) should have been converted to tensors
        # by this hook if they were marked as tensors.
        reconstructed_data_for_be = {}
        batch_encoding_payload = dct.get(_BATCH_ENCODING_DATA_MARKER, {})
        for k, v_data in batch_encoding_payload.items():
            # If v_data is a list (e.g., input_ids was list of lists from tolist())
            # and wasn't explicitly marked as a __tensor__ itself, convert it now.
            # This typically happens if the BatchEncoding's internal tensors were directly converted to lists.
            if isinstance(v_data, list) and k in ["input_ids", "token_type_ids", "attention_mask"]:
                try:
                    # Determine dtype (input_ids, token_type_ids are usually long)
                    tensor_dtype = torch.long if k in ["input_ids", "token_type_ids"] else torch.long # attention_mask can be long or bool
                    reconstructed_data_for_be[k] = torch.tensor(v_data, dtype=tensor_dtype)
                except Exception as e:
                    logger.error(f"Error converting field '{k}' in BatchEncoding to tensor: {e}. Keeping as list.")
                    reconstructed_data_for_be[k] = v_data # Fallback
            else:
                reconstructed_data_for_be[k] = v_data # Already a tensor or primitive
        return BatchEncoding(reconstructed_data_for_be)
    return dct

class JsonlIterableDataset(IterableDataset):
    def __init__(self, file_path, trait_names, n_comments_to_process,
                 other_numerical_feature_names, num_q_features_per_comment,
                 is_test_set=False, transform_fn=None, num_samples = None):
        super().__init__()
        self.file_path = file_path
        self.trait_names_ordered = trait_names
        self.n_comments_to_process = n_comments_to_process
        self.other_numerical_feature_names = other_numerical_feature_names
        self.num_q_features_per_comment = num_q_features_per_comment
        self.is_test_set = is_test_set
        # transform_fn is what PersonalityDatasetV3.__getitem__ does
        self.transform_fn = self._default_transform if transform_fn is None else transform_fn
        if num_samples is None:
            logger.info(f'Counting samples in {file_path} for __len__ was not provided...')
            self.num_samples = self._count_samples_in_file()
            logger.info(f"Counted {self.num_samples} samples in {self.file_path}.")
        else:
            self.num_samples = num_samples
        if self.num_samples == 0:
            logger.warning(f"Initialized JsonlIterableDataset for {self.file_path} with 0 samples. DataLoader will be empty.")
    


    
    def _count_samples_in_file(self):
            count = 0
            try:
                with open(self.file_path, 'r', encoding='utf-8') as f:
                    for _ in f:
                        count += 1
            except FileNotFoundError:
                logger.error(f"File not found during initial sample count: {self.file_path}. Returning 0 samples.")
                return 0
            except Exception as e:
                logger.error(f"Error during initial sample count for {self.file_path}: {e}. Returning 0 samples.")
                return 0
            return count
    



    def _process_line(self, line):
        try:
            # Apply the hook to each JSON object (line)
            sample = json.loads(line, object_hook=_json_object_hook_for_dataset)
            return self.transform_fn(sample, idx=None) # idx is not really used if sample has all info
        except json.JSONDecodeError as e:
            # logger.error(f"Error decoding JSON in {self.file_path}: {e} on line: {line[:100]}")
            return None
        except Exception as e_hook:
            # logger.error(f"Error in object_hook or transform_fn in {self.file_path}: {e_hook}")
            return None
        

    def __len__(self):
        return self.num_samples
    


    def _default_transform(self, sample, idx): # Replicates PersonalityDatasetV3.__getitem__ logic
        # --- Start of PersonalityDatasetV3.__getitem__ logic ---
        tokenized_info = sample.get('features', {}).get('comments_tokenized', {})
        all_input_ids = tokenized_info['input_ids']
        all_attention_mask = tokenized_info['attention_mask']
        
        
        num_actual_comments = all_input_ids.shape[0]
        # more robust seq_len 

        final_input_ids = torch.zeros((self.n_comments_to_process, all_input_ids.shape[1]), dtype=torch.long)
        final_attention_mask = torch.zeros((self.n_comments_to_process, all_attention_mask.shape[1]), dtype=torch.long)
        comment_active_flags = torch.zeros(self.n_comments_to_process, dtype=torch.bool)

        indices_to_select = list(range(num_actual_comments))
        if num_actual_comments > self.n_comments_to_process:
            indices_to_select = random.sample(indices_to_select, self.n_comments_to_process)
            comments_to_fill = self.n_comments_to_process
        else:
            comments_to_fill = num_actual_comments
        
        for i in range(comments_to_fill):
            original_idx = indices_to_select[i]
            final_input_ids[i] = all_input_ids[original_idx]
            final_attention_mask[i] = all_attention_mask[original_idx]
            comment_active_flags[i] = True

        raw_q_scores = sample['features'].get('q_scores', [])
        final_q_scores = torch.zeros((self.n_comments_to_process, self.num_q_features_per_comment), dtype=torch.float)
        num_actual_q_score_sets = len(raw_q_scores)
        q_scores_to_fill = min(num_actual_q_score_sets, self.n_comments_to_process)




        selected_raw_q_scores = []
        for i in range(comments_to_fill): # Iterate up to comments_to_fill
            original_comment_idx = indices_to_select[i]
            if original_comment_idx < len(raw_q_scores):
                qs_for_comment = raw_q_scores[original_comment_idx][:self.num_q_features_per_comment]
                # Pad if necessary
                padded_qs = qs_for_comment + [0.0] * (self.num_q_features_per_comment - len(qs_for_comment))
                selected_raw_q_scores.append(padded_qs[:self.num_q_features_per_comment])
            else:
                selected_raw_q_scores.append([0.0] * self.num_q_features_per_comment)

        if comments_to_fill > 0:
            try:
                final_q_scores[:comments_to_fill] = torch.tensor(selected_raw_q_scores, dtype=torch.float)
            except Exception as e: # Catch error if selected_raw_q_scores is ragged or non-numeric
                logger.error(f"Error converting selected_raw_q_scores to tensor: {e}. Data: {selected_raw_q_scores}")
                # final_q_scores will remain zeros for this batch
        # else: final_q_scores remains zeros.






        other_numerical_features_list = []
        for fname in self.other_numerical_feature_names:
            val = sample['features'].get(fname, 0.0)
            try:
                other_numerical_features_list.append(float(val))
            except (ValueError, TypeError):
                other_numerical_features_list.append(0.0)
        other_numerical_features_tensor = torch.tensor(other_numerical_features_list, dtype=torch.float)

        if not self.is_test_set:
            labels_dict = sample['labels']
            regression_labels = []
            for trait_key in self.trait_names_ordered:
                label_val = labels_dict.get(trait_key.title(), labels_dict.get(trait_key, 0.0))
                try:
                    label_float = float(label_val)
                    if not (0.0 <= label_float <= 1.0): label_float = np.clip(label_float, 0.0, 1.0)
                    regression_labels.append(label_float)
                except (ValueError, TypeError): regression_labels.append(0.0)
            labels_tensor = torch.tensor(regression_labels, dtype=torch.float)
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor, labels_tensor)
        else:
            return (final_input_ids, final_attention_mask, final_q_scores, comment_active_flags, other_numerical_features_tensor)
        # --- End of PersonalityDatasetV3.__getitem__ logic ---

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        file_iter = open(self.file_path, 'r', encoding='utf-8')

        if worker_info is None:  # single-process data loading
            for line in file_iter:
                processed_item = self._process_line(line)
                if processed_item:
                    yield processed_item
        else:  # multi-process data loading
            # Each worker processes a different part of the file (approximate)
            # This is a simplified way; for exact splitting, one might pre-calculate line offsets.
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            for i, line in enumerate(file_iter):
                if i % num_workers == worker_id:
                    processed_item = self._process_line(line)
                    if processed_item:
                        yield processed_item
        file_iter.close()


# --- Regression Loss Function (NEW) ---
# We'll use nn.MSELoss directly in the training loop.

# --- PersonalityModelV3 (Regression and q_scores integration) ---
class PersonalityModelV3(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 num_traits: int,
                 n_comments_to_process: int = 3,
                 dropout_rate: float = 0.2,
                 attention_hidden_dim: int = 128,
                 num_bert_layers_to_pool: int = 4,
                 num_q_features_per_comment: int = 3, # For Q1, Q2, Q3 scores per comment
                 num_other_numerical_features: int = 0, # From sample['features'] excluding q_scores
                 numerical_embedding_dim: int = 64
                ):
        super().__init__()
        self.bert_config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)
        self.bert = BertModel.from_pretrained(bert_model_name, config=self.bert_config)
        self.n_comments_to_process = n_comments_to_process
        self.num_bert_layers_to_pool = num_bert_layers_to_pool
        bert_hidden_size = self.bert.config.hidden_size
        self.num_q_features_per_comment = num_q_features_per_comment

        # Comment processing part (BERT embedding + q_scores)
        comment_feature_dim = bert_hidden_size + self.num_q_features_per_comment
        self.attention_w = nn.Linear(comment_feature_dim, attention_hidden_dim)
        self.attention_v = nn.Linear(attention_hidden_dim, 1, bias=False)
        
        self.dropout = nn.Dropout(dropout_rate)

        # Other numerical features processing part (from sample['features'])
        self.num_other_numerical_features = num_other_numerical_features
        self.uses_other_numerical_features = self.num_other_numerical_features > 0
        self.other_numerical_processor_output_dim = 0

        # Dimension of aggregated comment features (output of attention over comment_feature_dim)
        aggregated_comment_feature_dim = comment_feature_dim 
        combined_input_dim_for_heads = aggregated_comment_feature_dim

        if self.uses_other_numerical_features:
            self.other_numerical_processor_output_dim = numerical_embedding_dim
            self.other_numerical_processor = nn.Sequential(
                nn.Linear(self.num_other_numerical_features, self.other_numerical_processor_output_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )
            combined_input_dim_for_heads += self.other_numerical_processor_output_dim
            logger.info(f"Model will use {self.num_other_numerical_features} other numerical features, processed to dim {self.other_numerical_processor_output_dim}.")
        else:
            logger.info("Model will NOT use other numerical features.")

        # Trait regression heads
        self.trait_regressors = nn.ModuleList()
        for _ in range(num_traits):
            self.trait_regressors.append(
                nn.Linear(combined_input_dim_for_heads, 1) # Output one value per trait
            )

    def _pool_bert_layers(self, all_hidden_states: Tuple[torch.Tensor, ...], attention_mask: torch.Tensor) -> torch.Tensor:
        # Assuming all_hidden_states contains embeddings for all layers
        # The last 'num_bert_layers_to_pool' layers are averaged.
        # Or, more commonly, take the [CLS] token embedding from the last few layers or just the last layer.
        # Your current pooling averages token embeddings for selected layers. Let's keep it for now.
        
        layers_to_pool = all_hidden_states[-self.num_bert_layers_to_pool:]
        pooled_outputs = []
        expanded_attention_mask = attention_mask.unsqueeze(-1).expand_as(layers_to_pool[0]) # (batch*n_comments, seq_len, hidden_size)
        
        for layer_hidden_states in layers_to_pool:
            # Masked average pooling
            sum_embeddings = torch.sum(layer_hidden_states * expanded_attention_mask, dim=1) # (batch*n_comments, hidden_size)
            sum_mask = expanded_attention_mask.sum(dim=1) # (batch*n_comments, hidden_size)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled_outputs.append(sum_embeddings / sum_mask) # Element-wise division
            
        stacked_pooled_outputs = torch.stack(pooled_outputs, dim=0) # (num_pool_layers, batch*n_comments, hidden_size)
        mean_pooled_layers_embedding = torch.mean(stacked_pooled_outputs, dim=0) # (batch*n_comments, hidden_size)
        return mean_pooled_layers_embedding


    def forward(self,
                input_ids: torch.Tensor,      # (batch_size, n_comments, seq_len)
                attention_mask: torch.Tensor, # (batch_size, n_comments, seq_len)
                q_scores: torch.Tensor,       # (batch_size, n_comments, num_q_features)
                comment_active_mask: torch.Tensor, # (batch_size, n_comments)
                other_numerical_features: Optional[torch.Tensor] = None # (batch_size, num_other_num_features)
               ):
        batch_size = input_ids.shape[0]
        
        # Flatten for BERT: (batch_size * n_comments, seq_len)
        input_ids_flat = input_ids.view(-1, input_ids.shape[-1])
        attention_mask_flat = attention_mask.view(-1, attention_mask.shape[-1])
        
        bert_outputs = self.bert(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        # bert_last_hidden_state = bert_outputs.last_hidden_state # (batch*n_comments, seq_len, bert_hidden_size)
        # Pooled BERT embeddings for each comment
        # comment_bert_embeddings_flat = bert_last_hidden_state[:, 0, :] # Using [CLS] token
        comment_bert_embeddings_flat = self._pool_bert_layers(bert_outputs.hidden_states, attention_mask_flat)


        # Reshape back to (batch_size, n_comments, bert_hidden_size)
        comment_bert_embeddings = comment_bert_embeddings_flat.view(batch_size, self.n_comments_to_process, -1)
        
        # Concatenate q_scores with BERT embeddings for each comment
        # q_scores is (batch_size, n_comments, num_q_features)
        comment_features_with_q = torch.cat((comment_bert_embeddings, q_scores), dim=2)
        
        # Attention over combined comment features
        # comment_features_with_q shape: (batch_size, n_comments, bert_hidden_size + num_q_features)
        u = torch.tanh(self.attention_w(comment_features_with_q)) # (batch_size, n_comments, attention_hidden_dim)
        scores = self.attention_v(u).squeeze(-1) # (batch_size, n_comments)
        
        if comment_active_mask is not None:
            scores = scores.masked_fill(~comment_active_mask, -1e9) # Apply mask before softmax
            
        attention_weights = F.softmax(scores, dim=1) # (batch_size, n_comments)
        attention_weights_expanded = attention_weights.unsqueeze(-1) # (batch_size, n_comments, 1)
        
        # Weighted sum of comment_features_with_q
        aggregated_comment_features = torch.sum(attention_weights_expanded * comment_features_with_q, dim=1)
        # aggregated_comment_features shape: (batch_size, bert_hidden_size + num_q_features)

        final_features_for_heads = aggregated_comment_features
        if self.uses_other_numerical_features:
            if other_numerical_features is None or other_numerical_features.shape[1] != self.num_other_numerical_features:
                raise ValueError(
                    f"Other numerical features expected but not provided correctly. "
                    f"Expected {self.num_other_numerical_features}, got shape {other_numerical_features.shape if other_numerical_features is not None else 'None'}"
                )
            processed_other_numerical_features = self.other_numerical_processor(other_numerical_features)
            final_features_for_heads = torch.cat((aggregated_comment_features, processed_other_numerical_features), dim=1)
        
        combined_features_dropped = self.dropout(final_features_for_heads)
        
        trait_regression_outputs = []
        for regressor_head in self.trait_regressors:
            trait_regression_outputs.append(regressor_head(combined_features_dropped))
        
        # Concatenate outputs for all traits: (batch_size, num_traits)
        all_trait_outputs_raw = torch.cat(trait_regression_outputs, dim=1)
        
        # Apply sigmoid to constrain output to [0, 1] for regression
        all_trait_outputs_sigmoid = torch.sigmoid(all_trait_outputs_raw)
        
        return all_trait_outputs_sigmoid

    def predict_scores(self, outputs: torch.Tensor) -> torch.Tensor:
        # The forward pass already returns the sigmoid-activated scores
        return outputs



# --- Optuna Objective Function (MODIFIED for Regression) ---
def objective(trial: optuna.trial.Trial,
              # REMOVE: train_data_list: List[Dict],
              # REMOVE: val_data_list: List[Dict],
              # ADD file paths if you want to pass them, or use global constants
              train_file_path: str,
              val_file_path: str,
              global_config: Dict,
              device: torch.device,
              num_epochs_per_trial: int = 10):
    logger.info(f"Starting Optuna Trial {trial.number}")

    num_traits = len(global_config['TRAIT_NAMES'])
    other_numerical_feature_names_trial = global_config.get('OTHER_NUMERICAL_FEATURE_NAMES', [])
    num_other_numerical_features_trial = len(other_numerical_feature_names_trial)
    num_q_features_per_comment_trial = global_config.get('NUM_Q_FEATURES_PER_COMMENT', 3)

    # --- Suggest Hyperparameters ---
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.4) # Adjusted range
    attention_hidden_dim = trial.suggest_categorical("attention_hidden_dim", [128, 256, 512]) # Larger options
    lr_bert = trial.suggest_float("lr_bert", 5e-6, 1e-4, log=True) # Adjusted range
    lr_head = trial.suggest_float("lr_head", 1e-4, 1e-2, log=True) # Adjusted range
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True) # Adjusted range
    num_bert_layers_to_pool = trial.suggest_int("num_bert_layers_to_pool", 1, 4)
    n_comments_trial = trial.suggest_int("n_comments_to_process", 1, global_config.get('MAX_COMMENTS_TO_PROCESS_PHYSICAL', 3)) # Max based on data
    num_unfrozen_bert_layers = trial.suggest_int("num_unfrozen_bert_layers", 0, 6) # Fewer unfrozen layers often better
    patience_early_stopping = trial.suggest_int("patience_early_stopping", 3, 5)
    scheduler_type = trial.suggest_categorical("scheduler_type", ["none", "linear_warmup"])
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.05, 0.2) if scheduler_type != "none" else 0.0
    batch_size_trial = trial.suggest_categorical("batch_size", [8, 16]) # Kept smaller due to BERT

    other_numerical_embedding_dim_trial = 0
    if num_other_numerical_features_trial > 0:
        other_numerical_embedding_dim_trial = trial.suggest_categorical("other_numerical_embedding_dim", [32, 64])

    logger.info(f"Trial {trial.number} - Suggested Parameters: {trial.params}")
    try:
        logger.info(f"Trial {trial.number} - Loading data from: {train_file_path}, {val_file_path}")
        train_dataset_trial = JsonlIterableDataset( # Use JsonlIterableDataset
            file_path=train_file_path, # Pass the file path
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_TRAIN_SAMPLES')
        )
        val_dataset_trial = JsonlIterableDataset( # Use JsonlIterableDataset
            file_path=val_file_path,   # Pass the file path
            trait_names=global_config['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=n_comments_trial,
            other_numerical_feature_names=other_numerical_feature_names_trial,
            num_q_features_per_comment=num_q_features_per_comment_trial,
            is_test_set=False, num_samples=global_config.get('NUM_VAL_SAMPLES')
        )
        # For IterableDataset, shuffle is not a parameter.
        # num_workers > 0 can be tricky with IterableDatasets if not designed carefully. Start with 0.
        train_loader_trial = DataLoader(train_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
        val_loader_trial = DataLoader(val_dataset_trial, batch_size=batch_size_trial, num_workers=0, pin_memory=True if device.type == 'cuda' else False, persistent_workers=False)
    except Exception as e:
        logger.error(f"Trial {trial.number} - Error creating dataset/dataloader: {e}", exc_info=True)
        return float('inf')

    model = PersonalityModelV3(
        bert_model_name=global_config['BERT_MODEL_NAME'],
        num_traits=num_traits,
        n_comments_to_process=n_comments_trial,
        dropout_rate=dropout_rate,
        attention_hidden_dim=attention_hidden_dim,
        num_bert_layers_to_pool=num_bert_layers_to_pool,
        num_q_features_per_comment=num_q_features_per_comment_trial,
        num_other_numerical_features=num_other_numerical_features_trial,
        numerical_embedding_dim=other_numerical_embedding_dim_trial
    ).to(device)

    # BERT Layer Freezing
    for name, param in model.bert.named_parameters(): param.requires_grad = False # Freeze all initially
    if num_unfrozen_bert_layers > 0:
        if hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters(): param.requires_grad = True
        
        actual_layers_to_unfreeze = min(num_unfrozen_bert_layers, model.bert.config.num_hidden_layers)
        for i in range(model.bert.config.num_hidden_layers - actual_layers_to_unfreeze, model.bert.config.num_hidden_layers):
            if i >= 0:
                for param in model.bert.encoder.layer[i].parameters(): param.requires_grad = True
        
        if hasattr(model.bert, 'pooler') and model.bert.pooler is not None: # Though pooler is often not used for seq classification
            for param in model.bert.pooler.parameters(): param.requires_grad = True
    
    logger.debug(f"Trial {trial.number} - BERT params requiring grad: "
                 f"{sum(p.numel() for p in model.bert.parameters() if p.requires_grad)}")

    # Optimizer Setup
    optimizer_grouped_parameters = []
    bert_params_to_tune = [p for p in model.bert.parameters() if p.requires_grad]
    if bert_params_to_tune and lr_bert > 0:
         optimizer_grouped_parameters.append({"params": bert_params_to_tune, "lr": lr_bert, "weight_decay": 0.01}) # Different WD for BERT

    head_params = list(model.attention_w.parameters()) + list(model.attention_v.parameters())
    for regressor_head in model.trait_regressors:
        head_params.extend(list(regressor_head.parameters()))
    if model.uses_other_numerical_features:
        head_params.extend(list(model.other_numerical_processor.parameters()))
    
    optimizer_grouped_parameters.append({"params": head_params, "lr": lr_head, "weight_decay": weight_decay}) # Main WD for head
        
    if not any(pg['params'] for pg in optimizer_grouped_parameters if pg['params']): # Check if any group has params
        logger.warning(f"Trial {trial.number} - No parameters to optimize. Skipping training.")
        return float('inf') # Return high loss for minimization

    optimizer = optim.AdamW(optimizer_grouped_parameters) # WD applied per group
    
    # set schedule
    scheduler = None
    if scheduler_type == "linear_warmup":
        # Calculate num_training_steps using the pre-counted samples
        if global_config.get('NUM_TRAIN_SAMPLES', 0) > 0: # Check if count is available
            num_batches_per_epoch = (global_config['NUM_TRAIN_SAMPLES'] + batch_size_trial - 1) // batch_size_trial # Ceiling division
            num_training_steps = num_batches_per_epoch * num_epochs_per_trial
            num_warmup_steps = int(num_training_steps * warmup_ratio)
            if num_warmup_steps > 0 and num_training_steps > 0:
                scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
            else:
                logger.warning(f"Trial {trial.number}: Calculated num_warmup_steps or num_training_steps is zero. Scheduler not created. Warmup: {num_warmup_steps}, Training: {num_training_steps}")
        else:
            logger.warning(f"Trial {trial.number}: NUM_TRAIN_SAMPLES not available or zero in global_config. Cannot create linear_warmup scheduler.")


    # Regression loss
    loss_fn = nn.MSELoss().to(device) # Or nn.L1Loss()
    best_trial_val_loss = float('inf')
    patience_counter = 0
    for epoch in range(num_epochs_per_trial):
        model.train()
        total_train_loss = 0
        train_batches_processed = 0

        # testing shit
        #with torch.profiler.profile(
        #schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        #on_trace_ready=torch.profiler.tensorboard_trace_handler('./log_dir/profiler'), # Save to TensorBoard
        #record_shapes=True,
        #with_stack=True,
        #profile_memory=True
        #) as prof: # testing shit
        for batch_idx, batch_tuple in enumerate(train_loader_trial):
            input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
            
            optimizer.zero_grad()
            predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
            
            current_batch_loss = loss_fn(predicted_scores, labels_reg)
            
            if torch.isnan(current_batch_loss) or torch.isinf(current_batch_loss):
                logger.warning(f"Trial {trial.number}, Epoch {epoch+1}, Batch {batch_idx}: NaN or Inf loss detected. Skipping batch.")
                torch.cuda.empty_cache()
                continue

            current_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler: scheduler.step()
            total_train_loss += current_batch_loss.item()
            train_batches_processed += 1
                
                #testing shit
                #prof.step() # Signal profiler that a step is done
                #if batch_idx >= 5: # Profile a few initial steps
                #    break
                # testing shit, fix indent
            
        avg_train_loss = total_train_loss / train_batches_processed if train_batches_processed > 0 else float('inf')
        logger.info(f"Trial {trial.number}, Epoch {epoch+1}/{num_epochs_per_trial} completed. Avg Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        current_epoch_val_loss = 0
        val_batches_processed = 0
        all_val_preds_epoch = []
        all_val_labels_epoch = []
        with torch.no_grad():
            for batch_tuple in val_loader_trial:
                input_ids, attention_m, q_s, comment_active_m, other_num_feats, labels_reg = [b.to(device) for b in batch_tuple]
                if input_ids.numel() == 0: continue
                
                predicted_scores = model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                if predicted_scores.numel() == 0: continue
                
                batch_val_loss = loss_fn(predicted_scores, labels_reg)
                current_epoch_val_loss += batch_val_loss.item()
                all_val_preds_epoch.append(predicted_scores.cpu())
                all_val_labels_epoch.append(labels_reg.cpu())
                val_batches_processed += 1

        avg_val_loss_epoch = current_epoch_val_loss / val_batches_processed if val_batches_processed > 0 else float('inf')
        
        # Calculate MAE for logging (optional, but good for interpretability)
        val_mae = -1.0
        if all_val_labels_epoch:
            all_val_labels_cat = torch.cat(all_val_labels_epoch, dim=0)
            all_val_preds_cat = torch.cat(all_val_preds_epoch, dim=0)
            if all_val_labels_cat.numel() > 0:
                val_mae = F.l1_loss(all_val_preds_cat, all_val_labels_cat).item() # MAE

        logger.info(f"Trial {trial.number}, Epoch {epoch+1} Val Loss (MSE): {avg_val_loss_epoch:.4f}, Val MAE: {val_mae:.4f}")

        if avg_val_loss_epoch < best_trial_val_loss:
            best_trial_val_loss = avg_val_loss_epoch
            patience_counter = 0
            logger.debug(f"Trial {trial.number}, Epoch {epoch+1}: New best val_loss: {best_trial_val_loss:.4f}")
            
            best_model_state_for_trial = {k: v.cpu() for k, v in model.state_dict().items()}
            trial.set_user_attr("best_model_state_dict_for_trial", best_model_state_for_trial)
        else:
            patience_counter += 1
        
        trial.report(avg_val_loss_epoch, epoch) # Report validation loss to Optuna
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned by Optuna at epoch {epoch+1}.")
            del model, train_loader_trial, val_loader_trial, optimizer, scheduler
            torch.cuda.empty_cache(); gc.collect()
            return best_trial_val_loss # Return the best loss achieved so far for this pruned trial
        
        if patience_counter >= patience_early_stopping:
            logger.info(f"Trial {trial.number} - Early stopping at epoch {epoch+1} (Patience: {patience_early_stopping}).")
            break
        
        logger.info(f"Trial {trial.number} finished. Best Val Loss (MSE) for this trial: {best_trial_val_loss:.4f}")
        del model, train_loader_trial, val_loader_trial, optimizer, scheduler
        torch.cuda.empty_cache(); gc.collect()
        return best_trial_val_loss

# In your objective function:
# train_dataset_trial = JsonlIterableDataset(
#     file_path="train_data_streamed.jsonl", # Path to your train JSONL
#     trait_names=global_config['TRAIT_NAMES_ORDERED'],
#     n_comments_to_process=n_comments_trial,
#     other_numerical_feature_names=other_numerical_feature_names_trial,
#     num_q_features_per_comment=num_q_features_per_comment_trial,
#     is_test_set=False
# )
# val_dataset_trial = JsonlIterableDataset(...) # For validation
# train_loader_trial = DataLoader(train_dataset_trial, batch_size=batch_size_trial, num_workers=0) # shuffle=True not for IterableDataset
# val_loader_trial = DataLoader(val_dataset_trial, batch_size=batch_size_trial, num_workers=0)

In [None]:


# Assuming PersonalityDatasetV3, PersonalityModelV3, decode_from_json are defined/imported
# from your_module import PersonalityDatasetV3, PersonalityModelV3, decode_from_json
# Ensure transformers.get_linear_schedule_with_warmup is available if used in objective.




DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

# --- Data Loading ---
# Ensure decode_from_json is defined and works as expected
# def decode_from_json(data): return data # Placeholder if not available for this snippet
try:
    TRAIN_DATA_FILE = "train_data.jsonl" # Adjust if your filename is different
    VAL_DATA_FILE = "val_data.jsonl"   # Adjust
    TEST_DATA_FILE = "test_data.jsonl" # Adjust
except FileNotFoundError as e:
    logger.error(f"Data file not found: {e}. Exiting.")
    exit()
except Exception as e:
    logger.error(f"Error loading or decoding data: {e}. Exiting.")
    exit()


_trait_names_ordered_config = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Emotional stability', 'Humility']
_other_numerical_features_config = [
    'mean_words_per_comment', 'mean_sents_per_comment',
    'median_words_per_comment', 'mean_words_per_sentence', 'median_words_per_sentence',
    'sents_per_comment_skew', 'words_per_sentence_skew', 'total_double_whitespace',
    'punc_em_total', 'punc_qm_total', 'punc_period_total', 'punc_comma_total',
    'punc_colon_total', 'punc_semicolon_total', 'flesch_reading_ease_agg',
    'gunning_fog_agg', 'mean_word_len_overall', 'ttr_overall',
    'mean_sentiment_neg', 'mean_sentiment_neu', 'mean_sentiment_pos',
    'mean_sentiment_compound', 'std_sentiment_compound'
]

# --- Global Configuration ---
GLOBAL_CONFIG = {
    'BERT_MODEL_NAME': "bert-base-uncased",
    'TRAIT_NAMES_ORDERED': _trait_names_ordered_config,
    'TRAIT_NAMES': _trait_names_ordered_config,
    'MAX_COMMENTS_TO_PROCESS_PHYSICAL': 3,
    'NUM_Q_FEATURES_PER_COMMENT': 3,
    'OTHER_NUMERICAL_FEATURE_NAMES': _other_numerical_features_config,
    'TOKENIZER_MAX_LENGTH': 256
}

NUM_EPOCHS_PER_TRIAL_OPTUNA = 15 # Or your desired value
N_OPTUNA_TRIALS = 20             # Or your desired value


def count_lines_in_file(filepath):
    count = 0
    with open(filepath, 'r', encoding='utf-8') as f:
        for _ in f:
            count += 1
    return count

try:
    NUM_TRAIN_SAMPLES = count_lines_in_file(TRAIN_DATA_FILE)
    logger.info(f"Number of training samples in {TRAIN_DATA_FILE}: {NUM_TRAIN_SAMPLES}")
    if NUM_TRAIN_SAMPLES == 0:
        logger.error(f"Training file {TRAIN_DATA_FILE} is empty or not found. Exiting.")
        exit()
    GLOBAL_CONFIG['NUM_TRAIN_SAMPLES'] = NUM_TRAIN_SAMPLES
    
except FileNotFoundError:
    logger.error(f"Training file {TRAIN_DATA_FILE} not found for line counting. Exiting.")
    exit()

try:
    NUM_VAL_SAMPLES = count_lines_in_file(VAL_DATA_FILE)
    GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = NUM_VAL_SAMPLES
    logger.info(f"Number of validation samples in {VAL_DATA_FILE}: {NUM_VAL_SAMPLES}")
except FileNotFoundError:
    logger.error(f"Validation data file '{VAL_DATA_FILE}' not found for line counting. Validation length will be 0.")
    GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = 0 # Set a default or handle error appropriately
except Exception as e:
    logger.error(f"Error counting validation samples: {e}")
    GLOBAL_CONFIG['NUM_VAL_SAMPLES'] = 0



# START STUDY
logger.info(f"Starting Optuna study: {N_OPTUNA_TRIALS} trials, up to {NUM_EPOCHS_PER_TRIAL_OPTUNA} epochs/trial.")

study_name = "personality_regression_v4" # Updated name for clarity
storage_name = f"sqlite:///{study_name}.db"
BEST_PARAMS_FILENAME = f"{study_name}_best_params.json"
BEST_WEIGHTS_FILENAME = f"{study_name}_best_weights.pth"

study = optuna.create_study(study_name=study_name,
                            direction="minimize",
                            pruner=optuna.pruners.MedianPruner(n_warmup_steps=3, n_min_trials=5, interval_steps=1), # Adjusted pruner
                            storage=storage_name,
                            load_if_exists=True)
if study.trials: logger.info(f"Resuming existing study {study.study_name} with {len(study.trials)} previous trials.")

try:
    study.optimize(
        lambda trial: objective( # Assuming objective is defined above or imported
            trial, TRAIN_DATA_FILE, VAL_DATA_FILE,
            GLOBAL_CONFIG, DEVICE, num_epochs_per_trial=NUM_EPOCHS_PER_TRIAL_OPTUNA
        ),
        n_trials=N_OPTUNA_TRIALS,
        gc_after_trial=True, # Good for memory management with large models
        # n_jobs=1 # If using CUDA, often best to keep n_jobs=1 for Optuna unless objective is very CPU bound before GPU
    )
except Exception as e:
    logger.exception("An error occurred during the Optuna study.")

logger.info("\n--- Optuna Study Finished ---")
logger.info(f"Number of finished trials: {len(study.trials)}")

best_trial_overall = None # To store the actual best trial object

if not study.trials:
    logger.warning("No trials were completed in the study.")
else:
    try:
        completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE and t.value is not None]
        if completed_trials:
            # Optuna's study.best_trial should give the overall best
            best_trial_overall = study.best_trial

            if best_trial_overall:
                logger.info(f"Overall Best Trial Number: {best_trial_overall.number}")
                logger.info(f"  Value (Validation Loss - MSE): {best_trial_overall.value:.4f}")
                logger.info("  Best Params: ")
                for key, value in best_trial_overall.params.items():
                    logger.info(f"    {key}: {value}")

                # ---- SAVING BEST HYPERPARAMETERS (from overall best trial) ----
                with open(BEST_PARAMS_FILENAME, 'w') as f:
                    json.dump(best_trial_overall.params, f, indent=4)
                logger.info(f"Best hyperparameters saved to {BEST_PARAMS_FILENAME}")

                # ---- SAVING BEST MODEL WEIGHTS (from overall best trial) ----
                if "best_model_state_dict_for_trial" in best_trial_overall.user_attrs:
                    best_model_state = best_trial_overall.user_attrs["best_model_state_dict_for_trial"]
                    if best_model_state:
                        torch.save(best_model_state, BEST_WEIGHTS_FILENAME)
                        logger.info(f"Best model weights from trial {best_trial_overall.number} saved to {BEST_WEIGHTS_FILENAME}")
                    else:
                        logger.warning(
                            f"Overall best trial {best_trial_overall.number} has 'best_model_state_dict' "
                            "but its value is None. Weights not saved."
                        )
                else:
                    logger.warning(
                        f"Key 'best_model_state_dict' not found in overall best_trial.user_attrs. "
                        "Ensure your Optuna objective function stores the model's state_dict."
                    )
            else:
                logger.warning("Study has completed trials, but study.best_trial is None. Cannot save parameters or weights.")
        else:
            logger.warning("No trials completed successfully to determine the best trial. Cannot save parameters or weights.")

        study_df = study.trials_dataframe()
        # Add user attributes to dataframe if they exist and are simple types
        if completed_trials and best_trial_overall and "best_model_state_dict" in best_trial_overall.user_attrs :
            # Avoid adding the large state_dict to the CSV. Maybe add a flag or path.
            study_df['has_best_model_state'] = study_df['user_attrs_best_model_state_dict'].notna()
        study_df.to_csv(f"{study_name}_results.csv", index=False)
        logger.info(f"Optuna study results saved to {study_name}_results.csv")

    except Exception as e:
        logger.error(f"Could not process or save Optuna study results, parameters, or weights: {e}", exc_info=True)


# --- Example: Predicting on Test Data using saved best model and params ---
if os.path.exists(TEST_DATA_FILE) and best_trial_overall and best_trial_overall.params:
    logger.info(f"\n--- Predicting on Test Data using saved model from Trial {best_trial_overall.number} ---")
    try:
        # 1. Load best hyperparameters
        with open(BEST_PARAMS_FILENAME, 'r') as f:
            loaded_best_params = json.load(f)
        logger.info(f"Loaded best hyperparameters from {BEST_PARAMS_FILENAME}")

        # 2. Initialize model with best HPs
        # Ensure your model class (PersonalityModelV3) is defined or imported
        test_model = PersonalityModelV3(
            bert_model_name=GLOBAL_CONFIG['BERT_MODEL_NAME'],
            num_traits=len(GLOBAL_CONFIG['TRAIT_NAMES']),
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            dropout_rate=loaded_best_params.get("dropout_rate", 0.2), # Default if not in params
            attention_hidden_dim=loaded_best_params.get("attention_hidden_dim", 128),
            num_bert_layers_to_pool=loaded_best_params.get("num_bert_layers_to_pool", 2),
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            num_other_numerical_features=len(GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES']),
            numerical_embedding_dim=loaded_best_params.get("other_numerical_embedding_dim", 0) if GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'] else 0
        ).to(DEVICE)
        logger.info("Test model initialized with loaded best hyperparameters.")

        # 3. Load saved weights
        try:
            if torch.cuda.is_available():
                loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME)
            else:
                loaded_state_dict = torch.load(BEST_WEIGHTS_FILENAME, map_location=torch.device('cpu'))
            
            test_model.load_state_dict(loaded_state_dict)
            logger.info(f"Successfully loaded model weights from {BEST_WEIGHTS_FILENAME}")
        except FileNotFoundError:
            logger.error(f"Model weights file {BEST_WEIGHTS_FILENAME} not found. Cannot perform test prediction with loaded weights.")
            # Optionally, proceed with the uninitialized (but configured) model or exit
            raise # Re-raise if essential
        except Exception as e:
            logger.error(f"Error loading model weights: {e}. Predictions will be from a re-initialized model (likely untrained).")
            # Depending on severity, you might want to raise e here


        test_model.eval() # Set to evaluation mode

        # 4. Create Test DataLoader
        test_dataset = JsonlIterableDataset(
            data=TEST_DATA_FILE,
            trait_names=GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'],
            n_comments_to_process=loaded_best_params.get("n_comments_to_process", GLOBAL_CONFIG['MAX_COMMENTS_TO_PROCESS_PHYSICAL']),
            other_numerical_feature_names=GLOBAL_CONFIG['OTHER_NUMERICAL_FEATURE_NAMES'],
            num_q_features_per_comment=GLOBAL_CONFIG['NUM_Q_FEATURES_PER_COMMENT'],
            is_test_set=True
        )
        test_loader = DataLoader(test_dataset, batch_size=loaded_best_params.get("batch_size", 8), shuffle=False)

        all_test_predictions = []
        with torch.no_grad():
            for batch_tuple in test_loader:
                # Unpack based on what PersonalityDatasetV3 yields for is_test_set=True
                # Assuming it yields (input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                # Adjust if your dataset yields differently for test_set
                input_ids, attention_m, q_s, comment_active_m, other_num_feats = [b.to(DEVICE) for b in batch_tuple]
                predicted_scores = test_model(input_ids, attention_m, q_s, comment_active_m, other_num_feats)
                all_test_predictions.append(predicted_scores.cpu().numpy())

        if all_test_predictions:
            final_test_predictions = np.concatenate(all_test_predictions, axis=0)
            logger.info(f"Shape of final test predictions: {final_test_predictions.shape}")
            for i in range(min(5, len(final_test_predictions))): # Print first 5
                # Assuming test_data items have an 'id' field for logging
                sample_id = test_data[i].get('id', f'Unknown_ID_{i}')
                pred_dict = {trait: round(score.item(), 4) for trait, score in zip(GLOBAL_CONFIG['TRAIT_NAMES_ORDERED'], final_test_predictions[i])}
                logger.info(f"Test Sample {sample_id} Predictions: {pred_dict}")
            # np.save(f"{study_name}_test_predictions.npy", final_test_predictions)
            # logger.info(f"Test predictions saved to {study_name}_test_predictions.npy")
        else:
            logger.warning("No predictions generated for the test set.")

    except FileNotFoundError:
        logger.warning(f"Best parameters file {BEST_PARAMS_FILENAME} not found. Skipping test prediction with loaded model.")
    except Exception as e:
        logger.error(f"An error occurred during test prediction: {e}", exc_info=True)
elif not test_data:
    logger.info("No test data provided. Skipping test prediction example.")
elif not best_trial_overall or not best_trial_overall.params:
    logger.warning("No successful best trial found or best trial has no params. Skipping test prediction example.")