In [1]:
import os
import sys
import numpy as np
import torch
import logging
import json
from pathlib import Path

# Thêm thư mục gốc vào sys.path
sys.path.insert(0, os.path.abspath('.'))

# Cấu hình logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("test_data_module")

# Load config
with open('config.json', 'r') as f:
    config = json.load(f)

# Import module data
from src.data import (
    QTransformerDataModule, 
    load_npz_data, 
    analyze_dataset_statistics,
    create_data_visualizations
)

# Kiểm tra tồn tại của các file dữ liệu
train_val_file = Path(config['dataset_dir']) / config['training_files'][0]
test_file = Path(config['dataset_dir']) / config['testing_files'][0]

logger.info(f"Checking training file: {train_val_file} - exists: {train_val_file.exists()}")
logger.info(f"Checking test file: {test_file} - exists: {test_file.exists()}")

# Phân tích cấu trúc dữ liệu
logger.info("Analyzing training data structure...")
train_data = load_npz_data(str(train_val_file))
logger.info(f"Training data has {len(train_data['obs'])} samples")

# Tạo visualizations
viz_dir = "data_visualizations"
os.makedirs(viz_dir, exist_ok=True)
logger.info(f"Creating visualizations in {viz_dir}")
create_data_visualizations(train_data, viz_dir)

# Tạo DataModule
logger.info("Creating QTransformerDataModule...")
data_module = QTransformerDataModule(config)
data_module.setup()

# Kiểm tra DataLoaders
train_loader = data_module.get_train_dataloader()
val_loader = data_module.get_val_dataloader()
test_loader = data_module.get_test_dataloader()

logger.info(f"Train loader: {len(train_loader)} batches")
logger.info(f"Validation loader: {len(val_loader)} batches")
logger.info(f"Test loader: {len(test_loader)} batches" if test_loader else "No test loader")

# Lấy và kiểm tra một batch
logger.info("Getting a batch from train loader...")
batch = next(iter(train_loader))
logger.info(f"Batch keys: {list(batch.keys())}")
logger.info(f"Observations shape: {batch['observation'].shape}")
logger.info(f"Action vectors shape: {batch['action_vectors'].shape}")
logger.info(f"Rho values shape: {batch['rho_values'].shape}")
logger.info(f"Soft labels shape: {batch['soft_labels'].shape}")

logger.info("QTransformerDataModule test completed successfully!")

2025-04-16 23:01:38,884 - src.data - INFO - Loaded Q-Transformer data module with 23 components
2025-04-16 23:01:38,885 - test_data_module - INFO - Checking training file: train_val.npz - exists: True
2025-04-16 23:01:38,886 - test_data_module - INFO - Checking test file: test.npz - exists: True
2025-04-16 23:01:38,887 - test_data_module - INFO - Analyzing training data structure...
2025-04-16 23:01:54,417 - src.data.utils - INFO - Loaded train_val.npz in 15.53 seconds
2025-04-16 23:01:54,419 - src.data.utils - INFO - Found keys: ['obs', 'best_action', 'timestep', 'act_rho', 'soft_labels', 'act_vect']
2025-04-16 23:01:54,420 - src.data.utils - INFO -   obs: (50376, 3819), float32
2025-04-16 23:01:54,421 - src.data.utils - INFO -   best_action: (50376,), <U72
2025-04-16 23:01:54,422 - src.data.utils - INFO -   timestep: (50376,), datetime64[us]
2025-04-16 23:01:54,423 - src.data.utils - INFO -   act_rho: (50376, 50), [('action', '<U73'), ('rho_max', '<f4')]
2025-04-16 23:01:54,424 - src

<Figure size 1400x800 with 0 Axes>