In [None]:
<!-- Evaluación de Modelos con H5 -->
<VSCode.Cell language="markdown">
# Evaluación de Modelos con H5

Este notebook carga un archivo `.h5` con resultados de test, calcula métricas de error del modelo y grafica la secuencia de inputs, ground truth y predicciones.
</VSCode.Cell>
<VSCode.Cell language="python">
import os
import sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from src.metrics.metrics import calculate_metrics

# Agregar raíz del proyecto al Python path
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Estilo de gráficas
plt.rcParams.update({
    'figure.figsize': [12, 8],
    'figure.dpi': 100,
    'axes.grid': True,
    'grid.alpha': 0.3
})
%matplotlib inline
</VSCode.Cell>
<VSCode.Cell language="python">
def load_results(h5_path):
    """Carga inputs, targets, predictions y timestamps desde un archivo H5"""
    with h5py.File(h5_path, 'r') as f:
        # Inputs
        inputs = f['inputs/data'][:] if 'inputs/data' in f else f['inputs'][:]
        # Targets
        targets = f['targets/data'][:] if 'targets/data' in f else f['targets'][:]
        # Predictions
        predictions = f['predictions/data'][:] if 'predictions/data' in f else f['predictions'][:]
        # Timestamps (opcional)
        ts_in = [ts.decode('utf-8') for ts in f['inputs/timestamps'][:]] if 'inputs/timestamps' in f['inputs'] else []
        ts_gt = [ts.decode('utf-8') for ts in f['targets/timestamps'][:]] if 'targets/timestamps' in f['targets'] else []
        ts_pr = [ts.decode('utf-8') for ts in f['predictions/timestamps'][:]] if 'predictions/timestamps' in f['predictions'] else []
        timestamps = {'inputs': ts_in, 'targets': ts_gt, 'predictions': ts_pr}
        # Metadata
        metadata = {key: f.attrs[key] for key in f.attrs.keys()}
    return inputs, targets, predictions, metadata, timestamps

# Función helper

def get_center_crop(arr, size=32):
    """Extrae el recorte central de un tensor 4D o 3D"""
    if arr.ndim == 4:
        _, _, H, W = arr.shape
        h0, w0 = H//2 - size//2, W//2 - size//2
        return arr[:, :, h0:h0+size, w0:w0+size]
    elif arr.ndim == 3:
        _, H, W = arr.shape
        h0, w0 = H//2 - size//2, W//2 - size//2
        return arr[:, h0:h0+size, w0:w0+size]
    else:
        H, W = arr.shape
        h0, w0 = H//2 - size//2, W//2 - size//2
        return arr[h0:h0+size, w0:w0+size]
</VSCode.Cell>
<VSCode.Cell language="python">
# Seleccionar archivo H5
h5_path = os.path.join(PROJECT_ROOT, 'logs/unet4/test_results.h5')
print(f"Cargando archivo: {h5_path}")
inputs, targets, predictions, metadata, timestamps = load_results(h5_path)

print("\nMetadata:")
for k, v in metadata.items():
    print(f"- {k}: {v}")
print(f"\nShapes: inputs={inputs.shape}, targets={targets.shape}, predictions={predictions.shape}")
if timestamps['inputs']:
    print(f"Primera timestamp input: {timestamps['inputs'][0]}")
</VSCode.Cell>
<VSCode.Cell language="python">
# Calcular métricas globales y de recorte central
metrics_global = calculate_metrics(targets, predictions)
pred_center = get_center_crop(predictions)
tgt_center = get_center_crop(targets)
metrics_center = calculate_metrics(tgt_center, pred_center)

print("\nMétricas Globales:", metrics_global)
print("Métricas Centro 32x32:", metrics_center)
</VSCode.Cell>
<VSCode.Cell language="python">
# Visualizar secuencia
from matplotlib.gridspec import GridSpec

def visualize_sequence(inp, tgt, prd, idx=0):
    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(4, 6, figure=fig, wspace=0.1, hspace=0.3)
    cmap = 'viridis'
    # Inputs
    for i in range(inp.shape[1]):
        ax = fig.add_subplot(gs[0 if i<6 else 1, i%6])
        ax.imshow(inp[idx, i], cmap=cmap)
        title = f'In {i+1}' + (f"\n{timestamps['inputs'][i]}" if i < len(timestamps['inputs']) else '')
        ax.set_title(title, fontsize=8)
        ax.axis('off')
    # Ground truth
    for i in range(tgt.shape[1]):
        ax = fig.add_subplot(gs[2, i])
        ax.imshow(tgt[idx, i], cmap=cmap)
        title = f'GT {i+1}' + (f"\n{timestamps['targets'][i]}" if i < len(timestamps['targets']) else '')
        ax.set_title(title, fontsize=8)
        ax.axis('off')
    # Predictions
    for i in range(prd.shape[1]):
        ax = fig.add_subplot(gs[3, i])
        ax.imshow(prd[idx, i], cmap=cmap)
        title = f'Pr {i+1}' + (f"\n{timestamps['predictions'][i]}" if i < len(timestamps['predictions']) else '')
        ax.set_title(title, fontsize=8)
        ax.axis('off')
    plt.show()

visualize_sequence(inputs, targets, predictions, idx=0)
