# TT Reconstruction Evaluation
Evaluate model-predicted TT cores vs ground truth.

In [None]:
import torch
import torch.nn.functional as F
from itertools import product


In [None]:
def preprocess_pred_cores(pred_cores):
    return [core.squeeze(0) for core in pred_cores]


In [None]:
def generate_valid_input_indices_safe(*cores_lists):
    d = len(cores_lists[0])
    mode_sizes = []
    for i in range(d):
        n_i_list = [cores[i].shape[1] for cores in cores_lists]
        n_i_min = min(n_i_list)
        mode_sizes.append(n_i_min)
    grid = list(product(*[range(n) for n in mode_sizes]))
    return torch.tensor(grid, dtype=torch.long)


In [None]:
def contract_tt_cores(cores, input_indices):
    batch_size, d = input_indices.shape
    device = input_indices.device
    result = None
    for i in range(d):
        core = cores[i].to(device)
        x_i = input_indices[:, i]
        r_i, n_i, r_ip1 = core.shape
        assert x_i.max().item() < n_i, f"x_i contains value >= n_i={n_i} at dim {i}"
        selected = torch.stack([core[:, xi.item(), :] for xi in x_i], dim=0)
        result = selected if result is None else torch.bmm(result, selected)
    return result.view(batch_size)


In [None]:
def evaluate_function_approx(pred_cores, true_cores, input_indices):
    pred_cores_clean = preprocess_pred_cores(pred_cores)
    f_pred = contract_tt_cores(pred_cores_clean, input_indices)
    f_true = contract_tt_cores(true_cores, input_indices)
    mse = F.mse_loss(f_pred, f_true)
    return mse.item(), f_pred, f_true


In [None]:
# Assume mu_tensor, tt_true, model are available
with torch.no_grad():
    pred_cores = model(mu_tensor.unsqueeze(0))
    true_cores = [core for core in tt_true.cores]
    pred_cores_clean = preprocess_pred_cores(pred_cores)
    input_indices = generate_valid_input_indices_safe(pred_cores_clean, true_cores)
    mse, f_pred, f_true = evaluate_function_approx(pred_cores, true_cores, input_indices)
    print("Reconstruction MSE:", mse)
