In [1]:
import torch
import yaml
import sys
import os
from pathlib import Path

sys.path.insert(0, os.path.abspath('.'))

# Import module of  model
from src.models.observation_encoder import ObservationEncoder
from src.models.action_encoder import ActionEncoder
from src.models.cross_attention import CrossAttention
from src.models.model import PowerGridModel

# Import DataModule to get batch data
from src.data import QTransformerDataModule

# Load config
import json
with open('config.json', 'r') as f:
    config = json.load(f)

# Khởi tạo DataModule
data_module = QTransformerDataModule(config)
data_module.setup()

# Get a batch from train loader
train_loader = data_module.get_train_dataloader()
batch = next(iter(train_loader))

# Print information
print("Batch keys:", list(batch.keys()))
print("Observations shape:", batch['observation'].shape)
print("Action vectors shape:", batch['action_vectors'].shape)
print("Rho values shape:", batch['rho_values'].shape)
print("Soft labels shape:", batch['soft_labels'].shape)
print("action_weights:",batch["action_weights"])

# Initialize model
model = PowerGridModel(config)

# Print model archtecture
print("\nModel structure:")
print(model)

# Convert to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Prepare input batch
input_batch = {
    'observation': batch['observation'].to(device),
    'action_vectors': batch['action_vectors'].to(device)
}

# Ground truth for compute loss
targets = {
    'rho_values': batch['rho_values'].to(device),
    'soft_labels': batch['soft_labels'].to(device),
    'best_action': batch.get('best_action', None),
    'action_weights': batch['action_weights'].to(device)  
}

# Forward pass
print("\nRunning forward pass...")
with torch.no_grad():
    predictions = model(input_batch)

print("\nOutput shapes:")
for key, tensor in predictions.items():
    print(f"{key} shape: {tensor.shape}")

# Import loss function và metric
from src.utils.loss import PowerGridLoss
from src.utils.metrics import compute_all_metrics

# Tính loss
loss_fn = PowerGridLoss(
    rho_weight=config['model_params'].get('rho_weight', 0.5),
    soft_label_weight=config['model_params'].get('soft_label_weight', 0.5)
)

loss_dict = loss_fn(predictions, targets)

print("\nLoss values:")
for key, value in loss_dict.items():
    print(f"{key}: {value.item():.4f}")

# Compute metrics only "best_action disponible
if 'best_action' in batch and batch['best_action'] is not None:
    metrics = compute_all_metrics(predictions, targets)
    print("\nMetrics:")
    for key, value in metrics.items():
        print(f"{key}: {value:.4f}")

2025-04-18 13:26:53,602 - src.data - INFO - Loaded Q-Transformer data module with 23 components
2025-04-18 13:26:53,604 - src.data.qtransformer_data - INFO - Using action weights: False
2025-04-18 13:26:53,604 - src.data.qtransformer_data - INFO - Loading training data from 1 files...
2025-04-18 13:26:53,605 - src.data.qtransformer_data - INFO - Loading file: train_val.npz
2025-04-18 13:27:09,397 - src.data.enhanced_splitter - INFO - Splitting 50376 samples with strategy: stratified_temporal
2025-04-18 13:27:14,279 - src.data.enhanced_splitter - INFO - === Split Summary ===
2025-04-18 13:27:14,282 - src.data.enhanced_splitter - INFO - Train samples: 42840
2025-04-18 13:27:14,283 - src.data.enhanced_splitter - INFO - Val samples:   7536
2025-04-18 13:27:14,283 - src.data.enhanced_splitter - INFO - Test samples:  0
2025-04-18 13:27:14,284 - src.data.enhanced_splitter - INFO - Train time: 2021-01-04T07:45:00.000000 → 2023-12-07T12:55:00.000000
2025-04-18 13:27:14,285 - src.data.enhanced_s

Batch keys: ['observation', 'action_vectors', 'rho_values', 'soft_labels', 'action_weights', 'timestep', 'idx', 'best_action']
Observations shape: torch.Size([128, 3819])
Action vectors shape: torch.Size([128, 50, 1152])
Rho values shape: torch.Size([128, 50])
Soft labels shape: torch.Size([128, 50])
action_weights: tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

Model structure:
PowerGridModel(
  (observation_encoder): ObservationEncoder(
    (input_projection): Linear(in_features=3819, out_features=1024, bias=True)
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (encoder_layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (mha): MultiHeadAttention(
          (wq): Linear(in_features=1024, out_feature

<Figure size 1400x800 with 0 Axes>