In [6]:
import logging
import sys
import os

import numpy as np
import tensorflow as tf
from data_loaders.pain_ds_config import PainDatasetConfig
from learner.few_shot_pain_learner import FewShotPainLearner
from utils.logger import setup_logger

logger = setup_logger("FewShotPainLearner", level=logging.INFO)


"""Example usage of the few-shot pain learner."""
config = PainDatasetConfig()
print(tf.config.list_physical_devices())
print("Num GPUs Available: ", len(tf.config.list_physical_devices('MPS')))

logger.info("=" * 60)
logger.info("Multimodal Few-Shot Learning for Personalized Pain Assessment")
logger.info("=" * 60)

# Try different fusion methods
fusion_methods = ["attention"]

for fusion_method in fusion_methods:
    logger.info(f"\nTraining with fusion method: {fusion_method}")
# --- Diagnostic code for ValueError ---
# Get the data_dir from the learner's expected path
data_dir_path = "./data"
# Assuming config.data_path is 'X_pre.npy' based on common project structure and traceback
x_pre_file_name = config.data_path # This would be 'X_pre.npy'
x_pre_full_path = os.path.join(data_dir_path, x_pre_file_name)

print(f"\n--- Diagnosing data shape mismatch for '{x_pre_file_name}' ---")
try:
    # Load the numpy array without immediately reshaping
    loaded_data = np.load(x_pre_full_path)
    print(f"Actual shape of data in '{x_pre_file_name}': {loaded_data.shape}")
    print(f"Actual total elements in '{x_pre_file_name}': {loaded_data.size}")

    # The error message implied a target shape of (2495, 2500, 6, 1)
    expected_target_shape = (2495, 2500, 6, 1)
    expected_total_elements = np.prod(expected_target_shape)
    print(f"Expected target shape for reshape: {expected_target_shape}")
    print(f"Expected total elements for target shape: {expected_total_elements}")

    if loaded_data.size != expected_total_elements:
        print("Mismatched elements count detected!")
        print("This is the likely cause of the ValueError: cannot reshape array.")
    else:
        print("Element count matches, but reshape still failed. There might be an issue with order or specific reshape constraints.")

except FileNotFoundError:
    print(f"Error: Data file not found at '{x_pre_full_path}'. Please ensure the data file exists.")
except Exception as e:
    print(f"An unexpected error occurred while trying to load or inspect '{x_pre_file_name}': {e}")
    learner = FewShotPainLearner(
        config=config,
        data_dir="./data",
        learning_rate=1e-3,
        fusion_method=fusion_method,
    )
    sys.exit(1)
    cv_results = learner.train(
        num_epochs=100, episodes_per_epoch=50, val_episodes=10
    )
    logger.info(cv_results)
    logger.info(f"Training with {fusion_method} complete!")



[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
Num GPUs Available:  0
2026-02-14 17:50:28 │ INFO     │ FewShotPainLearner:20	 │ Multimodal Few-Shot Learning for Personalized Pain Assessment
2026-02-14 17:50:28 │ INFO     │ FewShotPainLearner:27	 │ 
Training with fusion method: attention

--- Diagnosing data shape mismatch for 'X_pre.npy' ---
Actual shape of data in 'X_pre.npy': (2495, 2500, 6, 1)
Actual total elements in 'X_pre.npy': 37425000
Expected target shape for reshape: (2495, 2500, 6, 1)
Expected total elements for target shape: 37425000
Element count matches, but reshape still failed. There might be an issue with order or specific reshape constraints.
